"""Graph build service. Pipeline step 2: build the project's standalone knowledge graph through the Zep/Graphiti API. """ import os import uuid import time import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass from .graphiti_adapter import GraphitiAdapter from ..config import Config from ..models.task import TaskManager, TaskStatus from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale def _classify_entity_type(name: str, summary: str, ontology: Optional[Dict]) -> str: """ Classify an entity into an ontology type using keyword matching against entity type names, descriptions, and examples. Falls back to 'Entity' if no ontology or no match found. """ if not ontology: return "Entity" entity_types = ontology.get("entity_types", []) if not entity_types: return "Entity" name_lower = (name or "").lower() summary_lower = (summary or "").lower() search_text = f"{name_lower} {summary_lower}" best_type = "Entity" best_score = 0 for et in entity_types: score = 0 type_name = et.get("name", "") type_name_lower = type_name.lower() # Exact name match in type name if type_name_lower in name_lower: score += 10 # Check examples list for example in et.get("examples", []): if example.lower() in search_text: score += 8 elif name_lower in example.lower(): score += 6 # Check description keywords desc_words = (et.get("description", "")).lower().split() for word in desc_words: if len(word) > 4 and word in search_text: score += 1 if score > best_score: best_score = score best_type = type_name return best_type if best_score > 0 else "Entity" @dataclass class GraphInfo: """Summary information about a built graph.""" graph_id: str node_count: int edge_count: int entity_types: List[str] def to_dict(self) -> Dict[str, Any]: return { "graph_id": self.graph_id, "node_count": self.node_count, "edge_count": self.edge_count, "entity_types": self.entity_types, } class GraphBuilderService: """Drives knowledge-graph construction via the Zep/Graphiti API.""" def __init__(self, api_key: Optional[str] = None): self.client = GraphitiAdapter() self.task_manager = TaskManager() def build_graph_async( self, text: str, ontology: Dict[str, Any], graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, batch_size: int = 3 ) -> str: """Kick off a graph build asynchronously. Args: text: Source text to ingest. ontology: Ontology definition (the output of pipeline step 1). graph_name: Display name for the graph. chunk_size: Characters per text chunk. chunk_overlap: Overlap (in characters) between consecutive chunks. batch_size: Number of chunks pushed to Zep per batch. Returns: The id of the task tracking the build. """ # Register a task to track build progress. task_id = self.task_manager.create_task( task_type="graph_build", metadata={ "graph_name": graph_name, "chunk_size": chunk_size, "text_length": len(text), } ) # Capture locale before spawning background thread current_locale = get_locale() # Run the build on a background thread so the request returns immediately. thread = threading.Thread( target=self._build_graph_worker, args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) ) thread.daemon = True thread.start() return task_id def _build_graph_worker( self, task_id: str, text: str, ontology: Dict[str, Any], graph_name: str, chunk_size: int, chunk_overlap: int, batch_size: int, locale: str = 'zh' ): """Background worker that performs the graph build.""" set_locale(locale) try: self.task_manager.update_task( task_id, status=TaskStatus.PROCESSING, progress=5, message=t('progress.startBuildingGraph') ) # 1. Create the graph. graph_id = self.create_graph(graph_name) self.task_manager.update_task( task_id, progress=10, message=t('progress.graphCreated', graphId=graph_id) ) # 2. Set the ontology. self.set_ontology(graph_id, ontology) self.task_manager.update_task( task_id, progress=15, message=t('progress.ontologySet') ) # 3. Split source text into chunks. chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) total_chunks = len(chunks) self.task_manager.update_task( task_id, progress=20, message=t('progress.textSplit', count=total_chunks) ) # 4. Push chunks to the graph in batches. episode_uuids = self.add_text_batches( graph_id, chunks, batch_size, lambda msg, prog: self.task_manager.update_task( task_id, progress=20 + int(prog * 0.4), # 20-60% message=msg ) ) # 5. Wait for Zep to finish processing the episodes. self.task_manager.update_task( task_id, progress=60, message=t('progress.waitingZepProcess') ) self._wait_for_episodes( episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, progress=60 + int(prog * 0.3), # 60-90% message=msg ) ) # 6. Fetch the final graph metadata. self.task_manager.update_task( task_id, progress=90, message=t('progress.fetchingGraphInfo') ) graph_info = self._get_graph_info(graph_id) self.task_manager.complete_task(task_id, { "graph_id": graph_id, "graph_info": graph_info.to_dict(), "chunks_processed": total_chunks, }) except Exception as e: import traceback error_msg = f"{str(e)}\n{traceback.format_exc()}" self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: """Create a new Zep graph and return its id (public API).""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" self.client.graph.create( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """Register the ontology with the graph (Graphiti uses it as an extraction prompt).""" self.client.graph.set_ontology( graph_ids=[graph_id], entities=ontology.get("entity_types"), edges=ontology.get("edge_types"), ) def add_text_batches( self, graph_id: str, chunks: List[str], batch_size: int = 3, progress_callback: Optional[Callable] = None, skip_chunks: int = 0, ) -> List[str]: """Push chunks to the graph in batches; returns the uuids of all episodes added. Args: skip_chunks: Number of chunks to skip (used for resume-after-restart). """ episode_uuids = [] total_chunks = len(chunks) for i in range(skip_chunks, total_chunks, batch_size): batch_chunks = chunks[i:i + batch_size] batch_num = i // batch_size + 1 total_batches = (total_chunks + batch_size - 1) // batch_size if progress_callback: progress = (i + len(batch_chunks)) / total_chunks progress_callback( t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)), progress ) # Build the per-episode payload structures expected by the client. episodes = [ type('Episode', (), {'data': chunk, 'type': 'text'})() for chunk in batch_chunks ] try: batch_result = self.client.graph.add_batch( graph_id=graph_id, episodes=episodes ) # Collect the uuids returned for each episode. if batch_result and isinstance(batch_result, list): for ep in batch_result: ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) if ep_uuid: episode_uuids.append(ep_uuid) # Throttle to avoid overwhelming the upstream API. time.sleep(1) except Exception as e: if progress_callback: progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0) raise return episode_uuids def _wait_for_episodes( self, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): """Poll each episode until Zep marks it processed, or the timeout expires.""" if not episode_uuids: if progress_callback: progress_callback(t('progress.noEpisodesWait'), 1.0) return start_time = time.time() pending_episodes = set(episode_uuids) completed_count = 0 total_episodes = len(episode_uuids) if progress_callback: progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0) while pending_episodes: if time.time() - start_time > timeout: if progress_callback: progress_callback( t('progress.episodesTimeout', completed=completed_count, total=total_episodes), completed_count / total_episodes ) break # Check the processing state of each pending episode. for ep_uuid in list(pending_episodes): try: episode = self.client.graph.episode.get(uuid_=ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 except Exception as e: # Tolerate a single failed query; the next loop iteration retries. pass elapsed = int(time.time() - start_time) if progress_callback: progress_callback( t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed), completed_count / total_episodes if total_episodes > 0 else 0 ) if pending_episodes: time.sleep(3) # poll every 3 seconds if progress_callback: progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) def _get_graph_info(self, graph_id: str) -> GraphInfo: """Fetch summary info (counts and entity types) for a graph.""" nodes = fetch_all_nodes(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id) # Tally distinct entity types across all nodes. entity_types = set() for node in nodes: if node.labels: for label in node.labels: if label not in ["Entity", "Node"]: entity_types.add(label) return GraphInfo( graph_id=graph_id, node_count=len(nodes), edge_count=len(edges), entity_types=list(entity_types) ) def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]: """Return the full graph payload including timestamps, attributes, and edges. Args: graph_id: Graph identifier. Returns: Dict with ``nodes``, ``edges``, and aggregate counts. """ nodes = fetch_all_nodes(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id) # Build a uuid->name map so edge endpoints can be labeled. node_map = {} for node in nodes: node_map[node.uuid_] = node.name or "" nodes_data = [] for node in nodes: created_at = getattr(node, 'created_at', None) if created_at: created_at = str(created_at) entity_type = _classify_entity_type(node.name, node.summary or "", ontology) labels = node.labels or [] if entity_type != "Entity" and entity_type not in labels: labels = [entity_type] + [l for l in labels if l != "Entity"] nodes_data.append({ "uuid": node.uuid_, "name": node.name, "labels": labels, "summary": node.summary or "", "attributes": node.attributes or {}, "created_at": created_at, }) edges_data = [] for edge in edges: created_at = getattr(edge, 'created_at', None) valid_at = getattr(edge, 'valid_at', None) invalid_at = getattr(edge, 'invalid_at', None) expired_at = getattr(edge, 'expired_at', None) # Normalize the episode list (the field may be missing or a single id). episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) if episodes and not isinstance(episodes, list): episodes = [str(episodes)] elif episodes: episodes = [str(e) for e in episodes] fact_type = getattr(edge, 'fact_type', None) or edge.name or "" edges_data.append({ "uuid": edge.uuid_, "name": edge.name or "", "fact": edge.fact or "", "fact_type": fact_type, "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, "source_node_name": node_map.get(edge.source_node_uuid, ""), "target_node_name": node_map.get(edge.target_node_uuid, ""), "attributes": edge.attributes or {}, "created_at": str(created_at) if created_at else None, "valid_at": str(valid_at) if valid_at else None, "invalid_at": str(invalid_at) if invalid_at else None, "expired_at": str(expired_at) if expired_at else None, "episodes": episodes or [], }) return { "graph_id": graph_id, "nodes": nodes_data, "edges": edges_data, "node_count": len(nodes_data), "edge_count": len(edges_data), } def delete_graph(self, graph_id: str): """Delete a graph by id.""" self.client.graph.delete(graph_id=graph_id)