feat(llm): add structured output reliability to LLMClient and refactor services

- chat_json() gains configurable retry (max_attempts), temperature backoff
  (temperature_step), optional retry delay, and a fallback_parser hook for
  service-specific rescue logic
- _clean_response_text: strip <think> tags and markdown code fences
- _fix_truncated_json: use unescaped-quote parity instead of last-char check
  to avoid spuriously quoting numeric values; fixes broken repair for arrays
- _try_fix_json: generic near-valid JSON salvage (newline normalisation,
  control-character stripping, greedy object extraction)
- simulation_config_generator: replace raw OpenAI client with LLMClient,
  remove duplicated local retry loop and truncated-JSON repair, pass
  service-specific config salvage as fallback_parser
- oasis_profile_generator: same refactor; keep rule-based profile fallback

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
LoryGlory 2026-03-23 15:45:09 +01:00
parent 1536a79334
commit 52c177fd66
3 changed files with 170 additions and 158 deletions

View File

@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from zep_cloud.client import Zep
from ..config import Config
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -192,9 +192,10 @@ class OasisProfileGenerator:
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
self.llm_client = LLMClient(
api_key=self.api_key,
base_url=self.base_url
base_url=self.base_url,
model=self.model_name
)
# Zep客户端用于检索丰富上下文
@ -520,64 +521,36 @@ class OasisProfileGenerator:
entity_name, entity_type, entity_summary, entity_attributes, context
)
# 尝试多次生成,直到成功或达到最大重试次数
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": self._get_system_prompt(is_individual)},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
)
content = response.choices[0].message.content
# 检查是否被截断finish_reason不是'stop'
finish_reason = response.choices[0].finish_reason
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
result = json.loads(content)
# 验证必需字段
if "bio" not in result or not result["bio"]:
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
if "persona" not in result or not result["persona"]:
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}"
return result
except json.JSONDecodeError as je:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
# 尝试修复JSON
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
if result.get("_fixed"):
del result["_fixed"]
return result
last_error = je
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(1 * (attempt + 1)) # 指数退避
logger.warning(f"LLM生成人设失败{max_attempts}次尝试): {last_error}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
try:
result = self.llm_client.chat_json(
messages=[
{"role": "system", "content": self._get_system_prompt(is_individual)},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=None,
max_attempts=3,
temperature_step=0.1,
fallback_parser=lambda content: self._try_fix_json(
content, entity_name, entity_type, entity_summary
),
retry_delay_seconds=1.0
)
if "bio" not in result or not result["bio"]:
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
if "persona" not in result or not result["persona"]:
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}"
if result.get("_fixed"):
del result["_fixed"]
return result
except Exception as e:
logger.warning(f"LLM生成人设失败3次尝试: {e}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON输出被max_tokens限制截断"""
@ -1197,4 +1170,3 @@ class OasisProfileGenerator:
"""[已废弃] 请使用 save_profiles() 方法"""
logger.warning("save_profiles_to_json已废弃请使用save_profiles方法")
self.save_profiles(profiles, file_path, platform)

View File

@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from openai import OpenAI
from ..config import Config
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -234,9 +233,10 @@ class SimulationConfigGenerator:
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
self.llm_client = LLMClient(
api_key=self.api_key,
base_url=self.base_url
base_url=self.base_url,
model=self.model_name
)
def generate_config(
@ -432,78 +432,23 @@ class SimulationConfigGenerator:
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
"""带重试的LLM调用包含JSON修复逻辑"""
import re
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
)
content = response.choices[0].message.content
finish_reason = response.choices[0].finish_reason
# 检查是否被截断
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
return json.loads(content)
except json.JSONDecodeError as e:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
# 尝试修复JSON
fixed = self._try_fix_config_json(content)
if fixed:
return fixed
last_error = e
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(2 * (attempt + 1))
raise last_error or Exception("LLM调用失败")
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON"""
content = content.strip()
# 计算未闭合的括号
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
# 检查是否有未闭合的字符串
if content and content[-1] not in '",}]':
content += '"'
# 闭合括号
content += ']' * open_brackets
content += '}' * open_braces
return content
return self.llm_client.chat_json(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=None,
max_attempts=3,
temperature_step=0.1,
fallback_parser=self._try_fix_config_json,
retry_delay_seconds=2.0
)
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试修复配置JSON"""
import re
# 修复被截断的情况
content = self._fix_truncated_json(content)
# 提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
@ -984,4 +929,3 @@ class SimulationConfigGenerator:
"influence_weight": 1.0
}

View File

@ -4,11 +4,15 @@ LLM客户端封装
"""
import json
import time
import re
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Callable
from openai import OpenAI
from ..config import Config
from .logger import get_logger
logger = get_logger('mirofish.llm_client')
class LLMClient:
@ -36,7 +40,7 @@ class LLMClient:
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 4096,
max_tokens: Optional[int] = 4096,
response_format: Optional[Dict] = None
) -> str:
"""
@ -55,23 +59,27 @@ class LLMClient:
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if response_format:
kwargs["response_format"] = response_format
response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
return content
content = response.choices[0].message.content or ""
return self._clean_response_text(content)
def chat_json(
self,
messages: List[Dict[str, str]],
temperature: float = 0.3,
max_tokens: int = 4096
max_tokens: Optional[int] = 4096,
max_attempts: int = 1,
temperature_step: float = 0.0,
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
retry_delay_seconds: float = 0.0
) -> Dict[str, Any]:
"""
发送聊天请求并返回JSON
@ -84,20 +92,108 @@ class LLMClient:
Returns:
解析后的JSON对象
"""
response = self.chat(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"}
)
# 清理markdown代码块标记
cleaned_response = response.strip()
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
cleaned_response = cleaned_response.strip()
last_error: Optional[Exception] = None
last_response = ""
for attempt in range(max_attempts):
current_temperature = max(0.0, temperature - (attempt * temperature_step))
try:
kwargs = {
"model": self.model,
"messages": messages,
"temperature": current_temperature,
"response_format": {"type": "json_object"}
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
response = self.client.chat.completions.create(**kwargs)
raw_content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason
cleaned_response = self._clean_response_text(raw_content)
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
cleaned_response = self._fix_truncated_json(cleaned_response)
last_response = cleaned_response
try:
return self._parse_json_response(cleaned_response)
except json.JSONDecodeError as parse_error:
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
fixed = self._try_fix_json(cleaned_response)
if fixed is not None:
return fixed
if fallback_parser is not None:
fallback_result = fallback_parser(cleaned_response)
if fallback_result is not None:
return fallback_result
last_error = parse_error
except Exception as exc:
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
last_error = exc
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
time.sleep(retry_delay_seconds * (attempt + 1))
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
def _clean_response_text(self, content: str) -> str:
"""清理模型响应中的思考内容和Markdown包裹。"""
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
return cleaned.strip()
def _parse_json_response(self, content: str) -> Dict[str, Any]:
return json.loads(content)
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON内容。"""
content = content.strip()
# If the number of unescaped quotes is odd we are inside an open string.
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
if unescaped_quote_count % 2 == 1:
content += '"'
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
content += ']' * open_brackets
content += '}' * open_braces
return content
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试从近似JSON内容中恢复结构化对象。"""
content = self._fix_truncated_json(content)
json_match = re.search(r'\{[\s\S]*\}', content)
if not json_match:
return None
json_str = json_match.group()
def fix_string_newlines(match: re.Match[str]) -> str:
value = match.group(0)
value = value.replace('\n', ' ').replace('\r', ' ')
value = re.sub(r'\s+', ' ', value)
return value
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
try:
return json.loads(cleaned_response)
return json.loads(json_str)
except json.JSONDecodeError:
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
json_str = re.sub(r'\s+', ' ', json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None