""" Graph building service Endpoint 2: Build a Standalone Graph using the Zep API """ import os import uuid import time import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass from zep_cloud import EntityEdgeSourceTarget from ..config import Config from ..graph import get_graph_backend from ..models.task import TaskManager, TaskStatus from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale @dataclass class GraphInfo: """Graph info""" graph_id: str node_count: int edge_count: int entity_types: List[str] def to_dict(self) -> Dict[str, Any]: return { "graph_id": self.graph_id, "node_count": self.node_count, "edge_count": self.edge_count, "entity_types": self.entity_types, } class GraphBuilderService: """ Graph building service Responsible for calling the Zep API to build the knowledge graph. """ def __init__(self): self._graph = get_graph_backend() self.task_manager = TaskManager() def build_graph_async( self, text: str, ontology: Dict[str, Any], graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, batch_size: int = 3 ) -> str: """ Build the graph asynchronously. Args: text: input text ontology: ontology definition (output from endpoint 1) graph_name: graph name chunk_size: text chunk size chunk_overlap: chunk overlap size batch_size: number of chunks per batch Returns: task ID """ # Create task task_id = self.task_manager.create_task( task_type="graph_build", metadata={ "graph_name": graph_name, "chunk_size": chunk_size, "text_length": len(text), } ) # Capture locale before spawning background thread current_locale = get_locale() # Run build in background thread thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) ) thread.daemon = True thread.start() return task_id def _build_graph_worker( self, task_id: str, text: str, ontology: Dict[str, Any], graph_name: str, chunk_size: int, chunk_overlap: int, batch_size: int, locale: str = 'zh' ): """Graph build worker thread""" set_locale(locale) try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, message=t('progress.startBuildingGraph') ) # 1. Create graph graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, message=t('progress.graphCreated', graphId=graph_id) ) # 2. Set ontology self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, message=t('progress.ontologySet') ) # 3. Split text into chunks chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, message=t('progress.textSplit', count=total_chunks) ) # 4. Send data in batches episode_uuids = self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( task_id, progress=20 + int(prog * 0.4), # 20-60% message=msg ) ) # 5. Wait for Zep processing to complete self.task_manager.update_task( task_id, progress=60, message=t('progress.waitingZepProcess') ) self._wait_for_episodes( episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, progress=60 + int(prog * 0.3), # 60-90% message=msg ) ) # 6. Fetch graph info self.task_manager.update_task( task_id, progress=90, message=t('progress.fetchingGraphInfo') ) graph_info = self._get_graph_info(graph_id) # Complete self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), "chunks_processed": total_chunks, }) except Exception as e: import traceback error_msg = f"{str(e)}\n{traceback.format_exc()}" self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: """Create a graph (public method)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" self._graph.create_graph( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """Set graph ontology (public method)""" import warnings from typing import Optional from pydantic import Field from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel # Suppress Pydantic v2 warnings about Field(default=None) # This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') # Zep reserved names that cannot be used as attribute names RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} def safe_attr_name(attr_name: str) -> str: """Convert reserved names to safe attribute names""" if attr_name.lower() in RESERVED_NAMES: return f"entity_{attr_name}" return attr_name # Dynamically create entity types entity_types = {} for entity_def in ontology.get("entity_types", []): name = entity_def["name"] description = entity_def.get("description", f"A {name} entity.") # Build attribute dict and type annotations (required by Pydantic v2) attrs = {"__doc__": description} annotations = {} for attr_def in entity_def.get("attributes", []): attr_name = safe_attr_name(attr_def["name"]) # Use safe name attr_desc = attr_def.get("description", attr_name) # Zep API requires Field description — this is mandatory attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[EntityText] # Type annotation attrs["__annotations__"] = annotations # Dynamically create class entity_class = type(name, (EntityModel,), attrs) entity_class.__doc__ = description entity_types[name] = entity_class # Dynamically create edge types edge_definitions = {} for edge_def in ontology.get("edge_types", []): name = edge_def["name"] description = edge_def.get("description", f"A {name} relationship.") # Build attribute dict and type annotations attrs = {"__doc__": description} annotations = {} for attr_def in edge_def.get("attributes", []): attr_name = safe_attr_name(attr_def["name"]) # Use safe name attr_desc = attr_def.get("description", attr_name) # Zep API requires Field description — this is mandatory attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[str] # Edge attributes use str type attrs["__annotations__"] = annotations # Dynamically create class class_name = ''.join(word.capitalize() for word in name.split('_')) edge_class = type(class_name, (EdgeModel,), attrs) edge_class.__doc__ = description # Build source_targets source_targets = [] for st in edge_def.get("source_targets", []): source_targets.append( EntityEdgeSourceTarget( source=st.get("source", "Entity"), target=st.get("target", "Entity") ) ) if source_targets: edge_definitions[name] = (edge_class, source_targets) if entity_types or edge_definitions: self._graph.set_ontology( graph_ids=[graph_id], entities=entity_types if entity_types else None, edges=edge_definitions if edge_definitions else None, ) def add_text_batches( self, graph_id: str, chunks: List[str], batch_size: int = 3, progress_callback: Optional[Callable] = None ) -> List[str]: """Add text to the graph in batches; returns a list of all episode UUIDs""" episode_uuids = [] total_chunks = len(chunks) for i in range(0, total_chunks, batch_size): batch_chunks = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size if progress_callback: progress = (i + len(batch_chunks)) / total_chunks progress_callback( t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)), progress ) episodes = [ {"data": chunk, "type": "text"} for chunk in batch_chunks ] try: returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes) episode_uuids.extend(returned_uuids) # Avoid sending requests too quickly time.sleep(1) except Exception as e: if progress_callback: progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0) raise return episode_uuids def _wait_for_episodes( self, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): """Wait for all episodes to finish processing (by polling each episode's processed status)""" if not episode_uuids: if progress_callback: progress_callback(t('progress.noEpisodesWait'), 1.0) return start_time = time.time() pending_episodes = set(episode_uuids) completed_count = 0 total_episodes = len(episode_uuids) if progress_callback: progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0) while pending_episodes: if time.time() - start_time > timeout: if progress_callback: progress_callback( t('progress.episodesTimeout', completed=completed_count, total=total_episodes), completed_count / total_episodes ) break # Check processing status of each episode for ep_uuid in list(pending_episodes): try: episode = self._graph.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 except Exception as e: # Ignore individual query errors and continue pass elapsed = int(time.time() - start_time) if progress_callback: progress_callback( t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed), completed_count / total_episodes if total_episodes > 0 else 0 ) if pending_episodes: time.sleep(3) # Check every 3 seconds if progress_callback: progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) def _get_graph_info(self, graph_id: str) -> GraphInfo: """Retrieve graph info""" nodes = self._graph.get_all_nodes(graph_id) edges = self._graph.get_all_edges(graph_id) entity_types = set() for node in nodes: for label in node.get("labels", []): if label not in ["Entity", "Node"]: entity_types.add(label) return GraphInfo( graph_id=graph_id, node_count=len(nodes), edge_count=len(edges), entity_types=list(entity_types) ) def get_graph_data(self, graph_id: str) -> Dict[str, Any]: """Retrieve full graph data (nodes + edges with timestamps and attributes).""" nodes = self._graph.get_all_nodes(graph_id) edges = self._graph.get_all_edges(graph_id) node_map = {n["uuid"]: n.get("name", "") for n in nodes} return { "graph_id": graph_id, "nodes": nodes, "edges": [ {**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""), "target_node_name": node_map.get(e.get("target_node_uuid", ""), "")} for e in edges ], "node_count": len(nodes), "edge_count": len(edges), } def delete_graph(self, graph_id: str): """Delete graph""" self._graph.delete_graph(graph_id)