fix: normalize structured LLM profile fields before serialization
This commit is contained in:
parent
985f89f49a
commit
77870cce90
|
|
@ -25,6 +25,47 @@ from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||||
logger = get_logger('mirofish.oasis_profile')
|
logger = get_logger('mirofish.oasis_profile')
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_to_str(value: Any) -> str:
|
||||||
|
"""Coerce a value to a plain string.
|
||||||
|
|
||||||
|
Handles dict, list, and other non-string types that may be returned
|
||||||
|
by LLM JSON parsing.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for key in ('text', 'value', 'description', 'content', 'summary', 'name'):
|
||||||
|
if key in value and isinstance(value[key], str):
|
||||||
|
return value[key]
|
||||||
|
return json.dumps(value, ensure_ascii=False)
|
||||||
|
if isinstance(value, list):
|
||||||
|
str_items = [_coerce_to_str(item) for item in value]
|
||||||
|
return ', '.join(str_items)
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_to_str_list(value: Any) -> List[str]:
|
||||||
|
"""Coerce a value to a list of strings.
|
||||||
|
|
||||||
|
Handles nested structures that may be returned by LLM JSON parsing.
|
||||||
|
"""
|
||||||
|
if isinstance(value, list):
|
||||||
|
result = []
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, str):
|
||||||
|
result.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
result.append(_coerce_to_str(item))
|
||||||
|
else:
|
||||||
|
result.append(str(item))
|
||||||
|
return result
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [value]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return [_coerce_to_str(value)]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OasisAgentProfile:
|
class OasisAgentProfile:
|
||||||
"""OASIS Agent Profile数据结构"""
|
"""OASIS Agent Profile数据结构"""
|
||||||
|
|
@ -57,6 +98,16 @@ class OasisAgentProfile:
|
||||||
|
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Normalize field types to guard against structured LLM outputs
|
||||||
|
(e.g. dict/list instead of plain strings)."""
|
||||||
|
self.bio = _coerce_to_str(self.bio)
|
||||||
|
self.persona = _coerce_to_str(self.persona)
|
||||||
|
self.country = _coerce_to_str(self.country) if self.country is not None else None
|
||||||
|
self.profession = _coerce_to_str(self.profession) if self.profession is not None else None
|
||||||
|
self.gender = _coerce_to_str(self.gender) if self.gender is not None else None
|
||||||
|
self.interested_topics = _coerce_to_str_list(self.interested_topics)
|
||||||
|
|
||||||
def to_reddit_format(self) -> Dict[str, Any]:
|
def to_reddit_format(self) -> Dict[str, Any]:
|
||||||
"""转换为Reddit平台格式"""
|
"""转换为Reddit平台格式"""
|
||||||
profile = {
|
profile = {
|
||||||
|
|
@ -549,6 +600,15 @@ class OasisProfileGenerator:
|
||||||
try:
|
try:
|
||||||
result = json.loads(content)
|
result = json.loads(content)
|
||||||
|
|
||||||
|
# Normalize types from LLM output
|
||||||
|
for str_field in ('bio', 'persona', 'country', 'profession'):
|
||||||
|
if str_field in result and result[str_field] is not None:
|
||||||
|
result[str_field] = _coerce_to_str(result[str_field])
|
||||||
|
if 'interested_topics' in result:
|
||||||
|
result['interested_topics'] = _coerce_to_str_list(
|
||||||
|
result['interested_topics']
|
||||||
|
)
|
||||||
|
|
||||||
# 验证必需字段
|
# 验证必需字段
|
||||||
if "bio" not in result or not result["bio"]:
|
if "bio" not in result or not result["bio"]:
|
||||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||||
|
|
@ -1092,15 +1152,19 @@ class OasisProfileGenerator:
|
||||||
|
|
||||||
# 写入数据行
|
# 写入数据行
|
||||||
for idx, profile in enumerate(profiles):
|
for idx, profile in enumerate(profiles):
|
||||||
|
# Defensive coercion in case __post_init__ was bypassed
|
||||||
|
bio = _coerce_to_str(profile.bio) if profile.bio else profile.name
|
||||||
|
persona = _coerce_to_str(profile.persona) if profile.persona else ''
|
||||||
|
|
||||||
# user_char: 完整人设(bio + persona),用于LLM系统提示
|
# user_char: 完整人设(bio + persona),用于LLM系统提示
|
||||||
user_char = profile.bio
|
user_char = bio
|
||||||
if profile.persona and profile.persona != profile.bio:
|
if persona and persona != bio:
|
||||||
user_char = f"{profile.bio} {profile.persona}"
|
user_char = f"{bio} {persona}"
|
||||||
# 处理换行符(CSV中用空格替代)
|
# 处理换行符(CSV中用空格替代)
|
||||||
user_char = user_char.replace('\n', ' ').replace('\r', ' ')
|
user_char = user_char.replace('\n', ' ').replace('\r', ' ')
|
||||||
|
|
||||||
# description: 简短简介,用于外部显示
|
# description: 简短简介,用于外部显示
|
||||||
description = profile.bio.replace('\n', ' ').replace('\r', ' ')
|
description = bio.replace('\n', ' ').replace('\r', ' ')
|
||||||
|
|
||||||
row = [
|
row = [
|
||||||
idx, # user_id: 从0开始的顺序ID
|
idx, # user_id: 从0开始的顺序ID
|
||||||
|
|
@ -1158,27 +1222,40 @@ class OasisProfileGenerator:
|
||||||
"""
|
"""
|
||||||
data = []
|
data = []
|
||||||
for idx, profile in enumerate(profiles):
|
for idx, profile in enumerate(profiles):
|
||||||
|
# Defensive coercion in case __post_init__ was bypassed
|
||||||
|
bio = _coerce_to_str(profile.bio) if profile.bio else f"{profile.name}"
|
||||||
|
persona = _coerce_to_str(profile.persona) if profile.persona else (
|
||||||
|
f"{profile.name} is a participant in social discussions."
|
||||||
|
)
|
||||||
|
country = _coerce_to_str(profile.country) if profile.country else "中国"
|
||||||
|
profession = _coerce_to_str(profile.profession) if profile.profession else None
|
||||||
|
interested_topics = (
|
||||||
|
_coerce_to_str_list(profile.interested_topics)
|
||||||
|
if profile.interested_topics
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# 使用与 to_reddit_format() 一致的格式
|
# 使用与 to_reddit_format() 一致的格式
|
||||||
item = {
|
item = {
|
||||||
"user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id
|
"user_id": profile.user_id if profile.user_id is not None else idx, # 关键:必须包含 user_id
|
||||||
"username": profile.user_name,
|
"username": profile.user_name,
|
||||||
"name": profile.name,
|
"name": profile.name,
|
||||||
"bio": profile.bio[:150] if profile.bio else f"{profile.name}",
|
"bio": bio[:150],
|
||||||
"persona": profile.persona or f"{profile.name} is a participant in social discussions.",
|
"persona": persona,
|
||||||
"karma": profile.karma if profile.karma else 1000,
|
"karma": profile.karma if profile.karma else 1000,
|
||||||
"created_at": profile.created_at,
|
"created_at": profile.created_at,
|
||||||
# OASIS必需字段 - 确保都有默认值
|
# OASIS必需字段 - 确保都有默认值
|
||||||
"age": profile.age if profile.age else 30,
|
"age": profile.age if profile.age else 30,
|
||||||
"gender": self._normalize_gender(profile.gender),
|
"gender": self._normalize_gender(profile.gender),
|
||||||
"mbti": profile.mbti if profile.mbti else "ISTJ",
|
"mbti": profile.mbti if profile.mbti else "ISTJ",
|
||||||
"country": profile.country if profile.country else "中国",
|
"country": country,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 可选字段
|
# 可选字段
|
||||||
if profile.profession:
|
if profession:
|
||||||
item["profession"] = profile.profession
|
item["profession"] = profession
|
||||||
if profile.interested_topics:
|
if interested_topics:
|
||||||
item["interested_topics"] = profile.interested_topics
|
item["interested_topics"] = interested_topics
|
||||||
|
|
||||||
data.append(item)
|
data.append(item)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue