fix(graph): harden ontology normalization for Zep limits
This commit is contained in:
parent
25d43f8a4b
commit
a026178d67
|
|
@ -210,74 +210,111 @@ class GraphBuilderService:
|
||||||
|
|
||||||
# Zep 保留名称,不能作为属性名
|
# Zep 保留名称,不能作为属性名
|
||||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||||||
|
|
||||||
def safe_attr_name(attr_name: str) -> str:
|
def safe_attr_name(attr_name: str) -> str:
|
||||||
"""将保留名称转换为安全名称"""
|
"""将保留名称转换为安全名称"""
|
||||||
if attr_name.lower() in RESERVED_NAMES:
|
if attr_name.lower() in RESERVED_NAMES:
|
||||||
return f"entity_{attr_name}"
|
return f"entity_{attr_name}"
|
||||||
return attr_name
|
return attr_name
|
||||||
|
|
||||||
|
def normalize_attributes(raw_attributes: Any) -> List[Dict[str, str]]:
|
||||||
|
normalized: List[Dict[str, str]] = []
|
||||||
|
for attr_def in raw_attributes or []:
|
||||||
|
if isinstance(attr_def, str):
|
||||||
|
attr_def = {"name": attr_def, "description": attr_def}
|
||||||
|
if not isinstance(attr_def, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
attr_name = str(attr_def.get("name", "")).strip()
|
||||||
|
if not attr_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized.append({
|
||||||
|
"name": attr_name,
|
||||||
|
"description": str(attr_def.get("description") or attr_name),
|
||||||
|
})
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def normalize_source_targets(raw_source_targets: Any) -> List[EntityEdgeSourceTarget]:
|
||||||
|
normalized: List[EntityEdgeSourceTarget] = []
|
||||||
|
for source_target in raw_source_targets or []:
|
||||||
|
if not isinstance(source_target, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized.append(
|
||||||
|
EntityEdgeSourceTarget(
|
||||||
|
source=str(source_target.get("source", "Entity")) or "Entity",
|
||||||
|
target=str(source_target.get("target", "Entity")) or "Entity",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Zep API allows max 10 source_targets per edge type.
|
||||||
|
return normalized[:10]
|
||||||
|
|
||||||
# 动态创建实体类型
|
# 动态创建实体类型
|
||||||
entity_types = {}
|
entity_types = {}
|
||||||
for entity_def in ontology.get("entity_types", []):
|
for entity_def in ontology.get("entity_types", []):
|
||||||
name = entity_def["name"]
|
if not isinstance(entity_def, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = str(entity_def.get("name", "")).strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
|
||||||
description = entity_def.get("description", f"A {name} entity.")
|
description = entity_def.get("description", f"A {name} entity.")
|
||||||
|
|
||||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
# 创建属性字典和类型注解(Pydantic v2 需要)
|
||||||
attrs = {"__doc__": description}
|
attrs = {"__doc__": description}
|
||||||
annotations = {}
|
annotations = {}
|
||||||
|
|
||||||
for attr_def in entity_def.get("attributes", []):
|
for attr_def in normalize_attributes(entity_def.get("attributes", [])):
|
||||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||||
attr_desc = attr_def.get("description", attr_name)
|
attr_desc = attr_def.get("description", attr_name)
|
||||||
# Zep API 需要 Field 的 description,这是必需的
|
# Zep API 需要 Field 的 description,这是必需的
|
||||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
annotations[attr_name] = Optional[EntityText] # 类型注解
|
||||||
|
|
||||||
attrs["__annotations__"] = annotations
|
attrs["__annotations__"] = annotations
|
||||||
|
|
||||||
# 动态创建类
|
# 动态创建类
|
||||||
entity_class = type(name, (EntityModel,), attrs)
|
entity_class = type(name, (EntityModel,), attrs)
|
||||||
entity_class.__doc__ = description
|
entity_class.__doc__ = description
|
||||||
entity_types[name] = entity_class
|
entity_types[name] = entity_class
|
||||||
|
|
||||||
# 动态创建边类型
|
# 动态创建边类型
|
||||||
edge_definitions = {}
|
edge_definitions = {}
|
||||||
for edge_def in ontology.get("edge_types", []):
|
for edge_def in ontology.get("edge_types", []):
|
||||||
name = edge_def["name"]
|
if not isinstance(edge_def, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = str(edge_def.get("name", "")).strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
|
||||||
description = edge_def.get("description", f"A {name} relationship.")
|
description = edge_def.get("description", f"A {name} relationship.")
|
||||||
|
|
||||||
# 创建属性字典和类型注解
|
# 创建属性字典和类型注解
|
||||||
attrs = {"__doc__": description}
|
attrs = {"__doc__": description}
|
||||||
annotations = {}
|
annotations = {}
|
||||||
|
|
||||||
for attr_def in edge_def.get("attributes", []):
|
for attr_def in normalize_attributes(edge_def.get("attributes", [])):
|
||||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||||
attr_desc = attr_def.get("description", attr_name)
|
attr_desc = attr_def.get("description", attr_name)
|
||||||
# Zep API 需要 Field 的 description,这是必需的
|
# Zep API 需要 Field 的 description,这是必需的
|
||||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
annotations[attr_name] = Optional[str] # 边属性用str类型
|
||||||
|
|
||||||
attrs["__annotations__"] = annotations
|
attrs["__annotations__"] = annotations
|
||||||
|
|
||||||
# 动态创建类
|
# 动态创建类
|
||||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
edge_class = type(class_name, (EdgeModel,), attrs)
|
||||||
edge_class.__doc__ = description
|
edge_class.__doc__ = description
|
||||||
|
|
||||||
# 构建source_targets
|
source_targets = normalize_source_targets(edge_def.get("source_targets", []))
|
||||||
source_targets = []
|
|
||||||
for st in edge_def.get("source_targets", []):
|
|
||||||
source_targets.append(
|
|
||||||
EntityEdgeSourceTarget(
|
|
||||||
source=st.get("source", "Entity"),
|
|
||||||
target=st.get("target", "Entity")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if source_targets:
|
if source_targets:
|
||||||
edge_definitions[name] = (edge_class, source_targets)
|
edge_definitions[name] = (edge_class, source_targets)
|
||||||
|
|
||||||
# 调用Zep API设置本体
|
# 调用Zep API设置本体
|
||||||
if entity_types or edge_definitions:
|
if entity_types or edge_definitions:
|
||||||
self.backend.set_ontology(
|
self.backend.set_ontology(
|
||||||
|
|
|
||||||
|
|
@ -256,38 +256,84 @@ class OntologyGenerator:
|
||||||
|
|
||||||
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""验证和后处理结果"""
|
"""验证和后处理结果"""
|
||||||
|
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
result = {}
|
||||||
|
|
||||||
# 确保必要字段存在
|
# 确保必要字段存在
|
||||||
if "entity_types" not in result:
|
if not isinstance(result.get("entity_types"), list):
|
||||||
result["entity_types"] = []
|
result["entity_types"] = []
|
||||||
if "edge_types" not in result:
|
if not isinstance(result.get("edge_types"), list):
|
||||||
result["edge_types"] = []
|
result["edge_types"] = []
|
||||||
if "analysis_summary" not in result:
|
if "analysis_summary" not in result:
|
||||||
result["analysis_summary"] = ""
|
result["analysis_summary"] = ""
|
||||||
|
|
||||||
# 验证实体类型
|
# 验证实体类型
|
||||||
|
validated_entities = []
|
||||||
for entity in result["entity_types"]:
|
for entity in result["entity_types"]:
|
||||||
if "attributes" not in entity:
|
if isinstance(entity, str):
|
||||||
entity["attributes"] = []
|
entity = {"name": entity, "description": f"Entity type: {entity}"}
|
||||||
if "examples" not in entity:
|
if not isinstance(entity, dict):
|
||||||
entity["examples"] = []
|
continue
|
||||||
# 确保description不超过100字符
|
|
||||||
if len(entity.get("description", "")) > 100:
|
name = str(entity.get("name", "")).strip()
|
||||||
entity["description"] = entity["description"][:97] + "..."
|
if not name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
attributes = entity.get("attributes")
|
||||||
|
if not isinstance(attributes, list):
|
||||||
|
attributes = []
|
||||||
|
|
||||||
|
examples = entity.get("examples")
|
||||||
|
if not isinstance(examples, list):
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
normalized = dict(entity)
|
||||||
|
normalized["name"] = name
|
||||||
|
normalized["attributes"] = attributes
|
||||||
|
normalized["examples"] = examples
|
||||||
|
if len(normalized.get("description", "")) > 100:
|
||||||
|
normalized["description"] = normalized["description"][:97] + "..."
|
||||||
|
|
||||||
|
validated_entities.append(normalized)
|
||||||
|
|
||||||
|
result["entity_types"] = validated_entities
|
||||||
|
|
||||||
# 验证关系类型
|
# 验证关系类型
|
||||||
|
validated_edges = []
|
||||||
for edge in result["edge_types"]:
|
for edge in result["edge_types"]:
|
||||||
if "source_targets" not in edge:
|
if isinstance(edge, str):
|
||||||
edge["source_targets"] = []
|
edge = {"name": edge, "description": f"Relationship type: {edge}"}
|
||||||
if "attributes" not in edge:
|
if not isinstance(edge, dict):
|
||||||
edge["attributes"] = []
|
continue
|
||||||
if len(edge.get("description", "")) > 100:
|
|
||||||
edge["description"] = edge["description"][:97] + "..."
|
name = str(edge.get("name", "")).strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_targets = edge.get("source_targets")
|
||||||
|
if not isinstance(source_targets, list):
|
||||||
|
source_targets = []
|
||||||
|
|
||||||
|
attributes = edge.get("attributes")
|
||||||
|
if not isinstance(attributes, list):
|
||||||
|
attributes = []
|
||||||
|
|
||||||
|
normalized = dict(edge)
|
||||||
|
normalized["name"] = name
|
||||||
|
normalized["source_targets"] = source_targets
|
||||||
|
normalized["attributes"] = attributes
|
||||||
|
if len(normalized.get("description", "")) > 100:
|
||||||
|
normalized["description"] = normalized["description"][:97] + "..."
|
||||||
|
|
||||||
|
validated_edges.append(normalized)
|
||||||
|
|
||||||
|
result["edge_types"] = validated_edges
|
||||||
|
|
||||||
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
|
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
|
||||||
MAX_ENTITY_TYPES = 10
|
MAX_ENTITY_TYPES = 10
|
||||||
MAX_EDGE_TYPES = 10
|
MAX_EDGE_TYPES = 10
|
||||||
|
|
||||||
# 兜底类型定义
|
# 兜底类型定义
|
||||||
person_fallback = {
|
person_fallback = {
|
||||||
"name": "Person",
|
"name": "Person",
|
||||||
|
|
@ -298,7 +344,7 @@ class OntologyGenerator:
|
||||||
],
|
],
|
||||||
"examples": ["ordinary citizen", "anonymous netizen"]
|
"examples": ["ordinary citizen", "anonymous netizen"]
|
||||||
}
|
}
|
||||||
|
|
||||||
organization_fallback = {
|
organization_fallback = {
|
||||||
"name": "Organization",
|
"name": "Organization",
|
||||||
"description": "Any organization not fitting other specific organization types.",
|
"description": "Any organization not fitting other specific organization types.",
|
||||||
|
|
@ -308,40 +354,40 @@ class OntologyGenerator:
|
||||||
],
|
],
|
||||||
"examples": ["small business", "community group"]
|
"examples": ["small business", "community group"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查是否已有兜底类型
|
# 检查是否已有兜底类型
|
||||||
entity_names = {e["name"] for e in result["entity_types"]}
|
entity_names = {e["name"] for e in result["entity_types"]}
|
||||||
has_person = "Person" in entity_names
|
has_person = "Person" in entity_names
|
||||||
has_organization = "Organization" in entity_names
|
has_organization = "Organization" in entity_names
|
||||||
|
|
||||||
# 需要添加的兜底类型
|
# 需要添加的兜底类型
|
||||||
fallbacks_to_add = []
|
fallbacks_to_add = []
|
||||||
if not has_person:
|
if not has_person:
|
||||||
fallbacks_to_add.append(person_fallback)
|
fallbacks_to_add.append(person_fallback)
|
||||||
if not has_organization:
|
if not has_organization:
|
||||||
fallbacks_to_add.append(organization_fallback)
|
fallbacks_to_add.append(organization_fallback)
|
||||||
|
|
||||||
if fallbacks_to_add:
|
if fallbacks_to_add:
|
||||||
current_count = len(result["entity_types"])
|
current_count = len(result["entity_types"])
|
||||||
needed_slots = len(fallbacks_to_add)
|
needed_slots = len(fallbacks_to_add)
|
||||||
|
|
||||||
# 如果添加后会超过 10 个,需要移除一些现有类型
|
# 如果添加后会超过 10 个,需要移除一些现有类型
|
||||||
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
||||||
# 计算需要移除多少个
|
# 计算需要移除多少个
|
||||||
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
||||||
# 从末尾移除(保留前面更重要的具体类型)
|
# 从末尾移除(保留前面更重要的具体类型)
|
||||||
result["entity_types"] = result["entity_types"][:-to_remove]
|
result["entity_types"] = result["entity_types"][:-to_remove]
|
||||||
|
|
||||||
# 添加兜底类型
|
# 添加兜底类型
|
||||||
result["entity_types"].extend(fallbacks_to_add)
|
result["entity_types"].extend(fallbacks_to_add)
|
||||||
|
|
||||||
# 最终确保不超过限制(防御性编程)
|
# 最终确保不超过限制(防御性编程)
|
||||||
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
||||||
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
||||||
|
|
||||||
if len(result["edge_types"]) > MAX_EDGE_TYPES:
|
if len(result["edge_types"]) > MAX_EDGE_TYPES:
|
||||||
result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES]
|
result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue