diff --git a/backend/app/graph/graphiti_backend.py b/backend/app/graph/graphiti_backend.py index 40b31eac..e8f9576f 100644 --- a/backend/app/graph/graphiti_backend.py +++ b/backend/app/graph/graphiti_backend.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional from .base import GraphBackend from ..config import Config from ..utils.logger import get_logger +from ..utils.llm_client import parse_azure_url def _neo4j_val(v: Any) -> Any: @@ -137,21 +138,6 @@ class GraphitiBackend(GraphBackend): self._edge_defs: Dict[str, Any] = {} self._client = self._build_client() - @staticmethod - def _parse_azure_url(raw_url: str): - """Strip /chat/completions or /embeddings suffix from Azure endpoint URLs. - Returns (clean_base_url, default_query_dict).""" - from urllib.parse import urlparse, parse_qs, urlunparse - default_query = {} - if raw_url and ('/chat/completions' in raw_url or '/embeddings' in raw_url): - parsed = urlparse(raw_url) - qs = parse_qs(parsed.query) - if 'api-version' in qs: - default_query['api-version'] = qs['api-version'][0] - clean_path = parsed.path.replace('/chat/completions', '').replace('/embeddings', '').rstrip('/') - raw_url = urlunparse(parsed._replace(path=clean_path, query='')) - return raw_url, default_query - def _build_client(self): from graphiti_core import Graphiti from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient @@ -160,9 +146,9 @@ class GraphitiBackend(GraphBackend): from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient from openai import AsyncOpenAI - llm_base_url, llm_query = self._parse_azure_url(Config.LLM_BASE_URL) - small_base_url, small_query = self._parse_azure_url(Config.LLM_SMALL_BASE_URL) - embed_base_url, embed_query = self._parse_azure_url(Config.LLM_EMBED_BASE_URL) + llm_base_url, llm_query = parse_azure_url(Config.LLM_BASE_URL) + small_base_url, small_query = parse_azure_url(Config.LLM_SMALL_BASE_URL) + embed_base_url, embed_query = parse_azure_url(Config.LLM_EMBED_BASE_URL) # Pre-built async clients so api-version is passed as default_query (Azure requirement) async_llm_client = AsyncOpenAI( diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 2670454e..7a42b7b6 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -22,6 +22,7 @@ from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, get_locale, set_locale, t +from ..utils.llm_client import parse_azure_url from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.oasis_profile') @@ -190,15 +191,17 @@ class OasisProfileGenerator: graph_id: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY - self.base_url = base_url or Config.LLM_BASE_URL + raw_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME - + if not self.api_key: raise ValueError("LLM_API_KEY is not configured") + self.base_url, _default_query = parse_azure_url(raw_url) self.client = OpenAI( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + default_query=_default_query if _default_query else None ) # Zep client for enriching context via retrieval diff --git a/backend/app/services/simulation_config_generator.py b/backend/app/services/simulation_config_generator.py index 71d4be10..ad4b38bb 100644 --- a/backend/app/services/simulation_config_generator.py +++ b/backend/app/services/simulation_config_generator.py @@ -23,6 +23,7 @@ from openai import OpenAI from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, t +from ..utils.llm_client import parse_azure_url from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.simulation_config') @@ -231,15 +232,17 @@ class SimulationConfigGenerator: model_name: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY - self.base_url = base_url or Config.LLM_BASE_URL + raw_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: raise ValueError("LLM_API_KEY is not configured") + self.base_url, _default_query = parse_azure_url(raw_url) self.client = OpenAI( api_key=self.api_key, - base_url=self.base_url + base_url=self.base_url, + default_query=_default_query if _default_query else None ) def generate_config( diff --git a/backend/app/utils/llm_client.py b/backend/app/utils/llm_client.py index 06a839af..f6b23bc1 100644 --- a/backend/app/utils/llm_client.py +++ b/backend/app/utils/llm_client.py @@ -12,6 +12,29 @@ from openai import OpenAI from ..config import Config +def parse_azure_url(raw_url: str): + """Strip /chat/completions or /embeddings suffix from Azure endpoint URLs. + + Azure Portal gives full URLs like: + https://.cognitiveservices.azure.com/openai/deployments//chat/completions?api-version=... + The OpenAI SDK expects a base_url and appends /chat/completions itself. + + Returns (clean_base_url, default_query_dict). + """ + default_query: Dict[str, str] = {} + if raw_url and ('/chat/completions' in raw_url or '/embeddings' in raw_url): + parsed = urlparse(raw_url) + qs = parse_qs(parsed.query) + if 'api-version' in qs: + default_query['api-version'] = qs['api-version'][0] + clean_path = (parsed.path + .replace('/chat/completions', '') + .replace('/embeddings', '') + .rstrip('/')) + raw_url = urlunparse(parsed._replace(path=clean_path, query='')) + return raw_url, default_query + + class LLMClient: """LLM client""" @@ -32,18 +55,7 @@ class LLMClient: if (Config.LLM_PROVIDER or "").lower() == "gemini" and not base_url: raw_url = "https://generativelanguage.googleapis.com/v1beta/openai/" - # Azure Portal provides full endpoint URLs like: - # https://.cognitiveservices.azure.com/openai/deployments//chat/completions?api-version=... - # The OpenAI SDK expects a base_url and appends /chat/completions itself, - # so we strip that suffix and extract api-version as a default query param. - default_query: Dict[str, str] = {} - if raw_url and '/chat/completions' in raw_url: - parsed = urlparse(raw_url) - qs = parse_qs(parsed.query) - if 'api-version' in qs: - default_query['api-version'] = qs['api-version'][0] - clean_path = parsed.path.replace('/chat/completions', '').rstrip('/') - raw_url = urlunparse(parsed._replace(path=clean_path, query='')) + raw_url, default_query = parse_azure_url(raw_url) self.base_url = raw_url self.client = OpenAI(