This commit is contained in:
Laura Roganovic 2026-05-28 17:33:02 -04:00 committed by GitHub
commit 5875f8ff0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1034 additions and 158 deletions

View File

@ -11,6 +11,7 @@ from flask import request, jsonify, send_file
from . import report_bp
from ..config import Config
from ..services.report_agent import ReportAgent, ReportManager, ReportStatus
from ..services.signal_extractor import SignalExtractor
from ..services.simulation_manager import SimulationManager
from ..models.project import ProjectManager
from ..models.task import TaskManager, TaskStatus
@ -930,6 +931,89 @@ def stream_console_log(report_id: str):
}), 500
# ============== 预测信号接口 ==============
@report_bp.route('/<report_id>/signal', methods=['POST'])
def extract_signal(report_id: str):
"""
从已完成的报告中提取结构化预测信号miro_signal
对报告的 markdown 内容执行一次 LLM 提取返回可供
外部预测市场管道直接消费的规范化概率信号
返回
{
"success": true,
"data": {
"signal_id": "uuid",
"schema_version": "1.1",
"report_id": "report_xxxx",
"simulation_id": "sim_xxxx",
"generated_at": "2026-...",
"thesis": {
"p_yes": 0.73,
"confidence": "high",
"action": "buy_yes",
"regime": "consensus_forming",
"summary": "...",
"drivers": ["...", "..."],
"invalidators": ["...", "..."]
}
}
}
"""
try:
report = ReportManager.get_report(report_id)
if not report:
return jsonify({
"success": False,
"error": f"报告不存在: {report_id}"
}), 404
if report.status != ReportStatus.COMPLETED:
return jsonify({
"success": False,
"error": f"报告尚未完成 (status={report.status.value}),无法提取信号"
}), 400
if not report.markdown_content:
return jsonify({
"success": False,
"error": "报告内容为空,无法提取信号"
}), 400
extractor = SignalExtractor()
signal = extractor.extract(
report_id=report_id,
simulation_id=report.simulation_id,
markdown_content=report.markdown_content,
simulation_requirement=report.simulation_requirement,
)
logger.info(f"信号提取完成: report={report_id} p_yes={signal.p_yes} action={signal.action}")
return jsonify({
"success": True,
"data": signal.to_dict()
})
except ValueError as e:
logger.error(f"信号提取失败 (LLM): {str(e)}")
return jsonify({
"success": False,
"error": str(e)
}), 422
except Exception as e:
logger.error(f"信号提取失败: {str(e)}")
return jsonify({
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}), 500
# ============== 工具调用接口(供调试使用)==============
@report_bp.route('/tools/search', methods=['POST'])

View File

@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from zep_cloud.client import Zep
from ..config import Config
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -193,9 +193,10 @@ class OasisProfileGenerator:
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
self.llm_client = LLMClient(
api_key=self.api_key,
base_url=self.base_url
base_url=self.base_url,
model=self.model_name
)
# Zep客户端用于检索丰富上下文
@ -521,64 +522,36 @@ class OasisProfileGenerator:
entity_name, entity_type, entity_summary, entity_attributes, context
)
# 尝试多次生成,直到成功或达到最大重试次数
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": self._get_system_prompt(is_individual)},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
)
content = response.choices[0].message.content
# 检查是否被截断finish_reason不是'stop'
finish_reason = response.choices[0].finish_reason
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
result = json.loads(content)
# 验证必需字段
if "bio" not in result or not result["bio"]:
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
if "persona" not in result or not result["persona"]:
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}"
return result
except json.JSONDecodeError as je:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}")
# 尝试修复JSON
result = self._try_fix_json(content, entity_name, entity_type, entity_summary)
if result.get("_fixed"):
del result["_fixed"]
return result
last_error = je
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(1 * (attempt + 1)) # 指数退避
logger.warning(f"LLM生成人设失败{max_attempts}次尝试): {last_error}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
try:
result = self.llm_client.chat_json(
messages=[
{"role": "system", "content": self._get_system_prompt(is_individual)},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=None,
max_attempts=3,
temperature_step=0.1,
fallback_parser=lambda content: self._try_fix_json(
content, entity_name, entity_type, entity_summary
),
retry_delay_seconds=1.0
)
if "bio" not in result or not result["bio"]:
result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}"
if "persona" not in result or not result["persona"]:
result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}"
if result.get("_fixed"):
del result["_fixed"]
return result
except Exception as e:
logger.warning(f"LLM生成人设失败3次尝试: {e}, 使用规则生成")
return self._generate_profile_rule_based(
entity_name, entity_type, entity_summary, entity_attributes
)
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON输出被max_tokens限制截断"""
@ -1202,4 +1175,3 @@ class OasisProfileGenerator:
"""[已废弃] 请使用 save_profiles() 方法"""
logger.warning("save_profiles_to_json已废弃请使用save_profiles方法")
self.save_profiles(profiles, file_path, platform)

View File

@ -0,0 +1,245 @@
"""
Miro Signal Extractor
Distils a completed simulation report into a canonical machine-readable
probability signal that external pipelines (e.g. prediction-market bots)
can consume directly.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import List, Optional
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
logger = get_logger('mirofish.signal_extractor')
SCHEMA_VERSION = "1.1"
_SYSTEM_PROMPT = """\
You are a structured-signal extractor. You will be given the full markdown text
of a social-simulation analysis report and the original simulation requirement
(the prediction question). Your job is to distil the report into a concise,
machine-readable probability signal.
Rules:
- p_yes must be a float strictly between 0.0 and 1.0 (never exactly 0 or 1).
- confidence must be one of: "high", "medium", "low".
- action must be one of: "buy_yes", "buy_no", "hold".
Use "buy_yes" when p_yes > 0.55, "buy_no" when p_yes < 0.45, else "hold".
- regime describes the dominant social dynamic observed in the simulation,
e.g. "consensus_forming", "contested", "uncertain", "momentum_shift",
"echo_chamber", "fragmented".
- summary is one sentence ( 30 words).
- drivers is a list of 24 short strings (key factors supporting the thesis).
- invalidators is a list of 24 short strings (key risks or counter-factors).
- Do not reproduce large sections of the report. Be concise.
- Respond ONLY with valid JSON matching the schema below no prose, no fences.
Required JSON schema:
{
"p_yes": <float 0.01.0>,
"confidence": "high" | "medium" | "low",
"action": "buy_yes" | "buy_no" | "hold",
"regime": <string>,
"summary": <string>,
"drivers": [<string>, ...],
"invalidators": [<string>, ...]
}
"""
@dataclass
class MiroSignal:
"""Canonical prediction signal extracted from a simulation report."""
signal_id: str
schema_version: str
report_id: str
simulation_id: str
generated_at: str
# Core thesis fields
p_yes: float
confidence: str # high | medium | low
action: str # buy_yes | buy_no | hold
regime: str
summary: str
drivers: List[str] = field(default_factory=list)
invalidators: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"signal_id": self.signal_id,
"schema_version": self.schema_version,
"report_id": self.report_id,
"simulation_id": self.simulation_id,
"generated_at": self.generated_at,
"thesis": {
"p_yes": self.p_yes,
"confidence": self.confidence,
"action": self.action,
"regime": self.regime,
"summary": self.summary,
"drivers": self.drivers,
"invalidators": self.invalidators,
},
}
class SignalExtractor:
"""Extracts a MiroSignal from a completed report's markdown content."""
_VALID_CONFIDENCE = {"high", "medium", "low"}
_VALID_ACTIONS = {"buy_yes", "buy_no", "hold"}
def __init__(self, llm_client: Optional[LLMClient] = None):
self._client = llm_client or LLMClient()
def extract(
self,
report_id: str,
simulation_id: str,
markdown_content: str,
simulation_requirement: str,
) -> MiroSignal:
"""
Distil *markdown_content* into a MiroSignal.
Args:
report_id: The report this signal is derived from.
simulation_id: Parent simulation ID.
markdown_content: Full report text (may be long).
simulation_requirement: The original prediction question / goal.
Returns:
MiroSignal with validated fields.
Raises:
ValueError: If the LLM fails to produce a valid signal after retries.
"""
# Trim to avoid token limits while keeping the most analytical content.
# Reports can exceed 30 k chars; the last third is usually the conclusion.
body = self._trim_report(markdown_content)
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"Simulation requirement (prediction question):\n{simulation_requirement}\n\n"
f"Report:\n{body}"
),
},
]
raw = self._client.chat_json(
messages=messages,
temperature=0.1,
max_tokens=512,
max_attempts=3,
temperature_step=0.05,
fallback_parser=self._salvage,
)
return self._build_signal(raw, report_id, simulation_id)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _trim_report(content: str, max_chars: int = 12_000) -> str:
"""Keep the tail of the report (conclusions) if it is very long."""
if len(content) <= max_chars:
return content
return "…[report truncated for signal extraction]\n\n" + content[-max_chars:]
def _build_signal(
self, raw: dict, report_id: str, simulation_id: str
) -> MiroSignal:
"""Validate and normalise the raw LLM dict into a MiroSignal."""
# p_yes
try:
p_yes = float(raw.get("p_yes", 0.5))
except (TypeError, ValueError):
p_yes = 0.5
p_yes = max(0.01, min(0.99, p_yes))
# confidence
confidence = str(raw.get("confidence", "medium")).lower()
if confidence not in self._VALID_CONFIDENCE:
confidence = "medium"
# action — recompute from p_yes if missing or invalid
action = str(raw.get("action", "")).lower()
if action not in self._VALID_ACTIONS:
if p_yes > 0.55:
action = "buy_yes"
elif p_yes < 0.45:
action = "buy_no"
else:
action = "hold"
# regime
regime = str(raw.get("regime", "uncertain")).strip() or "uncertain"
# summary
summary = str(raw.get("summary", "")).strip()
# list fields
drivers = [str(d) for d in raw.get("drivers", []) if d]
invalidators = [str(i) for i in raw.get("invalidators", []) if i]
return MiroSignal(
signal_id=str(uuid.uuid4()),
schema_version=SCHEMA_VERSION,
report_id=report_id,
simulation_id=simulation_id,
generated_at=datetime.now(timezone.utc).isoformat(),
p_yes=p_yes,
confidence=confidence,
action=action,
regime=regime,
summary=summary,
drivers=drivers,
invalidators=invalidators,
)
@staticmethod
def _salvage(raw_text: str) -> Optional[dict]:
"""
Last-resort fallback: scan for any float that looks like a probability
and a YES/NO sentiment to construct a minimal signal dict.
"""
import re
prob_match = re.search(r'\b(0\.\d+|1\.0+|0)\b', raw_text)
if not prob_match:
return None
try:
p = float(prob_match.group())
except ValueError:
return None
text_lower = raw_text.lower()
if "high" in text_lower:
confidence = "high"
elif "low" in text_lower:
confidence = "low"
else:
confidence = "medium"
return {
"p_yes": p,
"confidence": confidence,
"action": "buy_yes" if p > 0.55 else ("buy_no" if p < 0.45 else "hold"),
"regime": "uncertain",
"summary": "Signal salvaged from partial LLM output.",
"drivers": [],
"invalidators": [],
}

View File

@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass, field, asdict
from datetime import datetime
from openai import OpenAI
from ..config import Config
from ..utils.llm_client import LLMClient
from ..utils.logger import get_logger
from ..utils.locale import get_language_instruction, t
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -235,9 +234,10 @@ class SimulationConfigGenerator:
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
self.llm_client = LLMClient(
api_key=self.api_key,
base_url=self.base_url
base_url=self.base_url,
model=self.model_name
)
def generate_config(
@ -433,78 +433,23 @@ class SimulationConfigGenerator:
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
"""带重试的LLM调用包含JSON修复逻辑"""
import re
max_attempts = 3
last_error = None
for attempt in range(max_attempts):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度
# 不设置max_tokens让LLM自由发挥
)
content = response.choices[0].message.content
finish_reason = response.choices[0].finish_reason
# 检查是否被截断
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt+1})")
content = self._fix_truncated_json(content)
# 尝试解析JSON
try:
return json.loads(content)
except json.JSONDecodeError as e:
logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}")
# 尝试修复JSON
fixed = self._try_fix_config_json(content)
if fixed:
return fixed
last_error = e
except Exception as e:
logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}")
last_error = e
import time
time.sleep(2 * (attempt + 1))
raise last_error or Exception("LLM调用失败")
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON"""
content = content.strip()
# 计算未闭合的括号
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
# 检查是否有未闭合的字符串
if content and content[-1] not in '",}]':
content += '"'
# 闭合括号
content += ']' * open_brackets
content += '}' * open_braces
return content
return self.llm_client.chat_json(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=None,
max_attempts=3,
temperature_step=0.1,
fallback_parser=self._try_fix_config_json,
retry_delay_seconds=2.0
)
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试修复配置JSON"""
import re
# 修复被截断的情况
content = self._fix_truncated_json(content)
# 提取JSON部分
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
@ -988,4 +933,3 @@ class SimulationConfigGenerator:
"influence_weight": 1.0
}

View File

@ -4,11 +4,15 @@ LLM客户端封装
"""
import json
import time
import re
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Callable
from openai import OpenAI
from ..config import Config
from .logger import get_logger
logger = get_logger('mirofish.llm_client')
class LLMClient:
@ -36,7 +40,7 @@ class LLMClient:
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: int = 4096,
max_tokens: Optional[int] = 4096,
response_format: Optional[Dict] = None
) -> str:
"""
@ -55,23 +59,27 @@ class LLMClient:
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
if response_format:
kwargs["response_format"] = response_format
response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
return content
content = response.choices[0].message.content or ""
return self._clean_response_text(content)
def chat_json(
self,
messages: List[Dict[str, str]],
temperature: float = 0.3,
max_tokens: int = 4096
max_tokens: Optional[int] = 4096,
max_attempts: int = 1,
temperature_step: float = 0.0,
fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None,
retry_delay_seconds: float = 0.0
) -> Dict[str, Any]:
"""
发送聊天请求并返回JSON
@ -84,20 +92,108 @@ class LLMClient:
Returns:
解析后的JSON对象
"""
response = self.chat(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"}
)
# 清理markdown代码块标记
cleaned_response = response.strip()
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
cleaned_response = cleaned_response.strip()
last_error: Optional[Exception] = None
last_response = ""
for attempt in range(max_attempts):
current_temperature = max(0.0, temperature - (attempt * temperature_step))
try:
kwargs = {
"model": self.model,
"messages": messages,
"temperature": current_temperature,
"response_format": {"type": "json_object"}
}
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
response = self.client.chat.completions.create(**kwargs)
raw_content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason
cleaned_response = self._clean_response_text(raw_content)
if finish_reason == 'length':
logger.warning(f"LLM输出被截断 (attempt {attempt + 1})")
cleaned_response = self._fix_truncated_json(cleaned_response)
last_response = cleaned_response
try:
return self._parse_json_response(cleaned_response)
except json.JSONDecodeError as parse_error:
logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}")
fixed = self._try_fix_json(cleaned_response)
if fixed is not None:
return fixed
if fallback_parser is not None:
fallback_result = fallback_parser(cleaned_response)
if fallback_result is not None:
return fallback_result
last_error = parse_error
except Exception as exc:
logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}")
last_error = exc
if attempt < max_attempts - 1 and retry_delay_seconds > 0:
time.sleep(retry_delay_seconds * (attempt + 1))
raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error
def _clean_response_text(self, content: str) -> str:
"""清理模型响应中的思考内容和Markdown包裹。"""
cleaned = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'\n?```\s*$', '', cleaned)
return cleaned.strip()
def _parse_json_response(self, content: str) -> Dict[str, Any]:
return json.loads(content)
def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON内容。"""
content = content.strip()
# If the number of unescaped quotes is odd we are inside an open string.
unescaped_quote_count = len(re.findall(r'(?<!\\)"', content))
if unescaped_quote_count % 2 == 1:
content += '"'
open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']')
content += ']' * open_brackets
content += '}' * open_braces
return content
def _try_fix_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试从近似JSON内容中恢复结构化对象。"""
content = self._fix_truncated_json(content)
json_match = re.search(r'\{[\s\S]*\}', content)
if not json_match:
return None
json_str = json_match.group()
def fix_string_newlines(match: re.Match[str]) -> str:
value = match.group(0)
value = value.replace('\n', ' ').replace('\r', ' ')
value = re.sub(r'\s+', ' ', value)
return value
json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str)
try:
return json.loads(cleaned_response)
return json.loads(json_str)
except json.JSONDecodeError:
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
json_str = re.sub(r'\s+', ' ', json_str)
try:
return json.loads(json_str)
except json.JSONDecodeError:
return None

View File

View File

View File

@ -0,0 +1,236 @@
"""
Tests for SignalExtractor no real API calls, LLMClient fully mocked.
"""
import pytest
from unittest.mock import MagicMock, patch
from app.services.signal_extractor import SignalExtractor, MiroSignal, SCHEMA_VERSION
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_extractor(chat_json_return=None, chat_json_side_effect=None):
"""Return a SignalExtractor with a mocked LLMClient."""
mock_client = MagicMock()
if chat_json_side_effect is not None:
mock_client.chat_json.side_effect = chat_json_side_effect
else:
mock_client.chat_json.return_value = chat_json_return or {}
return SignalExtractor(llm_client=mock_client), mock_client
_SAMPLE_REPORT = """
## Executive Summary
The simulation shows strong consensus forming around a YES outcome.
Seventy-three percent of agents expressed optimism.
## Key Findings
- Social momentum is strongly positive.
- Counter-narratives remain marginal.
## Conclusion
The dominant dynamic is consensus formation with high confidence.
"""
_SAMPLE_REQUIREMENT = "Will the proposal pass by end of Q2 2026?"
_GOOD_LLM_RESPONSE = {
"p_yes": 0.73,
"confidence": "high",
"action": "buy_yes",
"regime": "consensus_forming",
"summary": "Strong agent consensus supports a YES outcome with high confidence.",
"drivers": ["70%+ agent agreement", "positive social momentum"],
"invalidators": ["marginal counter-narrative", "low information diversity"],
}
# ---------------------------------------------------------------------------
# Happy path
# ---------------------------------------------------------------------------
class TestExtractHappyPath:
def test_returns_miro_signal(self):
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
result = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert isinstance(result, MiroSignal)
def test_fields_match_llm_output(self):
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.p_yes == pytest.approx(0.73)
assert sig.confidence == "high"
assert sig.action == "buy_yes"
assert sig.regime == "consensus_forming"
assert "YES" in sig.summary or "consensus" in sig.summary.lower()
assert len(sig.drivers) == 2
assert len(sig.invalidators) == 2
def test_metadata_fields(self):
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
sig = extractor.extract("report_abc", "sim_xyz", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.report_id == "report_abc"
assert sig.simulation_id == "sim_xyz"
assert sig.schema_version == SCHEMA_VERSION
assert sig.signal_id # non-empty UUID
assert sig.generated_at # non-empty ISO timestamp
def test_to_dict_structure(self):
extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE)
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
d = sig.to_dict()
assert "thesis" in d
assert set(d["thesis"].keys()) == {
"p_yes", "confidence", "action", "regime",
"summary", "drivers", "invalidators",
}
def test_llm_called_with_low_temperature(self):
extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE)
extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
call_kwargs = mock_client.chat_json.call_args.kwargs
assert call_kwargs["temperature"] <= 0.2
def test_llm_called_with_retries(self):
extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE)
extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
call_kwargs = mock_client.chat_json.call_args.kwargs
assert call_kwargs.get("max_attempts", 1) >= 2
def test_simulation_requirement_in_messages(self):
extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE)
req = "Will the referendum pass?"
extractor.extract("r1", "s1", _SAMPLE_REPORT, req)
messages = mock_client.chat_json.call_args.kwargs["messages"]
user_content = next(m["content"] for m in messages if m["role"] == "user")
assert req in user_content
# ---------------------------------------------------------------------------
# Field validation and normalisation
# ---------------------------------------------------------------------------
class TestFieldValidation:
def test_p_yes_clamped_below_zero(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": -0.5})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.p_yes >= 0.01
def test_p_yes_clamped_above_one(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 1.5})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.p_yes <= 0.99
def test_invalid_confidence_falls_back_to_medium(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "confidence": "very_sure"})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.confidence == "medium"
def test_invalid_action_recomputed_from_p_yes_buy_yes(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.8, "action": "INVALID"})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.action == "buy_yes"
def test_invalid_action_recomputed_from_p_yes_buy_no(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.2, "action": "INVALID"})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.action == "buy_no"
def test_invalid_action_recomputed_from_p_yes_hold(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.5, "action": "INVALID"})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.action == "hold"
def test_missing_regime_defaults_to_uncertain(self):
resp = {k: v for k, v in _GOOD_LLM_RESPONSE.items() if k != "regime"}
extractor, _ = _make_extractor(resp)
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.regime == "uncertain"
def test_empty_drivers_list_accepted(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "drivers": []})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
assert sig.drivers == []
def test_non_list_drivers_handled(self):
extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "drivers": "some string"})
sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)
# Should not crash; string is iterable so each char becomes an item — acceptable
assert isinstance(sig.drivers, list)
# ---------------------------------------------------------------------------
# Report trimming
# ---------------------------------------------------------------------------
class TestReportTrimming:
def test_short_report_unchanged(self):
short = "Short report content."
result = SignalExtractor._trim_report(short, max_chars=100)
assert result == short
def test_long_report_trimmed(self):
long_report = "x" * 20_000
result = SignalExtractor._trim_report(long_report, max_chars=12_000)
assert len(result) < 20_000
assert "truncated" in result
def test_trimmed_report_keeps_tail(self):
# The tail (conclusion) is most important for signal extraction
long_report = "A" * 10_000 + "CONCLUSION"
result = SignalExtractor._trim_report(long_report, max_chars=100)
assert "CONCLUSION" in result
# ---------------------------------------------------------------------------
# Fallback (_salvage)
# ---------------------------------------------------------------------------
class TestSalvage:
def test_salvage_extracts_probability(self):
result = SignalExtractor._salvage("The probability is 0.68 for YES outcome.")
assert result is not None
assert result["p_yes"] == pytest.approx(0.68)
def test_salvage_returns_none_when_no_probability(self):
assert SignalExtractor._salvage("no numbers here at all") is None
def test_salvage_sets_action_buy_yes(self):
result = SignalExtractor._salvage("probability 0.80")
assert result["action"] == "buy_yes"
def test_salvage_sets_action_buy_no(self):
result = SignalExtractor._salvage("probability 0.20")
assert result["action"] == "buy_no"
def test_salvage_sets_action_hold(self):
result = SignalExtractor._salvage("probability 0.50")
assert result["action"] == "hold"
def test_salvage_detects_high_confidence(self):
result = SignalExtractor._salvage("high confidence, p=0.72")
assert result["confidence"] == "high"
def test_salvage_detects_low_confidence(self):
result = SignalExtractor._salvage("low certainty, p=0.30")
assert result["confidence"] == "low"
# ---------------------------------------------------------------------------
# LLM failure propagates as ValueError
# ---------------------------------------------------------------------------
class TestLLMFailure:
def test_raises_value_error_on_llm_failure(self):
extractor, mock_client = _make_extractor()
mock_client.chat_json.side_effect = ValueError("LLM返回的JSON格式无效: ...")
with pytest.raises(ValueError):
extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT)

View File

View File

@ -0,0 +1,299 @@
"""
Tests for LLMClient utility no real API calls, all OpenAI interactions mocked.
"""
import json
import pytest
from unittest.mock import MagicMock, patch, call
from app.utils.llm_client import LLMClient
# ---------------------------------------------------------------------------
# Helpers to build fake OpenAI response objects
# ---------------------------------------------------------------------------
def _make_response(content: str, finish_reason: str = "stop"):
choice = MagicMock()
choice.message.content = content
choice.finish_reason = finish_reason
resp = MagicMock()
resp.choices = [choice]
return resp
def _make_client(responses):
"""
Return a patched LLMClient whose underlying OpenAI client returns
*responses* in order on successive .chat.completions.create() calls.
"""
with patch("app.utils.llm_client.OpenAI") as MockOpenAI:
mock_openai_instance = MagicMock()
mock_openai_instance.chat.completions.create.side_effect = responses
MockOpenAI.return_value = mock_openai_instance
client = LLMClient(api_key="test-key", base_url="http://localhost", model="test-model")
# Expose the underlying mock for assertions
client._mock_create = mock_openai_instance.chat.completions.create
return client
# ---------------------------------------------------------------------------
# _clean_response_text
# ---------------------------------------------------------------------------
class TestCleanResponseText:
def setup_method(self):
with patch("app.utils.llm_client.OpenAI"):
self.client = LLMClient(api_key="k", base_url="u", model="m")
def test_passthrough_plain_json(self):
raw = '{"a": 1}'
assert self.client._clean_response_text(raw) == '{"a": 1}'
def test_strips_think_tags(self):
raw = '<think>internal reasoning</think>{"a": 1}'
assert self.client._clean_response_text(raw) == '{"a": 1}'
def test_strips_multiline_think_tags(self):
raw = '<think>\nline1\nline2\n</think>\n{"b": 2}'
assert self.client._clean_response_text(raw) == '{"b": 2}'
def test_strips_json_markdown_fence(self):
raw = '```json\n{"c": 3}\n```'
assert self.client._clean_response_text(raw) == '{"c": 3}'
def test_strips_plain_markdown_fence(self):
raw = '```\n{"d": 4}\n```'
assert self.client._clean_response_text(raw) == '{"d": 4}'
def test_strips_think_and_fence_combined(self):
raw = '<think>reasoning</think>\n```json\n{"e": 5}\n```'
assert self.client._clean_response_text(raw) == '{"e": 5}'
def test_empty_string(self):
assert self.client._clean_response_text("") == ""
def test_no_fence_no_think(self):
raw = ' {"f": 6} '
assert self.client._clean_response_text(raw) == '{"f": 6}'
# ---------------------------------------------------------------------------
# _fix_truncated_json
# ---------------------------------------------------------------------------
class TestFixTruncatedJson:
def setup_method(self):
with patch("app.utils.llm_client.OpenAI"):
self.client = LLMClient(api_key="k", base_url="u", model="m")
def test_closes_one_brace(self):
truncated = '{"key": "val'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["key"] == "val"
def test_closes_nested_braces(self):
truncated = '{"outer": {"inner": "x"'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["outer"]["inner"] == "x"
def test_closes_open_array(self):
truncated = '{"list": [1, 2'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["list"] == [1, 2]
def test_closes_array_and_brace(self):
truncated = '{"a": [1, 2, 3'
fixed = self.client._fix_truncated_json(truncated)
result = json.loads(fixed)
assert result["a"] == [1, 2, 3]
def test_already_valid_unchanged(self):
valid = '{"x": 1}'
fixed = self.client._fix_truncated_json(valid)
assert json.loads(fixed) == {"x": 1}
def test_trailing_dangling_value_gets_quote(self):
# ends mid-string without closing quote
truncated = '{"k": "incomplete'
fixed = self.client._fix_truncated_json(truncated)
# Should be parseable after repair
result = json.loads(fixed)
assert "k" in result
# ---------------------------------------------------------------------------
# _try_fix_json
# ---------------------------------------------------------------------------
class TestTryFixJson:
def setup_method(self):
with patch("app.utils.llm_client.OpenAI"):
self.client = LLMClient(api_key="k", base_url="u", model="m")
def test_returns_valid_json(self):
content = '{"name": "Alice", "age": 30}'
result = self.client._try_fix_json(content)
assert result == {"name": "Alice", "age": 30}
def test_extracts_json_from_surrounding_text(self):
content = 'Here is the result: {"score": 42} end.'
result = self.client._try_fix_json(content)
assert result is not None
assert result["score"] == 42
def test_fixes_newlines_inside_string_values(self):
# Literal newline inside a JSON string value is invalid JSON
content = '{"desc": "line one\nline two"}'
result = self.client._try_fix_json(content)
assert result is not None
assert "desc" in result
def test_returns_none_for_no_json_object(self):
assert self.client._try_fix_json("no json here at all") is None
def test_returns_none_for_empty_string(self):
assert self.client._try_fix_json("") is None
def test_recovers_truncated_object(self):
# Missing closing brace
content = '{"city": "Beijing", "pop": 21'
result = self.client._try_fix_json(content)
assert result is not None
assert result["city"] == "Beijing"
# ---------------------------------------------------------------------------
# chat_json — retry behavior
# ---------------------------------------------------------------------------
class TestChatJsonRetry:
# --- succeeds on first attempt ---
def test_success_first_attempt(self):
payload = {"status": "ok"}
client = _make_client([_make_response(json.dumps(payload))])
result = client.chat_json([{"role": "user", "content": "hi"}])
assert result == payload
assert client._mock_create.call_count == 1
# --- temperature backoff ---
def test_temperature_decreases_across_retries(self):
bad = _make_response("not json at all {{{{")
good = _make_response('{"ok": true}')
client = _make_client([bad, bad, good])
result = client.chat_json(
[{"role": "user", "content": "hi"}],
temperature=0.9,
temperature_step=0.3,
max_attempts=3,
)
assert result == {"ok": True}
calls = client._mock_create.call_args_list
assert calls[0].kwargs["temperature"] == pytest.approx(0.9)
assert calls[1].kwargs["temperature"] == pytest.approx(0.6)
assert calls[2].kwargs["temperature"] == pytest.approx(0.3)
def test_temperature_never_goes_below_zero(self):
good = _make_response('{"x": 1}')
client = _make_client([good])
client.chat_json(
[{"role": "user", "content": "hi"}],
temperature=0.1,
temperature_step=0.5,
max_attempts=1,
)
calls = client._mock_create.call_args_list
assert calls[0].kwargs["temperature"] == pytest.approx(0.1)
# --- fallback_parser ---
def test_fallback_parser_called_on_bad_json(self):
bad = _make_response("this is not json")
client = _make_client([bad])
fallback = MagicMock(return_value={"rescued": True})
result = client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=1,
fallback_parser=fallback,
)
assert result == {"rescued": True}
fallback.assert_called_once()
def test_fallback_parser_returning_none_does_not_short_circuit(self):
bad = _make_response("still not json ][")
good = _make_response('{"second": "attempt"}')
client = _make_client([bad, good])
fallback = MagicMock(return_value=None)
result = client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=2,
fallback_parser=fallback,
)
assert result == {"second": "attempt"}
# --- raises ValueError after all attempts fail ---
def test_raises_after_all_attempts_fail(self):
bad = _make_response("invalid {{{ json")
client = _make_client([bad, bad, bad])
with pytest.raises(ValueError, match="LLM返回的JSON格式无效"):
client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=3,
)
assert client._mock_create.call_count == 3
def test_raises_after_single_attempt(self):
bad = _make_response("nope")
client = _make_client([bad])
with pytest.raises(ValueError):
client.chat_json([{"role": "user", "content": "q"}], max_attempts=1)
# --- finish_reason == 'length' triggers truncation repair ---
def test_truncated_output_is_repaired(self):
truncated_json = '{"items": [1, 2, 3'
client = _make_client([_make_response(truncated_json, finish_reason="length")])
result = client.chat_json([{"role": "user", "content": "q"}])
assert result["items"] == [1, 2, 3]
# --- API exception counts as a failed attempt ---
def test_api_exception_retried(self):
from openai import APIError
exc = Exception("network failure")
good = _make_response('{"recovered": true}')
client = _make_client([exc, good])
result = client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=2,
)
assert result == {"recovered": True}
assert client._mock_create.call_count == 2
def test_api_exception_all_attempts_raises_value_error(self):
exc = Exception("always fails")
client = _make_client([exc, exc])
with pytest.raises(ValueError):
client.chat_json(
[{"role": "user", "content": "q"}],
max_attempts=2,
)