fix(report_agent): handle API token overflow crash with context length error recovery

This commit is contained in:
albert 2026-03-11 10:37:06 +08:00
parent 985f89f49a
commit 20c830af12
4 changed files with 77 additions and 21 deletions

View File

@ -4,6 +4,8 @@
LLM_API_KEY=your_api_key_here LLM_API_KEY=your_api_key_here
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
LLM_MODEL_NAME=qwen-plus LLM_MODEL_NAME=qwen-plus
# LLM最大输出token数根据模型能力调整默认4096
# LLM_MAX_TOKENS=4096
# ===== ZEP记忆图谱配置 ===== # ===== ZEP记忆图谱配置 =====
# 每月免费额度即可支撑简单使用https://app.getzep.com/ # 每月免费额度即可支撑简单使用https://app.getzep.com/

View File

@ -31,6 +31,7 @@ class Config:
LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_API_KEY = os.environ.get('LLM_API_KEY')
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
LLM_MAX_TOKENS = int(os.environ.get('LLM_MAX_TOKENS', '4096'))
# Zep配置 # Zep配置
ZEP_API_KEY = os.environ.get('ZEP_API_KEY') ZEP_API_KEY = os.environ.get('ZEP_API_KEY')

View File

@ -1294,11 +1294,17 @@ class ReportAgent:
for iteration in range(max_iterations): for iteration in range(max_iterations):
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating", "generating",
int((iteration / max_iterations) * 100), int((iteration / max_iterations) * 100),
f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})" f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})"
) )
# 防止消息历史无限增长导致上下文溢出
# 保留 system + user prompt前2条和最近的对话轮次
if len(messages) > 14:
messages = messages[:2] + messages[-12:]
logger.info(f"章节 {section.title}: 消息历史已裁剪至 {len(messages)} 条以防止上下文溢出")
# 调用LLM # 调用LLM
response = self.llm.chat( response = self.llm.chat(
messages=messages, messages=messages,
@ -1502,7 +1508,11 @@ class ReportAgent:
# 达到最大迭代次数,强制生成内容 # 达到最大迭代次数,强制生成内容
logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成") logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成")
messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG}) messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG})
# 裁剪消息以防止强制收尾时上下文溢出
if len(messages) > 14:
messages = messages[:2] + messages[-12:]
response = self.llm.chat( response = self.llm.chat(
messages=messages, messages=messages,
temperature=0.5, temperature=0.5,

View File

@ -4,16 +4,19 @@ LLM客户端封装
""" """
import json import json
import logging
import re import re
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from openai import OpenAI from openai import OpenAI, BadRequestError, APIError
from ..config import Config from ..config import Config
logger = logging.getLogger(__name__)
class LLMClient: class LLMClient:
"""LLM客户端""" """LLM客户端"""
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Optional[str] = None,
@ -23,64 +26,105 @@ class LLMClient:
self.api_key = api_key or Config.LLM_API_KEY self.api_key = api_key or Config.LLM_API_KEY
self.base_url = base_url or Config.LLM_BASE_URL self.base_url = base_url or Config.LLM_BASE_URL
self.model = model or Config.LLM_MODEL_NAME self.model = model or Config.LLM_MODEL_NAME
self.default_max_tokens = Config.LLM_MAX_TOKENS
if not self.api_key: if not self.api_key:
raise ValueError("LLM_API_KEY 未配置") raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI( self.client = OpenAI(
api_key=self.api_key, api_key=self.api_key,
base_url=self.base_url base_url=self.base_url
) )
@staticmethod
def _trim_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
当消息列表过长导致上下文溢出时裁剪中间的历史消息
保留第一条system prompt和最后几条消息移除中间部分
"""
if len(messages) <= 4:
return messages
# 保留 system prompt第1条+ 最近3轮对话最后6条
keep_tail = min(6, len(messages) - 1)
trimmed = [messages[0]] + messages[-keep_tail:]
logger.warning(
f"消息上下文过长,已裁剪: {len(messages)} -> {len(trimmed)} 条消息"
)
return trimmed
def chat( def chat(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 4096, max_tokens: Optional[int] = None,
response_format: Optional[Dict] = None response_format: Optional[Dict] = None
) -> str: ) -> str:
""" """
发送聊天请求 发送聊天请求
Args: Args:
messages: 消息列表 messages: 消息列表
temperature: 温度参数 temperature: 温度参数
max_tokens: 最大token数 max_tokens: 最大token数默认使用配置值
response_format: 响应格式如JSON模式 response_format: 响应格式如JSON模式
Returns: Returns:
模型响应文本 模型响应文本
""" """
if max_tokens is None:
max_tokens = self.default_max_tokens
kwargs = { kwargs = {
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_tokens": max_tokens,
} }
if response_format: if response_format:
kwargs["response_format"] = response_format kwargs["response_format"] = response_format
response = self.client.chat.completions.create(**kwargs) try:
response = self.client.chat.completions.create(**kwargs)
except BadRequestError as e:
error_msg = str(e).lower()
# 处理上下文长度超限错误:自动裁剪消息后重试一次
if "context_length" in error_msg or "maximum context" in error_msg or "token" in error_msg:
logger.warning(
f"上下文长度超限,尝试裁剪消息后重试: {e}"
)
trimmed_messages = self._trim_messages(messages)
if len(trimmed_messages) == len(messages):
# 无法进一步裁剪,向上抛出
raise
kwargs["messages"] = trimmed_messages
response = self.client.chat.completions.create(**kwargs)
else:
raise
except APIError as e:
logger.error(f"LLM API 调用失败: {e}")
raise
content = response.choices[0].message.content content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除 # 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip() content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
return content return content
def chat_json( def chat_json(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
temperature: float = 0.3, temperature: float = 0.3,
max_tokens: int = 4096 max_tokens: Optional[int] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
发送聊天请求并返回JSON 发送聊天请求并返回JSON
Args: Args:
messages: 消息列表 messages: 消息列表
temperature: 温度参数 temperature: 温度参数
max_tokens: 最大token数 max_tokens: 最大token数默认使用配置值
Returns: Returns:
解析后的JSON对象 解析后的JSON对象
""" """
@ -100,4 +144,3 @@ class LLMClient:
return json.loads(cleaned_response) return json.loads(cleaned_response)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")