MicroFish/backend/app/services/graph_builder.py

477 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图谱构建服务
接口2使用Zep API构建Standalone Graph
"""
import os
import uuid
import time
import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from .graphiti_adapter import GraphitiAdapter
from ..config import Config
from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale
def _classify_entity_type(name: str, summary: str, ontology: Optional[Dict]) -> str:
"""
Classify an entity into an ontology type using keyword matching
against entity type names, descriptions, and examples.
Falls back to 'Entity' if no ontology or no match found.
"""
if not ontology:
return "Entity"
entity_types = ontology.get("entity_types", [])
if not entity_types:
return "Entity"
name_lower = (name or "").lower()
summary_lower = (summary or "").lower()
search_text = f"{name_lower} {summary_lower}"
best_type = "Entity"
best_score = 0
for et in entity_types:
score = 0
type_name = et.get("name", "")
type_name_lower = type_name.lower()
# Exact name match in type name
if type_name_lower in name_lower:
score += 10
# Check examples list
for example in et.get("examples", []):
if example.lower() in search_text:
score += 8
elif name_lower in example.lower():
score += 6
# Check description keywords
desc_words = (et.get("description", "")).lower().split()
for word in desc_words:
if len(word) > 4 and word in search_text:
score += 1
if score > best_score:
best_score = score
best_type = type_name
return best_type if best_score > 0 else "Entity"
@dataclass
class GraphInfo:
"""图谱信息"""
graph_id: str
node_count: int
edge_count: int
entity_types: List[str]
def to_dict(self) -> Dict[str, Any]:
return {
"graph_id": self.graph_id,
"node_count": self.node_count,
"edge_count": self.edge_count,
"entity_types": self.entity_types,
}
class GraphBuilderService:
"""
图谱构建服务
负责调用Zep API构建知识图谱
"""
def __init__(self, api_key: Optional[str] = None):
self.client = GraphitiAdapter()
self.task_manager = TaskManager()
def build_graph_async(
self,
text: str,
ontology: Dict[str, Any],
graph_name: str = "MiroFish Graph",
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 3
) -> str:
"""
异步构建图谱
Args:
text: 输入文本
ontology: 本体定义来自接口1的输出
graph_name: 图谱名称
chunk_size: 文本块大小
chunk_overlap: 块重叠大小
batch_size: 每批发送的块数量
Returns:
任务ID
"""
# 创建任务
task_id = self.task_manager.create_task(
task_type="graph_build",
metadata={
"graph_name": graph_name,
"chunk_size": chunk_size,
"text_length": len(text),
}
)
# Capture locale before spawning background thread
current_locale = get_locale()
# 在后台线程中执行构建
thread = threading.Thread(
target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
)
thread.daemon = True
thread.start()
return task_id
def _build_graph_worker(
self,
task_id: str,
text: str,
ontology: Dict[str, Any],
graph_name: str,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
locale: str = 'zh'
):
"""图谱构建工作线程"""
set_locale(locale)
try:
self.task_manager.update_task(
task_id,
status=TaskStatus.PROCESSING,
progress=5,
message=t('progress.startBuildingGraph')
)
# 1. 创建图谱
graph_id = self.create_graph(graph_name)
self.task_manager.update_task(
task_id,
progress=10,
message=t('progress.graphCreated', graphId=graph_id)
)
# 2. 设置本体
self.set_ontology(graph_id, ontology)
self.task_manager.update_task(
task_id,
progress=15,
message=t('progress.ontologySet')
)
# 3. 文本分块
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks)
self.task_manager.update_task(
task_id,
progress=20,
message=t('progress.textSplit', count=total_chunks)
)
# 4. 分批发送数据
episode_uuids = self.add_text_batches(
graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=20 + int(prog * 0.4), # 20-60%
message=msg
)
)
# 5. 等待Zep处理完成
self.task_manager.update_task(
task_id,
progress=60,
message=t('progress.waitingZepProcess')
)
self._wait_for_episodes(
episode_uuids,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=60 + int(prog * 0.3), # 60-90%
message=msg
)
)
# 6. 获取图谱信息
self.task_manager.update_task(
task_id,
progress=90,
message=t('progress.fetchingGraphInfo')
)
graph_info = self._get_graph_info(graph_id)
# 完成
self.task_manager.complete_task(task_id, {
"graph_id": graph_id,
"graph_info": graph_info.to_dict(),
"chunks_processed": total_chunks,
})
except Exception as e:
import traceback
error_msg = f"{str(e)}\n{traceback.format_exc()}"
self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str:
"""创建Zep图谱公开方法"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
)
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""设置图谱本体提示Graphiti自动提取实体本体作为提示存储"""
self.client.graph.set_ontology(
graph_ids=[graph_id],
entities=ontology.get("entity_types"),
edges=ontology.get("edge_types"),
)
def add_text_batches(
self,
graph_id: str,
chunks: List[str],
batch_size: int = 3,
progress_callback: Optional[Callable] = None,
skip_chunks: int = 0,
) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表。
skip_chunks: 跳过已处理的块数(用于断点续传)。"""
episode_uuids = []
total_chunks = len(chunks)
for i in range(skip_chunks, total_chunks, batch_size):
batch_chunks = chunks[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_chunks + batch_size - 1) // batch_size
if progress_callback:
progress = (i + len(batch_chunks)) / total_chunks
progress_callback(
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
progress
)
# 构建episode数据
episodes = [
type('Episode', (), {'data': chunk, 'type': 'text'})()
for chunk in batch_chunks
]
# 发送到Zep
try:
batch_result = self.client.graph.add_batch(
graph_id=graph_id,
episodes=episodes
)
# 收集返回的 episode uuid
if batch_result and isinstance(batch_result, list):
for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid:
episode_uuids.append(ep_uuid)
# 避免请求过快
time.sleep(1)
except Exception as e:
if progress_callback:
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
raise
return episode_uuids
def _wait_for_episodes(
self,
episode_uuids: List[str],
progress_callback: Optional[Callable] = None,
timeout: int = 600
):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
if not episode_uuids:
if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0)
return
start_time = time.time()
pending_episodes = set(episode_uuids)
completed_count = 0
total_episodes = len(episode_uuids)
if progress_callback:
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
while pending_episodes:
if time.time() - start_time > timeout:
if progress_callback:
progress_callback(
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
completed_count / total_episodes
)
break
# 检查每个 episode 的处理状态
for ep_uuid in list(pending_episodes):
try:
episode = self.client.graph.episode.get(uuid_=ep_uuid)
is_processed = getattr(episode, 'processed', False)
if is_processed:
pending_episodes.remove(ep_uuid)
completed_count += 1
except Exception as e:
# 忽略单个查询错误,继续
pass
elapsed = int(time.time() - start_time)
if progress_callback:
progress_callback(
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
completed_count / total_episodes if total_episodes > 0 else 0
)
if pending_episodes:
time.sleep(3) # 每3秒检查一次
if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息"""
# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id)
# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id)
# 统计实体类型
entity_types = set()
for node in nodes:
if node.labels:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
node_count=len(nodes),
edge_count=len(edges),
entity_types=list(entity_types)
)
def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]:
"""
获取完整图谱数据(包含详细信息)
Args:
graph_id: 图谱ID
Returns:
包含nodes和edges的字典包括时间信息、属性等详细数据
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# 创建节点映射用于获取节点名称
node_map = {}
for node in nodes:
node_map[node.uuid_] = node.name or ""
nodes_data = []
for node in nodes:
# 获取创建时间
created_at = getattr(node, 'created_at', None)
if created_at:
created_at = str(created_at)
entity_type = _classify_entity_type(node.name, node.summary or "", ontology)
labels = node.labels or []
if entity_type != "Entity" and entity_type not in labels:
labels = [entity_type] + [l for l in labels if l != "Entity"]
nodes_data.append({
"uuid": node.uuid_,
"name": node.name,
"labels": labels,
"summary": node.summary or "",
"attributes": node.attributes or {},
"created_at": created_at,
})
edges_data = []
for edge in edges:
# 获取时间信息
created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None)
# 获取 episodes
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list):
episodes = [str(episodes)]
elif episodes:
episodes = [str(e) for e in episodes]
# 获取 fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({
"uuid": edge.uuid_,
"name": edge.name or "",
"fact": edge.fact or "",
"fact_type": fact_type,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"source_node_name": node_map.get(edge.source_node_uuid, ""),
"target_node_name": node_map.get(edge.target_node_uuid, ""),
"attributes": edge.attributes or {},
"created_at": str(created_at) if created_at else None,
"valid_at": str(valid_at) if valid_at else None,
"invalid_at": str(invalid_at) if invalid_at else None,
"expired_at": str(expired_at) if expired_at else None,
"episodes": episodes or [],
})
return {
"graph_id": graph_id,
"nodes": nodes_data,
"edges": edges_data,
"node_count": len(nodes_data),
"edge_count": len(edges_data),
}
def delete_graph(self, graph_id: str):
"""删除图谱"""
self.client.graph.delete(graph_id=graph_id)