feat(llm): add structured output reliability to LLMClient and refactor services
- chat_json() gains configurable retry (max_attempts), temperature backoff (temperature_step), optional retry delay, and a fallback_parser hook for service-specific rescue logic - _clean_response_text: strip <think> tags and markdown code fences - _fix_truncated_json: use unescaped-quote parity instead of last-char check to avoid spuriously quoting numeric values; fixes broken repair for arrays - _try_fix_json: generic near-valid JSON salvage (newline normalisation, control-character stripping, greedy object extraction) - simulation_config_generator: replace raw OpenAI client with LLMClient, remove duplicated local retry loop and truncated-JSON repair, pass service-specific config salvage as fallback_parser - oasis_profile_generator: same refactor; keep rule-based profile fallback Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1536a79334
commit
52c177fd66
|
|
@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
from zep_cloud.client import Zep
|
from zep_cloud.client import Zep
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from ..utils.llm_client import LLMClient
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||||
|
|
||||||
|
|
@ -192,9 +192,10 @@ class OasisProfileGenerator:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("LLM_API_KEY 未配置")
|
raise ValueError("LLM_API_KEY 未配置")
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.llm_client = LLMClient(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=self.base_url
|
base_url=self.base_url,
|
||||||
|
model=self.model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# Zep客户端用于检索丰富上下文
|
# Zep客户端用于检索丰富上下文
|
||||||
|
|
@ -520,64 +521,36 @@ class OasisProfileGenerator:
|
||||||
entity_name, entity_type, entity_summary, entity_attributes, context
|
entity_name, entity_type, entity_summary, entity_attributes, context
|
||||||
)
|
)
|
||||||
|
|
||||||
# 尝试多次生成,直到成功或达到最大重试次数
|
try:
|
||||||
max_attempts = 3
|
result = self.llm_client.chat_json(
|
||||||
last_error = None
|
messages=[
|
||||||
|
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=None,
|
||||||
|
max_attempts=3,
|
||||||
|
temperature_step=0.1,
|
||||||
|
fallback_parser=lambda content: self._try_fix_json(
|
||||||
|
content, entity_name, entity_type, entity_summary
|
||||||
|
),
|
||||||
|
retry_delay_seconds=1.0
|
||||||
|
)
|
||||||
|
|
||||||
for attempt in range(max_attempts):
|
if "bio" not in result or not result["bio"]:
|
||||||
try:
|
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||||
response = self.client.chat.completions.create(
|
if "persona" not in result or not result["persona"]:
|
||||||
model=self.model_name,
|
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
],
|
|
||||||
response_format={"type": "json_object"},
|
|
||||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
|
||||||
# 不设置max_tokens,让LLM自由发挥
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
if result.get("_fixed"):
|
||||||
|
del result["_fixed"]
|
||||||
|
|
||||||
# 检查是否被截断(finish_reason不是'stop')
|
return result
|
||||||
finish_reason = response.choices[0].finish_reason
|
except Exception as e:
|
||||||
if finish_reason == 'length':
|
logger.warning(f"LLM生成人设失败(3次尝试): {e}, 使用规则生成")
|
||||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
|
return self._generate_profile_rule_based(
|
||||||
content = self._fix_truncated_json(content)
|
entity_name, entity_type, entity_summary, entity_attributes
|
||||||
|
)
|
||||||
# 尝试解析JSON
|
|
||||||
try:
|
|
||||||
result = json.loads(content)
|
|
||||||
|
|
||||||
# 验证必需字段
|
|
||||||
if "bio" not in result or not result["bio"]:
|
|
||||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
|
||||||
if "persona" not in result or not result["persona"]:
|
|
||||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except json.JSONDecodeError as je:
|
|
||||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
|
|
||||||
|
|
||||||
# 尝试修复JSON
|
|
||||||
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
|
|
||||||
if result.get("_fixed"):
|
|
||||||
del result["_fixed"]
|
|
||||||
return result
|
|
||||||
|
|
||||||
last_error = je
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
|
||||||
last_error = e
|
|
||||||
import time
|
|
||||||
time.sleep(1 * (attempt + 1)) # 指数退避
|
|
||||||
|
|
||||||
logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成")
|
|
||||||
return self._generate_profile_rule_based(
|
|
||||||
entity_name, entity_type, entity_summary, entity_attributes
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fix_truncated_json(self, content: str) -> str:
|
def _fix_truncated_json(self, content: str) -> str:
|
||||||
"""修复被截断的JSON(输出被max_tokens限制截断)"""
|
"""修复被截断的JSON(输出被max_tokens限制截断)"""
|
||||||
|
|
@ -1197,4 +1170,3 @@ class OasisProfileGenerator:
|
||||||
"""[已废弃] 请使用 save_profiles() 方法"""
|
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||||
self.save_profiles(profiles, file_path, platform)
|
self.save_profiles(profiles, file_path, platform)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
|
||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from ..utils.llm_client import LLMClient
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||||
|
|
||||||
|
|
@ -234,9 +233,10 @@ class SimulationConfigGenerator:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("LLM_API_KEY 未配置")
|
raise ValueError("LLM_API_KEY 未配置")
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.llm_client = LLMClient(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=self.base_url
|
base_url=self.base_url,
|
||||||
|
model=self.model_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_config(
|
def generate_config(
|
||||||
|
|
@ -432,78 +432,23 @@ class SimulationConfigGenerator:
|
||||||
|
|
||||||
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
|
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
|
||||||
"""带重试的LLM调用,包含JSON修复逻辑"""
|
"""带重试的LLM调用,包含JSON修复逻辑"""
|
||||||
import re
|
return self.llm_client.chat_json(
|
||||||
|
messages=[
|
||||||
max_attempts = 3
|
{"role": "system", "content": system_prompt},
|
||||||
last_error = None
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
for attempt in range(max_attempts):
|
temperature=0.7,
|
||||||
try:
|
max_tokens=None,
|
||||||
response = self.client.chat.completions.create(
|
max_attempts=3,
|
||||||
model=self.model_name,
|
temperature_step=0.1,
|
||||||
messages=[
|
fallback_parser=self._try_fix_config_json,
|
||||||
{"role": "system", "content": system_prompt},
|
retry_delay_seconds=2.0
|
||||||
{"role": "user", "content": prompt}
|
)
|
||||||
],
|
|
||||||
response_format={"type": "json_object"},
|
|
||||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
|
||||||
# 不设置max_tokens,让LLM自由发挥
|
|
||||||
)
|
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
finish_reason = response.choices[0].finish_reason
|
|
||||||
|
|
||||||
# 检查是否被截断
|
|
||||||
if finish_reason == 'length':
|
|
||||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
|
|
||||||
content = self._fix_truncated_json(content)
|
|
||||||
|
|
||||||
# 尝试解析JSON
|
|
||||||
try:
|
|
||||||
return json.loads(content)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
|
|
||||||
|
|
||||||
# 尝试修复JSON
|
|
||||||
fixed = self._try_fix_config_json(content)
|
|
||||||
if fixed:
|
|
||||||
return fixed
|
|
||||||
|
|
||||||
last_error = e
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
|
||||||
last_error = e
|
|
||||||
import time
|
|
||||||
time.sleep(2 * (attempt + 1))
|
|
||||||
|
|
||||||
raise last_error or Exception("LLM调用失败")
|
|
||||||
|
|
||||||
def _fix_truncated_json(self, content: str) -> str:
|
|
||||||
"""修复被截断的JSON"""
|
|
||||||
content = content.strip()
|
|
||||||
|
|
||||||
# 计算未闭合的括号
|
|
||||||
open_braces = content.count('{') - content.count('}')
|
|
||||||
open_brackets = content.count('[') - content.count(']')
|
|
||||||
|
|
||||||
# 检查是否有未闭合的字符串
|
|
||||||
if content and content[-1] not in '",}]':
|
|
||||||
content += '"'
|
|
||||||
|
|
||||||
# 闭合括号
|
|
||||||
content += ']' * open_brackets
|
|
||||||
content += '}' * open_braces
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
|
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||||
"""尝试修复配置JSON"""
|
"""尝试修复配置JSON"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 修复被截断的情况
|
|
||||||
content = self._fix_truncated_json(content)
|
|
||||||
|
|
||||||
# 提取JSON部分
|
# 提取JSON部分
|
||||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||||
if json_match:
|
if json_match:
|
||||||
|
|
@ -984,4 +929,3 @@ class SimulationConfigGenerator:
|
||||||
"influence_weight": 1.0
|
"influence_weight": 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,15 @@ LLM客户端封装
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List, Callable
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.llm_client')
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
|
|
@ -36,7 +40,7 @@ class LLMClient:
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 4096,
|
max_tokens: Optional[int] = 4096,
|
||||||
response_format: Optional[Dict] = None
|
response_format: Optional[Dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -55,23 +59,27 @@ class LLMClient:
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
if response_format:
|
if response_format:
|
||||||
kwargs["response_format"] = response_format
|
kwargs["response_format"] = response_format
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**kwargs)
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content or ""
|
||||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
return self._clean_response_text(content)
|
||||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
|
||||||
return content
|
|
||||||
|
|
||||||
def chat_json(
|
def chat_json(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
temperature: float = 0.3,
|
temperature: float = 0.3,
|
||||||
max_tokens: int = 4096
|
max_tokens: Optional[int] = 4096,
|
||||||
|
max_attempts: int = 1,
|
||||||
|
temperature_step: float = 0.0,
|
||||||
|
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
|
||||||
|
retry_delay_seconds: float = 0.0
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
发送聊天请求并返回JSON
|
发送聊天请求并返回JSON
|
||||||
|
|
@ -84,20 +92,108 @@ class LLMClient:
|
||||||
Returns:
|
Returns:
|
||||||
解析后的JSON对象
|
解析后的JSON对象
|
||||||
"""
|
"""
|
||||||
response = self.chat(
|
last_error: Optional[Exception] = None
|
||||||
messages=messages,
|
last_response = ""
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
for attempt in range(max_attempts):
|
||||||
response_format={"type": "json_object"}
|
current_temperature = max(0.0, temperature - (attempt * temperature_step))
|
||||||
)
|
|
||||||
# 清理markdown代码块标记
|
try:
|
||||||
cleaned_response = response.strip()
|
kwargs = {
|
||||||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
"model": self.model,
|
||||||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
"messages": messages,
|
||||||
cleaned_response = cleaned_response.strip()
|
"temperature": current_temperature,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
|
raw_content = response.choices[0].message.content or ""
|
||||||
|
finish_reason = response.choices[0].finish_reason
|
||||||
|
|
||||||
|
cleaned_response = self._clean_response_text(raw_content)
|
||||||
|
if finish_reason == 'length':
|
||||||
|
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
|
||||||
|
cleaned_response = self._fix_truncated_json(cleaned_response)
|
||||||
|
|
||||||
|
last_response = cleaned_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._parse_json_response(cleaned_response)
|
||||||
|
except json.JSONDecodeError as parse_error:
|
||||||
|
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
|
||||||
|
|
||||||
|
fixed = self._try_fix_json(cleaned_response)
|
||||||
|
if fixed is not None:
|
||||||
|
return fixed
|
||||||
|
|
||||||
|
if fallback_parser is not None:
|
||||||
|
fallback_result = fallback_parser(cleaned_response)
|
||||||
|
if fallback_result is not None:
|
||||||
|
return fallback_result
|
||||||
|
|
||||||
|
last_error = parse_error
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
|
||||||
|
last_error = exc
|
||||||
|
|
||||||
|
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
|
||||||
|
time.sleep(retry_delay_seconds * (attempt + 1))
|
||||||
|
|
||||||
|
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
|
||||||
|
|
||||||
|
def _clean_response_text(self, content: str) -> str:
|
||||||
|
"""清理模型响应中的思考内容和Markdown包裹。"""
|
||||||
|
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||||
|
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
|
||||||
|
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
def _parse_json_response(self, content: str) -> Dict[str, Any]:
|
||||||
|
return json.loads(content)
|
||||||
|
|
||||||
|
def _fix_truncated_json(self, content: str) -> str:
|
||||||
|
"""修复被截断的JSON内容。"""
|
||||||
|
content = content.strip()
|
||||||
|
|
||||||
|
# If the number of unescaped quotes is odd we are inside an open string.
|
||||||
|
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
|
||||||
|
if unescaped_quote_count % 2 == 1:
|
||||||
|
content += '"'
|
||||||
|
|
||||||
|
open_braces = content.count('{') - content.count('}')
|
||||||
|
open_brackets = content.count('[') - content.count(']')
|
||||||
|
|
||||||
|
content += ']' * open_brackets
|
||||||
|
content += '}' * open_braces
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""尝试从近似JSON内容中恢复结构化对象。"""
|
||||||
|
content = self._fix_truncated_json(content)
|
||||||
|
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||||
|
if not json_match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_str = json_match.group()
|
||||||
|
|
||||||
|
def fix_string_newlines(match: re.Match[str]) -> str:
|
||||||
|
value = match.group(0)
|
||||||
|
value = value.replace('\n', ' ').replace('\r', ' ')
|
||||||
|
value = re.sub(r'\s+', ' ', value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return json.loads(cleaned_response)
|
return json.loads(json_str)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
|
||||||
|
json_str = re.sub(r'\s+', ' ', json_str)
|
||||||
|
try:
|
||||||
|
return json.loads(json_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue