MicroFish/backend/app/services/graph_builder.py

433 lines
15 KiB
Python

"""
Graph building service
Endpoint 2: Build a Standalone Graph using the Zep API
"""
import os
import uuid
import time
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from ..config import Config
from ..graph import get_graph_backend
from ..models.task import TaskManager, TaskStatus
from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale
@dataclass
class GraphInfo:
"""Graph info"""
graph_id: str
node_count: int
edge_count: int
entity_types: List[str]
def to_dict(self) -> Dict[str, Any]:
return {
"graph_id": self.graph_id,
"node_count": self.node_count,
"edge_count": self.edge_count,
"entity_types": self.entity_types,
}
class GraphBuilderService:
"""
Graph building service
Responsible for calling the Zep API to build the knowledge graph.
"""
def __init__(self):
self._graph = get_graph_backend()
self.task_manager = TaskManager()
def build_graph_async(
self,
text: str,
ontology: Dict[str, Any],
graph_name: str = "MiroFish Graph",
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 3
) -> str:
"""
Build the graph asynchronously.
Args:
text: input text
ontology: ontology definition (output from endpoint 1)
graph_name: graph name
chunk_size: text chunk size
chunk_overlap: chunk overlap size
batch_size: number of chunks per batch
Returns:
task ID
"""
# Create task
task_id = self.task_manager.create_task(
task_type="graph_build",
metadata={
"graph_name": graph_name,
"chunk_size": chunk_size,
"text_length": len(text),
}
)
# Capture locale before spawning background thread
current_locale = get_locale()
# Run build in background thread
thread = threading.Thread(
target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
)
thread.daemon = True
thread.start()
return task_id
def _build_graph_worker(
self,
task_id: str,
text: str,
ontology: Dict[str, Any],
graph_name: str,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
locale: str = 'zh'
):
"""Graph build worker thread"""
set_locale(locale)
try:
self.task_manager.update_task(
task_id,
status=TaskStatus.PROCESSING,
progress=5,
message=t('progress.startBuildingGraph')
)
# 1. Create graph
graph_id = self.create_graph(graph_name)
self.task_manager.update_task(
task_id,
progress=10,
message=t('progress.graphCreated', graphId=graph_id)
)
# 2. Set ontology
self.set_ontology(graph_id, ontology)
self.task_manager.update_task(
task_id,
progress=15,
message=t('progress.ontologySet')
)
# 3. Split text into chunks
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks)
self.task_manager.update_task(
task_id,
progress=20,
message=t('progress.textSplit', count=total_chunks)
)
# 4. Send data in batches
episode_uuids = self.add_text_batches(
graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=20 + int(prog * 0.4), # 20-60%
message=msg
)
)
# 5. Wait for Zep processing to complete
self.task_manager.update_task(
task_id,
progress=60,
message=t('progress.waitingZepProcess')
)
self._wait_for_episodes(
episode_uuids,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=60 + int(prog * 0.3), # 60-90%
message=msg
)
)
# 6. Fetch graph info
self.task_manager.update_task(
task_id,
progress=90,
message=t('progress.fetchingGraphInfo')
)
graph_info = self._get_graph_info(graph_id)
# Complete
self.task_manager.complete_task(task_id, {
"graph_id": graph_id,
"graph_info": graph_info.to_dict(),
"chunks_processed": total_chunks,
})
except Exception as e:
import traceback
error_msg = f"{str(e)}\n{traceback.format_exc()}"
self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str:
"""Create a graph (public method)"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self._graph.create_graph(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
)
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""Set graph ontology (public method)"""
from ..config import Config
if Config.GRAPH_BACKEND != "zep":
entities = {
e["name"]: {
"description": e.get("description", ""),
"attributes": e.get("attributes", []),
}
for e in ontology.get("entity_types", [])
}
edges = {
e["name"]: {
"description": e.get("description", ""),
"attributes": e.get("attributes", []),
}
for e in ontology.get("edge_types", [])
}
self._graph.set_ontology(graph_ids=[graph_id], entities=entities, edges=edges)
return
import warnings
from typing import Optional
from pydantic import Field
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
# Suppress Pydantic v2 warnings about Field(default=None)
# This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
# Zep reserved names that cannot be used as attribute names
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
def safe_attr_name(attr_name: str) -> str:
if attr_name.lower() in RESERVED_NAMES:
return f"entity_{attr_name}"
return attr_name
# Dynamically create entity types
entity_types = {}
for entity_def in ontology.get("entity_types", []):
name = entity_def["name"]
description = entity_def.get("description", f"A {name} entity.")
attrs = {"__doc__": description}
annotations = {}
for attr_def in entity_def.get("attributes", []):
attr_name = safe_attr_name(attr_def["name"])
attr_desc = attr_def.get("description", attr_name)
attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[EntityText]
attrs["__annotations__"] = annotations
entity_class = type(name, (EntityModel,), attrs)
entity_class.__doc__ = description
entity_types[name] = entity_class
# Dynamically create edge types
edge_definitions = {}
for edge_def in ontology.get("edge_types", []):
name = edge_def["name"]
description = edge_def.get("description", f"A {name} relationship.")
attrs = {"__doc__": description}
annotations = {}
for attr_def in edge_def.get("attributes", []):
attr_name = safe_attr_name(attr_def["name"])
attr_desc = attr_def.get("description", attr_name)
attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[str]
attrs["__annotations__"] = annotations
class_name = ''.join(word.capitalize() for word in name.split('_'))
edge_class = type(class_name, (EdgeModel,), attrs)
edge_class.__doc__ = description
from zep_cloud import EntityEdgeSourceTarget
source_targets = []
for st in edge_def.get("source_targets", []):
source_targets.append(
EntityEdgeSourceTarget(
source=st.get("source", "Entity"),
target=st.get("target", "Entity")
)
)
if source_targets:
edge_definitions[name] = (edge_class, source_targets)
if entity_types or edge_definitions:
self._graph.set_ontology(
graph_ids=[graph_id],
entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None,
)
def add_text_batches(
self,
graph_id: str,
chunks: List[str],
batch_size: int = 3,
progress_callback: Optional[Callable] = None
) -> List[str]:
"""Add text to the graph in batches; returns a list of all episode UUIDs"""
episode_uuids = []
total_chunks = len(chunks)
for i in range(0, total_chunks, batch_size):
batch_chunks = chunks[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_chunks + batch_size - 1) // batch_size
if progress_callback:
progress = (i + len(batch_chunks)) / total_chunks
progress_callback(
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
progress
)
episodes = [
{"data": chunk, "type": "text"}
for chunk in batch_chunks
]
try:
returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes)
episode_uuids.extend(returned_uuids)
# Avoid sending requests too quickly
time.sleep(1)
except Exception as e:
if progress_callback:
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
raise
return episode_uuids
def _wait_for_episodes(
self,
episode_uuids: List[str],
progress_callback: Optional[Callable] = None,
timeout: int = 600
):
"""Wait for all episodes to finish processing (by polling each episode's processed status)"""
if not episode_uuids:
if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0)
return
start_time = time.time()
pending_episodes = set(episode_uuids)
completed_count = 0
total_episodes = len(episode_uuids)
if progress_callback:
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
while pending_episodes:
if time.time() - start_time > timeout:
if progress_callback:
progress_callback(
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
completed_count / total_episodes
)
break
# Check processing status of each episode
for ep_uuid in list(pending_episodes):
try:
episode = self._graph.get_episode(ep_uuid)
is_processed = getattr(episode, 'processed', False)
if is_processed:
pending_episodes.remove(ep_uuid)
completed_count += 1
except Exception as e:
# Ignore individual query errors and continue
pass
elapsed = int(time.time() - start_time)
if progress_callback:
progress_callback(
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
completed_count / total_episodes if total_episodes > 0 else 0
)
if pending_episodes:
time.sleep(3) # Check every 3 seconds
if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""Retrieve graph info"""
nodes = self._graph.get_all_nodes(graph_id)
edges = self._graph.get_all_edges(graph_id)
entity_types = set()
for node in nodes:
for label in node.get("labels", []):
if label not in ["Entity", "Node"]:
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
node_count=len(nodes),
edge_count=len(edges),
entity_types=list(entity_types)
)
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""Retrieve full graph data (nodes + edges with timestamps and attributes)."""
nodes = self._graph.get_all_nodes(graph_id)
edges = self._graph.get_all_edges(graph_id)
node_map = {n["uuid"]: n.get("name", "") for n in nodes}
return {
"graph_id": graph_id,
"nodes": nodes,
"edges": [
{**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""),
"target_node_name": node_map.get(e.get("target_node_uuid", ""), "")}
for e in edges
],
"node_count": len(nodes),
"edge_count": len(edges),
}
def delete_graph(self, graph_id: str):
"""Delete graph"""
self._graph.delete_graph(graph_id)