Merge 00a2150365 into 96096ea0ff
This commit is contained in:
commit
5875f8ff0b
|
|
@ -11,6 +11,7 @@ from flask import request, jsonify, send_file
|
|||
from . import report_bp
|
||||
from ..config import Config
|
||||
from ..services.report_agent import ReportAgent, ReportManager, ReportStatus
|
||||
from ..services.signal_extractor import SignalExtractor
|
||||
from ..services.simulation_manager import SimulationManager
|
||||
from ..models.project import ProjectManager
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
|
|
@ -930,6 +931,89 @@ def stream_console_log(report_id: str):
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 预测信号接口 ==============
|
||||
|
||||
@report_bp.route('/<report_id>/signal', methods=['POST'])
|
||||
def extract_signal(report_id: str):
|
||||
"""
|
||||
从已完成的报告中提取结构化预测信号(miro_signal)
|
||||
|
||||
对报告的 markdown 内容执行一次 LLM 提取,返回可供
|
||||
外部预测市场管道直接消费的规范化概率信号。
|
||||
|
||||
返回:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"signal_id": "uuid",
|
||||
"schema_version": "1.1",
|
||||
"report_id": "report_xxxx",
|
||||
"simulation_id": "sim_xxxx",
|
||||
"generated_at": "2026-...",
|
||||
"thesis": {
|
||||
"p_yes": 0.73,
|
||||
"confidence": "high",
|
||||
"action": "buy_yes",
|
||||
"regime": "consensus_forming",
|
||||
"summary": "...",
|
||||
"drivers": ["...", "..."],
|
||||
"invalidators": ["...", "..."]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
report = ReportManager.get_report(report_id)
|
||||
|
||||
if not report:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"报告不存在: {report_id}"
|
||||
}), 404
|
||||
|
||||
if report.status != ReportStatus.COMPLETED:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": f"报告尚未完成 (status={report.status.value}),无法提取信号"
|
||||
}), 400
|
||||
|
||||
if not report.markdown_content:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "报告内容为空,无法提取信号"
|
||||
}), 400
|
||||
|
||||
extractor = SignalExtractor()
|
||||
signal = extractor.extract(
|
||||
report_id=report_id,
|
||||
simulation_id=report.simulation_id,
|
||||
markdown_content=report.markdown_content,
|
||||
simulation_requirement=report.simulation_requirement,
|
||||
)
|
||||
|
||||
logger.info(f"信号提取完成: report={report_id} p_yes={signal.p_yes} action={signal.action}")
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": signal.to_dict()
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"信号提取失败 (LLM): {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}), 422
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"信号提取失败: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}), 500
|
||||
|
||||
|
||||
# ============== 工具调用接口(供调试使用)==============
|
||||
|
||||
@report_bp.route('/tools/search', methods=['POST'])
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
|
@ -193,9 +193,10 @@ class OasisProfileGenerator:
|
|||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
self.llm_client = LLMClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
model=self.model_name
|
||||
)
|
||||
|
||||
# Zep客户端用于检索丰富上下文
|
||||
|
|
@ -521,64 +522,36 @@ class OasisProfileGenerator:
|
|||
entity_name, entity_type, entity_summary, entity_attributes, context
|
||||
)
|
||||
|
||||
# 尝试多次生成,直到成功或达到最大重试次数
|
||||
max_attempts = 3
|
||||
last_error = None
|
||||
try:
|
||||
result = self.llm_client.chat_json(
|
||||
messages=[
|
||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
max_attempts=3,
|
||||
temperature_step=0.1,
|
||||
fallback_parser=lambda content: self._try_fix_json(
|
||||
content, entity_name, entity_type, entity_summary
|
||||
),
|
||||
retry_delay_seconds=1.0
|
||||
)
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": self._get_system_prompt(is_individual)},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
||||
# 不设置max_tokens,让LLM自由发挥
|
||||
)
|
||||
if "bio" not in result or not result["bio"]:
|
||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||
if "persona" not in result or not result["persona"]:
|
||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if result.get("_fixed"):
|
||||
del result["_fixed"]
|
||||
|
||||
# 检查是否被截断(finish_reason不是'stop')
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
result = json.loads(content)
|
||||
|
||||
# 验证必需字段
|
||||
if "bio" not in result or not result["bio"]:
|
||||
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
|
||||
if "persona" not in result or not result["persona"]:
|
||||
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。"
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as je:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
|
||||
|
||||
# 尝试修复JSON
|
||||
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
|
||||
if result.get("_fixed"):
|
||||
del result["_fixed"]
|
||||
return result
|
||||
|
||||
last_error = je
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
last_error = e
|
||||
import time
|
||||
time.sleep(1 * (attempt + 1)) # 指数退避
|
||||
|
||||
logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM生成人设失败(3次尝试): {e}, 使用规则生成")
|
||||
return self._generate_profile_rule_based(
|
||||
entity_name, entity_type, entity_summary, entity_attributes
|
||||
)
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON(输出被max_tokens限制截断)"""
|
||||
|
|
@ -1202,4 +1175,3 @@ class OasisProfileGenerator:
|
|||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
Miro Signal Extractor
|
||||
Distils a completed simulation report into a canonical machine-readable
|
||||
probability signal that external pipelines (e.g. prediction-market bots)
|
||||
can consume directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.signal_extractor')
|
||||
|
||||
SCHEMA_VERSION = "1.1"
|
||||
|
||||
_SYSTEM_PROMPT = """\
|
||||
You are a structured-signal extractor. You will be given the full markdown text
|
||||
of a social-simulation analysis report and the original simulation requirement
|
||||
(the prediction question). Your job is to distil the report into a concise,
|
||||
machine-readable probability signal.
|
||||
|
||||
Rules:
|
||||
- p_yes must be a float strictly between 0.0 and 1.0 (never exactly 0 or 1).
|
||||
- confidence must be one of: "high", "medium", "low".
|
||||
- action must be one of: "buy_yes", "buy_no", "hold".
|
||||
Use "buy_yes" when p_yes > 0.55, "buy_no" when p_yes < 0.45, else "hold".
|
||||
- regime describes the dominant social dynamic observed in the simulation,
|
||||
e.g. "consensus_forming", "contested", "uncertain", "momentum_shift",
|
||||
"echo_chamber", "fragmented".
|
||||
- summary is one sentence (≤ 30 words).
|
||||
- drivers is a list of 2–4 short strings (key factors supporting the thesis).
|
||||
- invalidators is a list of 2–4 short strings (key risks or counter-factors).
|
||||
- Do not reproduce large sections of the report. Be concise.
|
||||
- Respond ONLY with valid JSON matching the schema below — no prose, no fences.
|
||||
|
||||
Required JSON schema:
|
||||
{
|
||||
"p_yes": <float 0.0–1.0>,
|
||||
"confidence": "high" | "medium" | "low",
|
||||
"action": "buy_yes" | "buy_no" | "hold",
|
||||
"regime": <string>,
|
||||
"summary": <string>,
|
||||
"drivers": [<string>, ...],
|
||||
"invalidators": [<string>, ...]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiroSignal:
|
||||
"""Canonical prediction signal extracted from a simulation report."""
|
||||
|
||||
signal_id: str
|
||||
schema_version: str
|
||||
report_id: str
|
||||
simulation_id: str
|
||||
generated_at: str
|
||||
|
||||
# Core thesis fields
|
||||
p_yes: float
|
||||
confidence: str # high | medium | low
|
||||
action: str # buy_yes | buy_no | hold
|
||||
regime: str
|
||||
summary: str
|
||||
drivers: List[str] = field(default_factory=list)
|
||||
invalidators: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"signal_id": self.signal_id,
|
||||
"schema_version": self.schema_version,
|
||||
"report_id": self.report_id,
|
||||
"simulation_id": self.simulation_id,
|
||||
"generated_at": self.generated_at,
|
||||
"thesis": {
|
||||
"p_yes": self.p_yes,
|
||||
"confidence": self.confidence,
|
||||
"action": self.action,
|
||||
"regime": self.regime,
|
||||
"summary": self.summary,
|
||||
"drivers": self.drivers,
|
||||
"invalidators": self.invalidators,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SignalExtractor:
|
||||
"""Extracts a MiroSignal from a completed report's markdown content."""
|
||||
|
||||
_VALID_CONFIDENCE = {"high", "medium", "low"}
|
||||
_VALID_ACTIONS = {"buy_yes", "buy_no", "hold"}
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self._client = llm_client or LLMClient()
|
||||
|
||||
def extract(
|
||||
self,
|
||||
report_id: str,
|
||||
simulation_id: str,
|
||||
markdown_content: str,
|
||||
simulation_requirement: str,
|
||||
) -> MiroSignal:
|
||||
"""
|
||||
Distil *markdown_content* into a MiroSignal.
|
||||
|
||||
Args:
|
||||
report_id: The report this signal is derived from.
|
||||
simulation_id: Parent simulation ID.
|
||||
markdown_content: Full report text (may be long).
|
||||
simulation_requirement: The original prediction question / goal.
|
||||
|
||||
Returns:
|
||||
MiroSignal with validated fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If the LLM fails to produce a valid signal after retries.
|
||||
"""
|
||||
# Trim to avoid token limits while keeping the most analytical content.
|
||||
# Reports can exceed 30 k chars; the last third is usually the conclusion.
|
||||
body = self._trim_report(markdown_content)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Simulation requirement (prediction question):\n{simulation_requirement}\n\n"
|
||||
f"Report:\n{body}"
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
raw = self._client.chat_json(
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
max_tokens=512,
|
||||
max_attempts=3,
|
||||
temperature_step=0.05,
|
||||
fallback_parser=self._salvage,
|
||||
)
|
||||
|
||||
return self._build_signal(raw, report_id, simulation_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _trim_report(content: str, max_chars: int = 12_000) -> str:
|
||||
"""Keep the tail of the report (conclusions) if it is very long."""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
return "…[report truncated for signal extraction]\n\n" + content[-max_chars:]
|
||||
|
||||
def _build_signal(
|
||||
self, raw: dict, report_id: str, simulation_id: str
|
||||
) -> MiroSignal:
|
||||
"""Validate and normalise the raw LLM dict into a MiroSignal."""
|
||||
# p_yes
|
||||
try:
|
||||
p_yes = float(raw.get("p_yes", 0.5))
|
||||
except (TypeError, ValueError):
|
||||
p_yes = 0.5
|
||||
p_yes = max(0.01, min(0.99, p_yes))
|
||||
|
||||
# confidence
|
||||
confidence = str(raw.get("confidence", "medium")).lower()
|
||||
if confidence not in self._VALID_CONFIDENCE:
|
||||
confidence = "medium"
|
||||
|
||||
# action — recompute from p_yes if missing or invalid
|
||||
action = str(raw.get("action", "")).lower()
|
||||
if action not in self._VALID_ACTIONS:
|
||||
if p_yes > 0.55:
|
||||
action = "buy_yes"
|
||||
elif p_yes < 0.45:
|
||||
action = "buy_no"
|
||||
else:
|
||||
action = "hold"
|
||||
|
||||
# regime
|
||||
regime = str(raw.get("regime", "uncertain")).strip() or "uncertain"
|
||||
|
||||
# summary
|
||||
summary = str(raw.get("summary", "")).strip()
|
||||
|
||||
# list fields
|
||||
drivers = [str(d) for d in raw.get("drivers", []) if d]
|
||||
invalidators = [str(i) for i in raw.get("invalidators", []) if i]
|
||||
|
||||
return MiroSignal(
|
||||
signal_id=str(uuid.uuid4()),
|
||||
schema_version=SCHEMA_VERSION,
|
||||
report_id=report_id,
|
||||
simulation_id=simulation_id,
|
||||
generated_at=datetime.now(timezone.utc).isoformat(),
|
||||
p_yes=p_yes,
|
||||
confidence=confidence,
|
||||
action=action,
|
||||
regime=regime,
|
||||
summary=summary,
|
||||
drivers=drivers,
|
||||
invalidators=invalidators,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _salvage(raw_text: str) -> Optional[dict]:
|
||||
"""
|
||||
Last-resort fallback: scan for any float that looks like a probability
|
||||
and a YES/NO sentiment to construct a minimal signal dict.
|
||||
"""
|
||||
import re
|
||||
|
||||
prob_match = re.search(r'\b(0\.\d+|1\.0+|0)\b', raw_text)
|
||||
if not prob_match:
|
||||
return None
|
||||
|
||||
try:
|
||||
p = float(prob_match.group())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
text_lower = raw_text.lower()
|
||||
if "high" in text_lower:
|
||||
confidence = "high"
|
||||
elif "low" in text_lower:
|
||||
confidence = "low"
|
||||
else:
|
||||
confidence = "medium"
|
||||
|
||||
return {
|
||||
"p_yes": p,
|
||||
"confidence": confidence,
|
||||
"action": "buy_yes" if p > 0.55 else ("buy_no" if p < 0.45 else "hold"),
|
||||
"regime": "uncertain",
|
||||
"summary": "Signal salvaged from partial LLM output.",
|
||||
"drivers": [],
|
||||
"invalidators": [],
|
||||
}
|
||||
|
|
@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
|
|||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_language_instruction, t
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
|
@ -235,9 +234,10 @@ class SimulationConfigGenerator:
|
|||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
self.client = OpenAI(
|
||||
self.llm_client = LLMClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
model=self.model_name
|
||||
)
|
||||
|
||||
def generate_config(
|
||||
|
|
@ -433,78 +433,23 @@ class SimulationConfigGenerator:
|
|||
|
||||
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
|
||||
"""带重试的LLM调用,包含JSON修复逻辑"""
|
||||
import re
|
||||
|
||||
max_attempts = 3
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
|
||||
# 不设置max_tokens,让LLM自由发挥
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
# 检查是否被截断
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
|
||||
# 尝试修复JSON
|
||||
fixed = self._try_fix_config_json(content)
|
||||
if fixed:
|
||||
return fixed
|
||||
|
||||
last_error = e
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
|
||||
last_error = e
|
||||
import time
|
||||
time.sleep(2 * (attempt + 1))
|
||||
|
||||
raise last_error or Exception("LLM调用失败")
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON"""
|
||||
content = content.strip()
|
||||
|
||||
# 计算未闭合的括号
|
||||
open_braces = content.count('{') - content.count('}')
|
||||
open_brackets = content.count('[') - content.count(']')
|
||||
|
||||
# 检查是否有未闭合的字符串
|
||||
if content and content[-1] not in '",}]':
|
||||
content += '"'
|
||||
|
||||
# 闭合括号
|
||||
content += ']' * open_brackets
|
||||
content += '}' * open_braces
|
||||
|
||||
return content
|
||||
return self.llm_client.chat_json(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=None,
|
||||
max_attempts=3,
|
||||
temperature_step=0.1,
|
||||
fallback_parser=self._try_fix_config_json,
|
||||
retry_delay_seconds=2.0
|
||||
)
|
||||
|
||||
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""尝试修复配置JSON"""
|
||||
import re
|
||||
|
||||
# 修复被截断的情况
|
||||
content = self._fix_truncated_json(content)
|
||||
|
||||
# 提取JSON部分
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if json_match:
|
||||
|
|
@ -988,4 +933,3 @@ class SimulationConfigGenerator:
|
|||
"influence_weight": 1.0
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,15 @@ LLM客户端封装
|
|||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.llm_client')
|
||||
|
||||
|
||||
class LLMClient:
|
||||
|
|
@ -36,7 +40,7 @@ class LLMClient:
|
|||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
max_tokens: Optional[int] = 4096,
|
||||
response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -55,23 +59,27 @@ class LLMClient:
|
|||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
content = response.choices[0].message.content
|
||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
return content
|
||||
content = response.choices[0].message.content or ""
|
||||
return self._clean_response_text(content)
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096
|
||||
max_tokens: Optional[int] = 4096,
|
||||
max_attempts: int = 1,
|
||||
temperature_step: float = 0.0,
|
||||
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
|
||||
retry_delay_seconds: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并返回JSON
|
||||
|
|
@ -84,20 +92,108 @@ class LLMClient:
|
|||
Returns:
|
||||
解析后的JSON对象
|
||||
"""
|
||||
response = self.chat(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
# 清理markdown代码块标记
|
||||
cleaned_response = response.strip()
|
||||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
||||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
||||
cleaned_response = cleaned_response.strip()
|
||||
last_error: Optional[Exception] = None
|
||||
last_response = ""
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
current_temperature = max(0.0, temperature - (attempt * temperature_step))
|
||||
|
||||
try:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": current_temperature,
|
||||
"response_format": {"type": "json_object"}
|
||||
}
|
||||
|
||||
if max_tokens is not None:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
raw_content = response.choices[0].message.content or ""
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
cleaned_response = self._clean_response_text(raw_content)
|
||||
if finish_reason == 'length':
|
||||
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
|
||||
cleaned_response = self._fix_truncated_json(cleaned_response)
|
||||
|
||||
last_response = cleaned_response
|
||||
|
||||
try:
|
||||
return self._parse_json_response(cleaned_response)
|
||||
except json.JSONDecodeError as parse_error:
|
||||
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
|
||||
|
||||
fixed = self._try_fix_json(cleaned_response)
|
||||
if fixed is not None:
|
||||
return fixed
|
||||
|
||||
if fallback_parser is not None:
|
||||
fallback_result = fallback_parser(cleaned_response)
|
||||
if fallback_result is not None:
|
||||
return fallback_result
|
||||
|
||||
last_error = parse_error
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
|
||||
last_error = exc
|
||||
|
||||
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
|
||||
time.sleep(retry_delay_seconds * (attempt + 1))
|
||||
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
|
||||
|
||||
def _clean_response_text(self, content: str) -> str:
|
||||
"""清理模型响应中的思考内容和Markdown包裹。"""
|
||||
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def _parse_json_response(self, content: str) -> Dict[str, Any]:
|
||||
return json.loads(content)
|
||||
|
||||
def _fix_truncated_json(self, content: str) -> str:
|
||||
"""修复被截断的JSON内容。"""
|
||||
content = content.strip()
|
||||
|
||||
# If the number of unescaped quotes is odd we are inside an open string.
|
||||
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
|
||||
if unescaped_quote_count % 2 == 1:
|
||||
content += '"'
|
||||
|
||||
open_braces = content.count('{') - content.count('}')
|
||||
open_brackets = content.count('[') - content.count(']')
|
||||
|
||||
content += ']' * open_brackets
|
||||
content += '}' * open_braces
|
||||
return content
|
||||
|
||||
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""尝试从近似JSON内容中恢复结构化对象。"""
|
||||
content = self._fix_truncated_json(content)
|
||||
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||
if not json_match:
|
||||
return None
|
||||
|
||||
json_str = json_match.group()
|
||||
|
||||
def fix_string_newlines(match: re.Match[str]) -> str:
|
||||
value = match.group(0)
|
||||
value = value.replace('\n', ' ').replace('\r', ' ')
|
||||
value = re.sub(r'\s+', ' ', value)
|
||||
return value
|
||||
|
||||
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
|
||||
|
||||
try:
|
||||
return json.loads(cleaned_response)
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
||||
|
||||
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
|
||||
json_str = re.sub(r'\s+', ' ', json_str)
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,236 @@
|
|||
"""
|
||||
Tests for SignalExtractor — no real API calls, LLMClient fully mocked.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.services.signal_extractor import SignalExtractor, MiroSignal, SCHEMA_VERSION
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_extractor(chat_json_return=None, chat_json_side_effect=None):
|
||||
"""Return a SignalExtractor with a mocked LLMClient."""
|
||||
mock_client = MagicMock()
|
||||
if chat_json_side_effect is not None:
|
||||
mock_client.chat_json.side_effect = chat_json_side_effect
|
||||
else:
|
||||
mock_client.chat_json.return_value = chat_json_return or {}
|
||||
return SignalExtractor(llm_client=mock_client), mock_client
|
||||
|
||||
|
||||
_SAMPLE_REPORT = """
|
||||
## Executive Summary
|
||||
The simulation shows strong consensus forming around a YES outcome.
|
||||
Seventy-three percent of agents expressed optimism.
|
||||
|
||||
## Key Findings
|
||||
- Social momentum is strongly positive.
|
||||
- Counter-narratives remain marginal.
|
||||
|
||||
## Conclusion
|
||||
The dominant dynamic is consensus formation with high confidence.
|
||||
"""
|
||||
|
||||
_SAMPLE_REQUIREMENT = "Will the proposal pass by end of Q2 2026?"
|
||||
|
||||
_GOOD_LLM_RESPONSE = {
|
||||
"p_yes": 0.73,
|
||||
"confidence": "high",
|
||||
"action": "buy_yes",
|
||||
"regime": "consensus_forming",
|
||||
"summary": "Strong agent consensus supports a YES outcome with high confidence.",
|
||||
"drivers": ["70%+ agent agreement", "positive social momentum"],
|
||||
"invalidators": ["marginal counter-narrative", "low information diversity"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractHappyPath:
|
||||
|
||||
def test_returns_miro_signal(self):
|
||||
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
result = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert isinstance(result, MiroSignal)
|
||||
|
||||
def test_fields_match_llm_output(self):
|
||||
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.p_yes == pytest.approx(0.73)
|
||||
assert sig.confidence == "high"
|
||||
assert sig.action == "buy_yes"
|
||||
assert sig.regime == "consensus_forming"
|
||||
assert "YES" in sig.summary or "consensus" in sig.summary.lower()
|
||||
assert len(sig.drivers) == 2
|
||||
assert len(sig.invalidators) == 2
|
||||
|
||||
def test_metadata_fields(self):
|
||||
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
sig = extractor.extract("report_abc", "sim_xyz", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.report_id == "report_abc"
|
||||
assert sig.simulation_id == "sim_xyz"
|
||||
assert sig.schema_version == SCHEMA_VERSION
|
||||
assert sig.signal_id # non-empty UUID
|
||||
assert sig.generated_at # non-empty ISO timestamp
|
||||
|
||||
def test_to_dict_structure(self):
|
||||
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
d = sig.to_dict()
|
||||
assert "thesis" in d
|
||||
assert set(d["thesis"].keys()) == {
|
||||
"p_yes", "confidence", "action", "regime",
|
||||
"summary", "drivers", "invalidators",
|
||||
}
|
||||
|
||||
def test_llm_called_with_low_temperature(self):
|
||||
extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
call_kwargs = mock_client.chat_json.call_args.kwargs
|
||||
assert call_kwargs["temperature"] <= 0.2
|
||||
|
||||
def test_llm_called_with_retries(self):
|
||||
extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
call_kwargs = mock_client.chat_json.call_args.kwargs
|
||||
assert call_kwargs.get("max_attempts", 1) >= 2
|
||||
|
||||
def test_simulation_requirement_in_messages(self):
|
||||
extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE)
|
||||
req = "Will the referendum pass?"
|
||||
extractor.extract("r1", "s1", _SAMPLE_REPORT, req)
|
||||
messages = mock_client.chat_json.call_args.kwargs["messages"]
|
||||
user_content = next(m["content"] for m in messages if m["role"] == "user")
|
||||
assert req in user_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Field validation and normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFieldValidation:
|
||||
|
||||
def test_p_yes_clamped_below_zero(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": -0.5})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.p_yes >= 0.01
|
||||
|
||||
def test_p_yes_clamped_above_one(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 1.5})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.p_yes <= 0.99
|
||||
|
||||
def test_invalid_confidence_falls_back_to_medium(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "confidence": "very_sure"})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.confidence == "medium"
|
||||
|
||||
def test_invalid_action_recomputed_from_p_yes_buy_yes(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.8, "action": "INVALID"})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.action == "buy_yes"
|
||||
|
||||
def test_invalid_action_recomputed_from_p_yes_buy_no(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.2, "action": "INVALID"})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.action == "buy_no"
|
||||
|
||||
def test_invalid_action_recomputed_from_p_yes_hold(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.5, "action": "INVALID"})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.action == "hold"
|
||||
|
||||
def test_missing_regime_defaults_to_uncertain(self):
|
||||
resp = {k: v for k, v in _GOOD_LLM_RESPONSE.items() if k != "regime"}
|
||||
extractor, _ = _make_extractor(resp)
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.regime == "uncertain"
|
||||
|
||||
def test_empty_drivers_list_accepted(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "drivers": []})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
assert sig.drivers == []
|
||||
|
||||
def test_non_list_drivers_handled(self):
|
||||
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "drivers": "some string"})
|
||||
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
# Should not crash; string is iterable so each char becomes an item — acceptable
|
||||
assert isinstance(sig.drivers, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Report trimming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReportTrimming:
|
||||
|
||||
def test_short_report_unchanged(self):
|
||||
short = "Short report content."
|
||||
result = SignalExtractor._trim_report(short, max_chars=100)
|
||||
assert result == short
|
||||
|
||||
def test_long_report_trimmed(self):
|
||||
long_report = "x" * 20_000
|
||||
result = SignalExtractor._trim_report(long_report, max_chars=12_000)
|
||||
assert len(result) < 20_000
|
||||
assert "truncated" in result
|
||||
|
||||
def test_trimmed_report_keeps_tail(self):
|
||||
# The tail (conclusion) is most important for signal extraction
|
||||
long_report = "A" * 10_000 + "CONCLUSION"
|
||||
result = SignalExtractor._trim_report(long_report, max_chars=100)
|
||||
assert "CONCLUSION" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback (_salvage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSalvage:
|
||||
|
||||
def test_salvage_extracts_probability(self):
|
||||
result = SignalExtractor._salvage("The probability is 0.68 for YES outcome.")
|
||||
assert result is not None
|
||||
assert result["p_yes"] == pytest.approx(0.68)
|
||||
|
||||
def test_salvage_returns_none_when_no_probability(self):
|
||||
assert SignalExtractor._salvage("no numbers here at all") is None
|
||||
|
||||
def test_salvage_sets_action_buy_yes(self):
|
||||
result = SignalExtractor._salvage("probability 0.80")
|
||||
assert result["action"] == "buy_yes"
|
||||
|
||||
def test_salvage_sets_action_buy_no(self):
|
||||
result = SignalExtractor._salvage("probability 0.20")
|
||||
assert result["action"] == "buy_no"
|
||||
|
||||
def test_salvage_sets_action_hold(self):
|
||||
result = SignalExtractor._salvage("probability 0.50")
|
||||
assert result["action"] == "hold"
|
||||
|
||||
def test_salvage_detects_high_confidence(self):
|
||||
result = SignalExtractor._salvage("high confidence, p=0.72")
|
||||
assert result["confidence"] == "high"
|
||||
|
||||
def test_salvage_detects_low_confidence(self):
|
||||
result = SignalExtractor._salvage("low certainty, p=0.30")
|
||||
assert result["confidence"] == "low"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM failure propagates as ValueError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLLMFailure:
|
||||
|
||||
def test_raises_value_error_on_llm_failure(self):
|
||||
extractor, mock_client = _make_extractor()
|
||||
mock_client.chat_json.side_effect = ValueError("LLM返回的JSON格式无效: ...")
|
||||
with pytest.raises(ValueError):
|
||||
extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
Tests for LLMClient utility — no real API calls, all OpenAI interactions mocked.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
from app.utils.llm_client import LLMClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build fake OpenAI response objects
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_response(content: str, finish_reason: str = "stop"):
|
||||
choice = MagicMock()
|
||||
choice.message.content = content
|
||||
choice.finish_reason = finish_reason
|
||||
resp = MagicMock()
|
||||
resp.choices = [choice]
|
||||
return resp
|
||||
|
||||
|
||||
def _make_client(responses):
|
||||
"""
|
||||
Return a patched LLMClient whose underlying OpenAI client returns
|
||||
*responses* in order on successive .chat.completions.create() calls.
|
||||
"""
|
||||
with patch("app.utils.llm_client.OpenAI") as MockOpenAI:
|
||||
mock_openai_instance = MagicMock()
|
||||
mock_openai_instance.chat.completions.create.side_effect = responses
|
||||
MockOpenAI.return_value = mock_openai_instance
|
||||
|
||||
client = LLMClient(api_key="test-key", base_url="http://localhost", model="test-model")
|
||||
# Expose the underlying mock for assertions
|
||||
client._mock_create = mock_openai_instance.chat.completions.create
|
||||
return client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _clean_response_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanResponseText:
|
||||
|
||||
def setup_method(self):
|
||||
with patch("app.utils.llm_client.OpenAI"):
|
||||
self.client = LLMClient(api_key="k", base_url="u", model="m")
|
||||
|
||||
def test_passthrough_plain_json(self):
|
||||
raw = '{"a": 1}'
|
||||
assert self.client._clean_response_text(raw) == '{"a": 1}'
|
||||
|
||||
def test_strips_think_tags(self):
|
||||
raw = '<think>internal reasoning</think>{"a": 1}'
|
||||
assert self.client._clean_response_text(raw) == '{"a": 1}'
|
||||
|
||||
def test_strips_multiline_think_tags(self):
|
||||
raw = '<think>\nline1\nline2\n</think>\n{"b": 2}'
|
||||
assert self.client._clean_response_text(raw) == '{"b": 2}'
|
||||
|
||||
def test_strips_json_markdown_fence(self):
|
||||
raw = '```json\n{"c": 3}\n```'
|
||||
assert self.client._clean_response_text(raw) == '{"c": 3}'
|
||||
|
||||
def test_strips_plain_markdown_fence(self):
|
||||
raw = '```\n{"d": 4}\n```'
|
||||
assert self.client._clean_response_text(raw) == '{"d": 4}'
|
||||
|
||||
def test_strips_think_and_fence_combined(self):
|
||||
raw = '<think>reasoning</think>\n```json\n{"e": 5}\n```'
|
||||
assert self.client._clean_response_text(raw) == '{"e": 5}'
|
||||
|
||||
def test_empty_string(self):
|
||||
assert self.client._clean_response_text("") == ""
|
||||
|
||||
def test_no_fence_no_think(self):
|
||||
raw = ' {"f": 6} '
|
||||
assert self.client._clean_response_text(raw) == '{"f": 6}'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _fix_truncated_json
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFixTruncatedJson:
|
||||
|
||||
def setup_method(self):
|
||||
with patch("app.utils.llm_client.OpenAI"):
|
||||
self.client = LLMClient(api_key="k", base_url="u", model="m")
|
||||
|
||||
def test_closes_one_brace(self):
|
||||
truncated = '{"key": "val'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["key"] == "val"
|
||||
|
||||
def test_closes_nested_braces(self):
|
||||
truncated = '{"outer": {"inner": "x"'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["outer"]["inner"] == "x"
|
||||
|
||||
def test_closes_open_array(self):
|
||||
truncated = '{"list": [1, 2'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["list"] == [1, 2]
|
||||
|
||||
def test_closes_array_and_brace(self):
|
||||
truncated = '{"a": [1, 2, 3'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
result = json.loads(fixed)
|
||||
assert result["a"] == [1, 2, 3]
|
||||
|
||||
def test_already_valid_unchanged(self):
|
||||
valid = '{"x": 1}'
|
||||
fixed = self.client._fix_truncated_json(valid)
|
||||
assert json.loads(fixed) == {"x": 1}
|
||||
|
||||
def test_trailing_dangling_value_gets_quote(self):
|
||||
# ends mid-string without closing quote
|
||||
truncated = '{"k": "incomplete'
|
||||
fixed = self.client._fix_truncated_json(truncated)
|
||||
# Should be parseable after repair
|
||||
result = json.loads(fixed)
|
||||
assert "k" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _try_fix_json
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTryFixJson:
|
||||
|
||||
def setup_method(self):
|
||||
with patch("app.utils.llm_client.OpenAI"):
|
||||
self.client = LLMClient(api_key="k", base_url="u", model="m")
|
||||
|
||||
def test_returns_valid_json(self):
|
||||
content = '{"name": "Alice", "age": 30}'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result == {"name": "Alice", "age": 30}
|
||||
|
||||
def test_extracts_json_from_surrounding_text(self):
|
||||
content = 'Here is the result: {"score": 42} end.'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result is not None
|
||||
assert result["score"] == 42
|
||||
|
||||
def test_fixes_newlines_inside_string_values(self):
|
||||
# Literal newline inside a JSON string value is invalid JSON
|
||||
content = '{"desc": "line one\nline two"}'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result is not None
|
||||
assert "desc" in result
|
||||
|
||||
def test_returns_none_for_no_json_object(self):
|
||||
assert self.client._try_fix_json("no json here at all") is None
|
||||
|
||||
def test_returns_none_for_empty_string(self):
|
||||
assert self.client._try_fix_json("") is None
|
||||
|
||||
def test_recovers_truncated_object(self):
|
||||
# Missing closing brace
|
||||
content = '{"city": "Beijing", "pop": 21'
|
||||
result = self.client._try_fix_json(content)
|
||||
assert result is not None
|
||||
assert result["city"] == "Beijing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# chat_json — retry behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestChatJsonRetry:
|
||||
|
||||
# --- succeeds on first attempt ---
|
||||
|
||||
def test_success_first_attempt(self):
|
||||
payload = {"status": "ok"}
|
||||
client = _make_client([_make_response(json.dumps(payload))])
|
||||
result = client.chat_json([{"role": "user", "content": "hi"}])
|
||||
assert result == payload
|
||||
assert client._mock_create.call_count == 1
|
||||
|
||||
# --- temperature backoff ---
|
||||
|
||||
def test_temperature_decreases_across_retries(self):
|
||||
bad = _make_response("not json at all {{{{")
|
||||
good = _make_response('{"ok": true}')
|
||||
client = _make_client([bad, bad, good])
|
||||
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
temperature=0.9,
|
||||
temperature_step=0.3,
|
||||
max_attempts=3,
|
||||
)
|
||||
assert result == {"ok": True}
|
||||
|
||||
calls = client._mock_create.call_args_list
|
||||
assert calls[0].kwargs["temperature"] == pytest.approx(0.9)
|
||||
assert calls[1].kwargs["temperature"] == pytest.approx(0.6)
|
||||
assert calls[2].kwargs["temperature"] == pytest.approx(0.3)
|
||||
|
||||
def test_temperature_never_goes_below_zero(self):
|
||||
good = _make_response('{"x": 1}')
|
||||
client = _make_client([good])
|
||||
client.chat_json(
|
||||
[{"role": "user", "content": "hi"}],
|
||||
temperature=0.1,
|
||||
temperature_step=0.5,
|
||||
max_attempts=1,
|
||||
)
|
||||
calls = client._mock_create.call_args_list
|
||||
assert calls[0].kwargs["temperature"] == pytest.approx(0.1)
|
||||
|
||||
# --- fallback_parser ---
|
||||
|
||||
def test_fallback_parser_called_on_bad_json(self):
|
||||
bad = _make_response("this is not json")
|
||||
client = _make_client([bad])
|
||||
|
||||
fallback = MagicMock(return_value={"rescued": True})
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=1,
|
||||
fallback_parser=fallback,
|
||||
)
|
||||
assert result == {"rescued": True}
|
||||
fallback.assert_called_once()
|
||||
|
||||
def test_fallback_parser_returning_none_does_not_short_circuit(self):
|
||||
bad = _make_response("still not json ][")
|
||||
good = _make_response('{"second": "attempt"}')
|
||||
client = _make_client([bad, good])
|
||||
|
||||
fallback = MagicMock(return_value=None)
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=2,
|
||||
fallback_parser=fallback,
|
||||
)
|
||||
assert result == {"second": "attempt"}
|
||||
|
||||
# --- raises ValueError after all attempts fail ---
|
||||
|
||||
def test_raises_after_all_attempts_fail(self):
|
||||
bad = _make_response("invalid {{{ json")
|
||||
client = _make_client([bad, bad, bad])
|
||||
|
||||
with pytest.raises(ValueError, match="LLM返回的JSON格式无效"):
|
||||
client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=3,
|
||||
)
|
||||
assert client._mock_create.call_count == 3
|
||||
|
||||
def test_raises_after_single_attempt(self):
|
||||
bad = _make_response("nope")
|
||||
client = _make_client([bad])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client.chat_json([{"role": "user", "content": "q"}], max_attempts=1)
|
||||
|
||||
# --- finish_reason == 'length' triggers truncation repair ---
|
||||
|
||||
def test_truncated_output_is_repaired(self):
|
||||
truncated_json = '{"items": [1, 2, 3'
|
||||
client = _make_client([_make_response(truncated_json, finish_reason="length")])
|
||||
result = client.chat_json([{"role": "user", "content": "q"}])
|
||||
assert result["items"] == [1, 2, 3]
|
||||
|
||||
# --- API exception counts as a failed attempt ---
|
||||
|
||||
def test_api_exception_retried(self):
|
||||
from openai import APIError
|
||||
exc = Exception("network failure")
|
||||
good = _make_response('{"recovered": true}')
|
||||
client = _make_client([exc, good])
|
||||
|
||||
result = client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=2,
|
||||
)
|
||||
assert result == {"recovered": True}
|
||||
assert client._mock_create.call_count == 2
|
||||
|
||||
def test_api_exception_all_attempts_raises_value_error(self):
|
||||
exc = Exception("always fails")
|
||||
client = _make_client([exc, exc])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
client.chat_json(
|
||||
[{"role": "user", "content": "q"}],
|
||||
max_attempts=2,
|
||||
)
|
||||
Loading…
Reference in New Issue