Merge 8fe735c1c9 into 96096ea0ff
This commit is contained in:
commit
dc68ec6a76
|
|
@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
|
@ -193,9 +193,10 @@ class OasisProfileGenerator:
|
|||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
self.llm_client = LLMClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
model=self.model_name
|
||||
)
|
||||
|
||||
# Zep客户端用于检索丰富上下文
|
||||
|
|
@ -521,64 +522,36 @@ class OasisProfileGenerator:
|
|||
entity_name, entity_type, entity_summary, entity_attributes, context
|
||||
)
|
||||
|
||||
# 尝试多次生成,直到成功或达到最大重试次数
|
||||
max_attempts = 3
|
||||
last_error = None
|
||||
try:
|
||||
result = self.llm_client.chat_json(
|
||||
messages=[
|
||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
max_attempts=3,
|
||||
temperature_step=0.1,
|
||||
fallback_parser=lambda content: self._try_fix_json(
|
||||
content, entity_name, entity_type, entity_summary
|
||||
),
|
||||
retry_delay_seconds=1.0
|
||||
)
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
||||
# 不设置max_tokens,让LLM自由发挥
|
||||
)
|
||||
if "bio" not in result or not result["bio"]:
|
||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||
if "persona" not in result or not result["persona"]:
|
||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if result.get("_fixed"):
|
||||
del result["_fixed"]
|
||||
|
||||
# 检查是否被截断(finish_reason不是'stop')
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
result = json.loads(content)
|
||||
|
||||
# 验证必需字段
|
||||
if "bio" not in result or not result["bio"]:
|
||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||
if "persona" not in result or not result["persona"]:
|
||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as je:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
|
||||
|
||||
# 尝试修复JSON
|
||||
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
|
||||
if result.get("_fixed"):
|
||||
del result["_fixed"]
|
||||
return result
|
||||
|
||||
last_error = je
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
last_error = e
|
||||
import time
|
||||
time.sleep(1 * (attempt + 1)) # 指数退避
|
||||
|
||||
logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM生成人设失败(3次尝试): {e}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON(输出被max_tokens限制截断)"""
|
||||
|
|
@ -1202,4 +1175,3 @@ class OasisProfileGenerator:
|
|||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
|
|||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_language_instruction, t
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
|
@ -235,9 +234,10 @@ class SimulationConfigGenerator:
|
|||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
self.llm_client = LLMClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
model=self.model_name
|
||||
)
|
||||
|
||||
def generate_config(
|
||||
|
|
@ -433,78 +433,23 @@ class SimulationConfigGenerator:
|
|||
|
||||
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
|
||||
"""带重试的LLM调用,包含JSON修复逻辑"""
|
||||
import re
|
||||
|
||||
max_attempts = 3
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
||||
# 不设置max_tokens,让LLM自由发挥
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
# 检查是否被截断
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
|
||||
# 尝试修复JSON
|
||||
fixed = self._try_fix_config_json(content)
|
||||
if fixed:
|
||||
return fixed
|
||||
|
||||
last_error = e
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
last_error = e
|
||||
import time
|
||||
time.sleep(2 * (attempt + 1))
|
||||
|
||||
raise last_error or Exception("LLM调用失败")
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON"""
|
||||
content = content.strip()
|
||||
|
||||
# 计算未闭合的括号
|
||||
open_braces = content.count('{') - content.count('}')
|
||||
open_brackets = content.count('[') - content.count(']')
|
||||
|
||||
# 检查是否有未闭合的字符串
|
||||
if content and content[-1] not in '",}]':
|
||||
content += '"'
|
||||
|
||||
# 闭合括号
|
||||
content += ']' * open_brackets
|
||||
content += '}' * open_braces
|
||||
|
||||
return content
|
||||
return self.llm_client.chat_json(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
max_attempts=3,
|
||||
temperature_step=0.1,
|
||||
fallback_parser=self._try_fix_config_json,
|
||||
retry_delay_seconds=2.0
|
||||
)
|
||||
|
||||
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""尝试修复配置JSON"""
|
||||
import re
|
||||
|
||||
# 修复被截断的情况
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 提取JSON部分
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
|
|
@ -988,4 +933,3 @@ class SimulationConfigGenerator:
|
|||
"influence_weight": 1.0
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@ LLM客户端封装
|
|||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.llm_client')
|
||||
|
||||
|
||||
class LLMClient:
|
||||
|
|
@ -36,7 +40,7 @@ class LLMClient:
|
|||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
max_tokens: Optional[int] = 4096,
|
||||
response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -55,23 +59,27 @@ class LLMClient:
|
|||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
content = response.choices[0].message.content
|
||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
return content
|
||||
content = response.choices[0].message.content or ""
|
||||
return self._clean_response_text(content)
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096
|
||||
max_tokens: Optional[int] = 4096,
|
||||
max_attempts: int = 1,
|
||||
temperature_step: float = 0.0,
|
||||
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
|
||||
retry_delay_seconds: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并返回JSON
|
||||
|
|
@ -84,20 +92,108 @@ class LLMClient:
|
|||
Returns:
|
||||
解析后的JSON对象
|
||||
"""
|
||||
response = self.chat(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
# 清理markdown代码块标记
|
||||
cleaned_response = response.strip()
|
||||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
||||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
||||
cleaned_response = cleaned_response.strip()
|
||||
last_error: Optional[Exception] = None
|
||||
last_response = ""
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
current_temperature = max(0.0, temperature - (attempt * temperature_step))
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": current_temperature,
|
||||
"response_format": {"type": "json_object"}
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
cleaned_response = self._clean_response_text(raw_content)
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
|
||||
cleaned_response = self._fix_truncated_json(cleaned_response)
|
||||
|
||||
last_response = cleaned_response
|
||||
|
||||
try:
|
||||
return self._parse_json_response(cleaned_response)
|
||||
except json.JSONDecodeError as parse_error:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
|
||||
|
||||
fixed = self._try_fix_json(cleaned_response)
|
||||
if fixed is not None:
|
||||
return fixed
|
||||
|
||||
if fallback_parser is not None:
|
||||
fallback_result = fallback_parser(cleaned_response)
|
||||
if fallback_result is not None:
|
||||
return fallback_result
|
||||
|
||||
last_error = parse_error
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
|
||||
last_error = exc
|
||||
|
||||
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
|
||||
time.sleep(retry_delay_seconds * (attempt + 1))
|
||||
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
|
||||
|
||||
def _clean_response_text(self, content: str) -> str:
|
||||
"""清理模型响应中的思考内容和Markdown包裹。"""
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def _parse_json_response(self, content: str) -> Dict[str, Any]:
|
||||
return json.loads(content)
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON内容。"""
|
||||
content = content.strip()
|
||||
|
||||
# If the number of unescaped quotes is odd we are inside an open string.
|
||||
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
|
||||
if unescaped_quote_count % 2 == 1:
|
||||
content += '"'
|
||||
|
||||
open_braces = content.count('{') - content.count('}')
|
||||
open_brackets = content.count('[') - content.count(']')
|
||||
|
||||
content += ']' * open_brackets
|
||||
content += '}' * open_braces
|
||||
return content
|
||||
|
||||
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""尝试从近似JSON内容中恢复结构化对象。"""
|
||||
content = self._fix_truncated_json(content)
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if not json_match:
|
||||
return None
|
||||
|
||||
json_str = json_match.group()
|
||||
|
||||
def fix_string_newlines(match: re.Match[str]) -> str:
|
||||
value = match.group(0)
|
||||
value = value.replace('\n', ' ').replace('\r', ' ')
|
||||
value = re.sub(r'\s+', ' ', value)
|
||||
return value
|
||||
|
||||
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
|
||||
|
||||
try:
|
||||
return json.loads(cleaned_response)
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
||||
|
||||
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
|
||||
json_str = re.sub(r'\s+', ' ', json_str)
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
Tests for LLMClient utility — no real API calls, all OpenAI interactions mocked.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
from app.utils.llm_client import LLMClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build fake OpenAI response objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_response(content: str, finish_reason: str = "stop"):
|
||||
choice = MagicMock()
|
||||
choice.message.content = content
|
||||
choice.finish_reason = finish_reason
|
||||
resp = MagicMock()
|
||||
resp.choices = [choice]
|
||||
return resp
|
||||
|
||||
|
||||
def _make_client(responses):
|
||||
"""
|
||||
Return a patched LLMClient whose underlying OpenAI client returns
|
||||
*responses* in order on successive .chat.completions.create() calls.
|
||||
"""
|
||||
with patch("app.utils.llm_client.OpenAI") as MockOpenAI:
|
||||
mock_openai_instance = MagicMock()
|
||||
mock_openai_instance.chat.completions.create.side_effect = responses
|
||||
MockOpenAI.return_value = mock_openai_instance
|
||||
|
||||
client = LLMClient(api_key="test-key", base_url="http://localhost", model="test-model")
|
||||
# Expose the underlying mock for assertions
|
||||
client._mock_create = mock_openai_instance.chat.completions.create
|
||||
return client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _clean_response_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanResponseText:
|
||||
|
||||
def setup_method(self):
|
||||
with patch("app.utils.llm_client.OpenAI"):
|
||||
self.client = LLMClient(api_key="k", base_url="u", model="m")
|
||||
|
||||
def test_passthrough_plain_json(self):
|
||||
raw = '{"a": 1}'
|
||||
assert self.client._clean_response_text(raw) == '{"a": 1}'
|
||||
|
||||
def test_strips_think_tags(self):
|
||||
raw = '<think>internal reasoning</think>{"a": 1}'
|
||||
assert self.client._clean_response_text(raw) == '{"a": 1}'
|
||||
|
||||
def test_strips_multiline_think_tags(self):
|
||||
raw = '<think>\nline1\nline2\n</think>\n{"b": 2}'
|
||||
assert self.client._clean_response_text(raw) == '{"b": 2}'
|
||||
|
||||
def test_strips_json_markdown_fence(self):
|
||||
raw = '```json\n{"c": 3}\n```'
|
||||
assert self.client._clean_response_text(raw) == '{"c": 3}'
|
||||
|
||||
def test_strips_plain_markdown_fence(self):
|
||||
raw = '```\n{"d": 4}\n```'
|
||||
assert self.client._clean_response_text(raw) == '{"d": 4}'
|
||||
|
||||
def test_strips_think_and_fence_combined(self):
|
||||
raw = '<think>reasoning</think>\n```json\n{"e": 5}\n```'
|
||||
assert self.client._clean_response_text(raw) == '{"e": 5}'
|
||||
|
||||
def test_empty_string(self):
|
||||
assert self.client._clean_response_text("") == ""
|
||||
|
||||
def test_no_fence_no_think(self):
|
||||
raw = ' {"f": 6} '
|
||||
assert self.client._clean_response_text(raw) == '{"f": 6}'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _fix_truncated_json
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFixTruncatedJson:
|
||||
|
||||
def setup_method(self):
|
||||
with patch("app.utils.llm_client.OpenAI"):
|
||||
self.client = LLMClient(api_key="k", base_url="u", model="m")
|
||||
|
||||
def test_closes_one_brace(self):
|
||||
truncated = '{"key": "val'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["key"] == "val"
|
||||
|
||||
def test_closes_nested_braces(self):
|
||||
truncated = '{"outer": {"inner": "x"'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["outer"]["inner"] == "x"
|
||||
|
||||
def test_closes_open_array(self):
|
||||
truncated = '{"list": [1, 2'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["list"] == [1, 2]
|
||||
|
||||
def test_closes_array_and_brace(self):
|
||||
truncated = '{"a": [1, 2, 3'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["a"] == [1, 2, 3]
|
||||
|
||||
def test_already_valid_unchanged(self):
|
||||
valid = '{"x": 1}'
|
||||
fixed = self.client._fix_truncated_json(valid)
|
||||
assert json.loads(fixed) == {"x": 1}
|
||||
|
||||
def test_trailing_dangling_value_gets_quote(self):
|
||||
# ends mid-string without closing quote
|
||||
truncated = '{"k": "incomplete'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
# Should be parseable after repair
|
||||
result = json.loads(fixed)
|
||||
assert "k" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_fix_json
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTryFixJson:
|
||||
|
||||
def setup_method(self):
|
||||
with patch("app.utils.llm_client.OpenAI"):
|
||||
self.client = LLMClient(api_key="k", base_url="u", model="m")
|
||||
|
||||
def test_returns_valid_json(self):
|
||||
content = '{"name": "Alice", "age": 30}'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_extracts_json_from_surrounding_text(self):
|
||||
content = 'Here is the result: {"score": 42} end.'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result is not None
|
||||
assert result["score"] == 42
|
||||
|
||||
def test_fixes_newlines_inside_string_values(self):
|
||||
# Literal newline inside a JSON string value is invalid JSON
|
||||
content = '{"desc": "line one\nline two"}'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result is not None
|
||||
assert "desc" in result
|
||||
|
||||
def test_returns_none_for_no_json_object(self):
|
||||
assert self.client._try_fix_json("no json here at all") is None
|
||||
|
||||
def test_returns_none_for_empty_string(self):
|
||||
assert self.client._try_fix_json("") is None
|
||||
|
||||
def test_recovers_truncated_object(self):
|
||||
# Missing closing brace
|
||||
content = '{"city": "Beijing", "pop": 21'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result is not None
|
||||
assert result["city"] == "Beijing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat_json — retry behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestChatJsonRetry:
|
||||
|
||||
# --- succeeds on first attempt ---
|
||||
|
||||
def test_success_first_attempt(self):
|
||||
payload = {"status": "ok"}
|
||||
client = _make_client([_make_response(json.dumps(payload))])
|
||||
result = client.chat_json([{"role": "user", "content": "hi"}])
|
||||
assert result == payload
|
||||
assert client._mock_create.call_count == 1
|
||||
|
||||
# --- temperature backoff ---
|
||||
|
||||
def test_temperature_decreases_across_retries(self):
|
||||
bad = _make_response("not json at all {{{{")
|
||||
good = _make_response('{"ok": true}')
|
||||
client = _make_client([bad, bad, good])
|
||||
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
temperature=0.9,
|
||||
temperature_step=0.3,
|
||||
max_attempts=3,
|
||||
)
|
||||
assert result == {"ok": True}
|
||||
|
||||
calls = client._mock_create.call_args_list
|
||||
assert calls[0].kwargs["temperature"] == pytest.approx(0.9)
|
||||
assert calls[1].kwargs["temperature"] == pytest.approx(0.6)
|
||||
assert calls[2].kwargs["temperature"] == pytest.approx(0.3)
|
||||
|
||||
def test_temperature_never_goes_below_zero(self):
|
||||
good = _make_response('{"x": 1}')
|
||||
client = _make_client([good])
|
||||
client.chat_json(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
temperature=0.1,
|
||||
temperature_step=0.5,
|
||||
max_attempts=1,
|
||||
)
|
||||
calls = client._mock_create.call_args_list
|
||||
assert calls[0].kwargs["temperature"] == pytest.approx(0.1)
|
||||
|
||||
# --- fallback_parser ---
|
||||
|
||||
def test_fallback_parser_called_on_bad_json(self):
|
||||
bad = _make_response("this is not json")
|
||||
client = _make_client([bad])
|
||||
|
||||
fallback = MagicMock(return_value={"rescued": True})
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=1,
|
||||
fallback_parser=fallback,
|
||||
)
|
||||
assert result == {"rescued": True}
|
||||
fallback.assert_called_once()
|
||||
|
||||
def test_fallback_parser_returning_none_does_not_short_circuit(self):
|
||||
bad = _make_response("still not json ][")
|
||||
good = _make_response('{"second": "attempt"}')
|
||||
client = _make_client([bad, good])
|
||||
|
||||
fallback = MagicMock(return_value=None)
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=2,
|
||||
fallback_parser=fallback,
|
||||
)
|
||||
assert result == {"second": "attempt"}
|
||||
|
||||
# --- raises ValueError after all attempts fail ---
|
||||
|
||||
def test_raises_after_all_attempts_fail(self):
|
||||
bad = _make_response("invalid {{{ json")
|
||||
client = _make_client([bad, bad, bad])
|
||||
|
||||
with pytest.raises(ValueError, match="LLM返回的JSON格式无效"):
|
||||
client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=3,
|
||||
)
|
||||
assert client._mock_create.call_count == 3
|
||||
|
||||
def test_raises_after_single_attempt(self):
|
||||
bad = _make_response("nope")
|
||||
client = _make_client([bad])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client.chat_json([{"role": "user", "content": "q"}], max_attempts=1)
|
||||
|
||||
# --- finish_reason == 'length' triggers truncation repair ---
|
||||
|
||||
def test_truncated_output_is_repaired(self):
|
||||
truncated_json = '{"items": [1, 2, 3'
|
||||
client = _make_client([_make_response(truncated_json, finish_reason="length")])
|
||||
result = client.chat_json([{"role": "user", "content": "q"}])
|
||||
assert result["items"] == [1, 2, 3]
|
||||
|
||||
# --- API exception counts as a failed attempt ---
|
||||
|
||||
def test_api_exception_retried(self):
|
||||
from openai import APIError
|
||||
exc = Exception("network failure")
|
||||
good = _make_response('{"recovered": true}')
|
||||
client = _make_client([exc, good])
|
||||
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=2,
|
||||
)
|
||||
assert result == {"recovered": True}
|
||||
assert client._mock_create.call_count == 2
|
||||
|
||||
def test_api_exception_all_attempts_raises_value_error(self):
|
||||
exc = Exception("always fails")
|
||||
client = _make_client([exc, exc])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=2,
|
||||
)
|
||||
Loading…
Reference in New Issue