277 lines
9.9 KiB
Python
277 lines
9.9 KiB
Python
"""
|
||
实体读取与过滤服务
|
||
从本地JSON图谱中读取节点,筛选出符合预定义实体类型的节点
|
||
"""
|
||
|
||
from typing import Dict, Any, List, Optional, Set
|
||
from dataclasses import dataclass, field
|
||
|
||
from ..config import Config
|
||
from ..utils.local_graph_store import LocalGraphStore
|
||
from ..utils.logger import get_logger
|
||
|
||
logger = get_logger('mirofish.zep_entity_reader')
|
||
|
||
|
||
@dataclass
|
||
class EntityNode:
|
||
"""实体节点数据结构"""
|
||
uuid: str
|
||
name: str
|
||
labels: List[str]
|
||
summary: str
|
||
attributes: Dict[str, Any]
|
||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"uuid": self.uuid,
|
||
"name": self.name,
|
||
"labels": self.labels,
|
||
"summary": self.summary,
|
||
"attributes": self.attributes,
|
||
"related_edges": self.related_edges,
|
||
"related_nodes": self.related_nodes,
|
||
}
|
||
|
||
def get_entity_type(self) -> Optional[str]:
|
||
"""获取实体类型(排除默认的Entity/Node标签)"""
|
||
for label in self.labels:
|
||
if label not in ("Entity", "Node"):
|
||
return label
|
||
return None
|
||
|
||
|
||
@dataclass
|
||
class FilteredEntities:
|
||
"""过滤后的实体集合"""
|
||
entities: List[EntityNode]
|
||
entity_types: Set[str]
|
||
total_count: int
|
||
filtered_count: int
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"entities": [e.to_dict() for e in self.entities],
|
||
"entity_types": list(self.entity_types),
|
||
"total_count": self.total_count,
|
||
"filtered_count": self.filtered_count,
|
||
}
|
||
|
||
|
||
class ZepEntityReader:
|
||
"""
|
||
实体读取与过滤服务
|
||
|
||
主要功能:
|
||
1. 从本地图谱读取所有节点
|
||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||
3. 获取每个实体的相关边和关联节点信息
|
||
"""
|
||
|
||
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||
# api_key参数保留以兼容旧调用方式,但不再使用
|
||
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||
self.store = LocalGraphStore(storage_dir)
|
||
|
||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||
"""获取图谱的所有节点"""
|
||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||
nodes = self.store.get_nodes(graph_id)
|
||
logger.info(f"共获取 {len(nodes)} 个节点")
|
||
return nodes
|
||
|
||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||
"""获取图谱的所有边"""
|
||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||
edges = self.store.get_edges(graph_id)
|
||
logger.info(f"共获取 {len(edges)} 条边")
|
||
return edges
|
||
|
||
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
|
||
"""获取指定节点的所有相关边"""
|
||
try:
|
||
return self.store.get_node_edges(graph_id, node_uuid)
|
||
except Exception as e:
|
||
logger.warning(f"获取节点 {node_uuid} 的边失败: {e}")
|
||
return []
|
||
|
||
def filter_defined_entities(
|
||
self,
|
||
graph_id: str,
|
||
defined_entity_types: Optional[List[str]] = None,
|
||
enrich_with_edges: bool = True
|
||
) -> FilteredEntities:
|
||
"""
|
||
筛选出符合预定义实体类型的节点
|
||
|
||
筛选逻辑:
|
||
- 节点的Labels包含除"Entity"和"Node"之外的标签 → 符合预定义类型,保留
|
||
- 节点的Labels只有"Entity"/"Node" → 不符合,跳过
|
||
|
||
Args:
|
||
graph_id: 图谱ID
|
||
defined_entity_types: 预定义实体类型列表(可选,若提供则只保留这些类型)
|
||
enrich_with_edges: 是否获取每个实体的相关边信息
|
||
"""
|
||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||
|
||
all_nodes = self.get_all_nodes(graph_id)
|
||
total_count = len(all_nodes)
|
||
|
||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||
|
||
node_map = {n["uuid"]: n for n in all_nodes}
|
||
|
||
filtered_entities = []
|
||
entity_types_found: Set[str] = set()
|
||
|
||
for node in all_nodes:
|
||
labels = node.get("labels") or []
|
||
custom_labels = [l for l in labels if l not in ("Entity", "Node")]
|
||
|
||
if not custom_labels:
|
||
continue
|
||
|
||
if defined_entity_types:
|
||
matching = [l for l in custom_labels if l in defined_entity_types]
|
||
if not matching:
|
||
continue
|
||
entity_type = matching[0]
|
||
else:
|
||
entity_type = custom_labels[0]
|
||
|
||
entity_types_found.add(entity_type)
|
||
|
||
entity = EntityNode(
|
||
uuid=node["uuid"],
|
||
name=node.get("name", ""),
|
||
labels=labels,
|
||
summary=node.get("summary", ""),
|
||
attributes=node.get("attributes", {}),
|
||
)
|
||
|
||
if enrich_with_edges:
|
||
related_edges = []
|
||
related_node_uuids: Set[str] = set()
|
||
|
||
for edge in all_edges:
|
||
if edge.get("source_node_uuid") == node["uuid"]:
|
||
related_edges.append({
|
||
"direction": "outgoing",
|
||
"edge_name": edge.get("name", ""),
|
||
"fact": edge.get("fact", ""),
|
||
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||
})
|
||
related_node_uuids.add(edge.get("target_node_uuid", ""))
|
||
elif edge.get("target_node_uuid") == node["uuid"]:
|
||
related_edges.append({
|
||
"direction": "incoming",
|
||
"edge_name": edge.get("name", ""),
|
||
"fact": edge.get("fact", ""),
|
||
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||
})
|
||
related_node_uuids.add(edge.get("source_node_uuid", ""))
|
||
|
||
entity.related_edges = related_edges
|
||
|
||
related_nodes = []
|
||
for related_uuid in related_node_uuids:
|
||
if related_uuid and related_uuid in node_map:
|
||
rn = node_map[related_uuid]
|
||
related_nodes.append({
|
||
"uuid": rn["uuid"],
|
||
"name": rn.get("name", ""),
|
||
"labels": rn.get("labels", []),
|
||
"summary": rn.get("summary", ""),
|
||
})
|
||
entity.related_nodes = related_nodes
|
||
|
||
filtered_entities.append(entity)
|
||
|
||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||
f"实体类型: {entity_types_found}")
|
||
|
||
return FilteredEntities(
|
||
entities=filtered_entities,
|
||
entity_types=entity_types_found,
|
||
total_count=total_count,
|
||
filtered_count=len(filtered_entities),
|
||
)
|
||
|
||
def get_entity_with_context(
|
||
self,
|
||
graph_id: str,
|
||
entity_uuid: str
|
||
) -> Optional[EntityNode]:
|
||
"""获取单个实体及其完整上下文(边和关联节点)"""
|
||
try:
|
||
node = self.store.get_node(graph_id, entity_uuid)
|
||
if not node:
|
||
return None
|
||
|
||
edges = self.get_node_edges(graph_id, entity_uuid)
|
||
all_nodes = self.get_all_nodes(graph_id)
|
||
node_map = {n["uuid"]: n for n in all_nodes}
|
||
|
||
related_edges = []
|
||
related_node_uuids: Set[str] = set()
|
||
|
||
for edge in edges:
|
||
if edge.get("source_node_uuid") == entity_uuid:
|
||
related_edges.append({
|
||
"direction": "outgoing",
|
||
"edge_name": edge.get("name", ""),
|
||
"fact": edge.get("fact", ""),
|
||
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||
})
|
||
related_node_uuids.add(edge.get("target_node_uuid", ""))
|
||
else:
|
||
related_edges.append({
|
||
"direction": "incoming",
|
||
"edge_name": edge.get("name", ""),
|
||
"fact": edge.get("fact", ""),
|
||
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||
})
|
||
related_node_uuids.add(edge.get("source_node_uuid", ""))
|
||
|
||
related_nodes = []
|
||
for related_uuid in related_node_uuids:
|
||
if related_uuid and related_uuid in node_map:
|
||
rn = node_map[related_uuid]
|
||
related_nodes.append({
|
||
"uuid": rn["uuid"],
|
||
"name": rn.get("name", ""),
|
||
"labels": rn.get("labels", []),
|
||
"summary": rn.get("summary", ""),
|
||
})
|
||
|
||
return EntityNode(
|
||
uuid=node["uuid"],
|
||
name=node.get("name", ""),
|
||
labels=node.get("labels", []),
|
||
summary=node.get("summary", ""),
|
||
attributes=node.get("attributes", {}),
|
||
related_edges=related_edges,
|
||
related_nodes=related_nodes,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取实体 {entity_uuid} 失败: {e}")
|
||
return None
|
||
|
||
def get_entities_by_type(
|
||
self,
|
||
graph_id: str,
|
||
entity_type: str,
|
||
enrich_with_edges: bool = True
|
||
) -> List[EntityNode]:
|
||
"""获取指定类型的所有实体"""
|
||
result = self.filter_defined_entities(
|
||
graph_id=graph_id,
|
||
defined_entity_types=[entity_type],
|
||
enrich_with_edges=enrich_with_edges
|
||
)
|
||
return result.entities
|