MicroFish/backend/app/services/zep_entity_reader.py

277 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
实体读取与过滤服务
从本地JSON图谱中读取节点筛选出符合预定义实体类型的节点
"""
from typing import Dict, Any, List, Optional, Set
from dataclasses import dataclass, field
from ..config import Config
from ..utils.local_graph_store import LocalGraphStore
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_entity_reader')
@dataclass
class EntityNode:
"""实体节点数据结构"""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
related_edges: List[Dict[str, Any]] = field(default_factory=list)
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes,
"related_edges": self.related_edges,
"related_nodes": self.related_nodes,
}
def get_entity_type(self) -> Optional[str]:
"""获取实体类型排除默认的Entity/Node标签"""
for label in self.labels:
if label not in ("Entity", "Node"):
return label
return None
@dataclass
class FilteredEntities:
"""过滤后的实体集合"""
entities: List[EntityNode]
entity_types: Set[str]
total_count: int
filtered_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"entities": [e.to_dict() for e in self.entities],
"entity_types": list(self.entity_types),
"total_count": self.total_count,
"filtered_count": self.filtered_count,
}
class ZepEntityReader:
"""
实体读取与过滤服务
主要功能:
1. 从本地图谱读取所有节点
2. 筛选出符合预定义实体类型的节点Labels不只是Entity的节点
3. 获取每个实体的相关边和关联节点信息
"""
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
# api_key参数保留以兼容旧调用方式但不再使用
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(storage_dir)
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""获取图谱的所有节点"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = self.store.get_nodes(graph_id)
logger.info(f"共获取 {len(nodes)} 个节点")
return nodes
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""获取图谱的所有边"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self.store.get_edges(graph_id)
logger.info(f"共获取 {len(edges)} 条边")
return edges
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
"""获取指定节点的所有相关边"""
try:
return self.store.get_node_edges(graph_id, node_uuid)
except Exception as e:
logger.warning(f"获取节点 {node_uuid} 的边失败: {e}")
return []
def filter_defined_entities(
self,
graph_id: str,
defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True
) -> FilteredEntities:
"""
筛选出符合预定义实体类型的节点
筛选逻辑:
- 节点的Labels包含除"Entity""Node"之外的标签 → 符合预定义类型,保留
- 节点的Labels只有"Entity"/"Node" → 不符合,跳过
Args:
graph_id: 图谱ID
defined_entity_types: 预定义实体类型列表(可选,若提供则只保留这些类型)
enrich_with_edges: 是否获取每个实体的相关边信息
"""
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes)
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
node_map = {n["uuid"]: n for n in all_nodes}
filtered_entities = []
entity_types_found: Set[str] = set()
for node in all_nodes:
labels = node.get("labels") or []
custom_labels = [l for l in labels if l not in ("Entity", "Node")]
if not custom_labels:
continue
if defined_entity_types:
matching = [l for l in custom_labels if l in defined_entity_types]
if not matching:
continue
entity_type = matching[0]
else:
entity_type = custom_labels[0]
entity_types_found.add(entity_type)
entity = EntityNode(
uuid=node["uuid"],
name=node.get("name", ""),
labels=labels,
summary=node.get("summary", ""),
attributes=node.get("attributes", {}),
)
if enrich_with_edges:
related_edges = []
related_node_uuids: Set[str] = set()
for edge in all_edges:
if edge.get("source_node_uuid") == node["uuid"]:
related_edges.append({
"direction": "outgoing",
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"target_node_uuid": edge.get("target_node_uuid", ""),
})
related_node_uuids.add(edge.get("target_node_uuid", ""))
elif edge.get("target_node_uuid") == node["uuid"]:
related_edges.append({
"direction": "incoming",
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"source_node_uuid": edge.get("source_node_uuid", ""),
})
related_node_uuids.add(edge.get("source_node_uuid", ""))
entity.related_edges = related_edges
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid and related_uuid in node_map:
rn = node_map[related_uuid]
related_nodes.append({
"uuid": rn["uuid"],
"name": rn.get("name", ""),
"labels": rn.get("labels", []),
"summary": rn.get("summary", ""),
})
entity.related_nodes = related_nodes
filtered_entities.append(entity)
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
f"实体类型: {entity_types_found}")
return FilteredEntities(
entities=filtered_entities,
entity_types=entity_types_found,
total_count=total_count,
filtered_count=len(filtered_entities),
)
def get_entity_with_context(
self,
graph_id: str,
entity_uuid: str
) -> Optional[EntityNode]:
"""获取单个实体及其完整上下文(边和关联节点)"""
try:
node = self.store.get_node(graph_id, entity_uuid)
if not node:
return None
edges = self.get_node_edges(graph_id, entity_uuid)
all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes}
related_edges = []
related_node_uuids: Set[str] = set()
for edge in edges:
if edge.get("source_node_uuid") == entity_uuid:
related_edges.append({
"direction": "outgoing",
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"target_node_uuid": edge.get("target_node_uuid", ""),
})
related_node_uuids.add(edge.get("target_node_uuid", ""))
else:
related_edges.append({
"direction": "incoming",
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"source_node_uuid": edge.get("source_node_uuid", ""),
})
related_node_uuids.add(edge.get("source_node_uuid", ""))
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid and related_uuid in node_map:
rn = node_map[related_uuid]
related_nodes.append({
"uuid": rn["uuid"],
"name": rn.get("name", ""),
"labels": rn.get("labels", []),
"summary": rn.get("summary", ""),
})
return EntityNode(
uuid=node["uuid"],
name=node.get("name", ""),
labels=node.get("labels", []),
summary=node.get("summary", ""),
attributes=node.get("attributes", {}),
related_edges=related_edges,
related_nodes=related_nodes,
)
except Exception as e:
logger.error(f"获取实体 {entity_uuid} 失败: {e}")
return None
def get_entities_by_type(
self,
graph_id: str,
entity_type: str,
enrich_with_edges: bool = True
) -> List[EntityNode]:
"""获取指定类型的所有实体"""
result = self.filter_defined_entities(
graph_id=graph_id,
defined_entity_types=[entity_type],
enrich_with_edges=enrich_with_edges
)
return result.entities