fix(report_agent): handle API token overflow crash with context length error recovery
This commit is contained in:
parent
985f89f49a
commit
20c830af12
|
|
@ -4,6 +4,8 @@
|
||||||
LLM_API_KEY=your_api_key_here
|
LLM_API_KEY=your_api_key_here
|
||||||
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||||
LLM_MODEL_NAME=qwen-plus
|
LLM_MODEL_NAME=qwen-plus
|
||||||
|
# LLM最大输出token数(根据模型能力调整,默认4096)
|
||||||
|
# LLM_MAX_TOKENS=4096
|
||||||
|
|
||||||
# ===== ZEP记忆图谱配置 =====
|
# ===== ZEP记忆图谱配置 =====
|
||||||
# 每月免费额度即可支撑简单使用:https://app.getzep.com/
|
# 每月免费额度即可支撑简单使用:https://app.getzep.com/
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class Config:
|
||||||
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
||||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||||
|
LLM_MAX_TOKENS = int(os.environ.get('LLM_MAX_TOKENS', '4096'))
|
||||||
|
|
||||||
# Zep配置
|
# Zep配置
|
||||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||||
|
|
|
||||||
|
|
@ -1294,11 +1294,17 @@ class ReportAgent:
|
||||||
for iteration in range(max_iterations):
|
for iteration in range(max_iterations):
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
"generating",
|
"generating",
|
||||||
int((iteration / max_iterations) * 100),
|
int((iteration / max_iterations) * 100),
|
||||||
f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})"
|
f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 防止消息历史无限增长导致上下文溢出
|
||||||
|
# 保留 system + user prompt(前2条)和最近的对话轮次
|
||||||
|
if len(messages) > 14:
|
||||||
|
messages = messages[:2] + messages[-12:]
|
||||||
|
logger.info(f"章节 {section.title}: 消息历史已裁剪至 {len(messages)} 条以防止上下文溢出")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm.chat(
|
response = self.llm.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -1502,7 +1508,11 @@ class ReportAgent:
|
||||||
# 达到最大迭代次数,强制生成内容
|
# 达到最大迭代次数,强制生成内容
|
||||||
logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成")
|
logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成")
|
||||||
messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG})
|
messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG})
|
||||||
|
|
||||||
|
# 裁剪消息以防止强制收尾时上下文溢出
|
||||||
|
if len(messages) > 14:
|
||||||
|
messages = messages[:2] + messages[-12:]
|
||||||
|
|
||||||
response = self.llm.chat(
|
response = self.llm.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,19 @@ LLM客户端封装
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from openai import OpenAI
|
from openai import OpenAI, BadRequestError, APIError
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""LLM客户端"""
|
"""LLM客户端"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
|
@ -23,64 +26,105 @@ class LLMClient:
|
||||||
self.api_key = api_key or Config.LLM_API_KEY
|
self.api_key = api_key or Config.LLM_API_KEY
|
||||||
self.base_url = base_url or Config.LLM_BASE_URL
|
self.base_url = base_url or Config.LLM_BASE_URL
|
||||||
self.model = model or Config.LLM_MODEL_NAME
|
self.model = model or Config.LLM_MODEL_NAME
|
||||||
|
self.default_max_tokens = Config.LLM_MAX_TOKENS
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("LLM_API_KEY 未配置")
|
raise ValueError("LLM_API_KEY 未配置")
|
||||||
|
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=self.base_url
|
base_url=self.base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _trim_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
当消息列表过长导致上下文溢出时,裁剪中间的历史消息。
|
||||||
|
保留第一条(system prompt)和最后几条消息,移除中间部分。
|
||||||
|
"""
|
||||||
|
if len(messages) <= 4:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# 保留 system prompt(第1条)+ 最近3轮对话(最后6条)
|
||||||
|
keep_tail = min(6, len(messages) - 1)
|
||||||
|
trimmed = [messages[0]] + messages[-keep_tail:]
|
||||||
|
logger.warning(
|
||||||
|
f"消息上下文过长,已裁剪: {len(messages)} -> {len(trimmed)} 条消息"
|
||||||
|
)
|
||||||
|
return trimmed
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 4096,
|
max_tokens: Optional[int] = None,
|
||||||
response_format: Optional[Dict] = None
|
response_format: Optional[Dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
发送聊天请求
|
发送聊天请求
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
temperature: 温度参数
|
temperature: 温度参数
|
||||||
max_tokens: 最大token数
|
max_tokens: 最大token数(默认使用配置值)
|
||||||
response_format: 响应格式(如JSON模式)
|
response_format: 响应格式(如JSON模式)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型响应文本
|
模型响应文本
|
||||||
"""
|
"""
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = self.default_max_tokens
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
if response_format:
|
if response_format:
|
||||||
kwargs["response_format"] = response_format
|
kwargs["response_format"] = response_format
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**kwargs)
|
try:
|
||||||
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
|
except BadRequestError as e:
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
# 处理上下文长度超限错误:自动裁剪消息后重试一次
|
||||||
|
if "context_length" in error_msg or "maximum context" in error_msg or "token" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"上下文长度超限,尝试裁剪消息后重试: {e}"
|
||||||
|
)
|
||||||
|
trimmed_messages = self._trim_messages(messages)
|
||||||
|
if len(trimmed_messages) == len(messages):
|
||||||
|
# 无法进一步裁剪,向上抛出
|
||||||
|
raise
|
||||||
|
kwargs["messages"] = trimmed_messages
|
||||||
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
except APIError as e:
|
||||||
|
logger.error(f"LLM API 调用失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||||
return content
|
return content
|
||||||
|
|
||||||
def chat_json(
|
def chat_json(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
temperature: float = 0.3,
|
temperature: float = 0.3,
|
||||||
max_tokens: int = 4096
|
max_tokens: Optional[int] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
发送聊天请求并返回JSON
|
发送聊天请求并返回JSON
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
temperature: 温度参数
|
temperature: 温度参数
|
||||||
max_tokens: 最大token数
|
max_tokens: 最大token数(默认使用配置值)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
解析后的JSON对象
|
解析后的JSON对象
|
||||||
"""
|
"""
|
||||||
|
|
@ -100,4 +144,3 @@ class LLMClient:
|
||||||
return json.loads(cleaned_response)
|
return json.loads(cleaned_response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue