diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 7704a627..5191ff90 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime -from openai import OpenAI from zep_cloud.client import Zep from ..config import Config +from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, get_locale, set_locale, t from .zep_entity_reader import EntityNode, ZepEntityReader @@ -193,9 +193,10 @@ class OasisProfileGenerator: if not self.api_key: raise ValueError("LLM_API_KEY 未配置") - self.client = OpenAI( + self.llm_client = LLMClient( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + model=self.model_name ) # Zep客户端用于检索丰富上下文 @@ -521,64 +522,36 @@ class OasisProfileGenerator: entity_name, entity_type, entity_summary, entity_attributes, context ) - # 尝试多次生成,直到成功或达到最大重试次数 - max_attempts = 3 - last_error = None - - for attempt in range(max_attempts): - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": self._get_system_prompt(is_individual)}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 - ) - - content = response.choices[0].message.content - - # 检查是否被截断(finish_reason不是'stop') - finish_reason = response.choices[0].finish_reason - if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") - content = self._fix_truncated_json(content) - - # 尝试解析JSON - try: - result = json.loads(content) - - # 验证必需字段 - if "bio" not in result or not result["bio"]: - result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" - if "persona" not in result or not result["persona"]: - result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" - - return result - - except json.JSONDecodeError as je: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}") - - # 尝试修复JSON - result = self._try_fix_json(content, entity_name, entity_type, entity_summary) - if result.get("_fixed"): - del result["_fixed"] - return result - - last_error = je - - except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") - last_error = e - import time - time.sleep(1 * (attempt + 1)) # 指数退避 - - logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成") - return self._generate_profile_rule_based( - entity_name, entity_type, entity_summary, entity_attributes - ) + try: + result = self.llm_client.chat_json( + messages=[ + {"role": "system", "content": self._get_system_prompt(is_individual)}, + {"role": "user", "content": prompt} + ], + temperature=0.7, + max_tokens=None, + max_attempts=3, + temperature_step=0.1, + fallback_parser=lambda content: self._try_fix_json( + content, entity_name, entity_type, entity_summary + ), + retry_delay_seconds=1.0 + ) + + if "bio" not in result or not result["bio"]: + result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" + if "persona" not in result or not result["persona"]: + result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" + + if result.get("_fixed"): + del result["_fixed"] + + return result + except Exception as e: + logger.warning(f"LLM生成人设失败(3次尝试): {e}, 使用规则生成") + return self._generate_profile_rule_based( + entity_name, entity_type, entity_summary, entity_attributes + ) def _fix_truncated_json(self, content: str) -> str: """修复被截断的JSON(输出被max_tokens限制截断)""" @@ -1202,4 +1175,3 @@ class OasisProfileGenerator: """[已废弃] 请使用 save_profiles() 方法""" logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") self.save_profiles(profiles, file_path, platform) - diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index cb77f6b6..6bb2c554 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass, field, asdict from datetime import datetime -from openai import OpenAI - from ..config import Config +from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, t from .zep_entity_reader import EntityNode, ZepEntityReader @@ -235,9 +234,10 @@ class SimulationConfigGenerator: if not self.api_key: raise ValueError("LLM_API_KEY 未配置") - self.client = OpenAI( + self.llm_client = LLMClient( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + model=self.model_name ) def generate_config( @@ -433,78 +433,23 @@ class SimulationConfigGenerator: def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]: """带重试的LLM调用,包含JSON修复逻辑""" - import re - - max_attempts = 3 - last_error = None - - for attempt in range(max_attempts): - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 - ) - - content = response.choices[0].message.content - finish_reason = response.choices[0].finish_reason - - # 检查是否被截断 - if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1})") - content = self._fix_truncated_json(content) - - # 尝试解析JSON - try: - return json.loads(content) - except json.JSONDecodeError as e: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}") - - # 尝试修复JSON - fixed = self._try_fix_config_json(content) - if fixed: - return fixed - - last_error = e - - except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") - last_error = e - import time - time.sleep(2 * (attempt + 1)) - - raise last_error or Exception("LLM调用失败") - - def _fix_truncated_json(self, content: str) -> str: - """修复被截断的JSON""" - content = content.strip() - - # 计算未闭合的括号 - open_braces = content.count('{') - content.count('}') - open_brackets = content.count('[') - content.count(']') - - # 检查是否有未闭合的字符串 - if content and content[-1] not in '",}]': - content += '"' - - # 闭合括号 - content += ']' * open_brackets - content += '}' * open_braces - - return content + return self.llm_client.chat_json( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + temperature=0.7, + max_tokens=None, + max_attempts=3, + temperature_step=0.1, + fallback_parser=self._try_fix_config_json, + retry_delay_seconds=2.0 + ) def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]: """尝试修复配置JSON""" import re - - # 修复被截断的情况 - content = self._fix_truncated_json(content) - + # 提取JSON部分 json_match = re.search(r'\{[\s\S]*\}', content) if json_match: @@ -988,4 +933,3 @@ class SimulationConfigGenerator: "influence_weight": 1.0 } - diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 6c1a81f4..953327ab 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -4,11 +4,15 @@ LLM客户端封装 """ import json +import time import re -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Callable from openai import OpenAI from ..config import Config +from .logger import get_logger + +logger = get_logger('mirofish.llm_client') class LLMClient: @@ -36,7 +40,7 @@ class LLMClient: self, messages: List[Dict[str, str]], temperature: float = 0.7, - max_tokens: int = 4096, + max_tokens: Optional[int] = 4096, response_format: Optional[Dict] = None ) -> str: """ @@ -55,23 +59,27 @@ class LLMClient: "model": self.model, "messages": messages, "temperature": temperature, - "max_tokens": max_tokens, } + + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens if response_format: kwargs["response_format"] = response_format response = self.client.chat.completions.create(**kwargs) - content = response.choices[0].message.content - # 部分模型(如MiniMax M2.5)会在content中包含思考内容,需要移除 - content = re.sub(r'[\s\S]*?', '', content).strip() - return content - + content = response.choices[0].message.content or "" + return self._clean_response_text(content) + def chat_json( self, messages: List[Dict[str, str]], temperature: float = 0.3, - max_tokens: int = 4096 + max_tokens: Optional[int] = 4096, + max_attempts: int = 1, + temperature_step: float = 0.0, + fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None, + retry_delay_seconds: float = 0.0 ) -> Dict[str, Any]: """ 发送聊天请求并返回JSON @@ -84,20 +92,108 @@ class LLMClient: Returns: 解析后的JSON对象 """ - response = self.chat( - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"} - ) - # 清理markdown代码块标记 - cleaned_response = response.strip() - cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) - cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) - cleaned_response = cleaned_response.strip() + last_error: Optional[Exception] = None + last_response = "" + + for attempt in range(max_attempts): + current_temperature = max(0.0, temperature - (attempt * temperature_step)) + + try: + kwargs = { + "model": self.model, + "messages": messages, + "temperature": current_temperature, + "response_format": {"type": "json_object"} + } + + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + + response = self.client.chat.completions.create(**kwargs) + raw_content = response.choices[0].message.content or "" + finish_reason = response.choices[0].finish_reason + + cleaned_response = self._clean_response_text(raw_content) + if finish_reason == 'length': + logger.warning(f"LLM输出被截断 (attempt {attempt + 1})") + cleaned_response = self._fix_truncated_json(cleaned_response) + + last_response = cleaned_response + + try: + return self._parse_json_response(cleaned_response) + except json.JSONDecodeError as parse_error: + logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}") + + fixed = self._try_fix_json(cleaned_response) + if fixed is not None: + return fixed + + if fallback_parser is not None: + fallback_result = fallback_parser(cleaned_response) + if fallback_result is not None: + return fallback_result + + last_error = parse_error + + except Exception as exc: + logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}") + last_error = exc + + if attempt < max_attempts - 1 and retry_delay_seconds > 0: + time.sleep(retry_delay_seconds * (attempt + 1)) + + raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error + + def _clean_response_text(self, content: str) -> str: + """清理模型响应中的思考内容和Markdown包裹。""" + cleaned = re.sub(r'[\s\S]*?', '', content).strip() + cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE) + cleaned = re.sub(r'\n?```\s*$', '', cleaned) + return cleaned.strip() + + def _parse_json_response(self, content: str) -> Dict[str, Any]: + return json.loads(content) + + def _fix_truncated_json(self, content: str) -> str: + """修复被截断的JSON内容。""" + content = content.strip() + + # If the number of unescaped quotes is odd we are inside an open string. + unescaped_quote_count = len(re.findall(r'(? Optional[Dict[str, Any]]: + """尝试从近似JSON内容中恢复结构化对象。""" + content = self._fix_truncated_json(content) + json_match = re.search(r'\{[\s\S]*\}', content) + if not json_match: + return None + + json_str = json_match.group() + + def fix_string_newlines(match: re.Match[str]) -> str: + value = match.group(0) + value = value.replace('\n', ' ').replace('\r', ' ') + value = re.sub(r'\s+', ' ', value) + return value + + json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) try: - return json.loads(cleaned_response) + return json.loads(json_str) except json.JSONDecodeError: - raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") - + json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) + json_str = re.sub(r'\s+', ' ', json_str) + try: + return json.loads(json_str) + except json.JSONDecodeError: + return None diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/utils/__init__.py b/backend/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/utils/test_llm_client.py b/backend/tests/utils/test_llm_client.py new file mode 100644 index 00000000..a59b8943 --- /dev/null +++ b/backend/tests/utils/test_llm_client.py @@ -0,0 +1,299 @@ +""" +Tests for LLMClient utility — no real API calls, all OpenAI interactions mocked. +""" + +import json +import pytest +from unittest.mock import MagicMock, patch, call + +from app.utils.llm_client import LLMClient + + +# --------------------------------------------------------------------------- +# Helpers to build fake OpenAI response objects +# --------------------------------------------------------------------------- + +def _make_response(content: str, finish_reason: str = "stop"): + choice = MagicMock() + choice.message.content = content + choice.finish_reason = finish_reason + resp = MagicMock() + resp.choices = [choice] + return resp + + +def _make_client(responses): + """ + Return a patched LLMClient whose underlying OpenAI client returns + *responses* in order on successive .chat.completions.create() calls. + """ + with patch("app.utils.llm_client.OpenAI") as MockOpenAI: + mock_openai_instance = MagicMock() + mock_openai_instance.chat.completions.create.side_effect = responses + MockOpenAI.return_value = mock_openai_instance + + client = LLMClient(api_key="test-key", base_url="http://localhost", model="test-model") + # Expose the underlying mock for assertions + client._mock_create = mock_openai_instance.chat.completions.create + return client + + +# --------------------------------------------------------------------------- +# _clean_response_text +# --------------------------------------------------------------------------- + +class TestCleanResponseText: + + def setup_method(self): + with patch("app.utils.llm_client.OpenAI"): + self.client = LLMClient(api_key="k", base_url="u", model="m") + + def test_passthrough_plain_json(self): + raw = '{"a": 1}' + assert self.client._clean_response_text(raw) == '{"a": 1}' + + def test_strips_think_tags(self): + raw = 'internal reasoning{"a": 1}' + assert self.client._clean_response_text(raw) == '{"a": 1}' + + def test_strips_multiline_think_tags(self): + raw = '\nline1\nline2\n\n{"b": 2}' + assert self.client._clean_response_text(raw) == '{"b": 2}' + + def test_strips_json_markdown_fence(self): + raw = '```json\n{"c": 3}\n```' + assert self.client._clean_response_text(raw) == '{"c": 3}' + + def test_strips_plain_markdown_fence(self): + raw = '```\n{"d": 4}\n```' + assert self.client._clean_response_text(raw) == '{"d": 4}' + + def test_strips_think_and_fence_combined(self): + raw = 'reasoning\n```json\n{"e": 5}\n```' + assert self.client._clean_response_text(raw) == '{"e": 5}' + + def test_empty_string(self): + assert self.client._clean_response_text("") == "" + + def test_no_fence_no_think(self): + raw = ' {"f": 6} ' + assert self.client._clean_response_text(raw) == '{"f": 6}' + + +# --------------------------------------------------------------------------- +# _fix_truncated_json +# --------------------------------------------------------------------------- + +class TestFixTruncatedJson: + + def setup_method(self): + with patch("app.utils.llm_client.OpenAI"): + self.client = LLMClient(api_key="k", base_url="u", model="m") + + def test_closes_one_brace(self): + truncated = '{"key": "val' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["key"] == "val" + + def test_closes_nested_braces(self): + truncated = '{"outer": {"inner": "x"' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["outer"]["inner"] == "x" + + def test_closes_open_array(self): + truncated = '{"list": [1, 2' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["list"] == [1, 2] + + def test_closes_array_and_brace(self): + truncated = '{"a": [1, 2, 3' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["a"] == [1, 2, 3] + + def test_already_valid_unchanged(self): + valid = '{"x": 1}' + fixed = self.client._fix_truncated_json(valid) + assert json.loads(fixed) == {"x": 1} + + def test_trailing_dangling_value_gets_quote(self): + # ends mid-string without closing quote + truncated = '{"k": "incomplete' + fixed = self.client._fix_truncated_json(truncated) + # Should be parseable after repair + result = json.loads(fixed) + assert "k" in result + + +# --------------------------------------------------------------------------- +# _try_fix_json +# --------------------------------------------------------------------------- + +class TestTryFixJson: + + def setup_method(self): + with patch("app.utils.llm_client.OpenAI"): + self.client = LLMClient(api_key="k", base_url="u", model="m") + + def test_returns_valid_json(self): + content = '{"name": "Alice", "age": 30}' + result = self.client._try_fix_json(content) + assert result == {"name": "Alice", "age": 30} + + def test_extracts_json_from_surrounding_text(self): + content = 'Here is the result: {"score": 42} end.' + result = self.client._try_fix_json(content) + assert result is not None + assert result["score"] == 42 + + def test_fixes_newlines_inside_string_values(self): + # Literal newline inside a JSON string value is invalid JSON + content = '{"desc": "line one\nline two"}' + result = self.client._try_fix_json(content) + assert result is not None + assert "desc" in result + + def test_returns_none_for_no_json_object(self): + assert self.client._try_fix_json("no json here at all") is None + + def test_returns_none_for_empty_string(self): + assert self.client._try_fix_json("") is None + + def test_recovers_truncated_object(self): + # Missing closing brace + content = '{"city": "Beijing", "pop": 21' + result = self.client._try_fix_json(content) + assert result is not None + assert result["city"] == "Beijing" + + +# --------------------------------------------------------------------------- +# chat_json — retry behavior +# --------------------------------------------------------------------------- + +class TestChatJsonRetry: + + # --- succeeds on first attempt --- + + def test_success_first_attempt(self): + payload = {"status": "ok"} + client = _make_client([_make_response(json.dumps(payload))]) + result = client.chat_json([{"role": "user", "content": "hi"}]) + assert result == payload + assert client._mock_create.call_count == 1 + + # --- temperature backoff --- + + def test_temperature_decreases_across_retries(self): + bad = _make_response("not json at all {{{{") + good = _make_response('{"ok": true}') + client = _make_client([bad, bad, good]) + + result = client.chat_json( + [{"role": "user", "content": "hi"}], + temperature=0.9, + temperature_step=0.3, + max_attempts=3, + ) + assert result == {"ok": True} + + calls = client._mock_create.call_args_list + assert calls[0].kwargs["temperature"] == pytest.approx(0.9) + assert calls[1].kwargs["temperature"] == pytest.approx(0.6) + assert calls[2].kwargs["temperature"] == pytest.approx(0.3) + + def test_temperature_never_goes_below_zero(self): + good = _make_response('{"x": 1}') + client = _make_client([good]) + client.chat_json( + [{"role": "user", "content": "hi"}], + temperature=0.1, + temperature_step=0.5, + max_attempts=1, + ) + calls = client._mock_create.call_args_list + assert calls[0].kwargs["temperature"] == pytest.approx(0.1) + + # --- fallback_parser --- + + def test_fallback_parser_called_on_bad_json(self): + bad = _make_response("this is not json") + client = _make_client([bad]) + + fallback = MagicMock(return_value={"rescued": True}) + result = client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=1, + fallback_parser=fallback, + ) + assert result == {"rescued": True} + fallback.assert_called_once() + + def test_fallback_parser_returning_none_does_not_short_circuit(self): + bad = _make_response("still not json ][") + good = _make_response('{"second": "attempt"}') + client = _make_client([bad, good]) + + fallback = MagicMock(return_value=None) + result = client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=2, + fallback_parser=fallback, + ) + assert result == {"second": "attempt"} + + # --- raises ValueError after all attempts fail --- + + def test_raises_after_all_attempts_fail(self): + bad = _make_response("invalid {{{ json") + client = _make_client([bad, bad, bad]) + + with pytest.raises(ValueError, match="LLM返回的JSON格式无效"): + client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=3, + ) + assert client._mock_create.call_count == 3 + + def test_raises_after_single_attempt(self): + bad = _make_response("nope") + client = _make_client([bad]) + + with pytest.raises(ValueError): + client.chat_json([{"role": "user", "content": "q"}], max_attempts=1) + + # --- finish_reason == 'length' triggers truncation repair --- + + def test_truncated_output_is_repaired(self): + truncated_json = '{"items": [1, 2, 3' + client = _make_client([_make_response(truncated_json, finish_reason="length")]) + result = client.chat_json([{"role": "user", "content": "q"}]) + assert result["items"] == [1, 2, 3] + + # --- API exception counts as a failed attempt --- + + def test_api_exception_retried(self): + from openai import APIError + exc = Exception("network failure") + good = _make_response('{"recovered": true}') + client = _make_client([exc, good]) + + result = client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=2, + ) + assert result == {"recovered": True} + assert client._mock_create.call_count == 2 + + def test_api_exception_all_attempts_raises_value_error(self): + exc = Exception("always fails") + client = _make_client([exc, exc]) + + with pytest.raises(ValueError): + client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=2, + )