444 lines
15 KiB
Python
444 lines
15 KiB
Python
"""
|
|
Graph building service
|
|
Endpoint 2: Build a Standalone Graph using the Zep API
|
|
"""
|
|
|
|
import os
|
|
import uuid
|
|
import time
|
|
import threading
|
|
from typing import Dict, Any, List, Optional, Callable
|
|
from dataclasses import dataclass
|
|
|
|
from ..config import Config
|
|
from ..graph import get_graph_backend
|
|
from ..models.task import TaskManager, TaskStatus
|
|
from .text_processor import TextProcessor
|
|
from ..utils.locale import t, get_locale, set_locale
|
|
|
|
|
|
@dataclass
|
|
class GraphInfo:
|
|
"""Graph info"""
|
|
graph_id: str
|
|
node_count: int
|
|
edge_count: int
|
|
entity_types: List[str]
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"graph_id": self.graph_id,
|
|
"node_count": self.node_count,
|
|
"edge_count": self.edge_count,
|
|
"entity_types": self.entity_types,
|
|
}
|
|
|
|
|
|
class GraphBuilderService:
|
|
"""
|
|
Graph building service
|
|
Responsible for calling the Zep API to build the knowledge graph.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._graph = get_graph_backend()
|
|
self.task_manager = TaskManager()
|
|
|
|
def build_graph_async(
|
|
self,
|
|
text: str,
|
|
ontology: Dict[str, Any],
|
|
graph_name: str = "MiroFish Graph",
|
|
chunk_size: int = 500,
|
|
chunk_overlap: int = 50,
|
|
batch_size: int = 3
|
|
) -> str:
|
|
"""
|
|
Build the graph asynchronously.
|
|
|
|
Args:
|
|
text: input text
|
|
ontology: ontology definition (output from endpoint 1)
|
|
graph_name: graph name
|
|
chunk_size: text chunk size
|
|
chunk_overlap: chunk overlap size
|
|
batch_size: number of chunks per batch
|
|
|
|
Returns:
|
|
task ID
|
|
"""
|
|
# Create task
|
|
task_id = self.task_manager.create_task(
|
|
task_type="graph_build",
|
|
metadata={
|
|
"graph_name": graph_name,
|
|
"chunk_size": chunk_size,
|
|
"text_length": len(text),
|
|
}
|
|
)
|
|
|
|
# Capture locale before spawning background thread
|
|
current_locale = get_locale()
|
|
|
|
# Run build in background thread
|
|
thread = threading.Thread(
|
|
target=self._build_graph_worker,
|
|
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
|
)
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
return task_id
|
|
|
|
def _build_graph_worker(
|
|
self,
|
|
task_id: str,
|
|
text: str,
|
|
ontology: Dict[str, Any],
|
|
graph_name: str,
|
|
chunk_size: int,
|
|
chunk_overlap: int,
|
|
batch_size: int,
|
|
locale: str = 'zh'
|
|
):
|
|
"""Graph build worker thread"""
|
|
set_locale(locale)
|
|
try:
|
|
self.task_manager.update_task(
|
|
task_id,
|
|
status=TaskStatus.PROCESSING,
|
|
progress=5,
|
|
message=t('progress.startBuildingGraph')
|
|
)
|
|
|
|
# 1. Create graph
|
|
graph_id = self.create_graph(graph_name)
|
|
self.task_manager.update_task(
|
|
task_id,
|
|
progress=10,
|
|
message=t('progress.graphCreated', graphId=graph_id)
|
|
)
|
|
|
|
# 2. Set ontology
|
|
self.set_ontology(graph_id, ontology)
|
|
self.task_manager.update_task(
|
|
task_id,
|
|
progress=15,
|
|
message=t('progress.ontologySet')
|
|
)
|
|
|
|
# 3. Split text into chunks
|
|
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
|
|
total_chunks = len(chunks)
|
|
self.task_manager.update_task(
|
|
task_id,
|
|
progress=20,
|
|
message=t('progress.textSplit', count=total_chunks)
|
|
)
|
|
|
|
# 4. Send data in batches
|
|
episode_uuids = self.add_text_batches(
|
|
graph_id, chunks, batch_size,
|
|
lambda msg, prog: self.task_manager.update_task(
|
|
task_id,
|
|
progress=20 + int(prog * 0.4), # 20-60%
|
|
message=msg
|
|
)
|
|
)
|
|
|
|
# 5. Wait for Zep processing to complete
|
|
self.task_manager.update_task(
|
|
task_id,
|
|
progress=60,
|
|
message=t('progress.waitingZepProcess')
|
|
)
|
|
|
|
self._wait_for_episodes(
|
|
episode_uuids,
|
|
lambda msg, prog: self.task_manager.update_task(
|
|
task_id,
|
|
progress=60 + int(prog * 0.3), # 60-90%
|
|
message=msg
|
|
)
|
|
)
|
|
|
|
# 6. Fetch graph info
|
|
self.task_manager.update_task(
|
|
task_id,
|
|
progress=90,
|
|
message=t('progress.fetchingGraphInfo')
|
|
)
|
|
|
|
graph_info = self._get_graph_info(graph_id)
|
|
|
|
# Complete
|
|
self.task_manager.complete_task(task_id, {
|
|
"graph_id": graph_id,
|
|
"graph_info": graph_info.to_dict(),
|
|
"chunks_processed": total_chunks,
|
|
})
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
|
self.task_manager.fail_task(task_id, error_msg)
|
|
|
|
def create_graph(self, name: str) -> str:
|
|
"""Create a graph (public method)"""
|
|
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
|
self._graph.create_graph(
|
|
graph_id=graph_id,
|
|
name=name,
|
|
description="MiroFish Social Simulation Graph"
|
|
)
|
|
return graph_id
|
|
|
|
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
|
"""Set graph ontology (public method)"""
|
|
from ..config import Config
|
|
if Config.GRAPH_BACKEND != "zep":
|
|
entities = {
|
|
e["name"]: {
|
|
"description": e.get("description", ""),
|
|
"attributes": e.get("attributes", []),
|
|
}
|
|
for e in ontology.get("entity_types", [])
|
|
}
|
|
edges = {
|
|
e["name"]: {
|
|
"description": e.get("description", ""),
|
|
"attributes": e.get("attributes", []),
|
|
}
|
|
for e in ontology.get("edge_types", [])
|
|
}
|
|
self._graph.set_ontology(graph_ids=[graph_id], entities=entities, edges=edges)
|
|
return
|
|
|
|
import warnings
|
|
from typing import Optional
|
|
from pydantic import Field
|
|
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
|
|
|
# Suppress Pydantic v2 warnings about Field(default=None)
|
|
# This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored
|
|
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
|
|
|
# Zep reserved names that cannot be used as attribute names
|
|
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
|
|
|
def safe_attr_name(attr_name: str) -> str:
|
|
if attr_name.lower() in RESERVED_NAMES:
|
|
return f"entity_{attr_name}"
|
|
return attr_name
|
|
|
|
# Dynamically create entity types
|
|
entity_types = {}
|
|
for entity_def in ontology.get("entity_types", []):
|
|
name = entity_def["name"]
|
|
description = entity_def.get("description", f"A {name} entity.")
|
|
|
|
attrs = {"__doc__": description}
|
|
annotations = {}
|
|
|
|
for attr_def in GraphBuilderService._normalize_entity_attributes(entity_def.get("attributes", [])):
|
|
attr_name = safe_attr_name(attr_def["name"])
|
|
attr_desc = attr_def.get("description", attr_name)
|
|
attrs[attr_name] = Field(description=attr_desc, default=None)
|
|
annotations[attr_name] = Optional[EntityText]
|
|
|
|
attrs["__annotations__"] = annotations
|
|
entity_class = type(name, (EntityModel,), attrs)
|
|
entity_class.__doc__ = description
|
|
entity_types[name] = entity_class
|
|
|
|
# Dynamically create edge types
|
|
edge_definitions = {}
|
|
for edge_def in ontology.get("edge_types", []):
|
|
name = edge_def["name"]
|
|
description = edge_def.get("description", f"A {name} relationship.")
|
|
|
|
attrs = {"__doc__": description}
|
|
annotations = {}
|
|
|
|
for attr_def in GraphBuilderService._normalize_entity_attributes(edge_def.get("attributes", [])):
|
|
attr_name = safe_attr_name(attr_def["name"])
|
|
attr_desc = attr_def.get("description", attr_name)
|
|
attrs[attr_name] = Field(description=attr_desc, default=None)
|
|
annotations[attr_name] = Optional[str]
|
|
|
|
attrs["__annotations__"] = annotations
|
|
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
|
edge_class = type(class_name, (EdgeModel,), attrs)
|
|
edge_class.__doc__ = description
|
|
|
|
from zep_cloud import EntityEdgeSourceTarget
|
|
source_targets = []
|
|
for st in edge_def.get("source_targets", []):
|
|
source_targets.append(
|
|
EntityEdgeSourceTarget(
|
|
source=st.get("source", "Entity"),
|
|
target=st.get("target", "Entity")
|
|
)
|
|
)
|
|
|
|
if source_targets:
|
|
edge_definitions[name] = (edge_class, source_targets)
|
|
|
|
if entity_types or edge_definitions:
|
|
self._graph.set_ontology(
|
|
graph_ids=[graph_id],
|
|
entities=entity_types if entity_types else None,
|
|
edges=edge_definitions if edge_definitions else None,
|
|
)
|
|
|
|
def add_text_batches(
|
|
self,
|
|
graph_id: str,
|
|
chunks: List[str],
|
|
batch_size: int = 3,
|
|
progress_callback: Optional[Callable] = None
|
|
) -> List[str]:
|
|
"""Add text to the graph in batches; returns a list of all episode UUIDs"""
|
|
episode_uuids = []
|
|
total_chunks = len(chunks)
|
|
|
|
for i in range(0, total_chunks, batch_size):
|
|
batch_chunks = chunks[i:i + batch_size]
|
|
batch_num = i // batch_size + 1
|
|
total_batches = (total_chunks + batch_size - 1) // batch_size
|
|
|
|
if progress_callback:
|
|
progress = (i + len(batch_chunks)) / total_chunks
|
|
progress_callback(
|
|
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
|
|
progress
|
|
)
|
|
|
|
episodes = [
|
|
{"data": chunk, "type": "text"}
|
|
for chunk in batch_chunks
|
|
]
|
|
|
|
try:
|
|
returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes)
|
|
episode_uuids.extend(returned_uuids)
|
|
|
|
# Avoid sending requests too quickly
|
|
time.sleep(1)
|
|
|
|
except Exception as e:
|
|
if progress_callback:
|
|
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
|
|
raise
|
|
|
|
return episode_uuids
|
|
|
|
def _wait_for_episodes(
|
|
self,
|
|
episode_uuids: List[str],
|
|
progress_callback: Optional[Callable] = None,
|
|
timeout: int = 600
|
|
):
|
|
"""Wait for all episodes to finish processing (by polling each episode's processed status)"""
|
|
if not episode_uuids:
|
|
if progress_callback:
|
|
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
|
return
|
|
|
|
start_time = time.time()
|
|
pending_episodes = set(episode_uuids)
|
|
completed_count = 0
|
|
total_episodes = len(episode_uuids)
|
|
|
|
if progress_callback:
|
|
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
|
|
|
|
while pending_episodes:
|
|
if time.time() - start_time > timeout:
|
|
if progress_callback:
|
|
progress_callback(
|
|
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
|
|
completed_count / total_episodes
|
|
)
|
|
break
|
|
|
|
# Check processing status of each episode
|
|
for ep_uuid in list(pending_episodes):
|
|
try:
|
|
episode = self._graph.get_episode(ep_uuid)
|
|
is_processed = getattr(episode, 'processed', False)
|
|
|
|
if is_processed:
|
|
pending_episodes.remove(ep_uuid)
|
|
completed_count += 1
|
|
|
|
except Exception as e:
|
|
# Ignore individual query errors and continue
|
|
pass
|
|
|
|
elapsed = int(time.time() - start_time)
|
|
if progress_callback:
|
|
progress_callback(
|
|
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
|
|
completed_count / total_episodes if total_episodes > 0 else 0
|
|
)
|
|
|
|
if pending_episodes:
|
|
time.sleep(3) # Check every 3 seconds
|
|
|
|
if progress_callback:
|
|
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
|
|
|
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
|
"""Retrieve graph info"""
|
|
nodes = self._graph.get_all_nodes(graph_id)
|
|
edges = self._graph.get_all_edges(graph_id)
|
|
|
|
entity_types = set()
|
|
for node in nodes:
|
|
for label in node.get("labels", []):
|
|
if label not in ["Entity", "Node"]:
|
|
entity_types.add(label)
|
|
|
|
return GraphInfo(
|
|
graph_id=graph_id,
|
|
node_count=len(nodes),
|
|
edge_count=len(edges),
|
|
entity_types=list(entity_types)
|
|
)
|
|
|
|
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
|
"""Retrieve full graph data (nodes + edges with timestamps and attributes)."""
|
|
nodes = self._graph.get_all_nodes(graph_id)
|
|
edges = self._graph.get_all_edges(graph_id)
|
|
|
|
node_map = {n["uuid"]: n.get("name", "") for n in nodes}
|
|
|
|
return {
|
|
"graph_id": graph_id,
|
|
"nodes": nodes,
|
|
"edges": [
|
|
{**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""),
|
|
"target_node_name": node_map.get(e.get("target_node_uuid", ""), "")}
|
|
for e in edges
|
|
],
|
|
"node_count": len(nodes),
|
|
"edge_count": len(edges),
|
|
}
|
|
|
|
def delete_graph(self, graph_id: str):
|
|
"""Delete graph"""
|
|
self._graph.delete_graph(graph_id)
|
|
|
|
@staticmethod
|
|
def _normalize_entity_attributes(attributes: list) -> list:
|
|
"""Ensure each attribute item is a dict; convert strings to minimal dicts."""
|
|
result = []
|
|
for attr in attributes:
|
|
if isinstance(attr, str):
|
|
result.append({"name": attr, "type": "text", "description": attr})
|
|
elif isinstance(attr, dict):
|
|
result.append(attr)
|
|
return result
|
|
|