""" Zep实体读取与过滤服务 从Zep图谱中读取节点,筛选出符合预定义实体类型的节点 """ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field from .graphiti_adapter import GraphitiAdapter from ..config import Config from ..utils.logger import get_logger from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from ..utils.locale import t logger = get_logger('mirofish.zep_entity_reader') # 用于泛型返回类型 T = TypeVar('T') @dataclass class EntityNode: """实体节点数据结构""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] # 相关的边信息 related_edges: List[Dict[str, Any]] = field(default_factory=list) # 相关的其他节点信息 related_nodes: List[Dict[str, Any]] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "labels": self.labels, "summary": self.summary, "attributes": self.attributes, "related_edges": self.related_edges, "related_nodes": self.related_nodes, } def get_entity_type(self) -> Optional[str]: """获取实体类型(排除默认的Entity标签)""" for label in self.labels: if label not in ["Entity", "Node"]: return label return None @dataclass class FilteredEntities: """过滤后的实体集合""" entities: List[EntityNode] entity_types: Set[str] total_count: int filtered_count: int def to_dict(self) -> Dict[str, Any]: return { "entities": [e.to_dict() for e in self.entities], "entity_types": list(self.entity_types), "total_count": self.total_count, "filtered_count": self.filtered_count, } class ZepEntityReader: """ Zep实体读取与过滤服务 主要功能: 1. 从Zep图谱读取所有节点 2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点) 3. 获取每个实体的相关边和关联节点信息 """ def __init__(self, api_key: Optional[str] = None): self.client = GraphitiAdapter() def _call_with_retry( self, func: Callable[[], T], operation_name: str, max_retries: int = 3, initial_delay: float = 2.0 ) -> T: """ 带重试机制的Zep API调用 Args: func: 要执行的函数(无参数的lambda或callable) operation_name: 操作名称,用于日志 max_retries: 最大重试次数(默认3次,即最多尝试3次) initial_delay: 初始延迟秒数 Returns: API调用结果 """ last_exception = None delay = initial_delay for attempt in range(max_retries): try: return func() except Exception as e: last_exception = e if attempt < max_retries - 1: logger.warning( t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay) ) time.sleep(delay) delay *= 2 # 指数退避 else: logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e))) raise last_exception def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ 获取图谱的所有节点(分页获取) Args: graph_id: 图谱ID Returns: 节点列表 """ logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id)) nodes = fetch_all_nodes(self.client, graph_id) nodes_data = [] for node in nodes: nodes_data.append({ "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), "name": node.name or "", "labels": node.labels or [], "summary": node.summary or "", "attributes": node.attributes or {}, }) logger.info(t("log.zep_entity_reader.m004", len=len(nodes_data))) return nodes_data def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ 获取图谱的所有边(分页获取) Args: graph_id: 图谱ID Returns: 边列表 """ logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id)) edges = fetch_all_edges(self.client, graph_id) edges_data = [] for edge in edges: edges_data.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "attributes": edge.attributes or {}, }) logger.info(t("log.zep_entity_reader.m006", len=len(edges_data))) return edges_data def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: """ 获取指定节点的所有相关边(带重试机制) Args: node_uuid: 节点UUID Returns: 边列表 """ try: # 使用重试机制调用Zep API edges = self._call_with_retry( func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) edges_data = [] for edge in edges: edges_data.append({ "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": edge.name or "", "fact": edge.fact or "", "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "attributes": edge.attributes or {}, }) return edges_data except Exception as e: logger.warning(t("log.zep_entity_reader.m007", node_uuid=node_uuid, str=str(e))) return [] def filter_defined_entities( self, graph_id: str, defined_entity_types: Optional[List[str]] = None, enrich_with_edges: bool = True ) -> FilteredEntities: """ 筛选出符合预定义实体类型的节点 筛选逻辑: - 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过 - 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留 Args: graph_id: 图谱ID defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型) enrich_with_edges: 是否获取每个实体的相关边信息 Returns: FilteredEntities: 过滤后的实体集合 """ logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id)) # Look up ontology from project to classify entities ontology = None try: from ..models.project import ProjectManager from .graph_builder import _classify_entity_type for p in ProjectManager.list_projects(): if p.graph_id == graph_id and p.ontology: ontology = p.ontology break except Exception: pass # 获取所有节点 all_nodes = self.get_all_nodes(graph_id) total_count = len(all_nodes) # Apply ontology-based classification so all nodes get proper type labels if ontology: for node in all_nodes: labels = node.get("labels", []) custom = [l for l in labels if l not in ("Entity", "Node")] if not custom: entity_type = _classify_entity_type( node.get("name", ""), node.get("summary", ""), ontology ) if entity_type != "Entity": node["labels"] = [entity_type] + labels # 获取所有边(用于后续关联查找) all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] # 构建节点UUID到节点数据的映射 node_map = {n["uuid"]: n for n in all_nodes} # 筛选符合条件的实体 filtered_entities = [] entity_types_found = set() for node in all_nodes: labels = node.get("labels", []) # 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签 custom_labels = [l for l in labels if l not in ["Entity", "Node"]] if not custom_labels: # 只有默认标签,跳过 continue # 如果指定了预定义类型,检查是否匹配 if defined_entity_types: matching_labels = [l for l in custom_labels if l in defined_entity_types] if not matching_labels: continue entity_type = matching_labels[0] else: entity_type = custom_labels[0] entity_types_found.add(entity_type) # 创建实体节点对象 entity = EntityNode( uuid=node["uuid"], name=node["name"], labels=labels, summary=node["summary"], attributes=node["attributes"], ) # 获取相关边和节点 if enrich_with_edges: related_edges = [] related_node_uuids = set() for edge in all_edges: if edge["source_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) elif edge["target_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) entity.related_edges = related_edges # 获取关联节点的基本信息 related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) entity.related_nodes = related_nodes filtered_entities.append(entity) logger.info(t("log.zep_entity_reader.m009", total_count=total_count, len=len(filtered_entities), entity_types_found=entity_types_found)) return FilteredEntities( entities=filtered_entities, entity_types=entity_types_found, total_count=total_count, filtered_count=len(filtered_entities), ) def get_entity_with_context( self, graph_id: str, entity_uuid: str ) -> Optional[EntityNode]: """ 获取单个实体及其完整上下文(边和关联节点,带重试机制) Args: graph_id: 图谱ID entity_uuid: 实体UUID Returns: EntityNode或None """ try: # 使用重试机制获取节点 node = self._call_with_retry( func=lambda: self.client.graph.node.get(uuid_=entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) if not node: return None # 获取节点的边 edges = self.get_node_edges(entity_uuid) # 获取所有节点用于关联查找 all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} # 处理相关边和节点 related_edges = [] related_node_uuids = set() for edge in edges: if edge["source_node_uuid"] == entity_uuid: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) else: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) # 获取关联节点信息 related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) return EntityNode( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", labels=node.labels or [], summary=node.summary or "", attributes=node.attributes or {}, related_edges=related_edges, related_nodes=related_nodes, ) except Exception as e: logger.error(t("log.zep_entity_reader.m010", entity_uuid=entity_uuid, str=str(e))) return None def get_entities_by_type( self, graph_id: str, entity_type: str, enrich_with_edges: bool = True ) -> List[EntityNode]: """ 获取指定类型的所有实体 Args: graph_id: 图谱ID entity_type: 实体类型(如 "Student", "PublicFigure" 等) enrich_with_edges: 是否获取相关边信息 Returns: 实体列表 """ result = self.filter_defined_entities( graph_id=graph_id, defined_entity_types=[entity_type], enrich_with_edges=enrich_with_edges ) return result.entities