feat(llm): add structured output reliability to LLMClient and refactor services
- chat_json() gains configurable retry (max_attempts), temperature backoff (temperature_step), optional retry delay, and a fallback_parser hook for service-specific rescue logic - _clean_response_text: strip <think> tags and markdown code fences - _fix_truncated_json: use unescaped-quote parity instead of last-char check to avoid spuriously quoting numeric values; fixes broken repair for arrays - _try_fix_json: generic near-valid JSON salvage (newline normalisation, control-character stripping, greedy object extraction) - simulation_config_generator: replace raw OpenAI client with LLMClient, remove duplicated local retry loop and truncated-JSON repair, pass service-specific config salvage as fallback_parser - oasis_profile_generator: same refactor; keep rule-based profile fallback Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1536a79334
commit
52c177fd66
|
|
@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
|
|
@ -192,9 +192,10 @@ class OasisProfileGenerator:
|
|||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
self.llm_client = LLMClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
model=self.model_name
|
||||
)
|
||||
|
||||
# Zep客户端用于检索丰富上下文
|
||||
|
|
@ -520,64 +521,36 @@ class OasisProfileGenerator:
|
|||
entity_name, entity_type, entity_summary, entity_attributes, context
|
||||
)
|
||||
|
||||
# 尝试多次生成,直到成功或达到最大重试次数
|
||||
max_attempts = 3
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
||||
# 不设置max_tokens,让LLM自由发挥
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# 检查是否被截断(finish_reason不是'stop')
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
result = json.loads(content)
|
||||
|
||||
# 验证必需字段
|
||||
if "bio" not in result or not result["bio"]:
|
||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||
if "persona" not in result or not result["persona"]:
|
||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as je:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
|
||||
|
||||
# 尝试修复JSON
|
||||
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
|
||||
if result.get("_fixed"):
|
||||
del result["_fixed"]
|
||||
return result
|
||||
|
||||
last_error = je
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
last_error = e
|
||||
import time
|
||||
time.sleep(1 * (attempt + 1)) # 指数退避
|
||||
|
||||
logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
try:
|
||||
result = self.llm_client.chat_json(
|
||||
messages=[
|
||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
max_attempts=3,
|
||||
temperature_step=0.1,
|
||||
fallback_parser=lambda content: self._try_fix_json(
|
||||
content, entity_name, entity_type, entity_summary
|
||||
),
|
||||
retry_delay_seconds=1.0
|
||||
)
|
||||
|
||||
if "bio" not in result or not result["bio"]:
|
||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||
if "persona" not in result or not result["persona"]:
|
||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||
|
||||
if result.get("_fixed"):
|
||||
del result["_fixed"]
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM生成人设失败(3次尝试): {e}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON(输出被max_tokens限制截断)"""
|
||||
|
|
@ -1197,4 +1170,3 @@ class OasisProfileGenerator:
|
|||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
|
|||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
|
|
@ -234,9 +233,10 @@ class SimulationConfigGenerator:
|
|||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
self.llm_client = LLMClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
model=self.model_name
|
||||
)
|
||||
|
||||
def generate_config(
|
||||
|
|
@ -432,78 +432,23 @@ class SimulationConfigGenerator:
|
|||
|
||||
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
|
||||
"""带重试的LLM调用,包含JSON修复逻辑"""
|
||||
import re
|
||||
|
||||
max_attempts = 3
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
||||
# 不设置max_tokens,让LLM自由发挥
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
# 检查是否被截断
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
|
||||
# 尝试修复JSON
|
||||
fixed = self._try_fix_config_json(content)
|
||||
if fixed:
|
||||
return fixed
|
||||
|
||||
last_error = e
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
last_error = e
|
||||
import time
|
||||
time.sleep(2 * (attempt + 1))
|
||||
|
||||
raise last_error or Exception("LLM调用失败")
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON"""
|
||||
content = content.strip()
|
||||
|
||||
# 计算未闭合的括号
|
||||
open_braces = content.count('{') - content.count('}')
|
||||
open_brackets = content.count('[') - content.count(']')
|
||||
|
||||
# 检查是否有未闭合的字符串
|
||||
if content and content[-1] not in '",}]':
|
||||
content += '"'
|
||||
|
||||
# 闭合括号
|
||||
content += ']' * open_brackets
|
||||
content += '}' * open_braces
|
||||
|
||||
return content
|
||||
return self.llm_client.chat_json(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
max_attempts=3,
|
||||
temperature_step=0.1,
|
||||
fallback_parser=self._try_fix_config_json,
|
||||
retry_delay_seconds=2.0
|
||||
)
|
||||
|
||||
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""尝试修复配置JSON"""
|
||||
import re
|
||||
|
||||
# 修复被截断的情况
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
|
||||
# 提取JSON部分
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
|
|
@ -984,4 +929,3 @@ class SimulationConfigGenerator:
|
|||
"influence_weight": 1.0
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@ LLM客户端封装
|
|||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.llm_client')
|
||||
|
||||
|
||||
class LLMClient:
|
||||
|
|
@ -36,7 +40,7 @@ class LLMClient:
|
|||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
max_tokens: Optional[int] = 4096,
|
||||
response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -55,23 +59,27 @@ class LLMClient:
|
|||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
content = response.choices[0].message.content
|
||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
return content
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
return self._clean_response_text(content)
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096
|
||||
max_tokens: Optional[int] = 4096,
|
||||
max_attempts: int = 1,
|
||||
temperature_step: float = 0.0,
|
||||
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
|
||||
retry_delay_seconds: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并返回JSON
|
||||
|
|
@ -84,20 +92,108 @@ class LLMClient:
|
|||
Returns:
|
||||
解析后的JSON对象
|
||||
"""
|
||||
response = self.chat(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
# 清理markdown代码块标记
|
||||
cleaned_response = response.strip()
|
||||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
||||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
||||
cleaned_response = cleaned_response.strip()
|
||||
last_error: Optional[Exception] = None
|
||||
last_response = ""
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
current_temperature = max(0.0, temperature - (attempt * temperature_step))
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": current_temperature,
|
||||
"response_format": {"type": "json_object"}
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
cleaned_response = self._clean_response_text(raw_content)
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
|
||||
cleaned_response = self._fix_truncated_json(cleaned_response)
|
||||
|
||||
last_response = cleaned_response
|
||||
|
||||
try:
|
||||
return self._parse_json_response(cleaned_response)
|
||||
except json.JSONDecodeError as parse_error:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
|
||||
|
||||
fixed = self._try_fix_json(cleaned_response)
|
||||
if fixed is not None:
|
||||
return fixed
|
||||
|
||||
if fallback_parser is not None:
|
||||
fallback_result = fallback_parser(cleaned_response)
|
||||
if fallback_result is not None:
|
||||
return fallback_result
|
||||
|
||||
last_error = parse_error
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
|
||||
last_error = exc
|
||||
|
||||
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
|
||||
time.sleep(retry_delay_seconds * (attempt + 1))
|
||||
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
|
||||
|
||||
def _clean_response_text(self, content: str) -> str:
|
||||
"""清理模型响应中的思考内容和Markdown包裹。"""
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def _parse_json_response(self, content: str) -> Dict[str, Any]:
|
||||
return json.loads(content)
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON内容。"""
|
||||
content = content.strip()
|
||||
|
||||
# If the number of unescaped quotes is odd we are inside an open string.
|
||||
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
|
||||
if unescaped_quote_count % 2 == 1:
|
||||
content += '"'
|
||||
|
||||
open_braces = content.count('{') - content.count('}')
|
||||
open_brackets = content.count('[') - content.count(']')
|
||||
|
||||
content += ']' * open_brackets
|
||||
content += '}' * open_braces
|
||||
return content
|
||||
|
||||
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""尝试从近似JSON内容中恢复结构化对象。"""
|
||||
content = self._fix_truncated_json(content)
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if not json_match:
|
||||
return None
|
||||
|
||||
json_str = json_match.group()
|
||||
|
||||
def fix_string_newlines(match: re.Match[str]) -> str:
|
||||
value = match.group(0)
|
||||
value = value.replace('\n', ' ').replace('\r', ' ')
|
||||
value = re.sub(r'\s+', ' ', value)
|
||||
return value
|
||||
|
||||
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
|
||||
|
||||
try:
|
||||
return json.loads(cleaned_response)
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
||||
|
||||
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
|
||||
json_str = re.sub(r'\s+', ' ', json_str)
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue