MicroFish/backend/app/services/graph_builder.py

468 lines
16 KiB
Python

"""Graph build service.
Pipeline step 2: build the project's standalone knowledge graph through the
Zep/Graphiti API.
"""
import os
import uuid
import time
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from .graphiti_adapter import GraphitiAdapter
from ..config import Config
from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale
def _classify_entity_type(name: str, summary: str, ontology: Optional[Dict]) -> str:
"""
Classify an entity into an ontology type using keyword matching
against entity type names, descriptions, and examples.
Falls back to 'Entity' if no ontology or no match found.
"""
if not ontology:
return "Entity"
entity_types = ontology.get("entity_types", [])
if not entity_types:
return "Entity"
name_lower = (name or "").lower()
summary_lower = (summary or "").lower()
search_text = f"{name_lower} {summary_lower}"
best_type = "Entity"
best_score = 0
for et in entity_types:
score = 0
type_name = et.get("name", "")
type_name_lower = type_name.lower()
# Exact name match in type name
if type_name_lower in name_lower:
score += 10
# Check examples list
for example in et.get("examples", []):
if example.lower() in search_text:
score += 8
elif name_lower in example.lower():
score += 6
# Check description keywords
desc_words = (et.get("description", "")).lower().split()
for word in desc_words:
if len(word) > 4 and word in search_text:
score += 1
if score > best_score:
best_score = score
best_type = type_name
return best_type if best_score > 0 else "Entity"
@dataclass
class GraphInfo:
"""Summary information about a built graph."""
graph_id: str
node_count: int
edge_count: int
entity_types: List[str]
def to_dict(self) -> Dict[str, Any]:
return {
"graph_id": self.graph_id,
"node_count": self.node_count,
"edge_count": self.edge_count,
"entity_types": self.entity_types,
}
class GraphBuilderService:
"""Drives knowledge-graph construction via the Zep/Graphiti API."""
def __init__(self, api_key: Optional[str] = None):
self.client = GraphitiAdapter()
self.task_manager = TaskManager()
def build_graph_async(
self,
text: str,
ontology: Dict[str, Any],
graph_name: str = "MiroFish Graph",
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 3
) -> str:
"""Kick off a graph build asynchronously.
Args:
text: Source text to ingest.
ontology: Ontology definition (the output of pipeline step 1).
graph_name: Display name for the graph.
chunk_size: Characters per text chunk.
chunk_overlap: Overlap (in characters) between consecutive chunks.
batch_size: Number of chunks pushed to Zep per batch.
Returns:
The id of the task tracking the build.
"""
# Register a task to track build progress.
task_id = self.task_manager.create_task(
task_type="graph_build",
metadata={
"graph_name": graph_name,
"chunk_size": chunk_size,
"text_length": len(text),
}
)
# Capture locale before spawning background thread
current_locale = get_locale()
# Run the build on a background thread so the request returns immediately.
thread = threading.Thread(
target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
)
thread.daemon = True
thread.start()
return task_id
def _build_graph_worker(
self,
task_id: str,
text: str,
ontology: Dict[str, Any],
graph_name: str,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
locale: str = 'zh'
):
"""Background worker that performs the graph build."""
set_locale(locale)
try:
self.task_manager.update_task(
task_id,
status=TaskStatus.PROCESSING,
progress=5,
message=t('progress.startBuildingGraph')
)
# 1. Create the graph.
graph_id = self.create_graph(graph_name)
self.task_manager.update_task(
task_id,
progress=10,
message=t('progress.graphCreated', graphId=graph_id)
)
# 2. Set the ontology.
self.set_ontology(graph_id, ontology)
self.task_manager.update_task(
task_id,
progress=15,
message=t('progress.ontologySet')
)
# 3. Split source text into chunks.
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks)
self.task_manager.update_task(
task_id,
progress=20,
message=t('progress.textSplit', count=total_chunks)
)
# 4. Push chunks to the graph in batches.
episode_uuids = self.add_text_batches(
graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=20 + int(prog * 0.4), # 20-60%
message=msg
)
)
# 5. Wait for Zep to finish processing the episodes.
self.task_manager.update_task(
task_id,
progress=60,
message=t('progress.waitingZepProcess')
)
self._wait_for_episodes(
episode_uuids,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=60 + int(prog * 0.3), # 60-90%
message=msg
)
)
# 6. Fetch the final graph metadata.
self.task_manager.update_task(
task_id,
progress=90,
message=t('progress.fetchingGraphInfo')
)
graph_info = self._get_graph_info(graph_id)
self.task_manager.complete_task(task_id, {
"graph_id": graph_id,
"graph_info": graph_info.to_dict(),
"chunks_processed": total_chunks,
})
except Exception as e:
import traceback
error_msg = f"{str(e)}\n{traceback.format_exc()}"
self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str:
"""Create a new Zep graph and return its id (public API)."""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
)
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""Register the ontology with the graph (Graphiti uses it as an extraction prompt)."""
self.client.graph.set_ontology(
graph_ids=[graph_id],
entities=ontology.get("entity_types"),
edges=ontology.get("edge_types"),
)
def add_text_batches(
self,
graph_id: str,
chunks: List[str],
batch_size: int = 3,
progress_callback: Optional[Callable] = None,
skip_chunks: int = 0,
) -> List[str]:
"""Push chunks to the graph in batches; returns the uuids of all episodes added.
Args:
skip_chunks: Number of chunks to skip (used for resume-after-restart).
"""
episode_uuids = []
total_chunks = len(chunks)
for i in range(skip_chunks, total_chunks, batch_size):
batch_chunks = chunks[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_chunks + batch_size - 1) // batch_size
if progress_callback:
progress = (i + len(batch_chunks)) / total_chunks
progress_callback(
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
progress
)
# Build the per-episode payload structures expected by the client.
episodes = [
type('Episode', (), {'data': chunk, 'type': 'text'})()
for chunk in batch_chunks
]
try:
batch_result = self.client.graph.add_batch(
graph_id=graph_id,
episodes=episodes
)
# Collect the uuids returned for each episode.
if batch_result and isinstance(batch_result, list):
for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid:
episode_uuids.append(ep_uuid)
# Throttle to avoid overwhelming the upstream API.
time.sleep(1)
except Exception as e:
if progress_callback:
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
raise
return episode_uuids
def _wait_for_episodes(
self,
episode_uuids: List[str],
progress_callback: Optional[Callable] = None,
timeout: int = 600
):
"""Poll each episode until Zep marks it processed, or the timeout expires."""
if not episode_uuids:
if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0)
return
start_time = time.time()
pending_episodes = set(episode_uuids)
completed_count = 0
total_episodes = len(episode_uuids)
if progress_callback:
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
while pending_episodes:
if time.time() - start_time > timeout:
if progress_callback:
progress_callback(
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
completed_count / total_episodes
)
break
# Check the processing state of each pending episode.
for ep_uuid in list(pending_episodes):
try:
episode = self.client.graph.episode.get(uuid_=ep_uuid)
is_processed = getattr(episode, 'processed', False)
if is_processed:
pending_episodes.remove(ep_uuid)
completed_count += 1
except Exception as e:
# Tolerate a single failed query; the next loop iteration retries.
pass
elapsed = int(time.time() - start_time)
if progress_callback:
progress_callback(
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
completed_count / total_episodes if total_episodes > 0 else 0
)
if pending_episodes:
time.sleep(3) # poll every 3 seconds
if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""Fetch summary info (counts and entity types) for a graph."""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# Tally distinct entity types across all nodes.
entity_types = set()
for node in nodes:
if node.labels:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
node_count=len(nodes),
edge_count=len(edges),
entity_types=list(entity_types)
)
def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]:
"""Return the full graph payload including timestamps, attributes, and edges.
Args:
graph_id: Graph identifier.
Returns:
Dict with ``nodes``, ``edges``, and aggregate counts.
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# Build a uuid->name map so edge endpoints can be labeled.
node_map = {}
for node in nodes:
node_map[node.uuid_] = node.name or ""
nodes_data = []
for node in nodes:
created_at = getattr(node, 'created_at', None)
if created_at:
created_at = str(created_at)
entity_type = _classify_entity_type(node.name, node.summary or "", ontology)
labels = node.labels or []
if entity_type != "Entity" and entity_type not in labels:
labels = [entity_type] + [l for l in labels if l != "Entity"]
nodes_data.append({
"uuid": node.uuid_,
"name": node.name,
"labels": labels,
"summary": node.summary or "",
"attributes": node.attributes or {},
"created_at": created_at,
})
edges_data = []
for edge in edges:
created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None)
# Normalize the episode list (the field may be missing or a single id).
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list):
episodes = [str(episodes)]
elif episodes:
episodes = [str(e) for e in episodes]
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({
"uuid": edge.uuid_,
"name": edge.name or "",
"fact": edge.fact or "",
"fact_type": fact_type,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"source_node_name": node_map.get(edge.source_node_uuid, ""),
"target_node_name": node_map.get(edge.target_node_uuid, ""),
"attributes": edge.attributes or {},
"created_at": str(created_at) if created_at else None,
"valid_at": str(valid_at) if valid_at else None,
"invalid_at": str(invalid_at) if invalid_at else None,
"expired_at": str(expired_at) if expired_at else None,
"episodes": episodes or [],
})
return {
"graph_id": graph_id,
"nodes": nodes_data,
"edges": edges_data,
"node_count": len(nodes_data),
"edge_count": len(edges_data),
}
def delete_graph(self, graph_id: str):
"""Delete a graph by id."""
self.client.graph.delete(graph_id=graph_id)