diff --git a/backend/app/api/report.py b/backend/app/api/report.py index d7f2a4d0..577231d5 100644 --- a/backend/app/api/report.py +++ b/backend/app/api/report.py @@ -11,6 +11,7 @@ from flask import request, jsonify, send_file from . import report_bp from ..config import Config from ..services.report_agent import ReportAgent, ReportManager, ReportStatus +from ..services.signal_extractor import SignalExtractor from ..services.simulation_manager import SimulationManager from ..models.project import ProjectManager from ..models.task import TaskManager, TaskStatus @@ -930,6 +931,89 @@ def stream_console_log(report_id: str): }), 500 +# ============== 预测信号接口 ============== + +@report_bp.route('//signal', methods=['POST']) +def extract_signal(report_id: str): + """ + 从已完成的报告中提取结构化预测信号(miro_signal) + + 对报告的 markdown 内容执行一次 LLM 提取,返回可供 + 外部预测市场管道直接消费的规范化概率信号。 + + 返回: + { + "success": true, + "data": { + "signal_id": "uuid", + "schema_version": "1.1", + "report_id": "report_xxxx", + "simulation_id": "sim_xxxx", + "generated_at": "2026-...", + "thesis": { + "p_yes": 0.73, + "confidence": "high", + "action": "buy_yes", + "regime": "consensus_forming", + "summary": "...", + "drivers": ["...", "..."], + "invalidators": ["...", "..."] + } + } + } + """ + try: + report = ReportManager.get_report(report_id) + + if not report: + return jsonify({ + "success": False, + "error": f"报告不存在: {report_id}" + }), 404 + + if report.status != ReportStatus.COMPLETED: + return jsonify({ + "success": False, + "error": f"报告尚未完成 (status={report.status.value}),无法提取信号" + }), 400 + + if not report.markdown_content: + return jsonify({ + "success": False, + "error": "报告内容为空,无法提取信号" + }), 400 + + extractor = SignalExtractor() + signal = extractor.extract( + report_id=report_id, + simulation_id=report.simulation_id, + markdown_content=report.markdown_content, + simulation_requirement=report.simulation_requirement, + ) + + logger.info(f"信号提取完成: report={report_id} p_yes={signal.p_yes} action={signal.action}") + + return jsonify({ + "success": True, + "data": signal.to_dict() + }) + + except ValueError as e: + logger.error(f"信号提取失败 (LLM): {str(e)}") + return jsonify({ + "success": False, + "error": str(e) + }), 422 + + except Exception as e: + logger.error(f"信号提取失败: {str(e)}") + return jsonify({ + "success": False, + "error": str(e), + "traceback": traceback.format_exc() + }), 500 + + # ============== 工具调用接口(供调试使用)============== @report_bp.route('/tools/search', methods=['POST']) diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 7704a627..5191ff90 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -15,10 +15,10 @@ from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime -from openai import OpenAI from zep_cloud.client import Zep from ..config import Config +from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, get_locale, set_locale, t from .zep_entity_reader import EntityNode, ZepEntityReader @@ -193,9 +193,10 @@ class OasisProfileGenerator: if not self.api_key: raise ValueError("LLM_API_KEY 未配置") - self.client = OpenAI( + self.llm_client = LLMClient( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + model=self.model_name ) # Zep客户端用于检索丰富上下文 @@ -521,64 +522,36 @@ class OasisProfileGenerator: entity_name, entity_type, entity_summary, entity_attributes, context ) - # 尝试多次生成,直到成功或达到最大重试次数 - max_attempts = 3 - last_error = None - - for attempt in range(max_attempts): - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": self._get_system_prompt(is_individual)}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 - ) - - content = response.choices[0].message.content - - # 检查是否被截断(finish_reason不是'stop') - finish_reason = response.choices[0].finish_reason - if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1}), 尝试修复...") - content = self._fix_truncated_json(content) - - # 尝试解析JSON - try: - result = json.loads(content) - - # 验证必需字段 - if "bio" not in result or not result["bio"]: - result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" - if "persona" not in result or not result["persona"]: - result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" - - return result - - except json.JSONDecodeError as je: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(je)[:80]}") - - # 尝试修复JSON - result = self._try_fix_json(content, entity_name, entity_type, entity_summary) - if result.get("_fixed"): - del result["_fixed"] - return result - - last_error = je - - except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") - last_error = e - import time - time.sleep(1 * (attempt + 1)) # 指数退避 - - logger.warning(f"LLM生成人设失败({max_attempts}次尝试): {last_error}, 使用规则生成") - return self._generate_profile_rule_based( - entity_name, entity_type, entity_summary, entity_attributes - ) + try: + result = self.llm_client.chat_json( + messages=[ + {"role": "system", "content": self._get_system_prompt(is_individual)}, + {"role": "user", "content": prompt} + ], + temperature=0.7, + max_tokens=None, + max_attempts=3, + temperature_step=0.1, + fallback_parser=lambda content: self._try_fix_json( + content, entity_name, entity_type, entity_summary + ), + retry_delay_seconds=1.0 + ) + + if "bio" not in result or not result["bio"]: + result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" + if "persona" not in result or not result["persona"]: + result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" + + if result.get("_fixed"): + del result["_fixed"] + + return result + except Exception as e: + logger.warning(f"LLM生成人设失败(3次尝试): {e}, 使用规则生成") + return self._generate_profile_rule_based( + entity_name, entity_type, entity_summary, entity_attributes + ) def _fix_truncated_json(self, content: str) -> str: """修复被截断的JSON(输出被max_tokens限制截断)""" @@ -1202,4 +1175,3 @@ class OasisProfileGenerator: """[已废弃] 请使用 save_profiles() 方法""" logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") self.save_profiles(profiles, file_path, platform) - diff --git a/backend/app/services/signal_extractor.py b/backend/app/services/signal_extractor.py new file mode 100644 index 00000000..2ba542d8 --- /dev/null +++ b/backend/app/services/signal_extractor.py @@ -0,0 +1,245 @@ +""" +Miro Signal Extractor +Distils a completed simulation report into a canonical machine-readable +probability signal that external pipelines (e.g. prediction-market bots) +can consume directly. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import List, Optional + +from ..utils.llm_client import LLMClient +from ..utils.logger import get_logger + +logger = get_logger('mirofish.signal_extractor') + +SCHEMA_VERSION = "1.1" + +_SYSTEM_PROMPT = """\ +You are a structured-signal extractor. You will be given the full markdown text +of a social-simulation analysis report and the original simulation requirement +(the prediction question). Your job is to distil the report into a concise, +machine-readable probability signal. + +Rules: +- p_yes must be a float strictly between 0.0 and 1.0 (never exactly 0 or 1). +- confidence must be one of: "high", "medium", "low". +- action must be one of: "buy_yes", "buy_no", "hold". + Use "buy_yes" when p_yes > 0.55, "buy_no" when p_yes < 0.45, else "hold". +- regime describes the dominant social dynamic observed in the simulation, + e.g. "consensus_forming", "contested", "uncertain", "momentum_shift", + "echo_chamber", "fragmented". +- summary is one sentence (≤ 30 words). +- drivers is a list of 2–4 short strings (key factors supporting the thesis). +- invalidators is a list of 2–4 short strings (key risks or counter-factors). +- Do not reproduce large sections of the report. Be concise. +- Respond ONLY with valid JSON matching the schema below — no prose, no fences. + +Required JSON schema: +{ + "p_yes": , + "confidence": "high" | "medium" | "low", + "action": "buy_yes" | "buy_no" | "hold", + "regime": , + "summary": , + "drivers": [, ...], + "invalidators": [, ...] +} +""" + + +@dataclass +class MiroSignal: + """Canonical prediction signal extracted from a simulation report.""" + + signal_id: str + schema_version: str + report_id: str + simulation_id: str + generated_at: str + + # Core thesis fields + p_yes: float + confidence: str # high | medium | low + action: str # buy_yes | buy_no | hold + regime: str + summary: str + drivers: List[str] = field(default_factory=list) + invalidators: List[str] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "signal_id": self.signal_id, + "schema_version": self.schema_version, + "report_id": self.report_id, + "simulation_id": self.simulation_id, + "generated_at": self.generated_at, + "thesis": { + "p_yes": self.p_yes, + "confidence": self.confidence, + "action": self.action, + "regime": self.regime, + "summary": self.summary, + "drivers": self.drivers, + "invalidators": self.invalidators, + }, + } + + +class SignalExtractor: + """Extracts a MiroSignal from a completed report's markdown content.""" + + _VALID_CONFIDENCE = {"high", "medium", "low"} + _VALID_ACTIONS = {"buy_yes", "buy_no", "hold"} + + def __init__(self, llm_client: Optional[LLMClient] = None): + self._client = llm_client or LLMClient() + + def extract( + self, + report_id: str, + simulation_id: str, + markdown_content: str, + simulation_requirement: str, + ) -> MiroSignal: + """ + Distil *markdown_content* into a MiroSignal. + + Args: + report_id: The report this signal is derived from. + simulation_id: Parent simulation ID. + markdown_content: Full report text (may be long). + simulation_requirement: The original prediction question / goal. + + Returns: + MiroSignal with validated fields. + + Raises: + ValueError: If the LLM fails to produce a valid signal after retries. + """ + # Trim to avoid token limits while keeping the most analytical content. + # Reports can exceed 30 k chars; the last third is usually the conclusion. + body = self._trim_report(markdown_content) + + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + { + "role": "user", + "content": ( + f"Simulation requirement (prediction question):\n{simulation_requirement}\n\n" + f"Report:\n{body}" + ), + }, + ] + + raw = self._client.chat_json( + messages=messages, + temperature=0.1, + max_tokens=512, + max_attempts=3, + temperature_step=0.05, + fallback_parser=self._salvage, + ) + + return self._build_signal(raw, report_id, simulation_id) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _trim_report(content: str, max_chars: int = 12_000) -> str: + """Keep the tail of the report (conclusions) if it is very long.""" + if len(content) <= max_chars: + return content + return "…[report truncated for signal extraction]\n\n" + content[-max_chars:] + + def _build_signal( + self, raw: dict, report_id: str, simulation_id: str + ) -> MiroSignal: + """Validate and normalise the raw LLM dict into a MiroSignal.""" + # p_yes + try: + p_yes = float(raw.get("p_yes", 0.5)) + except (TypeError, ValueError): + p_yes = 0.5 + p_yes = max(0.01, min(0.99, p_yes)) + + # confidence + confidence = str(raw.get("confidence", "medium")).lower() + if confidence not in self._VALID_CONFIDENCE: + confidence = "medium" + + # action — recompute from p_yes if missing or invalid + action = str(raw.get("action", "")).lower() + if action not in self._VALID_ACTIONS: + if p_yes > 0.55: + action = "buy_yes" + elif p_yes < 0.45: + action = "buy_no" + else: + action = "hold" + + # regime + regime = str(raw.get("regime", "uncertain")).strip() or "uncertain" + + # summary + summary = str(raw.get("summary", "")).strip() + + # list fields + drivers = [str(d) for d in raw.get("drivers", []) if d] + invalidators = [str(i) for i in raw.get("invalidators", []) if i] + + return MiroSignal( + signal_id=str(uuid.uuid4()), + schema_version=SCHEMA_VERSION, + report_id=report_id, + simulation_id=simulation_id, + generated_at=datetime.now(timezone.utc).isoformat(), + p_yes=p_yes, + confidence=confidence, + action=action, + regime=regime, + summary=summary, + drivers=drivers, + invalidators=invalidators, + ) + + @staticmethod + def _salvage(raw_text: str) -> Optional[dict]: + """ + Last-resort fallback: scan for any float that looks like a probability + and a YES/NO sentiment to construct a minimal signal dict. + """ + import re + + prob_match = re.search(r'\b(0\.\d+|1\.0+|0)\b', raw_text) + if not prob_match: + return None + + try: + p = float(prob_match.group()) + except ValueError: + return None + + text_lower = raw_text.lower() + if "high" in text_lower: + confidence = "high" + elif "low" in text_lower: + confidence = "low" + else: + confidence = "medium" + + return { + "p_yes": p, + "confidence": confidence, + "action": "buy_yes" if p > 0.55 else ("buy_no" if p < 0.45 else "hold"), + "regime": "uncertain", + "summary": "Signal salvaged from partial LLM output.", + "drivers": [], + "invalidators": [], + } diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index cb77f6b6..6bb2c554 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -16,9 +16,8 @@ from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass, field, asdict from datetime import datetime -from openai import OpenAI - from ..config import Config +from ..utils.llm_client import LLMClient from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, t from .zep_entity_reader import EntityNode, ZepEntityReader @@ -235,9 +234,10 @@ class SimulationConfigGenerator: if not self.api_key: raise ValueError("LLM_API_KEY 未配置") - self.client = OpenAI( + self.llm_client = LLMClient( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + model=self.model_name ) def generate_config( @@ -433,78 +433,23 @@ class SimulationConfigGenerator: def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]: """带重试的LLM调用,包含JSON修复逻辑""" - import re - - max_attempts = 3 - last_error = None - - for attempt in range(max_attempts): - try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt} - ], - response_format={"type": "json_object"}, - temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 - # 不设置max_tokens,让LLM自由发挥 - ) - - content = response.choices[0].message.content - finish_reason = response.choices[0].finish_reason - - # 检查是否被截断 - if finish_reason == 'length': - logger.warning(f"LLM输出被截断 (attempt {attempt+1})") - content = self._fix_truncated_json(content) - - # 尝试解析JSON - try: - return json.loads(content) - except json.JSONDecodeError as e: - logger.warning(f"JSON解析失败 (attempt {attempt+1}): {str(e)[:80]}") - - # 尝试修复JSON - fixed = self._try_fix_config_json(content) - if fixed: - return fixed - - last_error = e - - except Exception as e: - logger.warning(f"LLM调用失败 (attempt {attempt+1}): {str(e)[:80]}") - last_error = e - import time - time.sleep(2 * (attempt + 1)) - - raise last_error or Exception("LLM调用失败") - - def _fix_truncated_json(self, content: str) -> str: - """修复被截断的JSON""" - content = content.strip() - - # 计算未闭合的括号 - open_braces = content.count('{') - content.count('}') - open_brackets = content.count('[') - content.count(']') - - # 检查是否有未闭合的字符串 - if content and content[-1] not in '",}]': - content += '"' - - # 闭合括号 - content += ']' * open_brackets - content += '}' * open_braces - - return content + return self.llm_client.chat_json( + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + temperature=0.7, + max_tokens=None, + max_attempts=3, + temperature_step=0.1, + fallback_parser=self._try_fix_config_json, + retry_delay_seconds=2.0 + ) def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]: """尝试修复配置JSON""" import re - - # 修复被截断的情况 - content = self._fix_truncated_json(content) - + # 提取JSON部分 json_match = re.search(r'\{[\s\S]*\}', content) if json_match: @@ -988,4 +933,3 @@ class SimulationConfigGenerator: "influence_weight": 1.0 } - diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 6c1a81f4..953327ab 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -4,11 +4,15 @@ LLM客户端封装 """ import json +import time import re -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Callable from openai import OpenAI from ..config import Config +from .logger import get_logger + +logger = get_logger('mirofish.llm_client') class LLMClient: @@ -36,7 +40,7 @@ class LLMClient: self, messages: List[Dict[str, str]], temperature: float = 0.7, - max_tokens: int = 4096, + max_tokens: Optional[int] = 4096, response_format: Optional[Dict] = None ) -> str: """ @@ -55,23 +59,27 @@ class LLMClient: "model": self.model, "messages": messages, "temperature": temperature, - "max_tokens": max_tokens, } + + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens if response_format: kwargs["response_format"] = response_format response = self.client.chat.completions.create(**kwargs) - content = response.choices[0].message.content - # 部分模型(如MiniMax M2.5)会在content中包含思考内容,需要移除 - content = re.sub(r'[\s\S]*?', '', content).strip() - return content - + content = response.choices[0].message.content or "" + return self._clean_response_text(content) + def chat_json( self, messages: List[Dict[str, str]], temperature: float = 0.3, - max_tokens: int = 4096 + max_tokens: Optional[int] = 4096, + max_attempts: int = 1, + temperature_step: float = 0.0, + fallback_parser: Optional[Callable[[str], Optional[Dict[str, Any]]]] = None, + retry_delay_seconds: float = 0.0 ) -> Dict[str, Any]: """ 发送聊天请求并返回JSON @@ -84,20 +92,108 @@ class LLMClient: Returns: 解析后的JSON对象 """ - response = self.chat( - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - response_format={"type": "json_object"} - ) - # 清理markdown代码块标记 - cleaned_response = response.strip() - cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) - cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) - cleaned_response = cleaned_response.strip() + last_error: Optional[Exception] = None + last_response = "" + + for attempt in range(max_attempts): + current_temperature = max(0.0, temperature - (attempt * temperature_step)) + + try: + kwargs = { + "model": self.model, + "messages": messages, + "temperature": current_temperature, + "response_format": {"type": "json_object"} + } + + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + + response = self.client.chat.completions.create(**kwargs) + raw_content = response.choices[0].message.content or "" + finish_reason = response.choices[0].finish_reason + + cleaned_response = self._clean_response_text(raw_content) + if finish_reason == 'length': + logger.warning(f"LLM输出被截断 (attempt {attempt + 1})") + cleaned_response = self._fix_truncated_json(cleaned_response) + + last_response = cleaned_response + + try: + return self._parse_json_response(cleaned_response) + except json.JSONDecodeError as parse_error: + logger.warning(f"JSON解析失败 (attempt {attempt + 1}): {str(parse_error)[:80]}") + + fixed = self._try_fix_json(cleaned_response) + if fixed is not None: + return fixed + + if fallback_parser is not None: + fallback_result = fallback_parser(cleaned_response) + if fallback_result is not None: + return fallback_result + + last_error = parse_error + + except Exception as exc: + logger.warning(f"LLM调用失败 (attempt {attempt + 1}): {str(exc)[:80]}") + last_error = exc + + if attempt < max_attempts - 1 and retry_delay_seconds > 0: + time.sleep(retry_delay_seconds * (attempt + 1)) + + raise ValueError(f"LLM返回的JSON格式无效: {last_response}") from last_error + + def _clean_response_text(self, content: str) -> str: + """清理模型响应中的思考内容和Markdown包裹。""" + cleaned = re.sub(r'[\s\S]*?', '', content).strip() + cleaned = re.sub(r'^```(?:json)?\s*\n?', '', cleaned, flags=re.IGNORECASE) + cleaned = re.sub(r'\n?```\s*$', '', cleaned) + return cleaned.strip() + + def _parse_json_response(self, content: str) -> Dict[str, Any]: + return json.loads(content) + + def _fix_truncated_json(self, content: str) -> str: + """修复被截断的JSON内容。""" + content = content.strip() + + # If the number of unescaped quotes is odd we are inside an open string. + unescaped_quote_count = len(re.findall(r'(? Optional[Dict[str, Any]]: + """尝试从近似JSON内容中恢复结构化对象。""" + content = self._fix_truncated_json(content) + json_match = re.search(r'\{[\s\S]*\}', content) + if not json_match: + return None + + json_str = json_match.group() + + def fix_string_newlines(match: re.Match[str]) -> str: + value = match.group(0) + value = value.replace('\n', ' ').replace('\r', ' ') + value = re.sub(r'\s+', ' ', value) + return value + + json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) try: - return json.loads(cleaned_response) + return json.loads(json_str) except json.JSONDecodeError: - raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") - + json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) + json_str = re.sub(r'\s+', ' ', json_str) + try: + return json.loads(json_str) + except json.JSONDecodeError: + return None diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/services/__init__.py b/backend/tests/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/services/test_signal_extractor.py b/backend/tests/services/test_signal_extractor.py new file mode 100644 index 00000000..70f1d738 --- /dev/null +++ b/backend/tests/services/test_signal_extractor.py @@ -0,0 +1,236 @@ +""" +Tests for SignalExtractor — no real API calls, LLMClient fully mocked. +""" + +import pytest +from unittest.mock import MagicMock, patch + +from app.services.signal_extractor import SignalExtractor, MiroSignal, SCHEMA_VERSION + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_extractor(chat_json_return=None, chat_json_side_effect=None): + """Return a SignalExtractor with a mocked LLMClient.""" + mock_client = MagicMock() + if chat_json_side_effect is not None: + mock_client.chat_json.side_effect = chat_json_side_effect + else: + mock_client.chat_json.return_value = chat_json_return or {} + return SignalExtractor(llm_client=mock_client), mock_client + + +_SAMPLE_REPORT = """ +## Executive Summary +The simulation shows strong consensus forming around a YES outcome. +Seventy-three percent of agents expressed optimism. + +## Key Findings +- Social momentum is strongly positive. +- Counter-narratives remain marginal. + +## Conclusion +The dominant dynamic is consensus formation with high confidence. +""" + +_SAMPLE_REQUIREMENT = "Will the proposal pass by end of Q2 2026?" + +_GOOD_LLM_RESPONSE = { + "p_yes": 0.73, + "confidence": "high", + "action": "buy_yes", + "regime": "consensus_forming", + "summary": "Strong agent consensus supports a YES outcome with high confidence.", + "drivers": ["70%+ agent agreement", "positive social momentum"], + "invalidators": ["marginal counter-narrative", "low information diversity"], +} + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + +class TestExtractHappyPath: + + def test_returns_miro_signal(self): + extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE) + result = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert isinstance(result, MiroSignal) + + def test_fields_match_llm_output(self): + extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.p_yes == pytest.approx(0.73) + assert sig.confidence == "high" + assert sig.action == "buy_yes" + assert sig.regime == "consensus_forming" + assert "YES" in sig.summary or "consensus" in sig.summary.lower() + assert len(sig.drivers) == 2 + assert len(sig.invalidators) == 2 + + def test_metadata_fields(self): + extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE) + sig = extractor.extract("report_abc", "sim_xyz", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.report_id == "report_abc" + assert sig.simulation_id == "sim_xyz" + assert sig.schema_version == SCHEMA_VERSION + assert sig.signal_id # non-empty UUID + assert sig.generated_at # non-empty ISO timestamp + + def test_to_dict_structure(self): + extractor, _ = _make_extractor(_GOOD_LLM_RESPONSE) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + d = sig.to_dict() + assert "thesis" in d + assert set(d["thesis"].keys()) == { + "p_yes", "confidence", "action", "regime", + "summary", "drivers", "invalidators", + } + + def test_llm_called_with_low_temperature(self): + extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE) + extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + call_kwargs = mock_client.chat_json.call_args.kwargs + assert call_kwargs["temperature"] <= 0.2 + + def test_llm_called_with_retries(self): + extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE) + extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + call_kwargs = mock_client.chat_json.call_args.kwargs + assert call_kwargs.get("max_attempts", 1) >= 2 + + def test_simulation_requirement_in_messages(self): + extractor, mock_client = _make_extractor(_GOOD_LLM_RESPONSE) + req = "Will the referendum pass?" + extractor.extract("r1", "s1", _SAMPLE_REPORT, req) + messages = mock_client.chat_json.call_args.kwargs["messages"] + user_content = next(m["content"] for m in messages if m["role"] == "user") + assert req in user_content + + +# --------------------------------------------------------------------------- +# Field validation and normalisation +# --------------------------------------------------------------------------- + +class TestFieldValidation: + + def test_p_yes_clamped_below_zero(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": -0.5}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.p_yes >= 0.01 + + def test_p_yes_clamped_above_one(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 1.5}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.p_yes <= 0.99 + + def test_invalid_confidence_falls_back_to_medium(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "confidence": "very_sure"}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.confidence == "medium" + + def test_invalid_action_recomputed_from_p_yes_buy_yes(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.8, "action": "INVALID"}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.action == "buy_yes" + + def test_invalid_action_recomputed_from_p_yes_buy_no(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.2, "action": "INVALID"}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.action == "buy_no" + + def test_invalid_action_recomputed_from_p_yes_hold(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "p_yes": 0.5, "action": "INVALID"}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.action == "hold" + + def test_missing_regime_defaults_to_uncertain(self): + resp = {k: v for k, v in _GOOD_LLM_RESPONSE.items() if k != "regime"} + extractor, _ = _make_extractor(resp) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.regime == "uncertain" + + def test_empty_drivers_list_accepted(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "drivers": []}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + assert sig.drivers == [] + + def test_non_list_drivers_handled(self): + extractor, _ = _make_extractor({**_GOOD_LLM_RESPONSE, "drivers": "some string"}) + sig = extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) + # Should not crash; string is iterable so each char becomes an item — acceptable + assert isinstance(sig.drivers, list) + + +# --------------------------------------------------------------------------- +# Report trimming +# --------------------------------------------------------------------------- + +class TestReportTrimming: + + def test_short_report_unchanged(self): + short = "Short report content." + result = SignalExtractor._trim_report(short, max_chars=100) + assert result == short + + def test_long_report_trimmed(self): + long_report = "x" * 20_000 + result = SignalExtractor._trim_report(long_report, max_chars=12_000) + assert len(result) < 20_000 + assert "truncated" in result + + def test_trimmed_report_keeps_tail(self): + # The tail (conclusion) is most important for signal extraction + long_report = "A" * 10_000 + "CONCLUSION" + result = SignalExtractor._trim_report(long_report, max_chars=100) + assert "CONCLUSION" in result + + +# --------------------------------------------------------------------------- +# Fallback (_salvage) +# --------------------------------------------------------------------------- + +class TestSalvage: + + def test_salvage_extracts_probability(self): + result = SignalExtractor._salvage("The probability is 0.68 for YES outcome.") + assert result is not None + assert result["p_yes"] == pytest.approx(0.68) + + def test_salvage_returns_none_when_no_probability(self): + assert SignalExtractor._salvage("no numbers here at all") is None + + def test_salvage_sets_action_buy_yes(self): + result = SignalExtractor._salvage("probability 0.80") + assert result["action"] == "buy_yes" + + def test_salvage_sets_action_buy_no(self): + result = SignalExtractor._salvage("probability 0.20") + assert result["action"] == "buy_no" + + def test_salvage_sets_action_hold(self): + result = SignalExtractor._salvage("probability 0.50") + assert result["action"] == "hold" + + def test_salvage_detects_high_confidence(self): + result = SignalExtractor._salvage("high confidence, p=0.72") + assert result["confidence"] == "high" + + def test_salvage_detects_low_confidence(self): + result = SignalExtractor._salvage("low certainty, p=0.30") + assert result["confidence"] == "low" + + +# --------------------------------------------------------------------------- +# LLM failure propagates as ValueError +# --------------------------------------------------------------------------- + +class TestLLMFailure: + + def test_raises_value_error_on_llm_failure(self): + extractor, mock_client = _make_extractor() + mock_client.chat_json.side_effect = ValueError("LLM返回的JSON格式无效: ...") + with pytest.raises(ValueError): + extractor.extract("r1", "s1", _SAMPLE_REPORT, _SAMPLE_REQUIREMENT) diff --git a/backend/tests/utils/__init__.py b/backend/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/tests/utils/test_llm_client.py b/backend/tests/utils/test_llm_client.py new file mode 100644 index 00000000..a59b8943 --- /dev/null +++ b/backend/tests/utils/test_llm_client.py @@ -0,0 +1,299 @@ +""" +Tests for LLMClient utility — no real API calls, all OpenAI interactions mocked. +""" + +import json +import pytest +from unittest.mock import MagicMock, patch, call + +from app.utils.llm_client import LLMClient + + +# --------------------------------------------------------------------------- +# Helpers to build fake OpenAI response objects +# --------------------------------------------------------------------------- + +def _make_response(content: str, finish_reason: str = "stop"): + choice = MagicMock() + choice.message.content = content + choice.finish_reason = finish_reason + resp = MagicMock() + resp.choices = [choice] + return resp + + +def _make_client(responses): + """ + Return a patched LLMClient whose underlying OpenAI client returns + *responses* in order on successive .chat.completions.create() calls. + """ + with patch("app.utils.llm_client.OpenAI") as MockOpenAI: + mock_openai_instance = MagicMock() + mock_openai_instance.chat.completions.create.side_effect = responses + MockOpenAI.return_value = mock_openai_instance + + client = LLMClient(api_key="test-key", base_url="http://localhost", model="test-model") + # Expose the underlying mock for assertions + client._mock_create = mock_openai_instance.chat.completions.create + return client + + +# --------------------------------------------------------------------------- +# _clean_response_text +# --------------------------------------------------------------------------- + +class TestCleanResponseText: + + def setup_method(self): + with patch("app.utils.llm_client.OpenAI"): + self.client = LLMClient(api_key="k", base_url="u", model="m") + + def test_passthrough_plain_json(self): + raw = '{"a": 1}' + assert self.client._clean_response_text(raw) == '{"a": 1}' + + def test_strips_think_tags(self): + raw = 'internal reasoning{"a": 1}' + assert self.client._clean_response_text(raw) == '{"a": 1}' + + def test_strips_multiline_think_tags(self): + raw = '\nline1\nline2\n\n{"b": 2}' + assert self.client._clean_response_text(raw) == '{"b": 2}' + + def test_strips_json_markdown_fence(self): + raw = '```json\n{"c": 3}\n```' + assert self.client._clean_response_text(raw) == '{"c": 3}' + + def test_strips_plain_markdown_fence(self): + raw = '```\n{"d": 4}\n```' + assert self.client._clean_response_text(raw) == '{"d": 4}' + + def test_strips_think_and_fence_combined(self): + raw = 'reasoning\n```json\n{"e": 5}\n```' + assert self.client._clean_response_text(raw) == '{"e": 5}' + + def test_empty_string(self): + assert self.client._clean_response_text("") == "" + + def test_no_fence_no_think(self): + raw = ' {"f": 6} ' + assert self.client._clean_response_text(raw) == '{"f": 6}' + + +# --------------------------------------------------------------------------- +# _fix_truncated_json +# --------------------------------------------------------------------------- + +class TestFixTruncatedJson: + + def setup_method(self): + with patch("app.utils.llm_client.OpenAI"): + self.client = LLMClient(api_key="k", base_url="u", model="m") + + def test_closes_one_brace(self): + truncated = '{"key": "val' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["key"] == "val" + + def test_closes_nested_braces(self): + truncated = '{"outer": {"inner": "x"' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["outer"]["inner"] == "x" + + def test_closes_open_array(self): + truncated = '{"list": [1, 2' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["list"] == [1, 2] + + def test_closes_array_and_brace(self): + truncated = '{"a": [1, 2, 3' + fixed = self.client._fix_truncated_json(truncated) + result = json.loads(fixed) + assert result["a"] == [1, 2, 3] + + def test_already_valid_unchanged(self): + valid = '{"x": 1}' + fixed = self.client._fix_truncated_json(valid) + assert json.loads(fixed) == {"x": 1} + + def test_trailing_dangling_value_gets_quote(self): + # ends mid-string without closing quote + truncated = '{"k": "incomplete' + fixed = self.client._fix_truncated_json(truncated) + # Should be parseable after repair + result = json.loads(fixed) + assert "k" in result + + +# --------------------------------------------------------------------------- +# _try_fix_json +# --------------------------------------------------------------------------- + +class TestTryFixJson: + + def setup_method(self): + with patch("app.utils.llm_client.OpenAI"): + self.client = LLMClient(api_key="k", base_url="u", model="m") + + def test_returns_valid_json(self): + content = '{"name": "Alice", "age": 30}' + result = self.client._try_fix_json(content) + assert result == {"name": "Alice", "age": 30} + + def test_extracts_json_from_surrounding_text(self): + content = 'Here is the result: {"score": 42} end.' + result = self.client._try_fix_json(content) + assert result is not None + assert result["score"] == 42 + + def test_fixes_newlines_inside_string_values(self): + # Literal newline inside a JSON string value is invalid JSON + content = '{"desc": "line one\nline two"}' + result = self.client._try_fix_json(content) + assert result is not None + assert "desc" in result + + def test_returns_none_for_no_json_object(self): + assert self.client._try_fix_json("no json here at all") is None + + def test_returns_none_for_empty_string(self): + assert self.client._try_fix_json("") is None + + def test_recovers_truncated_object(self): + # Missing closing brace + content = '{"city": "Beijing", "pop": 21' + result = self.client._try_fix_json(content) + assert result is not None + assert result["city"] == "Beijing" + + +# --------------------------------------------------------------------------- +# chat_json — retry behavior +# --------------------------------------------------------------------------- + +class TestChatJsonRetry: + + # --- succeeds on first attempt --- + + def test_success_first_attempt(self): + payload = {"status": "ok"} + client = _make_client([_make_response(json.dumps(payload))]) + result = client.chat_json([{"role": "user", "content": "hi"}]) + assert result == payload + assert client._mock_create.call_count == 1 + + # --- temperature backoff --- + + def test_temperature_decreases_across_retries(self): + bad = _make_response("not json at all {{{{") + good = _make_response('{"ok": true}') + client = _make_client([bad, bad, good]) + + result = client.chat_json( + [{"role": "user", "content": "hi"}], + temperature=0.9, + temperature_step=0.3, + max_attempts=3, + ) + assert result == {"ok": True} + + calls = client._mock_create.call_args_list + assert calls[0].kwargs["temperature"] == pytest.approx(0.9) + assert calls[1].kwargs["temperature"] == pytest.approx(0.6) + assert calls[2].kwargs["temperature"] == pytest.approx(0.3) + + def test_temperature_never_goes_below_zero(self): + good = _make_response('{"x": 1}') + client = _make_client([good]) + client.chat_json( + [{"role": "user", "content": "hi"}], + temperature=0.1, + temperature_step=0.5, + max_attempts=1, + ) + calls = client._mock_create.call_args_list + assert calls[0].kwargs["temperature"] == pytest.approx(0.1) + + # --- fallback_parser --- + + def test_fallback_parser_called_on_bad_json(self): + bad = _make_response("this is not json") + client = _make_client([bad]) + + fallback = MagicMock(return_value={"rescued": True}) + result = client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=1, + fallback_parser=fallback, + ) + assert result == {"rescued": True} + fallback.assert_called_once() + + def test_fallback_parser_returning_none_does_not_short_circuit(self): + bad = _make_response("still not json ][") + good = _make_response('{"second": "attempt"}') + client = _make_client([bad, good]) + + fallback = MagicMock(return_value=None) + result = client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=2, + fallback_parser=fallback, + ) + assert result == {"second": "attempt"} + + # --- raises ValueError after all attempts fail --- + + def test_raises_after_all_attempts_fail(self): + bad = _make_response("invalid {{{ json") + client = _make_client([bad, bad, bad]) + + with pytest.raises(ValueError, match="LLM返回的JSON格式无效"): + client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=3, + ) + assert client._mock_create.call_count == 3 + + def test_raises_after_single_attempt(self): + bad = _make_response("nope") + client = _make_client([bad]) + + with pytest.raises(ValueError): + client.chat_json([{"role": "user", "content": "q"}], max_attempts=1) + + # --- finish_reason == 'length' triggers truncation repair --- + + def test_truncated_output_is_repaired(self): + truncated_json = '{"items": [1, 2, 3' + client = _make_client([_make_response(truncated_json, finish_reason="length")]) + result = client.chat_json([{"role": "user", "content": "q"}]) + assert result["items"] == [1, 2, 3] + + # --- API exception counts as a failed attempt --- + + def test_api_exception_retried(self): + from openai import APIError + exc = Exception("network failure") + good = _make_response('{"recovered": true}') + client = _make_client([exc, good]) + + result = client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=2, + ) + assert result == {"recovered": True} + assert client._mock_create.call_count == 2 + + def test_api_exception_all_attempts_raises_value_error(self): + exc = Exception("always fails") + client = _make_client([exc, exc]) + + with pytest.raises(ValueError): + client.chat_json( + [{"role": "user", "content": "q"}], + max_attempts=2, + )