fix(report_agent): handle API token overflow crash with context length error recovery
This commit is contained in:
parent
985f89f49a
commit
20c830af12
|
|
@ -4,6 +4,8 @@
|
|||
LLM_API_KEY=your_api_key_here
|
||||
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
LLM_MODEL_NAME=qwen-plus
|
||||
# LLM最大输出token数(根据模型能力调整,默认4096)
|
||||
# LLM_MAX_TOKENS=4096
|
||||
|
||||
# ===== ZEP记忆图谱配置 =====
|
||||
# 每月免费额度即可支撑简单使用:https://app.getzep.com/
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class Config:
|
|||
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||
LLM_MAX_TOKENS = int(os.environ.get('LLM_MAX_TOKENS', '4096'))
|
||||
|
||||
# Zep配置
|
||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||
|
|
|
|||
|
|
@ -1294,11 +1294,17 @@ class ReportAgent:
|
|||
for iteration in range(max_iterations):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating",
|
||||
"generating",
|
||||
int((iteration / max_iterations) * 100),
|
||||
f"深度检索与撰写中 ({tool_calls_count}/{self.MAX_TOOL_CALLS_PER_SECTION})"
|
||||
)
|
||||
|
||||
|
||||
# 防止消息历史无限增长导致上下文溢出
|
||||
# 保留 system + user prompt(前2条)和最近的对话轮次
|
||||
if len(messages) > 14:
|
||||
messages = messages[:2] + messages[-12:]
|
||||
logger.info(f"章节 {section.title}: 消息历史已裁剪至 {len(messages)} 条以防止上下文溢出")
|
||||
|
||||
# 调用LLM
|
||||
response = self.llm.chat(
|
||||
messages=messages,
|
||||
|
|
@ -1502,7 +1508,11 @@ class ReportAgent:
|
|||
# 达到最大迭代次数,强制生成内容
|
||||
logger.warning(f"章节 {section.title} 达到最大迭代次数,强制生成")
|
||||
messages.append({"role": "user", "content": REACT_FORCE_FINAL_MSG})
|
||||
|
||||
|
||||
# 裁剪消息以防止强制收尾时上下文溢出
|
||||
if len(messages) > 14:
|
||||
messages = messages[:2] + messages[-12:]
|
||||
|
||||
response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=0.5,
|
||||
|
|
|
|||
|
|
@ -4,16 +4,19 @@ LLM客户端封装
|
|||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, BadRequestError, APIError
|
||||
|
||||
from ..config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM客户端"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
|
|
@ -23,64 +26,105 @@ class LLMClient:
|
|||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model = model or Config.LLM_MODEL_NAME
|
||||
|
||||
self.default_max_tokens = Config.LLM_MAX_TOKENS
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _trim_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
当消息列表过长导致上下文溢出时,裁剪中间的历史消息。
|
||||
保留第一条(system prompt)和最后几条消息,移除中间部分。
|
||||
"""
|
||||
if len(messages) <= 4:
|
||||
return messages
|
||||
|
||||
# 保留 system prompt(第1条)+ 最近3轮对话(最后6条)
|
||||
keep_tail = min(6, len(messages) - 1)
|
||||
trimmed = [messages[0]] + messages[-keep_tail:]
|
||||
logger.warning(
|
||||
f"消息上下文过长,已裁剪: {len(messages)} -> {len(trimmed)} 条消息"
|
||||
)
|
||||
return trimmed
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
max_tokens: Optional[int] = None,
|
||||
response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
发送聊天请求
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
max_tokens: 最大token数(默认使用配置值)
|
||||
response_format: 响应格式(如JSON模式)
|
||||
|
||||
|
||||
Returns:
|
||||
模型响应文本
|
||||
"""
|
||||
if max_tokens is None:
|
||||
max_tokens = self.default_max_tokens
|
||||
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
except BadRequestError as e:
|
||||
error_msg = str(e).lower()
|
||||
# 处理上下文长度超限错误:自动裁剪消息后重试一次
|
||||
if "context_length" in error_msg or "maximum context" in error_msg or "token" in error_msg:
|
||||
logger.warning(
|
||||
f"上下文长度超限,尝试裁剪消息后重试: {e}"
|
||||
)
|
||||
trimmed_messages = self._trim_messages(messages)
|
||||
if len(trimmed_messages) == len(messages):
|
||||
# 无法进一步裁剪,向上抛出
|
||||
raise
|
||||
kwargs["messages"] = trimmed_messages
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
else:
|
||||
raise
|
||||
except APIError as e:
|
||||
logger.error(f"LLM API 调用失败: {e}")
|
||||
raise
|
||||
|
||||
content = response.choices[0].message.content
|
||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
return content
|
||||
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 4096
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并返回JSON
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
max_tokens: 最大token数(默认使用配置值)
|
||||
|
||||
Returns:
|
||||
解析后的JSON对象
|
||||
"""
|
||||
|
|
@ -100,4 +144,3 @@ class LLMClient:
|
|||
return json.loads(cleaned_response)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue