171 lines
6.5 KiB
Python
171 lines
6.5 KiB
Python
"""Ollama-backed cross-encoder reranker for Graphiti search.
|
|
|
|
Replaces the no-op ``_PassthroughReranker`` injected into Graphiti by default
|
|
with a real reranker that scores passages against a query through an Ollama
|
|
chat model exposed over its OpenAI-compatible ``/v1`` surface.
|
|
|
|
The class implements only ``CrossEncoderClient.rank`` (the sole abstract
|
|
member Graphiti requires) and is constructed by ``graphiti_adapter._get_graphiti``
|
|
when ``Config.RERANKER_PROVIDER == "ollama"``. It does not perform any
|
|
network I/O at construction time so the Flask app can boot even when the
|
|
Ollama daemon is unreachable; failures are handled inside ``rank`` and never
|
|
propagate, so graph search remains functional under degradation.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import re
|
|
from typing import List, Tuple
|
|
|
|
from openai import AsyncOpenAI
|
|
from graphiti_core.cross_encoder.client import CrossEncoderClient
|
|
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger('mirofish.ollama_reranker')
|
|
|
|
|
|
_THINK_BLOCK = re.compile(r"<think>[\s\S]*?</think>", re.IGNORECASE)
|
|
_CODE_FENCE_START = re.compile(r"^```(?:json)?\s*\n?", re.IGNORECASE)
|
|
_CODE_FENCE_END = re.compile(r"\n?```\s*$")
|
|
_FIRST_FLOAT = re.compile(r"-?\d+(?:\.\d+)?")
|
|
|
|
_SYSTEM_PROMPT = (
|
|
"You are a relevance grader. Given a user query and a single passage, "
|
|
"rate how relevant the passage is to the query on a continuous scale "
|
|
"from 0.0 (not relevant at all) to 1.0 (perfectly relevant). "
|
|
"Respond with a single JSON object of the form {\"score\": <float>} "
|
|
"and nothing else."
|
|
)
|
|
|
|
|
|
def _clip_unit(value: float) -> float:
|
|
"""Clamp ``value`` into the closed interval [0.0, 1.0]."""
|
|
if value < 0.0:
|
|
return 0.0
|
|
if value > 1.0:
|
|
return 1.0
|
|
return value
|
|
|
|
|
|
def _parse_score(raw: str) -> float:
|
|
"""Parse a model response into a relevance score in [0.0, 1.0].
|
|
|
|
Strips reasoning ``<think>`` blocks and markdown fences (the same
|
|
defensive pattern used in ``utils/llm_client.py``), then attempts
|
|
``json.loads`` and reads ``score``. Falls back to extracting the first
|
|
floating-point number from the cleaned text. Raises ``ValueError`` when
|
|
no numeric value can be recovered.
|
|
"""
|
|
text = _THINK_BLOCK.sub("", raw or "").strip()
|
|
text = _CODE_FENCE_START.sub("", text)
|
|
text = _CODE_FENCE_END.sub("", text).strip()
|
|
|
|
try:
|
|
parsed = json.loads(text)
|
|
except (json.JSONDecodeError, TypeError):
|
|
parsed = None
|
|
|
|
if isinstance(parsed, dict) and "score" in parsed:
|
|
try:
|
|
return _clip_unit(float(parsed["score"]))
|
|
except (TypeError, ValueError):
|
|
pass
|
|
|
|
match = _FIRST_FLOAT.search(text)
|
|
if match is not None:
|
|
try:
|
|
return _clip_unit(float(match.group(0)))
|
|
except ValueError:
|
|
pass
|
|
|
|
raise ValueError(f"no numeric score in model response: {text!r}")
|
|
|
|
|
|
class OllamaReranker(CrossEncoderClient):
|
|
"""Cross-encoder reranker that scores passages via an Ollama chat model.
|
|
|
|
Subclass of :class:`graphiti_core.cross_encoder.client.CrossEncoderClient`
|
|
that implements ``rank`` by issuing one chat-completion request per
|
|
passage through ``openai.AsyncOpenAI`` (which speaks the OpenAI-compatible
|
|
surface exposed by Ollama on ``/v1``).
|
|
|
|
Construction is side-effect-free: building the underlying ``AsyncOpenAI``
|
|
client does not perform any network I/O, so ``_get_graphiti`` can wire
|
|
this class up at startup even when the Ollama daemon is unavailable.
|
|
Failures surface only at ``rank`` call time and are degraded to a
|
|
passthrough-style result with a single ``WARNING`` log per failed call.
|
|
"""
|
|
|
|
def __init__(self, *, model: str, base_url: str, api_key: str) -> None:
|
|
"""Configure the reranker.
|
|
|
|
Args:
|
|
model: Name of the Ollama chat model used to score passages
|
|
(for example ``qwen2.5:3b``). The operator is expected to
|
|
have run ``ollama pull <model>`` before reranking is exercised.
|
|
base_url: OpenAI-compatible endpoint for the Ollama server, for
|
|
example ``http://localhost:11434/v1``.
|
|
api_key: API key forwarded to the OpenAI client. Ollama ignores
|
|
the value but the SDK requires a non-empty string.
|
|
"""
|
|
self._model = model
|
|
self._client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
|
|
|
async def _score_passage(self, query: str, passage: str, index: int) -> float:
|
|
"""Score a single passage; deterministic low fallback on parse failure."""
|
|
user_prompt = (
|
|
f"Query:\n{query}\n\n"
|
|
f"Passage:\n{passage}\n\n"
|
|
"Reply with only the JSON object described in the system prompt."
|
|
)
|
|
response = await self._client.chat.completions.create(
|
|
model=self._model,
|
|
messages=[
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{"role": "user", "content": user_prompt},
|
|
],
|
|
temperature=0.0,
|
|
max_tokens=32,
|
|
)
|
|
raw = response.choices[0].message.content or ""
|
|
try:
|
|
return _parse_score(raw)
|
|
except ValueError as exc:
|
|
logger.debug(
|
|
"Reranker parse failure (model=%s, passage_index=%d): %s",
|
|
self._model, index, exc,
|
|
)
|
|
return -0.001 * (index + 1)
|
|
|
|
async def rank(
|
|
self,
|
|
query: str,
|
|
passages: List[str],
|
|
) -> List[Tuple[str, float]]:
|
|
"""Return ``(passage, score)`` tuples sorted by score descending.
|
|
|
|
Empty ``passages`` returns ``[]`` without any model call. On a
|
|
whole-call failure (connection refused, model 404, timeout, etc.)
|
|
the method logs a single ``WARNING`` and returns the passages in
|
|
their original order with synthetic descending scores so graph
|
|
search keeps functioning. The method does not raise.
|
|
"""
|
|
if not passages:
|
|
return []
|
|
|
|
try:
|
|
scores = await asyncio.gather(
|
|
*(self._score_passage(query, p, i) for i, p in enumerate(passages))
|
|
)
|
|
except Exception as exc: # noqa: BLE001 — graceful degrade per design R5
|
|
logger.warning(
|
|
"Ollama reranker failed (model=%s, error=%s); falling back to passthrough order.",
|
|
self._model, type(exc).__name__,
|
|
)
|
|
return [(p, 1.0 - 0.01 * i) for i, p in enumerate(passages)]
|
|
|
|
scored = list(zip(passages, scores))
|
|
scored.sort(key=lambda item: item[1], reverse=True)
|
|
return scored
|