diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 37c9969c..b7877cb6 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -17,6 +17,7 @@ from ..config import Config from ..models.task import TaskManager, TaskStatus from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor +from .ontology_schema import normalize_ontology_schema from ..utils.locale import t, get_locale, set_locale @@ -208,6 +209,8 @@ class GraphBuilderService: from typing import Optional from pydantic import Field from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel + + ontology = normalize_ontology_schema(ontology) # 抑制 Pydantic v2 关于 Field(default=None) 的警告 # 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 @@ -503,4 +506,3 @@ class GraphBuilderService: def delete_graph(self, graph_id: str): """删除图谱""" self.client.graph.delete(graph_id=graph_id) - diff --git a/backend/app/services/ontology_generator.py b/backend/app/services/ontology_generator.py index 01a3d799..f08a4d79 100644 --- a/backend/app/services/ontology_generator.py +++ b/backend/app/services/ontology_generator.py @@ -9,6 +9,7 @@ import re from typing import Dict, Any, List, Optional from ..utils.llm_client import LLMClient from ..utils.locale import get_language_instruction +from .ontology_schema import normalize_ontology_schema logger = logging.getLogger(__name__) @@ -276,6 +277,7 @@ class OntologyGenerator: def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: """验证和后处理结果""" + result = normalize_ontology_schema(result) # 确保必要字段存在 if "entity_types" not in result: @@ -407,6 +409,8 @@ class OntologyGenerator: Returns: Python代码字符串 """ + ontology = normalize_ontology_schema(ontology) + code_lines = [ '"""', '自定义实体类型定义', @@ -503,4 +507,3 @@ class OntologyGenerator: code_lines.append('}') return '\n'.join(code_lines) - diff --git a/backend/app/services/ontology_schema.py b/backend/app/services/ontology_schema.py new file mode 100644 index 00000000..f8fe0f46 --- /dev/null +++ b/backend/app/services/ontology_schema.py @@ -0,0 +1,137 @@ +""" +Ontology schema normalization helpers. +""" + +import logging +from typing import Any, Dict, List + +logger = logging.getLogger(__name__) + + +def normalize_attribute_definitions( + attribute_defs: Any, + owner_label: str, +) -> List[Dict[str, str]]: + """Normalize attribute definitions into the expected name/type/description shape.""" + if not attribute_defs: + return [] + + if isinstance(attribute_defs, dict): + attribute_defs = [attribute_defs] + + if not isinstance(attribute_defs, list): + logger.warning( + "Invalid attribute definitions for %s: expected list, got %s", + owner_label, + type(attribute_defs).__name__, + ) + return [] + + normalized: List[Dict[str, str]] = [] + seen_names = set() + + for index, attr_def in enumerate(attribute_defs): + if not isinstance(attr_def, dict): + logger.warning( + "Skipping invalid attribute definition for %s at index %s: expected dict, got %s", + owner_label, + index, + type(attr_def).__name__, + ) + continue + + # Legacy payloads store one mapping of {attr_name: attr_description}. + if "name" not in attr_def: + logger.warning( + "Normalizing legacy attribute map for %s at index %s", + owner_label, + index, + ) + candidate_items = [ + { + "name": attr_name, + "type": "text", + "description": attr_desc, + } + for attr_name, attr_desc in attr_def.items() + ] + else: + candidate_items = [attr_def] + + for candidate in candidate_items: + attr_name = str(candidate.get("name", "")).strip() + if not attr_name: + logger.warning( + "Skipping attribute with empty name for %s at index %s", + owner_label, + index, + ) + continue + + if attr_name in seen_names: + logger.warning( + "Duplicate attribute '%s' removed for %s", + attr_name, + owner_label, + ) + continue + + seen_names.add(attr_name) + description = candidate.get("description") + normalized.append( + { + "name": attr_name, + "type": str(candidate.get("type") or "text"), + "description": str(description).strip() if description else attr_name, + } + ) + + return normalized + + +def normalize_ontology_schema(ontology: Dict[str, Any]) -> Dict[str, Any]: + """Return a copy of ontology with normalized entity and edge attributes.""" + if not isinstance(ontology, dict): + logger.warning( + "Invalid ontology payload: expected dict, got %s", + type(ontology).__name__, + ) + return {} + + normalized = dict(ontology) + + entity_types = [] + for entity in ontology.get("entity_types", []): + if not isinstance(entity, dict): + logger.warning( + "Skipping invalid entity definition: expected dict, got %s", + type(entity).__name__, + ) + continue + + normalized_entity = dict(entity) + normalized_entity["attributes"] = normalize_attribute_definitions( + entity.get("attributes", []), + f"entity '{entity.get('name', 'unknown')}'", + ) + entity_types.append(normalized_entity) + + edge_types = [] + for edge in ontology.get("edge_types", []): + if not isinstance(edge, dict): + logger.warning( + "Skipping invalid edge definition: expected dict, got %s", + type(edge).__name__, + ) + continue + + normalized_edge = dict(edge) + normalized_edge["attributes"] = normalize_attribute_definitions( + edge.get("attributes", []), + f"edge '{edge.get('name', 'unknown')}'", + ) + edge_types.append(normalized_edge) + + normalized["entity_types"] = entity_types + normalized["edge_types"] = edge_types + return normalized diff --git a/backend/tests/test_ontology_normalization.py b/backend/tests/test_ontology_normalization.py new file mode 100644 index 00000000..89543264 --- /dev/null +++ b/backend/tests/test_ontology_normalization.py @@ -0,0 +1,71 @@ +from app.services.graph_builder import GraphBuilderService +from app.services.ontology_schema import normalize_ontology_schema + + +def legacy_ontology(): + return { + "entity_types": [ + { + "name": "Founder", + "description": "Startup founder", + "attributes": [ + { + "full_name": "Founder full name", + "role": "Founder title", + "description": "Founder bio", + } + ], + "examples": ["Ada Lovelace"], + } + ], + "edge_types": [ + { + "name": "FOUNDS", + "description": "Founder starts a company", + "source_targets": [{"source": "Founder", "target": "Organization"}], + "attributes": [{"started_at": "When the company was started"}], + } + ], + } + + +def test_normalize_ontology_schema_converts_legacy_attribute_maps(): + normalized = normalize_ontology_schema(legacy_ontology()) + + assert normalized["entity_types"][0]["attributes"] == [ + {"name": "full_name", "type": "text", "description": "Founder full name"}, + {"name": "role", "type": "text", "description": "Founder title"}, + {"name": "description", "type": "text", "description": "Founder bio"}, + ] + assert normalized["edge_types"][0]["attributes"] == [ + {"name": "started_at", "type": "text", "description": "When the company was started"}, + ] + + +def test_graph_builder_set_ontology_accepts_legacy_attribute_maps(): + captured = {} + + class DummyGraph: + def set_ontology(self, **kwargs): + captured.update(kwargs) + + class DummyClient: + graph = DummyGraph() + + builder = GraphBuilderService.__new__(GraphBuilderService) + builder.client = DummyClient() + + builder.set_ontology("graph_123", legacy_ontology()) + + assert captured["graph_ids"] == ["graph_123"] + assert set(captured["entities"]["Founder"].model_fields.keys()) == { + "description", + "full_name", + "role", + } + + edge_model, source_targets = captured["edges"]["FOUNDS"] + assert set(edge_model.model_fields.keys()) == {"started_at"} + assert len(source_targets) == 1 + assert source_targets[0].source == "Founder" + assert source_targets[0].target == "Organization"