This commit is contained in:
Laura Roganovic 2026-05-28 17:33:03 -04:00 committed by GitHub
commit dc68ec6a76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 469 additions and 158 deletions

View File

@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from zep_cloud.client import Zep
from ..config import Config
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -193,9 +193,10 @@ class OasisProfileGenerator:
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
self.llm_client = LLMClient(
api_key=self.api_key,
base_url=self.base_url
base_url=self.base_url,
model=self.model_name
)
# Zep客户端用于检索丰富上下文
@ -521,61 +522,33 @@ class OasisProfileGenerator:
entity_name, entity_type, entity_summary, entity_attributes, context
)
# 尝试多次生成,直到成功或达到最大重试次数
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
result = self.llm_client.chat_json(
messages=[
{"role": "system", "content": self._get_system_prompt(is_individual)},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
temperature=0.7,
max_tokens=None,
max_attempts=3,
temperature_step=0.1,
fallback_parser=lambda content: self._try_fix_json(
content, entity_name, entity_type, entity_summary
),
retry_delay_seconds=1.0
)
content = response.choices[0].message.content
# 检查是否被截断finish_reason不是'stop'
finish_reason = response.choices[0].finish_reason
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
result = json.loads(content)
# 验证必需字段
if "bio" not in result or not result["bio"]:
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
if "persona" not in result or not result["persona"]:
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}"
return result
except json.JSONDecodeError as je:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
# 尝试修复JSON
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
if result.get("_fixed"):
del result["_fixed"]
return result
last_error = je
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(1 * (attempt + 1)) # 指数退避
logger.warning(f"LLM生成人设失败{max_attempts}次尝试): {last_error}, 使用规则生成")
logger.warning(f"LLM生成人设失败3次尝试: {e}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
@ -1202,4 +1175,3 @@ class OasisProfileGenerator:
"""[已废弃] 请使用 save_profiles() 方法"""
logger.warning("save_profiles_to_json已废弃请使用save_profiles方法")
self.save_profiles(profiles, file_path, platform)

View File

@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from openai import OpenAI
from ..config import Config
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
from ..utils.locale import get_language_instruction, t
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -235,9 +234,10 @@ class SimulationConfigGenerator:
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
self.llm_client = LLMClient(
api_key=self.api_key,
base_url=self.base_url
base_url=self.base_url,
model=self.model_name
)
def generate_config(
@ -433,78 +433,23 @@ class SimulationConfigGenerator:
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
"""带重试的LLM调用包含JSON修复逻辑"""
import re
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
return self.llm_client.chat_json(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
temperature=0.7,
max_tokens=None,
max_attempts=3,
temperature_step=0.1,
fallback_parser=self._try_fix_config_json,
retry_delay_seconds=2.0
)
content = response.choices[0].message.content
finish_reason = response.choices[0].finish_reason
# 检查是否被截断
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
return json.loads(content)
except json.JSONDecodeError as e:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
# 尝试修复JSON
fixed = self._try_fix_config_json(content)
if fixed:
return fixed
last_error = e
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(2 * (attempt + 1))
raise last_error or Exception("LLM调用失败")
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON"""
content = content.strip()
# 计算未闭合的括号
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
# 检查是否有未闭合的字符串
if content and content[-1] not in '",}]':
content += '"'
# 闭合括号
content += ']' * open_brackets
content += '}' * open_braces
return content
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试修复配置JSON"""
import re
# 修复被截断的情况
content = self._fix_truncated_json(content)
# 提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
@ -988,4 +933,3 @@ class SimulationConfigGenerator:
"influence_weight": 1.0
}

View File

@ -4,11 +4,15 @@ LLM客户端封装
"""
import json
import time
import re
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Callable
from openai import OpenAI
from ..config import Config
from .logger import get_logger
logger = get_logger('mirofish.llm_client')
class LLMClient:
@ -36,7 +40,7 @@ class LLMClient:
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 4096,
max_tokens: Optional[int] = 4096,
response_format: Optional[Dict] = None
) -> str:
"""
@ -55,23 +59,27 @@ class LLMClient:
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if response_format:
kwargs["response_format"] = response_format
response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
return content
content = response.choices[0].message.content or ""
return self._clean_response_text(content)
def chat_json(
self,
messages: List[Dict[str, str]],
temperature: float = 0.3,
max_tokens: int = 4096
max_tokens: Optional[int] = 4096,
max_attempts: int = 1,
temperature_step: float = 0.0,
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
retry_delay_seconds: float = 0.0
) -> Dict[str, Any]:
"""
发送聊天请求并返回JSON
@ -84,20 +92,108 @@ class LLMClient:
Returns:
解析后的JSON对象
"""
response = self.chat(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"}
)
# 清理markdown代码块标记
cleaned_response = response.strip()
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
cleaned_response = cleaned_response.strip()
last_error: Optional[Exception] = None
last_response = ""
for attempt in range(max_attempts):
current_temperature = max(0.0, temperature - (attempt * temperature_step))
try:
return json.loads(cleaned_response)
except json.JSONDecodeError:
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
kwargs = {
"model": self.model,
"messages": messages,
"temperature": current_temperature,
"response_format": {"type": "json_object"}
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
response = self.client.chat.completions.create(**kwargs)
raw_content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason
cleaned_response = self._clean_response_text(raw_content)
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
cleaned_response = self._fix_truncated_json(cleaned_response)
last_response = cleaned_response
try:
return self._parse_json_response(cleaned_response)
except json.JSONDecodeError as parse_error:
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
fixed = self._try_fix_json(cleaned_response)
if fixed is not None:
return fixed
if fallback_parser is not None:
fallback_result = fallback_parser(cleaned_response)
if fallback_result is not None:
return fallback_result
last_error = parse_error
except Exception as exc:
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
last_error = exc
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
time.sleep(retry_delay_seconds * (attempt + 1))
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
def _clean_response_text(self, content: str) -> str:
"""清理模型响应中的思考内容和Markdown包裹。"""
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
return cleaned.strip()
def _parse_json_response(self, content: str) -> Dict[str, Any]:
return json.loads(content)
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON内容。"""
content = content.strip()
# If the number of unescaped quotes is odd we are inside an open string.
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
if unescaped_quote_count % 2 == 1:
content += '"'
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
content += ']' * open_brackets
content += '}' * open_braces
return content
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试从近似JSON内容中恢复结构化对象。"""
content = self._fix_truncated_json(content)
json_match = re.search(r'\{[\s\S]*\}', content)
if not json_match:
return None
json_str = json_match.group()
def fix_string_newlines(match: re.Match[str]) -> str:
value = match.group(0)
value = value.replace('\n', ' ').replace('\r', ' ')
value = re.sub(r'\s+', ' ', value)
return value
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
json_str = re.sub(r'\s+', ' ', json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None

View File

View File

View File

@ -0,0 +1,299 @@
"""
Tests for LLMClient utility no real API calls, all OpenAI interactions mocked.
"""
import json
import pytest
from unittest.mock import MagicMock, patch, call
from app.utils.llm_client import LLMClient
# ---------------------------------------------------------------------------
# Helpers to build fake OpenAI response objects
# ---------------------------------------------------------------------------
def _make_response(content: str, finish_reason: str = "stop"):
choice = MagicMock()
choice.message.content = content
choice.finish_reason = finish_reason
resp = MagicMock()
resp.choices = [choice]
return resp
def _make_client(responses):
"""
Return a patched LLMClient whose underlying OpenAI client returns
*responses* in order on successive .chat.completions.create() calls.
"""
with patch("app.utils.llm_client.OpenAI") as MockOpenAI:
mock_openai_instance = MagicMock()
mock_openai_instance.chat.completions.create.side_effect = responses
MockOpenAI.return_value = mock_openai_instance
client = LLMClient(api_key="test-key", base_url="http://localhost", model="test-model")
# Expose the underlying mock for assertions
client._mock_create = mock_openai_instance.chat.completions.create
return client
# ---------------------------------------------------------------------------
# _clean_response_text
# ---------------------------------------------------------------------------
class TestCleanResponseText:
def setup_method(self):
with patch("app.utils.llm_client.OpenAI"):
self.client = LLMClient(api_key="k", base_url="u", model="m")
def test_passthrough_plain_json(self):
raw = '{"a": 1}'
assert self.client._clean_response_text(raw) == '{"a": 1}'
def test_strips_think_tags(self):
raw = '<think>internal reasoning</think>{"a": 1}'
assert self.client._clean_response_text(raw) == '{"a": 1}'
def test_strips_multiline_think_tags(self):
raw = '<think>\nline1\nline2\n</think>\n{"b": 2}'
assert self.client._clean_response_text(raw) == '{"b": 2}'
def test_strips_json_markdown_fence(self):
raw = '```json\n{"c": 3}\n```'
assert self.client._clean_response_text(raw) == '{"c": 3}'
def test_strips_plain_markdown_fence(self):
raw = '```\n{"d": 4}\n```'
assert self.client._clean_response_text(raw) == '{"d": 4}'
def test_strips_think_and_fence_combined(self):
raw = '<think>reasoning</think>\n```json\n{"e": 5}\n```'
assert self.client._clean_response_text(raw) == '{"e": 5}'
def test_empty_string(self):
assert self.client._clean_response_text("") == ""
def test_no_fence_no_think(self):
raw = ' {"f": 6} '
assert self.client._clean_response_text(raw) == '{"f": 6}'
# ---------------------------------------------------------------------------
# _fix_truncated_json
# ---------------------------------------------------------------------------
class TestFixTruncatedJson:
def setup_method(self):
with patch("app.utils.llm_client.OpenAI"):
self.client = LLMClient(api_key="k", base_url="u", model="m")
def test_closes_one_brace(self):
truncated = '{"key": "val'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["key"] == "val"
def test_closes_nested_braces(self):
truncated = '{"outer": {"inner": "x"'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["outer"]["inner"] == "x"
def test_closes_open_array(self):
truncated = '{"list": [1, 2'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["list"] == [1, 2]
def test_closes_array_and_brace(self):
truncated = '{"a": [1, 2, 3'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["a"] == [1, 2, 3]
def test_already_valid_unchanged(self):
valid = '{"x": 1}'
fixed = self.client._fix_truncated_json(valid)
assert json.loads(fixed) == {"x": 1}
def test_trailing_dangling_value_gets_quote(self):
# ends mid-string without closing quote
truncated = '{"k": "incomplete'
fixed = self.client._fix_truncated_json(truncated)
# Should be parseable after repair
result = json.loads(fixed)
assert "k" in result
# ---------------------------------------------------------------------------
# _try_fix_json
# ---------------------------------------------------------------------------
class TestTryFixJson:
def setup_method(self):
with patch("app.utils.llm_client.OpenAI"):
self.client = LLMClient(api_key="k", base_url="u", model="m")
def test_returns_valid_json(self):
content = '{"name": "Alice", "age": 30}'
result = self.client._try_fix_json(content)
assert result == {"name": "Alice", "age": 30}
def test_extracts_json_from_surrounding_text(self):
content = 'Here is the result: {"score": 42} end.'
result = self.client._try_fix_json(content)
assert result is not None
assert result["score"] == 42
def test_fixes_newlines_inside_string_values(self):
# Literal newline inside a JSON string value is invalid JSON
content = '{"desc": "line one\nline two"}'
result = self.client._try_fix_json(content)
assert result is not None
assert "desc" in result
def test_returns_none_for_no_json_object(self):
assert self.client._try_fix_json("no json here at all") is None
def test_returns_none_for_empty_string(self):
assert self.client._try_fix_json("") is None
def test_recovers_truncated_object(self):
# Missing closing brace
content = '{"city": "Beijing", "pop": 21'
result = self.client._try_fix_json(content)
assert result is not None
assert result["city"] == "Beijing"
# ---------------------------------------------------------------------------
# chat_json — retry behavior
# ---------------------------------------------------------------------------
class TestChatJsonRetry:
# --- succeeds on first attempt ---
def test_success_first_attempt(self):
payload = {"status": "ok"}
client = _make_client([_make_response(json.dumps(payload))])
result = client.chat_json([{"role": "user", "content": "hi"}])
assert result == payload
assert client._mock_create.call_count == 1
# --- temperature backoff ---
def test_temperature_decreases_across_retries(self):
bad = _make_response("not json at all {{{{")
good = _make_response('{"ok": true}')
client = _make_client([bad, bad, good])
result = client.chat_json(
[{"role": "user", "content": "hi"}],
temperature=0.9,
temperature_step=0.3,
max_attempts=3,
)
assert result == {"ok": True}
calls = client._mock_create.call_args_list
assert calls[0].kwargs["temperature"] == pytest.approx(0.9)
assert calls[1].kwargs["temperature"] == pytest.approx(0.6)
assert calls[2].kwargs["temperature"] == pytest.approx(0.3)
def test_temperature_never_goes_below_zero(self):
good = _make_response('{"x": 1}')
client = _make_client([good])
client.chat_json(
[{"role": "user", "content": "hi"}],
temperature=0.1,
temperature_step=0.5,
max_attempts=1,
)
calls = client._mock_create.call_args_list
assert calls[0].kwargs["temperature"] == pytest.approx(0.1)
# --- fallback_parser ---
def test_fallback_parser_called_on_bad_json(self):
bad = _make_response("this is not json")
client = _make_client([bad])
fallback = MagicMock(return_value={"rescued": True})
result = client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=1,
fallback_parser=fallback,
)
assert result == {"rescued": True}
fallback.assert_called_once()
def test_fallback_parser_returning_none_does_not_short_circuit(self):
bad = _make_response("still not json ][")
good = _make_response('{"second": "attempt"}')
client = _make_client([bad, good])
fallback = MagicMock(return_value=None)
result = client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=2,
fallback_parser=fallback,
)
assert result == {"second": "attempt"}
# --- raises ValueError after all attempts fail ---
def test_raises_after_all_attempts_fail(self):
bad = _make_response("invalid {{{ json")
client = _make_client([bad, bad, bad])
with pytest.raises(ValueError, match="LLM返回的JSON格式无效"):
client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=3,
)
assert client._mock_create.call_count == 3
def test_raises_after_single_attempt(self):
bad = _make_response("nope")
client = _make_client([bad])
with pytest.raises(ValueError):
client.chat_json([{"role": "user", "content": "q"}], max_attempts=1)
# --- finish_reason == 'length' triggers truncation repair ---
def test_truncated_output_is_repaired(self):
truncated_json = '{"items": [1, 2, 3'
client = _make_client([_make_response(truncated_json, finish_reason="length")])
result = client.chat_json([{"role": "user", "content": "q"}])
assert result["items"] == [1, 2, 3]
# --- API exception counts as a failed attempt ---
def test_api_exception_retried(self):
from openai import APIError
exc = Exception("network failure")
good = _make_response('{"recovered": true}')
client = _make_client([exc, good])
result = client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=2,
)
assert result == {"recovered": True}
assert client._mock_create.call_count == 2
def test_api_exception_all_attempts_raises_value_error(self):
exc = Exception("always fails")
client = _make_client([exc, exc])
with pytest.raises(ValueError):
client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=2,
)