93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
"""LLM client wrapper.
|
|
|
|
All providers are called through the OpenAI-compatible API surface.
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from openai import OpenAI
|
|
|
|
from ..config import Config
|
|
|
|
|
|
class LLMClient:
|
|
"""Thin wrapper around the OpenAI-compatible chat completions API."""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
):
|
|
self.api_key = api_key or Config.LLM_API_KEY
|
|
self.base_url = base_url or Config.LLM_BASE_URL
|
|
self.model = model or Config.LLM_MODEL_NAME
|
|
|
|
if not self.api_key:
|
|
raise ValueError("LLM_API_KEY 未配置")
|
|
|
|
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
|
|
def chat(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
response_format: Optional[Dict] = None,
|
|
) -> str:
|
|
"""Send a chat completion request.
|
|
|
|
Args:
|
|
messages: Chat messages in OpenAI format.
|
|
temperature: Sampling temperature.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
response_format: Optional response format hint (e.g. JSON mode).
|
|
|
|
Returns:
|
|
The assistant's response text.
|
|
"""
|
|
kwargs = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
|
|
if response_format:
|
|
kwargs["response_format"] = response_format
|
|
|
|
response = self.client.chat.completions.create(**kwargs)
|
|
content = response.choices[0].message.content
|
|
# Some reasoning models (e.g. MiniMax M2.5) embed <think>...</think> blocks; strip them.
|
|
content = re.sub(r"<think>[\s\S]*?</think>", "", content).strip()
|
|
return content
|
|
|
|
def chat_json(self, messages, temperature=0.3, max_tokens=4096):
|
|
try:
|
|
response = self.chat(
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
response_format={"type": "json_object"},
|
|
)
|
|
except Exception:
|
|
# Retry without response_format for unsupported providers
|
|
response = self.chat(
|
|
messages=messages, temperature=temperature, max_tokens=max_tokens
|
|
)
|
|
|
|
# Strip surrounding markdown code-fence markers if present.
|
|
cleaned_response = response.strip()
|
|
cleaned_response = re.sub(
|
|
r"^```(?:json)?\s*\n?", "", cleaned_response, flags=re.IGNORECASE
|
|
)
|
|
cleaned_response = re.sub(r"\n?```\s*$", "", cleaned_response)
|
|
cleaned_response = cleaned_response.strip()
|
|
|
|
try:
|
|
return json.loads(cleaned_response)
|
|
except json.JSONDecodeError:
|
|
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|