fix(graph): harden ontology normalization for Zep limits
This commit is contained in:
parent
25d43f8a4b
commit
a026178d67
|
|
@ -210,74 +210,111 @@ class GraphBuilderService:
|
|||
|
||||
# Zep 保留名称,不能作为属性名
|
||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||||
|
||||
|
||||
def safe_attr_name(attr_name: str) -> str:
|
||||
"""将保留名称转换为安全名称"""
|
||||
if attr_name.lower() in RESERVED_NAMES:
|
||||
return f"entity_{attr_name}"
|
||||
return attr_name
|
||||
|
||||
|
||||
def normalize_attributes(raw_attributes: Any) -> List[Dict[str, str]]:
|
||||
normalized: List[Dict[str, str]] = []
|
||||
for attr_def in raw_attributes or []:
|
||||
if isinstance(attr_def, str):
|
||||
attr_def = {"name": attr_def, "description": attr_def}
|
||||
if not isinstance(attr_def, dict):
|
||||
continue
|
||||
|
||||
attr_name = str(attr_def.get("name", "")).strip()
|
||||
if not attr_name:
|
||||
continue
|
||||
|
||||
normalized.append({
|
||||
"name": attr_name,
|
||||
"description": str(attr_def.get("description") or attr_name),
|
||||
})
|
||||
return normalized
|
||||
|
||||
def normalize_source_targets(raw_source_targets: Any) -> List[EntityEdgeSourceTarget]:
|
||||
normalized: List[EntityEdgeSourceTarget] = []
|
||||
for source_target in raw_source_targets or []:
|
||||
if not isinstance(source_target, dict):
|
||||
continue
|
||||
|
||||
normalized.append(
|
||||
EntityEdgeSourceTarget(
|
||||
source=str(source_target.get("source", "Entity")) or "Entity",
|
||||
target=str(source_target.get("target", "Entity")) or "Entity",
|
||||
)
|
||||
)
|
||||
|
||||
# Zep API allows max 10 source_targets per edge type.
|
||||
return normalized[:10]
|
||||
|
||||
# 动态创建实体类型
|
||||
entity_types = {}
|
||||
for entity_def in ontology.get("entity_types", []):
|
||||
name = entity_def["name"]
|
||||
if not isinstance(entity_def, dict):
|
||||
continue
|
||||
|
||||
name = str(entity_def.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
description = entity_def.get("description", f"A {name} entity.")
|
||||
|
||||
|
||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
for attr_def in entity_def.get("attributes", []):
|
||||
|
||||
for attr_def in normalize_attributes(entity_def.get("attributes", [])):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
||||
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
|
||||
# 动态创建类
|
||||
entity_class = type(name, (EntityModel,), attrs)
|
||||
entity_class.__doc__ = description
|
||||
entity_types[name] = entity_class
|
||||
|
||||
|
||||
# 动态创建边类型
|
||||
edge_definitions = {}
|
||||
for edge_def in ontology.get("edge_types", []):
|
||||
name = edge_def["name"]
|
||||
if not isinstance(edge_def, dict):
|
||||
continue
|
||||
|
||||
name = str(edge_def.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
description = edge_def.get("description", f"A {name} relationship.")
|
||||
|
||||
|
||||
# 创建属性字典和类型注解
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
for attr_def in edge_def.get("attributes", []):
|
||||
|
||||
for attr_def in normalize_attributes(edge_def.get("attributes", [])):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
||||
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
|
||||
# 动态创建类
|
||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
||||
edge_class.__doc__ = description
|
||||
|
||||
# 构建source_targets
|
||||
source_targets = []
|
||||
for st in edge_def.get("source_targets", []):
|
||||
source_targets.append(
|
||||
EntityEdgeSourceTarget(
|
||||
source=st.get("source", "Entity"),
|
||||
target=st.get("target", "Entity")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
source_targets = normalize_source_targets(edge_def.get("source_targets", []))
|
||||
if source_targets:
|
||||
edge_definitions[name] = (edge_class, source_targets)
|
||||
|
||||
|
||||
# 调用Zep API设置本体
|
||||
if entity_types or edge_definitions:
|
||||
self.backend.set_ontology(
|
||||
|
|
|
|||
|
|
@ -256,38 +256,84 @@ class OntologyGenerator:
|
|||
|
||||
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""验证和后处理结果"""
|
||||
|
||||
|
||||
if not isinstance(result, dict):
|
||||
result = {}
|
||||
|
||||
# 确保必要字段存在
|
||||
if "entity_types" not in result:
|
||||
if not isinstance(result.get("entity_types"), list):
|
||||
result["entity_types"] = []
|
||||
if "edge_types" not in result:
|
||||
if not isinstance(result.get("edge_types"), list):
|
||||
result["edge_types"] = []
|
||||
if "analysis_summary" not in result:
|
||||
result["analysis_summary"] = ""
|
||||
|
||||
|
||||
# 验证实体类型
|
||||
validated_entities = []
|
||||
for entity in result["entity_types"]:
|
||||
if "attributes" not in entity:
|
||||
entity["attributes"] = []
|
||||
if "examples" not in entity:
|
||||
entity["examples"] = []
|
||||
# 确保description不超过100字符
|
||||
if len(entity.get("description", "")) > 100:
|
||||
entity["description"] = entity["description"][:97] + "..."
|
||||
|
||||
if isinstance(entity, str):
|
||||
entity = {"name": entity, "description": f"Entity type: {entity}"}
|
||||
if not isinstance(entity, dict):
|
||||
continue
|
||||
|
||||
name = str(entity.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
attributes = entity.get("attributes")
|
||||
if not isinstance(attributes, list):
|
||||
attributes = []
|
||||
|
||||
examples = entity.get("examples")
|
||||
if not isinstance(examples, list):
|
||||
examples = []
|
||||
|
||||
normalized = dict(entity)
|
||||
normalized["name"] = name
|
||||
normalized["attributes"] = attributes
|
||||
normalized["examples"] = examples
|
||||
if len(normalized.get("description", "")) > 100:
|
||||
normalized["description"] = normalized["description"][:97] + "..."
|
||||
|
||||
validated_entities.append(normalized)
|
||||
|
||||
result["entity_types"] = validated_entities
|
||||
|
||||
# 验证关系类型
|
||||
validated_edges = []
|
||||
for edge in result["edge_types"]:
|
||||
if "source_targets" not in edge:
|
||||
edge["source_targets"] = []
|
||||
if "attributes" not in edge:
|
||||
edge["attributes"] = []
|
||||
if len(edge.get("description", "")) > 100:
|
||||
edge["description"] = edge["description"][:97] + "..."
|
||||
|
||||
if isinstance(edge, str):
|
||||
edge = {"name": edge, "description": f"Relationship type: {edge}"}
|
||||
if not isinstance(edge, dict):
|
||||
continue
|
||||
|
||||
name = str(edge.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
source_targets = edge.get("source_targets")
|
||||
if not isinstance(source_targets, list):
|
||||
source_targets = []
|
||||
|
||||
attributes = edge.get("attributes")
|
||||
if not isinstance(attributes, list):
|
||||
attributes = []
|
||||
|
||||
normalized = dict(edge)
|
||||
normalized["name"] = name
|
||||
normalized["source_targets"] = source_targets
|
||||
normalized["attributes"] = attributes
|
||||
if len(normalized.get("description", "")) > 100:
|
||||
normalized["description"] = normalized["description"][:97] + "..."
|
||||
|
||||
validated_edges.append(normalized)
|
||||
|
||||
result["edge_types"] = validated_edges
|
||||
|
||||
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
|
||||
MAX_ENTITY_TYPES = 10
|
||||
MAX_EDGE_TYPES = 10
|
||||
|
||||
|
||||
# 兜底类型定义
|
||||
person_fallback = {
|
||||
"name": "Person",
|
||||
|
|
@ -298,7 +344,7 @@ class OntologyGenerator:
|
|||
],
|
||||
"examples": ["ordinary citizen", "anonymous netizen"]
|
||||
}
|
||||
|
||||
|
||||
organization_fallback = {
|
||||
"name": "Organization",
|
||||
"description": "Any organization not fitting other specific organization types.",
|
||||
|
|
@ -308,40 +354,40 @@ class OntologyGenerator:
|
|||
],
|
||||
"examples": ["small business", "community group"]
|
||||
}
|
||||
|
||||
|
||||
# 检查是否已有兜底类型
|
||||
entity_names = {e["name"] for e in result["entity_types"]}
|
||||
has_person = "Person" in entity_names
|
||||
has_organization = "Organization" in entity_names
|
||||
|
||||
|
||||
# 需要添加的兜底类型
|
||||
fallbacks_to_add = []
|
||||
if not has_person:
|
||||
fallbacks_to_add.append(person_fallback)
|
||||
if not has_organization:
|
||||
fallbacks_to_add.append(organization_fallback)
|
||||
|
||||
|
||||
if fallbacks_to_add:
|
||||
current_count = len(result["entity_types"])
|
||||
needed_slots = len(fallbacks_to_add)
|
||||
|
||||
|
||||
# 如果添加后会超过 10 个,需要移除一些现有类型
|
||||
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
||||
# 计算需要移除多少个
|
||||
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
||||
# 从末尾移除(保留前面更重要的具体类型)
|
||||
result["entity_types"] = result["entity_types"][:-to_remove]
|
||||
|
||||
|
||||
# 添加兜底类型
|
||||
result["entity_types"].extend(fallbacks_to_add)
|
||||
|
||||
|
||||
# 最终确保不超过限制(防御性编程)
|
||||
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
||||
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
||||
|
||||
|
||||
if len(result["edge_types"]) > MAX_EDGE_TYPES:
|
||||
result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue