""" 图谱构建服务 接口2:使用Zep API构建Standalone Graph """ import os import uuid import time import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass from .graphiti_adapter import GraphitiAdapter from ..config import Config from ..models.task import TaskManager, TaskStatus from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale def _classify_entity_type(name: str, summary: str, ontology: Optional[Dict]) -> str: """ Classify an entity into an ontology type using keyword matching against entity type names, descriptions, and examples. Falls back to 'Entity' if no ontology or no match found. """ if not ontology: return "Entity" entity_types = ontology.get("entity_types", []) if not entity_types: return "Entity" name_lower = (name or "").lower() summary_lower = (summary or "").lower() search_text = f"{name_lower} {summary_lower}" best_type = "Entity" best_score = 0 for et in entity_types: score = 0 type_name = et.get("name", "") type_name_lower = type_name.lower() # Exact name match in type name if type_name_lower in name_lower: score += 10 # Check examples list for example in et.get("examples", []): if example.lower() in search_text: score += 8 elif name_lower in example.lower(): score += 6 # Check description keywords desc_words = (et.get("description", "")).lower().split() for word in desc_words: if len(word) > 4 and word in search_text: score += 1 if score > best_score: best_score = score best_type = type_name return best_type if best_score > 0 else "Entity" @dataclass class GraphInfo: """图谱信息""" graph_id: str node_count: int edge_count: int entity_types: List[str] def to_dict(self) -> Dict[str, Any]: return { "graph_id": self.graph_id, "node_count": self.node_count, "edge_count": self.edge_count, "entity_types": self.entity_types, } class GraphBuilderService: """ 图谱构建服务 负责调用Zep API构建知识图谱 """ def __init__(self, api_key: Optional[str] = None): self.client = GraphitiAdapter() self.task_manager = TaskManager() def build_graph_async( self, text: str, ontology: Dict[str, Any], graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, batch_size: int = 3 ) -> str: """ 异步构建图谱 Args: text: 输入文本 ontology: 本体定义(来自接口1的输出) graph_name: 图谱名称 chunk_size: 文本块大小 chunk_overlap: 块重叠大小 batch_size: 每批发送的块数量 Returns: 任务ID """ # 创建任务 task_id = self.task_manager.create_task( task_type="graph_build", metadata={ "graph_name": graph_name, "chunk_size": chunk_size, "text_length": len(text), } ) # Capture locale before spawning background thread current_locale = get_locale() # 在后台线程中执行构建 thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) ) thread.daemon = True thread.start() return task_id def _build_graph_worker( self, task_id: str, text: str, ontology: Dict[str, Any], graph_name: str, chunk_size: int, chunk_overlap: int, batch_size: int, locale: str = 'zh' ): """图谱构建工作线程""" set_locale(locale) try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, message=t('progress.startBuildingGraph') ) # 1. 创建图谱 graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, message=t('progress.graphCreated', graphId=graph_id) ) # 2. 设置本体 self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, message=t('progress.ontologySet') ) # 3. 文本分块 chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, message=t('progress.textSplit', count=total_chunks) ) # 4. 分批发送数据 episode_uuids = self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( task_id, progress=20 + int(prog * 0.4), # 20-60% message=msg ) ) # 5. 等待Zep处理完成 self.task_manager.update_task( task_id, progress=60, message=t('progress.waitingZepProcess') ) self._wait_for_episodes( episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, progress=60 + int(prog * 0.3), # 60-90% message=msg ) ) # 6. 获取图谱信息 self.task_manager.update_task( task_id, progress=90, message=t('progress.fetchingGraphInfo') ) graph_info = self._get_graph_info(graph_id) # 完成 self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), "chunks_processed": total_chunks, }) except Exception as e: import traceback error_msg = f"{str(e)}\n{traceback.format_exc()}" self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: """创建Zep图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" self.client.graph.create( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """设置图谱本体提示(Graphiti自动提取实体,本体作为提示存储)""" self.client.graph.set_ontology( graph_ids=[graph_id], entities=ontology.get("entity_types"), edges=ontology.get("edge_types"), ) def add_text_batches( self, graph_id: str, chunks: List[str], batch_size: int = 3, progress_callback: Optional[Callable] = None, skip_chunks: int = 0, ) -> List[str]: """分批添加文本到图谱,返回所有 episode 的 uuid 列表。 skip_chunks: 跳过已处理的块数(用于断点续传)。""" episode_uuids = [] total_chunks = len(chunks) for i in range(skip_chunks, total_chunks, batch_size): batch_chunks = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size if progress_callback: progress = (i + len(batch_chunks)) / total_chunks progress_callback( t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)), progress ) # 构建episode数据 episodes = [ type('Episode', (), {'data': chunk, 'type': 'text'})() for chunk in batch_chunks ] # 发送到Zep try: batch_result = self.client.graph.add_batch( graph_id=graph_id, episodes=episodes ) # 收集返回的 episode uuid if batch_result and isinstance(batch_result, list): for ep in batch_result: ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) if ep_uuid: episode_uuids.append(ep_uuid) # 避免请求过快 time.sleep(1) except Exception as e: if progress_callback: progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0) raise return episode_uuids def _wait_for_episodes( self, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" if not episode_uuids: if progress_callback: progress_callback(t('progress.noEpisodesWait'), 1.0) return start_time = time.time() pending_episodes = set(episode_uuids) completed_count = 0 total_episodes = len(episode_uuids) if progress_callback: progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0) while pending_episodes: if time.time() - start_time > timeout: if progress_callback: progress_callback( t('progress.episodesTimeout', completed=completed_count, total=total_episodes), completed_count / total_episodes ) break # 检查每个 episode 的处理状态 for ep_uuid in list(pending_episodes): try: episode = self.client.graph.episode.get(uuid_=ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 except Exception as e: # 忽略单个查询错误,继续 pass elapsed = int(time.time() - start_time) if progress_callback: progress_callback( t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed), completed_count / total_episodes if total_episodes > 0 else 0 ) if pending_episodes: time.sleep(3) # 每3秒检查一次 if progress_callback: progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" # 获取节点(分页) nodes = fetch_all_nodes(self.client, graph_id) # 获取边(分页) edges = fetch_all_edges(self.client, graph_id) # 统计实体类型 entity_types = set() for node in nodes: if node.labels: for label in node.labels: if label not in ["Entity", "Node"]: entity_types.add(label) return GraphInfo( graph_id=graph_id, node_count=len(nodes), edge_count=len(edges), entity_types=list(entity_types) ) def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]: """ 获取完整图谱数据(包含详细信息) Args: graph_id: 图谱ID Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ nodes = fetch_all_nodes(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id) # 创建节点映射用于获取节点名称 node_map = {} for node in nodes: node_map[node.uuid_] = node.name or "" nodes_data = [] for node in nodes: # 获取创建时间 created_at = getattr(node, 'created_at', None) if created_at: created_at = str(created_at) entity_type = _classify_entity_type(node.name, node.summary or "", ontology) labels = node.labels or [] if entity_type != "Entity" and entity_type not in labels: labels = [entity_type] + [l for l in labels if l != "Entity"] nodes_data.append({ "uuid": node.uuid_, "name": node.name, "labels": labels, "summary": node.summary or "", "attributes": node.attributes or {}, "created_at": created_at, }) edges_data = [] for edge in edges: # 获取时间信息 created_at = getattr(edge, 'created_at', None) valid_at = getattr(edge, 'valid_at', None) invalid_at = getattr(edge, 'invalid_at', None) expired_at = getattr(edge, 'expired_at', None) # 获取 episodes episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) if episodes and not isinstance(episodes, list): episodes = [str(episodes)] elif episodes: episodes = [str(e) for e in episodes] # 获取 fact_type fact_type = getattr(edge, 'fact_type', None) or edge.name or "" edges_data.append({ "uuid": edge.uuid_, "name": edge.name or "", "fact": edge.fact or "", "fact_type": fact_type, "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "source_node_name": node_map.get(edge.source_node_uuid, ""), "target_node_name": node_map.get(edge.target_node_uuid, ""), "attributes": edge.attributes or {}, "created_at": str(created_at) if created_at else None, "valid_at": str(valid_at) if valid_at else None, "invalid_at": str(invalid_at) if invalid_at else None, "expired_at": str(expired_at) if expired_at else None, "episodes": episodes or [], }) return { "graph_id": graph_id, "nodes": nodes_data, "edges": edges_data, "node_count": len(nodes_data), "edge_count": len(edges_data), } def delete_graph(self, graph_id: str): """删除图谱""" self.client.graph.delete(graph_id=graph_id)