125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
"""
|
||
LLM客户端封装
|
||
统一使用OpenAI格式调用
|
||
"""
|
||
|
||
import json
|
||
import re
|
||
from typing import Optional, Dict, Any, List
|
||
from openai import OpenAI
|
||
|
||
from ..config import Config
|
||
|
||
|
||
class LLMClient:
|
||
"""LLM客户端"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
model: Optional[str] = None
|
||
):
|
||
self.api_key = api_key or Config.LLM_API_KEY
|
||
self.base_url = base_url or Config.LLM_BASE_URL
|
||
self.model = model or Config.LLM_MODEL_NAME
|
||
|
||
if not self.api_key:
|
||
raise ValueError("LLM_API_KEY 未配置")
|
||
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
|
||
def chat(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 4096,
|
||
response_format: Optional[Dict] = None
|
||
) -> str:
|
||
"""
|
||
发送聊天请求
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
response_format: 响应格式(如JSON模式)
|
||
|
||
Returns:
|
||
模型响应文本
|
||
"""
|
||
kwargs = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"max_completion_tokens": max_tokens,
|
||
}
|
||
kwargs["temperature"] = temperature
|
||
|
||
if response_format:
|
||
kwargs["response_format"] = response_format
|
||
|
||
response = self.client.chat.completions.create(**kwargs)
|
||
content = response.choices[0].message.content
|
||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||
return content
|
||
|
||
def chat_json(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
temperature: float = 0.3,
|
||
max_tokens: int = 16384
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
发送聊天请求并返回JSON
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
temperature: 温度参数
|
||
max_tokens: 最大token数
|
||
|
||
Returns:
|
||
解析后的JSON对象
|
||
"""
|
||
response = self.chat(
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
)
|
||
# 清理markdown代码块标记
|
||
cleaned_response = response.strip()
|
||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
||
cleaned_response = cleaned_response.strip()
|
||
|
||
try:
|
||
return json.loads(cleaned_response)
|
||
except json.JSONDecodeError:
|
||
# Aggressive JSON repair for truncated responses
|
||
repaired = cleaned_response
|
||
# Try progressively trimming from the end until valid JSON
|
||
for trim in range(min(len(repaired), 2000)):
|
||
candidate = repaired[:len(repaired) - trim].rstrip(', \n\t\r')
|
||
# Remove incomplete string at end
|
||
if candidate and candidate[-1] not in ']}}"':
|
||
continue
|
||
open_braces = candidate.count('{') - candidate.count('}')
|
||
open_brackets = candidate.count('[') - candidate.count(']')
|
||
fixed = candidate
|
||
for _ in range(open_brackets):
|
||
fixed += ']'
|
||
for _ in range(open_braces):
|
||
fixed += '}'
|
||
try:
|
||
result = json.loads(fixed)
|
||
import logging
|
||
logging.getLogger(__name__).warning(f"JSON repaired by trimming {trim} chars")
|
||
return result
|
||
except json.JSONDecodeError:
|
||
continue
|
||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response[:500]}")
|
||
|