""" Graph building service Endpoint 2: Build a Standalone Graph using the Zep API """ import os import uuid import time import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass from ..config import Config from ..graph import get_graph_backend from ..models.task import TaskManager, TaskStatus from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale @dataclass class GraphInfo: """Graph info""" graph_id: str node_count: int edge_count: int entity_types: List[str] def to_dict(self) -> Dict[str, Any]: return { "graph_id": self.graph_id, "node_count": self.node_count, "edge_count": self.edge_count, "entity_types": self.entity_types, } class GraphBuilderService: """ Graph building service Responsible for calling the Zep API to build the knowledge graph. """ def __init__(self): self._graph = get_graph_backend() self.task_manager = TaskManager() def build_graph_async( self, text: str, ontology: Dict[str, Any], graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, batch_size: int = 3 ) -> str: """ Build the graph asynchronously. Args: text: input text ontology: ontology definition (output from endpoint 1) graph_name: graph name chunk_size: text chunk size chunk_overlap: chunk overlap size batch_size: number of chunks per batch Returns: task ID """ # Create task task_id = self.task_manager.create_task( task_type="graph_build", metadata={ "graph_name": graph_name, "chunk_size": chunk_size, "text_length": len(text), } ) # Capture locale before spawning background thread current_locale = get_locale() # Run build in background thread thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) ) thread.daemon = True thread.start() return task_id def _build_graph_worker( self, task_id: str, text: str, ontology: Dict[str, Any], graph_name: str, chunk_size: int, chunk_overlap: int, batch_size: int, locale: str = 'zh' ): """Graph build worker thread""" set_locale(locale) try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, message=t('progress.startBuildingGraph') ) # 1. Create graph graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, message=t('progress.graphCreated', graphId=graph_id) ) # 2. Set ontology self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, message=t('progress.ontologySet') ) # 3. Split text into chunks chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, message=t('progress.textSplit', count=total_chunks) ) # 4. Send data in batches episode_uuids = self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( task_id, progress=20 + int(prog * 0.4), # 20-60% message=msg ) ) # 5. Wait for Zep processing to complete self.task_manager.update_task( task_id, progress=60, message=t('progress.waitingZepProcess') ) self._wait_for_episodes( episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, progress=60 + int(prog * 0.3), # 60-90% message=msg ) ) # 6. Fetch graph info self.task_manager.update_task( task_id, progress=90, message=t('progress.fetchingGraphInfo') ) graph_info = self._get_graph_info(graph_id) # Complete self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), "chunks_processed": total_chunks, }) except Exception as e: import traceback error_msg = f"{str(e)}\n{traceback.format_exc()}" self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: """Create a graph (public method)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" self._graph.create_graph( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """Set graph ontology (public method)""" from ..config import Config if Config.GRAPH_BACKEND != "zep": entities = { e["name"]: { "description": e.get("description", ""), "attributes": e.get("attributes", []), } for e in ontology.get("entity_types", []) } edges = { e["name"]: { "description": e.get("description", ""), "attributes": e.get("attributes", []), } for e in ontology.get("edge_types", []) } self._graph.set_ontology(graph_ids=[graph_id], entities=entities, edges=edges) return import warnings from typing import Optional from pydantic import Field from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel # Suppress Pydantic v2 warnings about Field(default=None) # This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') # Zep reserved names that cannot be used as attribute names RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} def safe_attr_name(attr_name: str) -> str: if attr_name.lower() in RESERVED_NAMES: return f"entity_{attr_name}" return attr_name # Dynamically create entity types entity_types = {} for entity_def in ontology.get("entity_types", []): name = entity_def["name"] description = entity_def.get("description", f"A {name} entity.") attrs = {"__doc__": description} annotations = {} for attr_def in entity_def.get("attributes", []): attr_name = safe_attr_name(attr_def["name"]) attr_desc = attr_def.get("description", attr_name) attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[EntityText] attrs["__annotations__"] = annotations entity_class = type(name, (EntityModel,), attrs) entity_class.__doc__ = description entity_types[name] = entity_class # Dynamically create edge types edge_definitions = {} for edge_def in ontology.get("edge_types", []): name = edge_def["name"] description = edge_def.get("description", f"A {name} relationship.") attrs = {"__doc__": description} annotations = {} for attr_def in edge_def.get("attributes", []): attr_name = safe_attr_name(attr_def["name"]) attr_desc = attr_def.get("description", attr_name) attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[str] attrs["__annotations__"] = annotations class_name = ''.join(word.capitalize() for word in name.split('_')) edge_class = type(class_name, (EdgeModel,), attrs) edge_class.__doc__ = description from zep_cloud import EntityEdgeSourceTarget source_targets = [] for st in edge_def.get("source_targets", []): source_targets.append( EntityEdgeSourceTarget( source=st.get("source", "Entity"), target=st.get("target", "Entity") ) ) if source_targets: edge_definitions[name] = (edge_class, source_targets) if entity_types or edge_definitions: self._graph.set_ontology( graph_ids=[graph_id], entities=entity_types if entity_types else None, edges=edge_definitions if edge_definitions else None, ) def add_text_batches( self, graph_id: str, chunks: List[str], batch_size: int = 3, progress_callback: Optional[Callable] = None ) -> List[str]: """Add text to the graph in batches; returns a list of all episode UUIDs""" episode_uuids = [] total_chunks = len(chunks) for i in range(0, total_chunks, batch_size): batch_chunks = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size if progress_callback: progress = (i + len(batch_chunks)) / total_chunks progress_callback( t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)), progress ) episodes = [ {"data": chunk, "type": "text"} for chunk in batch_chunks ] try: returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes) episode_uuids.extend(returned_uuids) # Avoid sending requests too quickly time.sleep(1) except Exception as e: if progress_callback: progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0) raise return episode_uuids def _wait_for_episodes( self, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): """Wait for all episodes to finish processing (by polling each episode's processed status)""" if not episode_uuids: if progress_callback: progress_callback(t('progress.noEpisodesWait'), 1.0) return start_time = time.time() pending_episodes = set(episode_uuids) completed_count = 0 total_episodes = len(episode_uuids) if progress_callback: progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0) while pending_episodes: if time.time() - start_time > timeout: if progress_callback: progress_callback( t('progress.episodesTimeout', completed=completed_count, total=total_episodes), completed_count / total_episodes ) break # Check processing status of each episode for ep_uuid in list(pending_episodes): try: episode = self._graph.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 except Exception as e: # Ignore individual query errors and continue pass elapsed = int(time.time() - start_time) if progress_callback: progress_callback( t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed), completed_count / total_episodes if total_episodes > 0 else 0 ) if pending_episodes: time.sleep(3) # Check every 3 seconds if progress_callback: progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) def _get_graph_info(self, graph_id: str) -> GraphInfo: """Retrieve graph info""" nodes = self._graph.get_all_nodes(graph_id) edges = self._graph.get_all_edges(graph_id) entity_types = set() for node in nodes: for label in node.get("labels", []): if label not in ["Entity", "Node"]: entity_types.add(label) return GraphInfo( graph_id=graph_id, node_count=len(nodes), edge_count=len(edges), entity_types=list(entity_types) ) def get_graph_data(self, graph_id: str) -> Dict[str, Any]: """Retrieve full graph data (nodes + edges with timestamps and attributes).""" nodes = self._graph.get_all_nodes(graph_id) edges = self._graph.get_all_edges(graph_id) node_map = {n["uuid"]: n.get("name", "") for n in nodes} return { "graph_id": graph_id, "nodes": nodes, "edges": [ {**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""), "target_node_name": node_map.get(e.get("target_node_uuid", ""), "")} for e in edges ], "node_count": len(nodes), "edge_count": len(edges), } def delete_graph(self, graph_id: str): """Delete graph""" self._graph.delete_graph(graph_id)