docs(i18n): translate chinese docstrings/comments in backend/app/{models,utils} and partial services

This commit is contained in:
Dominik Seemann 2026-05-07 14:44:08 +00:00
parent 74997fd088
commit e3f7defefc
13 changed files with 464 additions and 518 deletions

View File

@ -1,6 +1,4 @@
""" """Data model package."""
数据模型模块
"""
from .task import TaskManager, TaskStatus from .task import TaskManager, TaskStatus
from .project import Project, ProjectStatus, ProjectManager from .project import Project, ProjectStatus, ProjectManager

View File

@ -1,6 +1,7 @@
""" """Project context management.
项目上下文管理
用于在服务端持久化项目状态避免前端在接口间传递大量数据 Persists project state on the server so the frontend does not have to round-trip
large blobs of context between API calls.
""" """
import os import os
@ -15,45 +16,45 @@ from ..config import Config
class ProjectStatus(str, Enum): class ProjectStatus(str, Enum):
"""项目状态""" """Project lifecycle status."""
CREATED = "created" # 刚创建,文件已上传 CREATED = "created" # just created, files uploaded
ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成 ONTOLOGY_GENERATED = "ontology_generated" # ontology has been generated
GRAPH_BUILDING = "graph_building" # 图谱构建中 GRAPH_BUILDING = "graph_building" # graph build in progress
GRAPH_COMPLETED = "graph_completed" # 图谱构建完成 GRAPH_COMPLETED = "graph_completed" # graph build finished
FAILED = "failed" # 失败 FAILED = "failed" # build failed
@dataclass @dataclass
class Project: class Project:
"""项目数据模型""" """Project data model."""
project_id: str project_id: str
name: str name: str
status: ProjectStatus status: ProjectStatus
created_at: str created_at: str
updated_at: str updated_at: str
# 文件信息 # File information
files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}] files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}]
total_text_length: int = 0 total_text_length: int = 0
# 本体信息接口1生成后填充 # Ontology information (filled in after step 1 generates it)
ontology: Optional[Dict[str, Any]] = None ontology: Optional[Dict[str, Any]] = None
analysis_summary: Optional[str] = None analysis_summary: Optional[str] = None
# 图谱信息接口2完成后填充 # Graph information (filled in after step 2 finishes)
graph_id: Optional[str] = None graph_id: Optional[str] = None
graph_build_task_id: Optional[str] = None graph_build_task_id: Optional[str] = None
# 配置 # Configuration
simulation_requirement: Optional[str] = None simulation_requirement: Optional[str] = None
chunk_size: int = 500 chunk_size: int = 500
chunk_overlap: int = 50 chunk_overlap: int = 50
# 错误信息 # Error message when status == FAILED
error: Optional[str] = None error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Serialize the project to a JSON-friendly dict."""
return { return {
"project_id": self.project_id, "project_id": self.project_id,
"name": self.name, "name": self.name,
@ -71,14 +72,14 @@ class Project:
"chunk_overlap": self.chunk_overlap, "chunk_overlap": self.chunk_overlap,
"error": self.error "error": self.error
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Project': def from_dict(cls, data: Dict[str, Any]) -> 'Project':
"""从字典创建""" """Reconstruct a project from its serialized dict."""
status = data.get('status', 'created') status = data.get('status', 'created')
if isinstance(status, str): if isinstance(status, str):
status = ProjectStatus(status) status = ProjectStatus(status)
return cls( return cls(
project_id=data['project_id'], project_id=data['project_id'],
name=data.get('name', 'Unnamed Project'), name=data.get('name', 'Unnamed Project'),
@ -99,52 +100,51 @@ class Project:
class ProjectManager: class ProjectManager:
"""项目管理器 - 负责项目的持久化存储和检索""" """Project manager: handles persistence and retrieval of projects on disk."""
# 项目存储根目录 # Root directory for project storage
PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects') PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects')
@classmethod @classmethod
def _ensure_projects_dir(cls): def _ensure_projects_dir(cls):
"""确保项目目录存在""" """Ensure the projects root directory exists."""
os.makedirs(cls.PROJECTS_DIR, exist_ok=True) os.makedirs(cls.PROJECTS_DIR, exist_ok=True)
@classmethod @classmethod
def _get_project_dir(cls, project_id: str) -> str: def _get_project_dir(cls, project_id: str) -> str:
"""获取项目目录路径""" """Return the on-disk directory for a project."""
return os.path.join(cls.PROJECTS_DIR, project_id) return os.path.join(cls.PROJECTS_DIR, project_id)
@classmethod @classmethod
def _get_project_meta_path(cls, project_id: str) -> str: def _get_project_meta_path(cls, project_id: str) -> str:
"""获取项目元数据文件路径""" """Return the path to a project's metadata JSON file."""
return os.path.join(cls._get_project_dir(project_id), 'project.json') return os.path.join(cls._get_project_dir(project_id), 'project.json')
@classmethod @classmethod
def _get_project_files_dir(cls, project_id: str) -> str: def _get_project_files_dir(cls, project_id: str) -> str:
"""获取项目文件存储目录""" """Return the directory where project source files are stored."""
return os.path.join(cls._get_project_dir(project_id), 'files') return os.path.join(cls._get_project_dir(project_id), 'files')
@classmethod @classmethod
def _get_project_text_path(cls, project_id: str) -> str: def _get_project_text_path(cls, project_id: str) -> str:
"""获取项目提取文本存储路径""" """Return the path to a project's extracted text file."""
return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt') return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt')
@classmethod @classmethod
def create_project(cls, name: str = "Unnamed Project") -> Project: def create_project(cls, name: str = "Unnamed Project") -> Project:
""" """Create a new project.
创建新项目
Args: Args:
name: 项目名称 name: Display name for the project.
Returns: Returns:
新创建的Project对象 The newly created ``Project`` instance.
""" """
cls._ensure_projects_dir() cls._ensure_projects_dir()
project_id = f"proj_{uuid.uuid4().hex[:12]}" project_id = f"proj_{uuid.uuid4().hex[:12]}"
now = datetime.now().isoformat() now = datetime.now().isoformat()
project = Project( project = Project(
project_id=project_id, project_id=project_id,
name=name, name=name,
@ -152,154 +152,147 @@ class ProjectManager:
created_at=now, created_at=now,
updated_at=now updated_at=now
) )
# 创建项目目录结构 # Create the on-disk project directory layout
project_dir = cls._get_project_dir(project_id) project_dir = cls._get_project_dir(project_id)
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
os.makedirs(project_dir, exist_ok=True) os.makedirs(project_dir, exist_ok=True)
os.makedirs(files_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True)
# 保存项目元数据 # Persist project metadata
cls.save_project(project) cls.save_project(project)
return project return project
@classmethod @classmethod
def save_project(cls, project: Project) -> None: def save_project(cls, project: Project) -> None:
"""保存项目元数据""" """Persist project metadata to disk."""
project.updated_at = datetime.now().isoformat() project.updated_at = datetime.now().isoformat()
meta_path = cls._get_project_meta_path(project.project_id) meta_path = cls._get_project_meta_path(project.project_id)
with open(meta_path, 'w', encoding='utf-8') as f: with open(meta_path, 'w', encoding='utf-8') as f:
json.dump(project.to_dict(), f, ensure_ascii=False, indent=2) json.dump(project.to_dict(), f, ensure_ascii=False, indent=2)
@classmethod @classmethod
def get_project(cls, project_id: str) -> Optional[Project]: def get_project(cls, project_id: str) -> Optional[Project]:
""" """Load a project by id.
获取项目
Args: Args:
project_id: 项目ID project_id: Project identifier.
Returns: Returns:
Project对象如果不存在返回None The ``Project`` if it exists, otherwise ``None``.
""" """
meta_path = cls._get_project_meta_path(project_id) meta_path = cls._get_project_meta_path(project_id)
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
return None return None
with open(meta_path, 'r', encoding='utf-8') as f: with open(meta_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
return Project.from_dict(data) return Project.from_dict(data)
@classmethod @classmethod
def list_projects(cls, limit: int = 50) -> List[Project]: def list_projects(cls, limit: int = 50) -> List[Project]:
""" """List existing projects, newest first.
列出所有项目
Args: Args:
limit: 返回数量限制 limit: Maximum number of projects to return.
Returns: Returns:
项目列表按创建时间倒序 Projects ordered by ``created_at`` descending.
""" """
cls._ensure_projects_dir() cls._ensure_projects_dir()
projects = [] projects = []
for project_id in os.listdir(cls.PROJECTS_DIR): for project_id in os.listdir(cls.PROJECTS_DIR):
project = cls.get_project(project_id) project = cls.get_project(project_id)
if project: if project:
projects.append(project) projects.append(project)
# 按创建时间倒序排序
projects.sort(key=lambda p: p.created_at, reverse=True) projects.sort(key=lambda p: p.created_at, reverse=True)
return projects[:limit] return projects[:limit]
@classmethod @classmethod
def delete_project(cls, project_id: str) -> bool: def delete_project(cls, project_id: str) -> bool:
""" """Delete a project and all of its files.
删除项目及其所有文件
Args: Args:
project_id: 项目ID project_id: Project identifier.
Returns: Returns:
是否删除成功 ``True`` if the project existed and was removed, ``False`` otherwise.
""" """
project_dir = cls._get_project_dir(project_id) project_dir = cls._get_project_dir(project_id)
if not os.path.exists(project_dir): if not os.path.exists(project_dir):
return False return False
shutil.rmtree(project_dir) shutil.rmtree(project_dir)
return True return True
@classmethod @classmethod
def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]: def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]:
""" """Save an uploaded file under the project's files directory.
保存上传的文件到项目目录
Args: Args:
project_id: 项目ID project_id: Project identifier.
file_storage: Flask的FileStorage对象 file_storage: Flask ``FileStorage`` object from the request.
original_filename: 原始文件名 original_filename: The user-supplied filename.
Returns: Returns:
文件信息字典 {filename, path, size} Dict describing the saved file: ``{original_filename, saved_filename, path, size}``.
""" """
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
os.makedirs(files_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True)
# 生成安全的文件名 # Generate a safe randomized filename to avoid collisions
ext = os.path.splitext(original_filename)[1].lower() ext = os.path.splitext(original_filename)[1].lower()
safe_filename = f"{uuid.uuid4().hex[:8]}{ext}" safe_filename = f"{uuid.uuid4().hex[:8]}{ext}"
file_path = os.path.join(files_dir, safe_filename) file_path = os.path.join(files_dir, safe_filename)
# 保存文件
file_storage.save(file_path) file_storage.save(file_path)
# 获取文件大小
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
return { return {
"original_filename": original_filename, "original_filename": original_filename,
"saved_filename": safe_filename, "saved_filename": safe_filename,
"path": file_path, "path": file_path,
"size": file_size "size": file_size
} }
@classmethod @classmethod
def save_extracted_text(cls, project_id: str, text: str) -> None: def save_extracted_text(cls, project_id: str, text: str) -> None:
"""保存提取的文本""" """Persist the project's extracted full text to disk."""
text_path = cls._get_project_text_path(project_id) text_path = cls._get_project_text_path(project_id)
with open(text_path, 'w', encoding='utf-8') as f: with open(text_path, 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
@classmethod @classmethod
def get_extracted_text(cls, project_id: str) -> Optional[str]: def get_extracted_text(cls, project_id: str) -> Optional[str]:
"""获取提取的文本""" """Read back the project's extracted full text, or ``None`` if absent."""
text_path = cls._get_project_text_path(project_id) text_path = cls._get_project_text_path(project_id)
if not os.path.exists(text_path): if not os.path.exists(text_path):
return None return None
with open(text_path, 'r', encoding='utf-8') as f: with open(text_path, 'r', encoding='utf-8') as f:
return f.read() return f.read()
@classmethod @classmethod
def get_project_files(cls, project_id: str) -> List[str]: def get_project_files(cls, project_id: str) -> List[str]:
"""获取项目的所有文件路径""" """Return the on-disk paths of all files in the project."""
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
if not os.path.exists(files_dir): if not os.path.exists(files_dir):
return [] return []
return [ return [
os.path.join(files_dir, f) os.path.join(files_dir, f)
for f in os.listdir(files_dir) for f in os.listdir(files_dir)
if os.path.isfile(os.path.join(files_dir, f)) if os.path.isfile(os.path.join(files_dir, f))
] ]

View File

@ -1,6 +1,6 @@
""" """Task state management.
任务状态管理
用于跟踪长时间运行的任务如图谱构建 Tracks long-running tasks (e.g. graph build) so callers can poll progress.
""" """
import uuid import uuid
@ -14,30 +14,30 @@ from ..utils.locale import t
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
"""任务状态枚举""" """Task status enum."""
PENDING = "pending" # 等待中 PENDING = "pending" # waiting
PROCESSING = "processing" # 处理中 PROCESSING = "processing" # in progress
COMPLETED = "completed" # 已完成 COMPLETED = "completed" # finished successfully
FAILED = "failed" # 失败 FAILED = "failed" # finished with error
@dataclass @dataclass
class Task: class Task:
"""任务数据类""" """Task data class."""
task_id: str task_id: str
task_type: str task_type: str
status: TaskStatus status: TaskStatus
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
progress: int = 0 # 总进度百分比 0-100 progress: int = 0 # overall progress percentage 0-100
message: str = "" # 状态消息 message: str = "" # human-readable status message
result: Optional[Dict] = None # 任务结果 result: Optional[Dict] = None # task result payload
error: Optional[str] = None # 错误信息 error: Optional[str] = None # error message when failed
metadata: Dict = field(default_factory=dict) # 额外元数据 metadata: Dict = field(default_factory=dict) # arbitrary caller metadata
progress_detail: Dict = field(default_factory=dict) # 详细进度信息 progress_detail: Dict = field(default_factory=dict) # fine-grained progress info
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Serialize the task to a JSON-friendly dict."""
return { return {
"task_id": self.task_id, "task_id": self.task_id,
"task_type": self.task_type, "task_type": self.task_type,
@ -54,16 +54,12 @@ class Task:
class TaskManager: class TaskManager:
""" """Thread-safe singleton task registry."""
任务管理器
线程安全的任务状态管理
"""
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls): def __new__(cls):
"""单例模式"""
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
@ -71,21 +67,20 @@ class TaskManager:
cls._instance._tasks: Dict[str, Task] = {} cls._instance._tasks: Dict[str, Task] = {}
cls._instance._task_lock = threading.Lock() cls._instance._task_lock = threading.Lock()
return cls._instance return cls._instance
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
""" """Create a new task.
创建新任务
Args: Args:
task_type: 任务类型 task_type: Task type identifier.
metadata: 额外元数据 metadata: Optional caller-supplied metadata.
Returns: Returns:
任务ID The newly created task id.
""" """
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
now = datetime.now() now = datetime.now()
task = Task( task = Task(
task_id=task_id, task_id=task_id,
task_type=task_type, task_type=task_type,
@ -94,17 +89,17 @@ class TaskManager:
updated_at=now, updated_at=now,
metadata=metadata or {} metadata=metadata or {}
) )
with self._task_lock: with self._task_lock:
self._tasks[task_id] = task self._tasks[task_id] = task
return task_id return task_id
def get_task(self, task_id: str) -> Optional[Task]: def get_task(self, task_id: str) -> Optional[Task]:
"""获取任务""" """Return the task for ``task_id`` or ``None`` if unknown."""
with self._task_lock: with self._task_lock:
return self._tasks.get(task_id) return self._tasks.get(task_id)
def update_task( def update_task(
self, self,
task_id: str, task_id: str,
@ -115,17 +110,16 @@ class TaskManager:
error: Optional[str] = None, error: Optional[str] = None,
progress_detail: Optional[Dict] = None progress_detail: Optional[Dict] = None
): ):
""" """Update mutable fields on an existing task.
更新任务状态
Args: Args:
task_id: 任务ID task_id: Task id to update.
status: 新状态 status: New status, if changing.
progress: 进度 progress: New overall progress (0-100), if changing.
message: 消息 message: New status message, if changing.
result: 结果 result: New result payload, if changing.
error: 错误信息 error: New error message, if changing.
progress_detail: 详细进度信息 progress_detail: New fine-grained progress info, if changing.
""" """
with self._task_lock: with self._task_lock:
task = self._tasks.get(task_id) task = self._tasks.get(task_id)
@ -143,9 +137,9 @@ class TaskManager:
task.error = error task.error = error
if progress_detail is not None: if progress_detail is not None:
task.progress_detail = progress_detail task.progress_detail = progress_detail
def complete_task(self, task_id: str, result: Dict): def complete_task(self, task_id: str, result: Dict):
"""标记任务完成""" """Mark a task as completed and attach the result."""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.COMPLETED, status=TaskStatus.COMPLETED,
@ -153,29 +147,29 @@ class TaskManager:
message=t('progress.taskComplete'), message=t('progress.taskComplete'),
result=result result=result
) )
def fail_task(self, task_id: str, error: str): def fail_task(self, task_id: str, error: str):
"""标记任务失败""" """Mark a task as failed and attach the error message."""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.FAILED, status=TaskStatus.FAILED,
message=t('progress.taskFailed'), message=t('progress.taskFailed'),
error=error error=error
) )
def list_tasks(self, task_type: Optional[str] = None) -> list: def list_tasks(self, task_type: Optional[str] = None) -> list:
"""列出任务""" """List tasks, optionally filtered by ``task_type``, newest first."""
with self._task_lock: with self._task_lock:
tasks = list(self._tasks.values()) tasks = list(self._tasks.values())
if task_type: if task_type:
tasks = [t for t in tasks if t.task_type == task_type] tasks = [t for t in tasks if t.task_type == task_type]
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
def cleanup_old_tasks(self, max_age_hours: int = 24): def cleanup_old_tasks(self, max_age_hours: int = 24):
"""清理旧任务""" """Drop completed/failed tasks older than ``max_age_hours``."""
from datetime import timedelta from datetime import timedelta
cutoff = datetime.now() - timedelta(hours=max_age_hours) cutoff = datetime.now() - timedelta(hours=max_age_hours)
with self._task_lock: with self._task_lock:
old_ids = [ old_ids = [
tid for tid, task in self._tasks.items() tid for tid, task in self._tasks.items()

View File

@ -1,6 +1,4 @@
""" """Business services package."""
业务服务模块
"""
from .ontology_generator import OntologyGenerator from .ontology_generator import OntologyGenerator
from .graph_builder import GraphBuilderService from .graph_builder import GraphBuilderService

View File

@ -1,6 +1,7 @@
""" """Graph build service.
图谱构建服务
接口2使用Zep API构建Standalone Graph Pipeline step 2: build the project's standalone knowledge graph through the
Zep/Graphiti API.
""" """
import os import os
@ -69,7 +70,7 @@ def _classify_entity_type(name: str, summary: str, ontology: Optional[Dict]) ->
@dataclass @dataclass
class GraphInfo: class GraphInfo:
"""图谱信息""" """Summary information about a built graph."""
graph_id: str graph_id: str
node_count: int node_count: int
edge_count: int edge_count: int
@ -85,10 +86,7 @@ class GraphInfo:
class GraphBuilderService: class GraphBuilderService:
""" """Drives knowledge-graph construction via the Zep/Graphiti API."""
图谱构建服务
负责调用Zep API构建知识图谱
"""
def __init__(self, api_key: Optional[str] = None): def __init__(self, api_key: Optional[str] = None):
self.client = GraphitiAdapter() self.client = GraphitiAdapter()
@ -103,21 +101,20 @@ class GraphBuilderService:
chunk_overlap: int = 50, chunk_overlap: int = 50,
batch_size: int = 3 batch_size: int = 3
) -> str: ) -> str:
""" """Kick off a graph build asynchronously.
异步构建图谱
Args: Args:
text: 输入文本 text: Source text to ingest.
ontology: 本体定义来自接口1的输出 ontology: Ontology definition (the output of pipeline step 1).
graph_name: 图谱名称 graph_name: Display name for the graph.
chunk_size: 文本块大小 chunk_size: Characters per text chunk.
chunk_overlap: 块重叠大小 chunk_overlap: Overlap (in characters) between consecutive chunks.
batch_size: 每批发送的块数量 batch_size: Number of chunks pushed to Zep per batch.
Returns: Returns:
任务ID The id of the task tracking the build.
""" """
# 创建任务 # Register a task to track build progress.
task_id = self.task_manager.create_task( task_id = self.task_manager.create_task(
task_type="graph_build", task_type="graph_build",
metadata={ metadata={
@ -130,7 +127,7 @@ class GraphBuilderService:
# Capture locale before spawning background thread # Capture locale before spawning background thread
current_locale = get_locale() current_locale = get_locale()
# 在后台线程中执行构建 # Run the build on a background thread so the request returns immediately.
thread = threading.Thread( thread = threading.Thread(
target=self._build_graph_worker, target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
@ -151,7 +148,7 @@ class GraphBuilderService:
batch_size: int, batch_size: int,
locale: str = 'zh' locale: str = 'zh'
): ):
"""图谱构建工作线程""" """Background worker that performs the graph build."""
set_locale(locale) set_locale(locale)
try: try:
self.task_manager.update_task( self.task_manager.update_task(
@ -161,7 +158,7 @@ class GraphBuilderService:
message=t('progress.startBuildingGraph') message=t('progress.startBuildingGraph')
) )
# 1. 创建图谱 # 1. Create the graph.
graph_id = self.create_graph(graph_name) graph_id = self.create_graph(graph_name)
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
@ -169,7 +166,7 @@ class GraphBuilderService:
message=t('progress.graphCreated', graphId=graph_id) message=t('progress.graphCreated', graphId=graph_id)
) )
# 2. 设置本体 # 2. Set the ontology.
self.set_ontology(graph_id, ontology) self.set_ontology(graph_id, ontology)
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
@ -177,7 +174,7 @@ class GraphBuilderService:
message=t('progress.ontologySet') message=t('progress.ontologySet')
) )
# 3. 文本分块 # 3. Split source text into chunks.
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks) total_chunks = len(chunks)
self.task_manager.update_task( self.task_manager.update_task(
@ -186,7 +183,7 @@ class GraphBuilderService:
message=t('progress.textSplit', count=total_chunks) message=t('progress.textSplit', count=total_chunks)
) )
# 4. 分批发送数据 # 4. Push chunks to the graph in batches.
episode_uuids = self.add_text_batches( episode_uuids = self.add_text_batches(
graph_id, chunks, batch_size, graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task( lambda msg, prog: self.task_manager.update_task(
@ -196,7 +193,7 @@ class GraphBuilderService:
) )
) )
# 5. 等待Zep处理完成 # 5. Wait for Zep to finish processing the episodes.
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
progress=60, progress=60,
@ -212,7 +209,7 @@ class GraphBuilderService:
) )
) )
# 6. 获取图谱信息 # 6. Fetch the final graph metadata.
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
progress=90, progress=90,
@ -220,8 +217,7 @@ class GraphBuilderService:
) )
graph_info = self._get_graph_info(graph_id) graph_info = self._get_graph_info(graph_id)
# 完成
self.task_manager.complete_task(task_id, { self.task_manager.complete_task(task_id, {
"graph_id": graph_id, "graph_id": graph_id,
"graph_info": graph_info.to_dict(), "graph_info": graph_info.to_dict(),
@ -234,7 +230,7 @@ class GraphBuilderService:
self.task_manager.fail_task(task_id, error_msg) self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str: def create_graph(self, name: str) -> str:
"""创建Zep图谱公开方法""" """Create a new Zep graph and return its id (public API)."""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create( self.client.graph.create(
@ -246,7 +242,7 @@ class GraphBuilderService:
return graph_id return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""设置图谱本体提示Graphiti自动提取实体本体作为提示存储""" """Register the ontology with the graph (Graphiti uses it as an extraction prompt)."""
self.client.graph.set_ontology( self.client.graph.set_ontology(
graph_ids=[graph_id], graph_ids=[graph_id],
entities=ontology.get("entity_types"), entities=ontology.get("entity_types"),
@ -261,8 +257,11 @@ class GraphBuilderService:
progress_callback: Optional[Callable] = None, progress_callback: Optional[Callable] = None,
skip_chunks: int = 0, skip_chunks: int = 0,
) -> List[str]: ) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表。 """Push chunks to the graph in batches; returns the uuids of all episodes added.
skip_chunks: 跳过已处理的块数用于断点续传"""
Args:
skip_chunks: Number of chunks to skip (used for resume-after-restart).
"""
episode_uuids = [] episode_uuids = []
total_chunks = len(chunks) total_chunks = len(chunks)
@ -279,27 +278,26 @@ class GraphBuilderService:
) )
# 构建episode数据 # Build the per-episode payload structures expected by the client.
episodes = [ episodes = [
type('Episode', (), {'data': chunk, 'type': 'text'})() type('Episode', (), {'data': chunk, 'type': 'text'})()
for chunk in batch_chunks for chunk in batch_chunks
] ]
# 发送到Zep
try: try:
batch_result = self.client.graph.add_batch( batch_result = self.client.graph.add_batch(
graph_id=graph_id, graph_id=graph_id,
episodes=episodes episodes=episodes
) )
# 收集返回的 episode uuid # Collect the uuids returned for each episode.
if batch_result and isinstance(batch_result, list): if batch_result and isinstance(batch_result, list):
for ep in batch_result: for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid: if ep_uuid:
episode_uuids.append(ep_uuid) episode_uuids.append(ep_uuid)
# 避免请求过快 # Throttle to avoid overwhelming the upstream API.
time.sleep(1) time.sleep(1)
except Exception as e: except Exception as e:
@ -315,7 +313,7 @@ class GraphBuilderService:
progress_callback: Optional[Callable] = None, progress_callback: Optional[Callable] = None,
timeout: int = 600 timeout: int = 600
): ):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" """Poll each episode until Zep marks it processed, or the timeout expires."""
if not episode_uuids: if not episode_uuids:
if progress_callback: if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0) progress_callback(t('progress.noEpisodesWait'), 1.0)
@ -338,18 +336,18 @@ class GraphBuilderService:
) )
break break
# 检查每个 episode 的处理状态 # Check the processing state of each pending episode.
for ep_uuid in list(pending_episodes): for ep_uuid in list(pending_episodes):
try: try:
episode = self.client.graph.episode.get(uuid_=ep_uuid) episode = self.client.graph.episode.get(uuid_=ep_uuid)
is_processed = getattr(episode, 'processed', False) is_processed = getattr(episode, 'processed', False)
if is_processed: if is_processed:
pending_episodes.remove(ep_uuid) pending_episodes.remove(ep_uuid)
completed_count += 1 completed_count += 1
except Exception as e: except Exception as e:
# 忽略单个查询错误,继续 # Tolerate a single failed query; the next loop iteration retries.
pass pass
elapsed = int(time.time() - start_time) elapsed = int(time.time() - start_time)
@ -360,20 +358,17 @@ class GraphBuilderService:
) )
if pending_episodes: if pending_episodes:
time.sleep(3) # 每3秒检查一次 time.sleep(3) # poll every 3 seconds
if progress_callback: if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo: def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息""" """Fetch summary info (counts and entity types) for a graph."""
# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
# 统计实体类型 # Tally distinct entity types across all nodes.
entity_types = set() entity_types = set()
for node in nodes: for node in nodes:
if node.labels: if node.labels:
@ -389,26 +384,24 @@ class GraphBuilderService:
) )
def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]: def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]:
""" """Return the full graph payload including timestamps, attributes, and edges.
获取完整图谱数据包含详细信息
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
Returns: Returns:
包含nodes和edges的字典包括时间信息属性等详细数据 Dict with ``nodes``, ``edges``, and aggregate counts.
""" """
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
# 创建节点映射用于获取节点名称 # Build a uuid->name map so edge endpoints can be labeled.
node_map = {} node_map = {}
for node in nodes: for node in nodes:
node_map[node.uuid_] = node.name or "" node_map[node.uuid_] = node.name or ""
nodes_data = [] nodes_data = []
for node in nodes: for node in nodes:
# 获取创建时间
created_at = getattr(node, 'created_at', None) created_at = getattr(node, 'created_at', None)
if created_at: if created_at:
created_at = str(created_at) created_at = str(created_at)
@ -429,20 +422,18 @@ class GraphBuilderService:
edges_data = [] edges_data = []
for edge in edges: for edge in edges:
# 获取时间信息
created_at = getattr(edge, 'created_at', None) created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None) valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None) invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None) expired_at = getattr(edge, 'expired_at', None)
# 获取 episodes # Normalize the episode list (the field may be missing or a single id).
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list): if episodes and not isinstance(episodes, list):
episodes = [str(episodes)] episodes = [str(episodes)]
elif episodes: elif episodes:
episodes = [str(e) for e in episodes] episodes = [str(e) for e in episodes]
# 获取 fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or "" fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({ edges_data.append({
@ -471,6 +462,6 @@ class GraphBuilderService:
} }
def delete_graph(self, graph_id: str): def delete_graph(self, graph_id: str):
"""删除图谱""" """Delete a graph by id."""
self.client.graph.delete(graph_id=graph_id) self.client.graph.delete(graph_id=graph_id)

View File

@ -1,6 +1,7 @@
""" """Ontology generation service.
本体生成服务
接口1分析文本内容生成适合社会模拟的实体和关系类型定义 Pipeline step 1: analyze the source text and propose entity and relationship
types that fit a social-media opinion simulation.
""" """
import json import json
@ -14,19 +15,19 @@ logger = logging.getLogger(__name__)
def _to_pascal_case(name: str) -> str: def _to_pascal_case(name: str) -> str:
"""将任意格式的名称转换为 PascalCase'works_for' -> 'WorksFor', 'person' -> 'Person'""" """Convert an arbitrary identifier to PascalCase (e.g. ``works_for`` -> ``WorksFor``)."""
# 按非字母数字字符分割 # Split on non-alphanumeric separators first.
parts = re.split(r'[^a-zA-Z0-9]+', name) parts = re.split(r'[^a-zA-Z0-9]+', name)
# 再按 camelCase 边界分割(如 'camelCase' -> ['camel', 'Case'] # Then split on camelCase boundaries (e.g. ``camelCase`` -> ``['camel', 'Case']``).
words = [] words = []
for part in parts: for part in parts:
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_')) words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
# 每个词首字母大写,过滤空串 # Title-case each non-empty word and concatenate.
result = ''.join(word.capitalize() for word in words if word) result = ''.join(word.capitalize() for word in words if word)
return result if result else 'Unknown' return result if result else 'Unknown'
# 本体生成的系统提示词 # System prompt template for ontology generation.
ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。 ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。
**重要你必须输出有效的JSON格式数据不要输出任何其他内容** **重要你必须输出有效的JSON格式数据不要输出任何其他内容**
@ -174,10 +175,7 @@ B. **具体类型8个根据文本内容设计**
class OntologyGenerator: class OntologyGenerator:
""" """Generate an entity- and edge-type ontology from arbitrary input text."""
本体生成器
分析文本内容生成实体和关系类型定义
"""
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm_client = llm_client or LLMClient() self.llm_client = llm_client or LLMClient()
@ -188,18 +186,17 @@ class OntologyGenerator:
simulation_requirement: str, simulation_requirement: str,
additional_context: Optional[str] = None additional_context: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Generate an ontology definition.
生成本体定义
Args: Args:
document_texts: 文档文本列表 document_texts: Source document text segments.
simulation_requirement: 模拟需求描述 simulation_requirement: Description of the simulation goal.
additional_context: 额外上下文 additional_context: Optional supplemental context.
Returns: Returns:
本体定义entity_types, edge_types等 The ontology dict with ``entity_types``, ``edge_types``, and a summary.
""" """
# 构建用户消息 # Compose the user message that frames the LLM request.
user_message = self._build_user_message( user_message = self._build_user_message(
document_texts, document_texts,
simulation_requirement, simulation_requirement,
@ -213,19 +210,19 @@ class OntologyGenerator:
{"role": "user", "content": user_message} {"role": "user", "content": user_message}
] ]
# 调用LLM # Invoke the LLM.
result = self.llm_client.chat_json( result = self.llm_client.chat_json(
messages=messages, messages=messages,
temperature=0.3, temperature=0.3,
max_tokens=4096 max_tokens=4096
) )
# 验证和后处理 # Validate the LLM response and post-process it.
result = self._validate_and_process(result) result = self._validate_and_process(result)
return result return result
# 传给 LLM 的文本最大长度5万字 # Maximum length of source text passed to the LLM (50k characters).
MAX_TEXT_LENGTH_FOR_LLM = 50000 MAX_TEXT_LENGTH_FOR_LLM = 50000
def _build_user_message( def _build_user_message(
@ -234,13 +231,14 @@ class OntologyGenerator:
simulation_requirement: str, simulation_requirement: str,
additional_context: Optional[str] additional_context: Optional[str]
) -> str: ) -> str:
"""构建用户消息""" """Build the user-message string for the ontology LLM call."""
# 合并文本 # Concatenate the source documents into a single string.
combined_text = "\n\n---\n\n".join(document_texts) combined_text = "\n\n---\n\n".join(document_texts)
original_length = len(combined_text) original_length = len(combined_text)
# 如果文本超过5万字截断仅影响传给LLM的内容不影响图谱构建 # If the combined text exceeds the LLM input cap, truncate it for the
# LLM call only. The full text is still used for graph construction.
if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM: if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM:
combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM] combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM]
combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..." combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..."
@ -275,9 +273,9 @@ class OntologyGenerator:
return message return message
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""验证和后处理结果""" """Validate and post-process the LLM-generated ontology dict."""
# 确保必要字段存在 # Ensure required top-level fields exist.
if "entity_types" not in result: if "entity_types" not in result:
result["entity_types"] = [] result["entity_types"] = []
if "edge_types" not in result: if "edge_types" not in result:
@ -285,11 +283,12 @@ class OntologyGenerator:
if "analysis_summary" not in result: if "analysis_summary" not in result:
result["analysis_summary"] = "" result["analysis_summary"] = ""
# 验证实体类型 # Validate entity types.
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用 # Track original-name -> PascalCase mapping so edge source_targets
# references can be fixed up consistently below.
entity_name_map = {} entity_name_map = {}
for entity in result["entity_types"]: for entity in result["entity_types"]:
# 强制将 entity name 转为 PascalCaseZep API 要求) # Force entity names to PascalCase (required by the Zep API).
if "name" in entity: if "name" in entity:
original_name = entity["name"] original_name = entity["name"]
entity["name"] = _to_pascal_case(original_name) entity["name"] = _to_pascal_case(original_name)
@ -300,19 +299,20 @@ class OntologyGenerator:
entity["attributes"] = [] entity["attributes"] = []
if "examples" not in entity: if "examples" not in entity:
entity["examples"] = [] entity["examples"] = []
# 确保description不超过100字符 # Truncate descriptions longer than 100 characters.
if len(entity.get("description", "")) > 100: if len(entity.get("description", "")) > 100:
entity["description"] = entity["description"][:97] + "..." entity["description"] = entity["description"][:97] + "..."
# 验证关系类型 # Validate edge types.
for edge in result["edge_types"]: for edge in result["edge_types"]:
# 强制将 edge name 转为 SCREAMING_SNAKE_CASEZep API 要求) # Force edge names to SCREAMING_SNAKE_CASE (required by the Zep API).
if "name" in edge: if "name" in edge:
original_name = edge["name"] original_name = edge["name"]
edge["name"] = original_name.upper() edge["name"] = original_name.upper()
if edge["name"] != original_name: if edge["name"] != original_name:
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'") logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致 # Rewrite source_targets entity-name references to match the
# PascalCase-normalized entity names.
for st in edge.get("source_targets", []): for st in edge.get("source_targets", []):
if st.get("source") in entity_name_map: if st.get("source") in entity_name_map:
st["source"] = entity_name_map[st["source"]] st["source"] = entity_name_map[st["source"]]
@ -325,11 +325,11 @@ class OntologyGenerator:
if len(edge.get("description", "")) > 100: if len(edge.get("description", "")) > 100:
edge["description"] = edge["description"][:97] + "..." edge["description"] = edge["description"][:97] + "..."
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 # Zep API caps: at most 10 custom entity types and 10 custom edge types.
MAX_ENTITY_TYPES = 10 MAX_ENTITY_TYPES = 10
MAX_EDGE_TYPES = 10 MAX_EDGE_TYPES = 10
# 去重:按 name 去重,保留首次出现的 # Deduplicate by name, keeping the first occurrence.
seen_names = set() seen_names = set()
deduped = [] deduped = []
for entity in result["entity_types"]: for entity in result["entity_types"]:
@ -341,7 +341,7 @@ class OntologyGenerator:
logger.warning(f"Duplicate entity type '{name}' removed during validation") logger.warning(f"Duplicate entity type '{name}' removed during validation")
result["entity_types"] = deduped result["entity_types"] = deduped
# 兜底类型定义 # Fallback entity-type definitions used when the LLM omits them.
person_fallback = { person_fallback = {
"name": "Person", "name": "Person",
"description": "Any individual person not fitting other specific person types.", "description": "Any individual person not fitting other specific person types.",
@ -362,33 +362,31 @@ class OntologyGenerator:
"examples": ["small business", "community group"] "examples": ["small business", "community group"]
} }
# 检查是否已有兜底类型 # Check whether the fallback types are already present.
entity_names = {e["name"] for e in result["entity_types"]} entity_names = {e["name"] for e in result["entity_types"]}
has_person = "Person" in entity_names has_person = "Person" in entity_names
has_organization = "Organization" in entity_names has_organization = "Organization" in entity_names
# 需要添加的兜底类型 # Collect missing fallback types to add below.
fallbacks_to_add = [] fallbacks_to_add = []
if not has_person: if not has_person:
fallbacks_to_add.append(person_fallback) fallbacks_to_add.append(person_fallback)
if not has_organization: if not has_organization:
fallbacks_to_add.append(organization_fallback) fallbacks_to_add.append(organization_fallback)
if fallbacks_to_add: if fallbacks_to_add:
current_count = len(result["entity_types"]) current_count = len(result["entity_types"])
needed_slots = len(fallbacks_to_add) needed_slots = len(fallbacks_to_add)
# 如果添加后会超过 10 个,需要移除一些现有类型 # If adding the fallbacks would exceed the cap, drop some existing types.
if current_count + needed_slots > MAX_ENTITY_TYPES: if current_count + needed_slots > MAX_ENTITY_TYPES:
# 计算需要移除多少个
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
# 从末尾移除(保留前面更重要的具体类型) # Drop trailing types first; the more specific types come earlier.
result["entity_types"] = result["entity_types"][:-to_remove] result["entity_types"] = result["entity_types"][:-to_remove]
# 添加兜底类型
result["entity_types"].extend(fallbacks_to_add) result["entity_types"].extend(fallbacks_to_add)
# 最终确保不超过限制(防御性编程) # Defensive cap enforcement: hard-trim if anything slipped through.
if len(result["entity_types"]) > MAX_ENTITY_TYPES: if len(result["entity_types"]) > MAX_ENTITY_TYPES:
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES] result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
@ -398,14 +396,13 @@ class OntologyGenerator:
return result return result
def generate_python_code(self, ontology: Dict[str, Any]) -> str: def generate_python_code(self, ontology: Dict[str, Any]) -> str:
""" """Render the ontology definition as Python source code.
将本体定义转换为Python代码类似ontology.py
Args: Args:
ontology: 本体定义 ontology: Ontology definition dict.
Returns: Returns:
Python代码字符串 Python source code as a single string.
""" """
code_lines = [ code_lines = [
'"""', '"""',
@ -421,7 +418,7 @@ class OntologyGenerator:
'', '',
] ]
# 生成实体类型 # Emit each entity type as a Python class.
for entity in ontology.get("entity_types", []): for entity in ontology.get("entity_types", []):
name = entity["name"] name = entity["name"]
desc = entity.get("description", f"A {name} entity.") desc = entity.get("description", f"A {name} entity.")
@ -447,10 +444,10 @@ class OntologyGenerator:
code_lines.append('# ============== 关系类型定义 ==============') code_lines.append('# ============== 关系类型定义 ==============')
code_lines.append('') code_lines.append('')
# 生成关系类型 # Emit each edge type as a Python class.
for edge in ontology.get("edge_types", []): for edge in ontology.get("edge_types", []):
name = edge["name"] name = edge["name"]
# 转换为PascalCase类名 # Convert SCREAMING_SNAKE_CASE -> PascalCase for the class name.
class_name = ''.join(word.capitalize() for word in name.split('_')) class_name = ''.join(word.capitalize() for word in name.split('_'))
desc = edge.get("description", f"A {name} relationship.") desc = edge.get("description", f"A {name} relationship.")
@ -472,7 +469,7 @@ class OntologyGenerator:
code_lines.append('') code_lines.append('')
code_lines.append('') code_lines.append('')
# 生成类型字典 # Emit the type registries.
code_lines.append('# ============== 类型配置 ==============') code_lines.append('# ============== 类型配置 ==============')
code_lines.append('') code_lines.append('')
code_lines.append('ENTITY_TYPES = {') code_lines.append('ENTITY_TYPES = {')
@ -489,7 +486,7 @@ class OntologyGenerator:
code_lines.append('}') code_lines.append('}')
code_lines.append('') code_lines.append('')
# 生成边的source_targets映射 # Emit the edge source_targets map.
code_lines.append('EDGE_SOURCE_TARGETS = {') code_lines.append('EDGE_SOURCE_TARGETS = {')
for edge in ontology.get("edge_types", []): for edge in ontology.get("edge_types", []):
name = edge["name"] name = edge["name"]

View File

@ -1,68 +1,64 @@
""" """Text processing service."""
文本处理服务
"""
from typing import List, Optional from typing import List, Optional
from ..utils.file_parser import FileParser, split_text_into_chunks from ..utils.file_parser import FileParser, split_text_into_chunks
class TextProcessor: class TextProcessor:
"""文本处理器""" """Facade for the text-extraction and chunking pipeline."""
@staticmethod @staticmethod
def extract_from_files(file_paths: List[str]) -> str: def extract_from_files(file_paths: List[str]) -> str:
"""从多个文件提取文本""" """Extract and concatenate text from multiple files."""
return FileParser.extract_from_multiple(file_paths) return FileParser.extract_from_multiple(file_paths)
@staticmethod @staticmethod
def split_text( def split_text(
text: str, text: str,
chunk_size: int = 500, chunk_size: int = 500,
overlap: int = 50 overlap: int = 50
) -> List[str]: ) -> List[str]:
""" """Split text into chunks.
分割文本
Args: Args:
text: 原始文本 text: The source text.
chunk_size: 块大小 chunk_size: Target characters per chunk.
overlap: 重叠大小 overlap: Overlap between consecutive chunks.
Returns: Returns:
文本块列表 A list of chunk strings.
""" """
return split_text_into_chunks(text, chunk_size, overlap) return split_text_into_chunks(text, chunk_size, overlap)
@staticmethod @staticmethod
def preprocess_text(text: str) -> str: def preprocess_text(text: str) -> str:
""" """Pre-process text by normalizing whitespace and line endings.
预处理文本
- 移除多余空白 - Collapse runs of blank lines to at most two newlines.
- 标准化换行 - Normalize line endings to ``\\n``.
- Strip leading/trailing whitespace from each line.
Args: Args:
text: 原始文本 text: The source text.
Returns: Returns:
处理后的文本 The cleaned text.
""" """
import re import re
# 标准化换行
text = text.replace('\r\n', '\n').replace('\r', '\n') text = text.replace('\r\n', '\n').replace('\r', '\n')
# 移除连续空行(保留最多两个换行) # Collapse 3+ consecutive newlines down to a blank-line separator.
text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r'\n{3,}', '\n\n', text)
# 移除行首行尾空白
lines = [line.strip() for line in text.split('\n')] lines = [line.strip() for line in text.split('\n')]
text = '\n'.join(lines) text = '\n'.join(lines)
return text.strip() return text.strip()
@staticmethod @staticmethod
def get_text_stats(text: str) -> dict: def get_text_stats(text: str) -> dict:
"""获取文本统计信息""" """Return basic text statistics: total chars, lines, and words."""
return { return {
"total_chars": len(text), "total_chars": len(text),
"total_lines": text.count('\n') + 1, "total_lines": text.count('\n') + 1,

View File

@ -1,6 +1,4 @@
""" """Backend utilities package."""
工具模块
"""
from .file_parser import FileParser from .file_parser import FileParser
from .llm_client import LLMClient from .llm_client import LLMClient

View File

@ -1,6 +1,6 @@
""" """File parsing utilities.
文件解析工具
支持PDFMarkdownTXT文件的文本提取 Supports text extraction from PDF, Markdown, and plain-text files.
""" """
import os import os
@ -9,30 +9,27 @@ from typing import List, Optional
def _read_text_with_fallback(file_path: str) -> str: def _read_text_with_fallback(file_path: str) -> str:
""" """Read a text file, falling back through encoding detectors when UTF-8 fails.
读取文本文件UTF-8失败时自动探测编码
Multi-stage fallback strategy:
采用多级回退策略 1. Try UTF-8 first.
1. 首先尝试 UTF-8 解码 2. Use ``charset_normalizer`` to detect the encoding.
2. 使用 charset_normalizer 检测编码 3. Fall back to ``chardet``.
3. 回退到 chardet 检测编码 4. Last resort: decode with UTF-8 + ``errors='replace'``.
4. 最终使用 UTF-8 + errors='replace' 兜底
Args: Args:
file_path: 文件路径 file_path: Path to the file to read.
Returns: Returns:
解码后的文本内容 The decoded text content.
""" """
data = Path(file_path).read_bytes() data = Path(file_path).read_bytes()
# 首先尝试 UTF-8
try: try:
return data.decode('utf-8') return data.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
# 尝试使用 charset_normalizer 检测编码
encoding = None encoding = None
try: try:
from charset_normalizer import from_bytes from charset_normalizer import from_bytes
@ -41,8 +38,7 @@ def _read_text_with_fallback(file_path: str) -> str:
encoding = best.encoding encoding = best.encoding
except Exception: except Exception:
pass pass
# 回退到 chardet
if not encoding: if not encoding:
try: try:
import chardet import chardet
@ -50,89 +46,86 @@ def _read_text_with_fallback(file_path: str) -> str:
encoding = result.get('encoding') if result else None encoding = result.get('encoding') if result else None
except Exception: except Exception:
pass pass
# 最终兜底:使用 UTF-8 + replace
if not encoding: if not encoding:
encoding = 'utf-8' encoding = 'utf-8'
return data.decode(encoding, errors='replace') return data.decode(encoding, errors='replace')
class FileParser: class FileParser:
"""文件解析器""" """Parser for the supported document formats."""
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'} SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
@classmethod @classmethod
def extract_text(cls, file_path: str) -> str: def extract_text(cls, file_path: str) -> str:
""" """Extract plain text from a single supported file.
从文件中提取文本
Args: Args:
file_path: 文件路径 file_path: Path to the file.
Returns: Returns:
提取的文本内容 The extracted text content.
""" """
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}") raise FileNotFoundError(f"文件不存在: {file_path}")
suffix = path.suffix.lower() suffix = path.suffix.lower()
if suffix not in cls.SUPPORTED_EXTENSIONS: if suffix not in cls.SUPPORTED_EXTENSIONS:
raise ValueError(f"不支持的文件格式: {suffix}") raise ValueError(f"不支持的文件格式: {suffix}")
if suffix == '.pdf': if suffix == '.pdf':
return cls._extract_from_pdf(file_path) return cls._extract_from_pdf(file_path)
elif suffix in {'.md', '.markdown'}: elif suffix in {'.md', '.markdown'}:
return cls._extract_from_md(file_path) return cls._extract_from_md(file_path)
elif suffix == '.txt': elif suffix == '.txt':
return cls._extract_from_txt(file_path) return cls._extract_from_txt(file_path)
raise ValueError(f"无法处理的文件格式: {suffix}") raise ValueError(f"无法处理的文件格式: {suffix}")
@staticmethod @staticmethod
def _extract_from_pdf(file_path: str) -> str: def _extract_from_pdf(file_path: str) -> str:
"""从PDF提取文本""" """Extract text from a PDF file using PyMuPDF."""
try: try:
import fitz # PyMuPDF import fitz # PyMuPDF
except ImportError: except ImportError:
raise ImportError("需要安装PyMuPDF: pip install PyMuPDF") raise ImportError("需要安装PyMuPDF: pip install PyMuPDF")
text_parts = [] text_parts = []
with fitz.open(file_path) as doc: with fitz.open(file_path) as doc:
for page in doc: for page in doc:
text = page.get_text() text = page.get_text()
if text.strip(): if text.strip():
text_parts.append(text) text_parts.append(text)
return "\n\n".join(text_parts) return "\n\n".join(text_parts)
@staticmethod @staticmethod
def _extract_from_md(file_path: str) -> str: def _extract_from_md(file_path: str) -> str:
"""从Markdown提取文本支持自动编码检测""" """Extract text from a Markdown file with automatic encoding detection."""
return _read_text_with_fallback(file_path) return _read_text_with_fallback(file_path)
@staticmethod @staticmethod
def _extract_from_txt(file_path: str) -> str: def _extract_from_txt(file_path: str) -> str:
"""从TXT提取文本支持自动编码检测""" """Extract text from a plain-text file with automatic encoding detection."""
return _read_text_with_fallback(file_path) return _read_text_with_fallback(file_path)
@classmethod @classmethod
def extract_from_multiple(cls, file_paths: List[str]) -> str: def extract_from_multiple(cls, file_paths: List[str]) -> str:
""" """Extract and concatenate text from multiple files.
从多个文件提取文本并合并
Args: Args:
file_paths: 文件路径列表 file_paths: Paths of files to read.
Returns: Returns:
合并后的文本 The merged text, with per-file headers separating each section.
""" """
all_texts = [] all_texts = []
for i, file_path in enumerate(file_paths, 1): for i, file_path in enumerate(file_paths, 1):
try: try:
text = cls.extract_text(file_path) text = cls.extract_text(file_path)
@ -140,50 +133,48 @@ class FileParser:
all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}") all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}")
except Exception as e: except Exception as e:
all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===") all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===")
return "\n\n".join(all_texts) return "\n\n".join(all_texts)
def split_text_into_chunks( def split_text_into_chunks(
text: str, text: str,
chunk_size: int = 500, chunk_size: int = 500,
overlap: int = 50 overlap: int = 50
) -> List[str]: ) -> List[str]:
""" """Split text into overlapping chunks.
将文本分割成小块
Args: Args:
text: 原始文本 text: The source text to split.
chunk_size: 每块的字符数 chunk_size: Target characters per chunk.
overlap: 重叠字符数 overlap: Number of characters overlapping between consecutive chunks.
Returns: Returns:
文本块列表 A list of chunk strings.
""" """
if len(text) <= chunk_size: if len(text) <= chunk_size:
return [text] if text.strip() else [] return [text] if text.strip() else []
chunks = [] chunks = []
start = 0 start = 0
while start < len(text): while start < len(text):
end = start + chunk_size end = start + chunk_size
# 尝试在句子边界处分割 # Prefer splitting on a sentence boundary near the chunk end
if end < len(text): if end < len(text):
# 查找最近的句子结束符
for sep in ['', '', '', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']: for sep in ['', '', '', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
last_sep = text[start:end].rfind(sep) last_sep = text[start:end].rfind(sep)
if last_sep != -1 and last_sep > chunk_size * 0.3: if last_sep != -1 and last_sep > chunk_size * 0.3:
end = start + last_sep + len(sep) end = start + last_sep + len(sep)
break break
chunk = text[start:end].strip() chunk = text[start:end].strip()
if chunk: if chunk:
chunks.append(chunk) chunks.append(chunk)
# 下一个块从重叠位置开始 # Next chunk starts at the overlap point
start = end - overlap if end < len(text) else len(text) start = end - overlap if end < len(text) else len(text)
return chunks return chunks

View File

@ -1,6 +1,6 @@
""" """LLM client wrapper.
LLM客户端封装
统一使用OpenAI格式调用 All providers are called through the OpenAI-compatible API surface.
""" """
import json import json
@ -13,7 +13,7 @@ from ..config import Config
class LLMClient: class LLMClient:
"""LLM客户端""" """Thin wrapper around the OpenAI-compatible chat completions API."""
def __init__( def __init__(
self, self,
@ -37,17 +37,16 @@ class LLMClient:
max_tokens: int = 4096, max_tokens: int = 4096,
response_format: Optional[Dict] = None, response_format: Optional[Dict] = None,
) -> str: ) -> str:
""" """Send a chat completion request.
发送聊天请求
Args: Args:
messages: 消息列表 messages: Chat messages in OpenAI format.
temperature: 温度参数 temperature: Sampling temperature.
max_tokens: 最大token数 max_tokens: Maximum number of tokens to generate.
response_format: 响应格式如JSON模式 response_format: Optional response format hint (e.g. JSON mode).
Returns: Returns:
模型响应文本 The assistant's response text.
""" """
kwargs = { kwargs = {
"model": self.model, "model": self.model,
@ -61,7 +60,7 @@ class LLMClient:
response = self.client.chat.completions.create(**kwargs) response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除 # Some reasoning models (e.g. MiniMax M2.5) embed <think>...</think> blocks; strip them.
content = re.sub(r"<think>[\s\S]*?</think>", "", content).strip() content = re.sub(r"<think>[\s\S]*?</think>", "", content).strip()
return content return content
@ -79,7 +78,7 @@ class LLMClient:
messages=messages, temperature=temperature, max_tokens=max_tokens messages=messages, temperature=temperature, max_tokens=max_tokens
) )
# 清理markdown代码块标记 # Strip surrounding markdown code-fence markers if present.
cleaned_response = response.strip() cleaned_response = response.strip()
cleaned_response = re.sub( cleaned_response = re.sub(
r"^```(?:json)?\s*\n?", "", cleaned_response, flags=re.IGNORECASE r"^```(?:json)?\s*\n?", "", cleaned_response, flags=re.IGNORECASE

View File

@ -1,6 +1,7 @@
""" """Logger configuration module.
日志配置模块
提供统一的日志管理同时输出到控制台和文件 Provides unified logging that writes simultaneously to the console and a
rotating log file.
""" """
import os import os
@ -11,59 +12,55 @@ from logging.handlers import RotatingFileHandler
def _ensure_utf8_stdout(): def _ensure_utf8_stdout():
""" """Force stdout/stderr to UTF-8.
确保 stdout/stderr 使用 UTF-8 编码
解决 Windows 控制台中文乱码问题 Fixes garbled non-ASCII output on the Windows console.
""" """
if sys.platform == 'win32': if sys.platform == 'win32':
# Windows 下重新配置标准输出为 UTF-8 # On Windows, reconfigure the standard streams to UTF-8.
if hasattr(sys.stdout, 'reconfigure'): if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stdout.reconfigure(encoding='utf-8', errors='replace')
if hasattr(sys.stderr, 'reconfigure'): if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace')
# 日志目录 # Directory that holds rotated log files.
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs') LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger: def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
""" """Configure and return a logger.
设置日志器
Args: Args:
name: 日志器名称 name: Logger name.
level: 日志级别 level: Minimum log level for the logger.
Returns: Returns:
配置好的日志器 The configured logger.
""" """
# 确保日志目录存在
os.makedirs(LOG_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True)
# 创建日志器
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
# 阻止日志向上传播到根 logger避免重复输出 # Prevent propagation to the root logger to avoid duplicate output.
logger.propagate = False logger.propagate = False
# 如果已经有处理器,不重复添加 # If handlers are already attached, do not re-add them.
if logger.handlers: if logger.handlers:
return logger return logger
# 日志格式
detailed_formatter = logging.Formatter( detailed_formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s', '[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S' datefmt='%Y-%m-%d %H:%M:%S'
) )
simple_formatter = logging.Formatter( simple_formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s: %(message)s', '[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%H:%M:%S' datefmt='%H:%M:%S'
) )
# 1. 文件处理器 - 详细日志(按日期命名,带轮转) # 1. File handler — detailed log, named by date and rotated by size.
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log' log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
os.path.join(LOG_DIR, log_filename), os.path.join(LOG_DIR, log_filename),
@ -73,30 +70,28 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
) )
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(detailed_formatter) file_handler.setFormatter(detailed_formatter)
# 2. 控制台处理器 - 简洁日志INFO及以上 # 2. Console handler — concise log, INFO and above.
# 确保 Windows 下使用 UTF-8 编码,避免中文乱码 # Ensure UTF-8 on Windows so non-ASCII characters render correctly.
_ensure_utf8_stdout() _ensure_utf8_stdout()
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
console_handler.setFormatter(simple_formatter) console_handler.setFormatter(simple_formatter)
# 添加处理器
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
return logger return logger
def get_logger(name: str = 'mirofish') -> logging.Logger: def get_logger(name: str = 'mirofish') -> logging.Logger:
""" """Return an existing logger by name, creating it lazily if needed.
获取日志器如果不存在则创建
Args: Args:
name: 日志器名称 name: Logger name.
Returns: Returns:
日志器实例 The logger instance.
""" """
logger = logging.getLogger(name) logger = logging.getLogger(name)
if not logger.handlers: if not logger.handlers:
@ -104,11 +99,11 @@ def get_logger(name: str = 'mirofish') -> logging.Logger:
return logger return logger
# 创建默认日志器 # Default module-level logger.
logger = setup_logger() logger = setup_logger()
# 便捷方法 # Convenience module-level helpers.
def debug(msg, *args, **kwargs): def debug(msg, *args, **kwargs):
logger.debug(msg, *args, **kwargs) logger.debug(msg, *args, **kwargs)

View File

@ -1,6 +1,7 @@
""" """API call retry primitives.
API调用重试机制
用于处理LLM等外部API调用的重试逻辑 Helpers for retrying calls to external APIs (LLMs, etc.) with exponential
backoff and jitter.
""" """
import time import time
@ -21,18 +22,17 @@ def retry_with_backoff(
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None on_retry: Optional[Callable[[Exception, int], None]] = None
): ):
""" """Decorator that retries a callable with exponential backoff.
带指数退避的重试装饰器
Args: Args:
max_retries: 最大重试次数 max_retries: Maximum number of retries before giving up.
initial_delay: 初始延迟 initial_delay: Initial delay in seconds before the first retry.
max_delay: 最大延迟 max_delay: Cap on the delay between retries (seconds).
backoff_factor: 退避因子 backoff_factor: Multiplicative factor applied to the delay each retry.
jitter: 是否添加随机抖动 jitter: When ``True``, randomize the delay to avoid thundering herd.
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry.
on_retry: 重试时的回调函数 (exception, retry_count) on_retry: Optional callback invoked on each retry as ``(exception, retry_count)``.
Usage: Usage:
@retry_with_backoff(max_retries=3) @retry_with_backoff(max_retries=3)
def call_llm_api(): def call_llm_api():
@ -55,7 +55,7 @@ def retry_with_backoff(
logger.error(f"函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}") logger.error(f"函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}")
raise raise
# 计算延迟 # Compute the next delay, capped at ``max_delay``.
current_delay = min(delay, max_delay) current_delay = min(delay, max_delay)
if jitter: if jitter:
current_delay = current_delay * (0.5 + random.random()) current_delay = current_delay * (0.5 + random.random())
@ -86,9 +86,7 @@ def retry_with_backoff_async(
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None on_retry: Optional[Callable[[Exception, int], None]] = None
): ):
""" """Async variant of :func:`retry_with_backoff`."""
异步版本的重试装饰器
"""
import asyncio import asyncio
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@ -130,9 +128,7 @@ def retry_with_backoff_async(
class RetryableAPIClient: class RetryableAPIClient:
""" """Class-based wrapper around the retry helpers."""
可重试的API客户端封装
"""
def __init__( def __init__(
self, self,
@ -153,17 +149,16 @@ class RetryableAPIClient:
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
**kwargs **kwargs
) -> Any: ) -> Any:
""" """Invoke ``func`` with retry on failure.
执行函数调用并在失败时重试
Args: Args:
func: 要调用的函数 func: Callable to invoke.
*args: 函数参数 *args: Positional arguments forwarded to ``func``.
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry.
**kwargs: 函数关键字参数 **kwargs: Keyword arguments forwarded to ``func``.
Returns: Returns:
函数返回值 The value returned by ``func``.
""" """
last_exception = None last_exception = None
delay = self.initial_delay delay = self.initial_delay
@ -199,17 +194,17 @@ class RetryableAPIClient:
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
continue_on_failure: bool = True continue_on_failure: bool = True
) -> Tuple[list, list]: ) -> Tuple[list, list]:
""" """Process ``items`` in sequence, retrying each independently on failure.
批量调用并对每个失败项单独重试
Args: Args:
items: 要处理的项目列表 items: Items to process.
process_func: 处理函数接收单个item作为参数 process_func: Callable invoked once per item.
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry.
continue_on_failure: 单项失败后是否继续处理其他项 continue_on_failure: When ``True``, keep processing remaining items after a failure.
Returns: Returns:
(成功结果列表, 失败项列表) ``(successes, failures)`` a list of successful results and a list
of failure descriptors ``{"index", "item", "error"}``.
""" """
results = [] results = []
failures = [] failures = []

View File

@ -1,7 +1,8 @@
"""Zep Graph 分页读取工具。 """Zep Graph paging helpers.
Zep node/edge 列表接口使用 UUID cursor 分页 Zep's node/edge list APIs paginate with a UUID cursor. This module wraps the
本模块封装自动翻页逻辑含单页重试对调用方透明地返回完整列表 auto-paging loop (including per-page retry) so callers see the full list
transparently.
""" """
from __future__ import annotations from __future__ import annotations
@ -30,7 +31,7 @@ def _fetch_page_with_retry(
page_description: str = "page", page_description: str = "page",
**kwargs: Any, **kwargs: Any,
) -> list[Any]: ) -> list[Any]:
"""单页请求失败时指数退避重试。自动处理429限速。""" """Fetch one page, retrying with exponential backoff. Handles 429 rate limits."""
if max_retries < 1: if max_retries < 1:
raise ValueError("max_retries must be >= 1") raise ValueError("max_retries must be >= 1")
@ -43,7 +44,7 @@ def _fetch_page_with_retry(
except Exception as e: except Exception as e:
last_exception = e last_exception = e
if attempt < max_retries - 1: if attempt < max_retries - 1:
# 检测429限速使用retry-after头部指定的等待时间 # If a 429 rate limit is detected, prefer the retry-after header for the wait.
wait = delay wait = delay
logger.warning( logger.warning(
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {wait:.1f}s..." f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {wait:.1f}s..."
@ -65,7 +66,7 @@ def fetch_all_nodes(
max_retries: int = _DEFAULT_MAX_RETRIES, max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY, retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]: ) -> list[Any]:
"""分页获取图谱节点,最多返回 max_items 条(默认 2000。每页请求自带重试。""" """Page through graph nodes; return at most ``max_items`` (default 2000). Each page is retried internally."""
all_nodes: list[Any] = [] all_nodes: list[Any] = []
cursor: str | None = None cursor: str | None = None
page_num = 0 page_num = 0
@ -110,7 +111,7 @@ def fetch_all_edges(
max_retries: int = _DEFAULT_MAX_RETRIES, max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY, retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]: ) -> list[Any]:
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。""" """Page through every graph edge and return the full list. Each page is retried internally."""
all_edges: list[Any] = [] all_edges: list[Any] = []
cursor: str | None = None cursor: str | None = None
page_num = 0 page_num = 0