1643 lines
62 KiB
Python
1643 lines
62 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
import threading
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from .embeddings import EmbeddingClient, cosine_similarity
|
|
from .extraction import GraphExtractor
|
|
from .models import GraphEdge, GraphEpisode, GraphNode, GraphRecord, GraphSearchResults
|
|
from .reranker import RerankerClient
|
|
from .settings import settings
|
|
|
|
logger = logging.getLogger("mirofish.local_zep")
|
|
_TOKEN_RE = re.compile(r"[\w\u4e00-\u9fff]+", re.UNICODE)
|
|
_CONFLICTING_EDGE_NAMES = {
|
|
"SUPPORTS": {"OPPOSES"},
|
|
"OPPOSES": {"SUPPORTS"},
|
|
"APPROVES": {"REJECTS", "OPPOSES"},
|
|
"REJECTS": {"APPROVES", "SUPPORTS"},
|
|
"LIKES": {"DISLIKES"},
|
|
"DISLIKES": {"LIKES"},
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class _SearchCandidate:
|
|
kind: str
|
|
uuid: str
|
|
text: str
|
|
item: GraphNode | GraphEdge | GraphEpisode
|
|
embedding: list[float] = field(default_factory=list)
|
|
semantic_score: float = 0.0
|
|
lexical_score: float = 0.0
|
|
score: float = 0.0
|
|
relevance: float | None = None
|
|
episode_count: int = 0
|
|
distance: int | None = None
|
|
|
|
|
|
def _now_iso() -> str:
|
|
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def _coerce_iso(value: str | None) -> str:
|
|
value = (value or "").strip()
|
|
return value or _now_iso()
|
|
|
|
|
|
def _normalize_name(value: str) -> str:
|
|
return " ".join((value or "").strip().lower().split())
|
|
|
|
|
|
def _normalize_fact(value: str) -> str:
|
|
return re.sub(r"\s+", " ", (value or "").strip().lower().rstrip("."))
|
|
|
|
|
|
def _primary_label(labels: list[str]) -> str:
|
|
for label in labels:
|
|
if label not in {"Entity", "Node"}:
|
|
return label
|
|
return "Entity"
|
|
|
|
|
|
def _json_dumps(value: Any) -> str:
|
|
return json.dumps(value or {}, ensure_ascii=False, sort_keys=True)
|
|
|
|
|
|
def _json_loads(value: str | None, default: Any) -> Any:
|
|
if not value:
|
|
return default
|
|
try:
|
|
return json.loads(value)
|
|
except json.JSONDecodeError:
|
|
return default
|
|
|
|
|
|
def _tokenize(text: str) -> list[str]:
|
|
return _TOKEN_RE.findall((text or "").lower())
|
|
|
|
|
|
def _camel_case(value: str) -> str:
|
|
parts = value.split("_")
|
|
return parts[0] + "".join(part[:1].upper() + part[1:] for part in parts[1:])
|
|
|
|
|
|
def _get_value(source: Any, key: str, default: Any = None) -> Any:
|
|
if source is None:
|
|
return default
|
|
|
|
keys = [key, _camel_case(key)]
|
|
for candidate in keys:
|
|
if isinstance(source, dict) and candidate in source:
|
|
return source[candidate]
|
|
if hasattr(source, candidate):
|
|
return getattr(source, candidate)
|
|
|
|
return default
|
|
|
|
|
|
def _as_list(value: Any) -> list[Any]:
|
|
if value is None:
|
|
return []
|
|
if isinstance(value, list):
|
|
return value
|
|
if isinstance(value, tuple) or isinstance(value, set):
|
|
return list(value)
|
|
return [value]
|
|
|
|
|
|
def _bm25_scores(query: str, documents: list[str]) -> list[float]:
|
|
query_terms = _tokenize(query[:400])
|
|
if not query_terms or not documents:
|
|
return [0.0] * len(documents)
|
|
|
|
tokenized_docs = [_tokenize(document) for document in documents]
|
|
doc_count = len(tokenized_docs)
|
|
avg_len = sum(len(tokens) for tokens in tokenized_docs) / max(doc_count, 1)
|
|
if avg_len <= 0:
|
|
avg_len = 1.0
|
|
|
|
document_frequency: dict[str, int] = {}
|
|
for tokens in tokenized_docs:
|
|
for token in set(tokens):
|
|
document_frequency[token] = document_frequency.get(token, 0) + 1
|
|
|
|
scores: list[float] = []
|
|
k1 = 1.5
|
|
b = 0.75
|
|
for tokens in tokenized_docs:
|
|
term_counts: dict[str, int] = {}
|
|
for token in tokens:
|
|
term_counts[token] = term_counts.get(token, 0) + 1
|
|
|
|
score = 0.0
|
|
doc_len = max(len(tokens), 1)
|
|
for token in query_terms:
|
|
tf = term_counts.get(token, 0)
|
|
if tf <= 0:
|
|
continue
|
|
df = document_frequency.get(token, 0)
|
|
idf = math.log(1.0 + (doc_count - df + 0.5) / (df + 0.5))
|
|
denominator = tf + k1 * (1.0 - b + b * doc_len / avg_len)
|
|
score += idf * (tf * (k1 + 1.0)) / denominator
|
|
|
|
if query.lower().strip() and query.lower().strip() in (documents[len(scores)] or "").lower():
|
|
score += 1.5
|
|
scores.append(score)
|
|
|
|
return scores
|
|
|
|
|
|
def _rank_positions(candidates: list[_SearchCandidate], attr: str) -> dict[str, int]:
|
|
ranked = sorted(candidates, key=lambda candidate: getattr(candidate, attr), reverse=True)
|
|
return {
|
|
candidate.uuid: rank
|
|
for rank, candidate in enumerate(ranked, start=1)
|
|
if getattr(candidate, attr) > 0
|
|
}
|
|
|
|
|
|
def _matches_labels(labels: list[str], include: list[str], exclude: list[str]) -> bool:
|
|
label_set = set(labels)
|
|
if include and not label_set.intersection(include):
|
|
return False
|
|
if exclude and label_set.intersection(exclude):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _compare_value(value: Any, operator: str, expected: Any = None) -> bool:
|
|
operator = (operator or "=").upper()
|
|
if operator == "IS NULL":
|
|
return value is None
|
|
if operator == "IS NOT NULL":
|
|
return value is not None
|
|
|
|
if value is None:
|
|
return False
|
|
|
|
left = value
|
|
right = expected
|
|
try:
|
|
left = float(value)
|
|
right = float(expected)
|
|
except (TypeError, ValueError):
|
|
left = str(value)
|
|
right = str(expected)
|
|
|
|
if operator == "=":
|
|
return left == right
|
|
if operator == "<>":
|
|
return left != right
|
|
if operator == ">":
|
|
return left > right
|
|
if operator == "<":
|
|
return left < right
|
|
if operator == ">=":
|
|
return left >= right
|
|
if operator == "<=":
|
|
return left <= right
|
|
return True
|
|
|
|
|
|
class LocalZepStore:
|
|
def __init__(self, db_path: str | None = None) -> None:
|
|
self.db_path = db_path or settings.local_zep_db_path
|
|
self._lock = threading.RLock()
|
|
self._embedding_client: EmbeddingClient | None = None
|
|
self._extractor: GraphExtractor | None = None
|
|
self._reranker_client: RerankerClient | None = None
|
|
|
|
db_dir = os.path.dirname(self.db_path)
|
|
if db_dir:
|
|
os.makedirs(db_dir, exist_ok=True)
|
|
|
|
self._ensure_schema()
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self.db_path, timeout=30, check_same_thread=False)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA foreign_keys = ON")
|
|
return conn
|
|
|
|
def _ensure_schema(self) -> None:
|
|
with self._connect() as conn:
|
|
conn.executescript(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS graphs (
|
|
graph_id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
description TEXT DEFAULT '',
|
|
ontology_json TEXT DEFAULT '{}',
|
|
created_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS episodes (
|
|
uuid TEXT PRIMARY KEY,
|
|
graph_id TEXT NOT NULL,
|
|
data TEXT NOT NULL,
|
|
type TEXT NOT NULL,
|
|
processed INTEGER NOT NULL DEFAULT 0,
|
|
error TEXT,
|
|
metadata_json TEXT DEFAULT '{}',
|
|
source_description TEXT,
|
|
role TEXT,
|
|
role_type TEXT,
|
|
thread_id TEXT,
|
|
task_id TEXT,
|
|
created_at TEXT NOT NULL,
|
|
FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS nodes (
|
|
uuid TEXT PRIMARY KEY,
|
|
graph_id TEXT NOT NULL,
|
|
name TEXT NOT NULL,
|
|
normalized_name TEXT NOT NULL,
|
|
primary_label TEXT NOT NULL,
|
|
labels_json TEXT NOT NULL,
|
|
summary TEXT DEFAULT '',
|
|
attributes_json TEXT DEFAULT '{}',
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_nodes_identity
|
|
ON nodes(graph_id, normalized_name, primary_label);
|
|
|
|
CREATE TABLE IF NOT EXISTS edges (
|
|
uuid TEXT PRIMARY KEY,
|
|
graph_id TEXT NOT NULL,
|
|
name TEXT NOT NULL,
|
|
fact TEXT NOT NULL,
|
|
source_node_uuid TEXT NOT NULL,
|
|
target_node_uuid TEXT NOT NULL,
|
|
attributes_json TEXT DEFAULT '{}',
|
|
created_at TEXT NOT NULL,
|
|
valid_at TEXT,
|
|
invalid_at TEXT,
|
|
expired_at TEXT,
|
|
FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE,
|
|
FOREIGN KEY(source_node_uuid) REFERENCES nodes(uuid) ON DELETE CASCADE,
|
|
FOREIGN KEY(target_node_uuid) REFERENCES nodes(uuid) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_edges_identity
|
|
ON edges(graph_id, source_node_uuid, target_node_uuid, name, fact);
|
|
|
|
CREATE TABLE IF NOT EXISTS edge_episodes (
|
|
edge_uuid TEXT NOT NULL,
|
|
episode_uuid TEXT NOT NULL,
|
|
PRIMARY KEY(edge_uuid, episode_uuid),
|
|
FOREIGN KEY(edge_uuid) REFERENCES edges(uuid) ON DELETE CASCADE,
|
|
FOREIGN KEY(episode_uuid) REFERENCES episodes(uuid) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS node_embeddings (
|
|
node_uuid TEXT PRIMARY KEY,
|
|
embedding_json TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
FOREIGN KEY(node_uuid) REFERENCES nodes(uuid) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS edge_embeddings (
|
|
edge_uuid TEXT PRIMARY KEY,
|
|
embedding_json TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
FOREIGN KEY(edge_uuid) REFERENCES edges(uuid) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS episode_embeddings (
|
|
episode_uuid TEXT PRIMARY KEY,
|
|
embedding_json TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
FOREIGN KEY(episode_uuid) REFERENCES episodes(uuid) ON DELETE CASCADE
|
|
);
|
|
"""
|
|
)
|
|
self._ensure_column(conn, "episodes", "metadata_json", "TEXT DEFAULT '{}'")
|
|
self._ensure_column(conn, "episodes", "source_description", "TEXT")
|
|
self._ensure_column(conn, "episodes", "role", "TEXT")
|
|
self._ensure_column(conn, "episodes", "role_type", "TEXT")
|
|
self._ensure_column(conn, "episodes", "thread_id", "TEXT")
|
|
self._ensure_column(conn, "episodes", "task_id", "TEXT")
|
|
|
|
def _ensure_column(self, conn: sqlite3.Connection, table: str, column: str, definition: str) -> None:
|
|
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
|
existing = {row["name"] for row in rows}
|
|
if column not in existing:
|
|
conn.execute(f"ALTER TABLE {table} ADD COLUMN {column} {definition}")
|
|
|
|
def _get_embedding_client(self) -> EmbeddingClient:
|
|
if self._embedding_client is None:
|
|
self._embedding_client = EmbeddingClient()
|
|
return self._embedding_client
|
|
|
|
def _get_extractor(self) -> GraphExtractor:
|
|
if self._extractor is None:
|
|
self._extractor = GraphExtractor()
|
|
return self._extractor
|
|
|
|
def _get_reranker_client(self) -> RerankerClient:
|
|
if self._reranker_client is None:
|
|
self._reranker_client = RerankerClient()
|
|
return self._reranker_client
|
|
|
|
def create_graph(self, graph_id: str, name: str, description: str = "") -> GraphRecord:
|
|
created_at = _now_iso()
|
|
with self._lock, self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO graphs(graph_id, name, description, created_at)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(graph_id) DO UPDATE SET
|
|
name = excluded.name,
|
|
description = excluded.description
|
|
""",
|
|
(graph_id, name, description, created_at),
|
|
)
|
|
return GraphRecord(graph_id=graph_id, name=name, description=description, created_at=created_at)
|
|
|
|
def delete_graph(self, graph_id: str) -> None:
|
|
with self._lock, self._connect() as conn:
|
|
conn.execute("DELETE FROM graphs WHERE graph_id = ?", (graph_id,))
|
|
|
|
def set_ontology(self, graph_id: str, ontology: dict[str, Any]) -> None:
|
|
with self._lock, self._connect() as conn:
|
|
conn.execute(
|
|
"UPDATE graphs SET ontology_json = ? WHERE graph_id = ?",
|
|
(_json_dumps(ontology or {}), graph_id),
|
|
)
|
|
|
|
def get_ontology(self, graph_id: str) -> dict[str, Any]:
|
|
with self._connect() as conn:
|
|
row = conn.execute(
|
|
"SELECT ontology_json FROM graphs WHERE graph_id = ?",
|
|
(graph_id,),
|
|
).fetchone()
|
|
return _json_loads(row["ontology_json"], {}) if row else {}
|
|
|
|
def get_graph(self, graph_id: str) -> GraphRecord | None:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT * FROM graphs WHERE graph_id = ?", (graph_id,)).fetchone()
|
|
if not row:
|
|
return None
|
|
return GraphRecord(
|
|
graph_id=row["graph_id"],
|
|
name=row["name"],
|
|
description=row["description"] or "",
|
|
created_at=row["created_at"],
|
|
)
|
|
|
|
def add(
|
|
self,
|
|
graph_id: str,
|
|
data: str,
|
|
type_: str = "text",
|
|
created_at: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
source_description: str | None = None,
|
|
) -> GraphEpisode:
|
|
episode_created_at = _coerce_iso(created_at)
|
|
episode = GraphEpisode(
|
|
uuid_=uuid.uuid4().hex,
|
|
graph_id=graph_id,
|
|
data=data,
|
|
type=type_,
|
|
processed=False,
|
|
created_at=episode_created_at,
|
|
metadata=metadata or {},
|
|
source_description=source_description,
|
|
)
|
|
|
|
with self._lock, self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO episodes(
|
|
uuid, graph_id, data, type, processed, error, metadata_json,
|
|
source_description, role, role_type, thread_id, task_id, created_at
|
|
)
|
|
VALUES (?, ?, ?, ?, 0, NULL, ?, ?, NULL, NULL, NULL, NULL, ?)
|
|
""",
|
|
(
|
|
episode.uuid_,
|
|
graph_id,
|
|
data,
|
|
type_,
|
|
_json_dumps(metadata or {}),
|
|
source_description,
|
|
episode.created_at,
|
|
),
|
|
)
|
|
|
|
try:
|
|
ontology = self.get_ontology(graph_id)
|
|
extracted = self._get_extractor().extract(data, ontology)
|
|
touched_nodes, touched_edges = self._apply_extraction(
|
|
graph_id,
|
|
episode.uuid_,
|
|
extracted,
|
|
ontology,
|
|
episode.created_at or _now_iso(),
|
|
)
|
|
episode.processed = True
|
|
with self._lock, self._connect() as conn:
|
|
conn.execute(
|
|
"UPDATE episodes SET processed = 1, error = NULL WHERE uuid = ?",
|
|
(episode.uuid_,),
|
|
)
|
|
self._refresh_node_embeddings(graph_id, touched_nodes)
|
|
self._refresh_edge_embeddings(graph_id, touched_edges)
|
|
self._refresh_episode_embeddings(graph_id, {episode.uuid_})
|
|
except Exception as exc:
|
|
logger.exception("Local graph episode processing failed: %s", exc)
|
|
with self._lock, self._connect() as conn:
|
|
conn.execute(
|
|
"UPDATE episodes SET processed = 0, error = ? WHERE uuid = ?",
|
|
(str(exc), episode.uuid_),
|
|
)
|
|
episode.error = str(exc)
|
|
raise
|
|
|
|
return self.get_episode(episode.uuid_) or episode
|
|
|
|
def add_batch(self, graph_id: str, episodes: list[Any]) -> list[GraphEpisode]:
|
|
results = []
|
|
for episode in episodes:
|
|
data = getattr(episode, "data", "") if episode is not None else ""
|
|
type_ = getattr(episode, "type", "text") if episode is not None else "text"
|
|
created_at = getattr(episode, "created_at", None) if episode is not None else None
|
|
metadata = getattr(episode, "metadata", None) if episode is not None else None
|
|
source_description = getattr(episode, "source_description", None) if episode is not None else None
|
|
results.append(
|
|
self.add(
|
|
graph_id=graph_id,
|
|
data=data,
|
|
type_=type_,
|
|
created_at=created_at,
|
|
metadata=metadata,
|
|
source_description=source_description,
|
|
)
|
|
)
|
|
return results
|
|
|
|
def get_episode(self, uuid_: str) -> GraphEpisode | None:
|
|
with self._connect() as conn:
|
|
row = conn.execute(
|
|
"SELECT * FROM episodes WHERE uuid = ?",
|
|
(uuid_,),
|
|
).fetchone()
|
|
return self._row_to_episode(row) if row else None
|
|
|
|
def get_episodes_by_graph_id(self, graph_id: str, lastn: int | None = None):
|
|
query = "SELECT * FROM episodes WHERE graph_id = ? ORDER BY created_at DESC, uuid DESC"
|
|
params: list[Any] = [graph_id]
|
|
if lastn:
|
|
query += " LIMIT ?"
|
|
params.append(lastn)
|
|
with self._connect() as conn:
|
|
rows = conn.execute(query, params).fetchall()
|
|
return type("EpisodeList", (), {"episodes": [self._row_to_episode(row) for row in rows]})()
|
|
|
|
def get_nodes_page(self, graph_id: str, limit: int = 100, uuid_cursor: str | None = None) -> list[GraphNode]:
|
|
query = "SELECT * FROM nodes WHERE graph_id = ?"
|
|
params: list[Any] = [graph_id]
|
|
if uuid_cursor:
|
|
query += " AND uuid > ?"
|
|
params.append(uuid_cursor)
|
|
query += " ORDER BY uuid LIMIT ?"
|
|
params.append(limit)
|
|
|
|
with self._connect() as conn:
|
|
rows = conn.execute(query, params).fetchall()
|
|
return [self._row_to_node(row) for row in rows]
|
|
|
|
def get_edges_page(self, graph_id: str, limit: int = 100, uuid_cursor: str | None = None) -> list[GraphEdge]:
|
|
query = "SELECT * FROM edges WHERE graph_id = ?"
|
|
params: list[Any] = [graph_id]
|
|
if uuid_cursor:
|
|
query += " AND uuid > ?"
|
|
params.append(uuid_cursor)
|
|
query += " ORDER BY uuid LIMIT ?"
|
|
params.append(limit)
|
|
|
|
with self._connect() as conn:
|
|
rows = conn.execute(query, params).fetchall()
|
|
edge_ids = [row["uuid"] for row in rows]
|
|
episode_map = self._load_edge_episode_map(conn, edge_ids)
|
|
return [self._row_to_edge(row, episode_map.get(row["uuid"], [])) for row in rows]
|
|
|
|
def get_node(self, uuid_: str) -> GraphNode | None:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT * FROM nodes WHERE uuid = ?", (uuid_,)).fetchone()
|
|
return self._row_to_node(row) if row else None
|
|
|
|
def get_edge(self, uuid_: str) -> GraphEdge | None:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT * FROM edges WHERE uuid = ?", (uuid_,)).fetchone()
|
|
if not row:
|
|
return None
|
|
episode_map = self._load_edge_episode_map(conn, [uuid_])
|
|
return self._row_to_edge(row, episode_map.get(uuid_, []))
|
|
|
|
def get_entity_edges(self, node_uuid: str) -> list[GraphEdge]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM edges
|
|
WHERE source_node_uuid = ? OR target_node_uuid = ?
|
|
ORDER BY created_at DESC, uuid
|
|
""",
|
|
(node_uuid, node_uuid),
|
|
).fetchall()
|
|
edge_ids = [row["uuid"] for row in rows]
|
|
episode_map = self._load_edge_episode_map(conn, edge_ids)
|
|
return [self._row_to_edge(row, episode_map.get(row["uuid"], [])) for row in rows]
|
|
|
|
def search(
|
|
self,
|
|
graph_id: str,
|
|
query: str,
|
|
limit: int = 10,
|
|
scope: str = "edges",
|
|
reranker: str = "rrf",
|
|
mmr_lambda: float | None = None,
|
|
center_node_uuid: str | None = None,
|
|
search_filters: Any = None,
|
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
) -> GraphSearchResults:
|
|
results = GraphSearchResults()
|
|
query = (query or "").strip()[:400]
|
|
if not query:
|
|
return results
|
|
|
|
query_embedding: list[float] = []
|
|
try:
|
|
query_embedding = self._get_embedding_client().embed_text(query)
|
|
except Exception as exc:
|
|
logger.warning("Embedding lookup failed, falling back to lexical search: %s", exc)
|
|
|
|
with self._connect() as conn:
|
|
candidates = self._build_search_candidates(
|
|
conn=conn,
|
|
graph_id=graph_id,
|
|
query=query,
|
|
query_embedding=query_embedding,
|
|
scope=(scope or "edges").lower(),
|
|
search_filters=search_filters,
|
|
)
|
|
if not candidates:
|
|
return results
|
|
|
|
if bfs_origin_node_uuids:
|
|
distances = self._graph_distances(conn, graph_id, bfs_origin_node_uuids)
|
|
self._apply_distances(conn, candidates, distances)
|
|
|
|
self._rank_candidates(
|
|
conn=conn,
|
|
graph_id=graph_id,
|
|
query=query,
|
|
query_embedding=query_embedding,
|
|
candidates=candidates,
|
|
reranker=reranker or "rrf",
|
|
mmr_lambda=mmr_lambda,
|
|
center_node_uuid=center_node_uuid,
|
|
)
|
|
|
|
ranked = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)[: max(limit, 0)]
|
|
results.edges = [
|
|
self._scored_item(candidate)
|
|
for candidate in ranked
|
|
if candidate.kind == "edge"
|
|
]
|
|
results.nodes = [
|
|
self._scored_item(candidate)
|
|
for candidate in ranked
|
|
if candidate.kind == "node"
|
|
]
|
|
results.episodes = [
|
|
self._scored_item(candidate)
|
|
for candidate in ranked
|
|
if candidate.kind == "episode"
|
|
]
|
|
|
|
return results
|
|
|
|
def _build_search_candidates(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
graph_id: str,
|
|
query: str,
|
|
query_embedding: list[float],
|
|
scope: str,
|
|
search_filters: Any,
|
|
) -> list[_SearchCandidate]:
|
|
candidates: list[_SearchCandidate] = []
|
|
|
|
if scope in {"edges", "both"}:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT
|
|
e.*,
|
|
ee.embedding_json,
|
|
src.name AS source_name,
|
|
src.labels_json AS source_labels_json,
|
|
dst.name AS target_name,
|
|
dst.labels_json AS target_labels_json
|
|
FROM edges e
|
|
JOIN nodes src ON src.uuid = e.source_node_uuid
|
|
JOIN nodes dst ON dst.uuid = e.target_node_uuid
|
|
LEFT JOIN edge_embeddings ee ON ee.edge_uuid = e.uuid
|
|
WHERE e.graph_id = ?
|
|
""",
|
|
(graph_id,),
|
|
).fetchall()
|
|
edge_ids = [row["uuid"] for row in rows]
|
|
episode_map = self._load_edge_episode_map(conn, edge_ids)
|
|
for row in rows:
|
|
if not self._edge_matches_filters(row, search_filters):
|
|
continue
|
|
edge = self._row_to_edge(row, episode_map.get(row["uuid"], []))
|
|
if not self._episode_metadata_matches_any(conn, edge.episodes, search_filters):
|
|
continue
|
|
text = " ".join(filter(None, [row["name"], row["fact"], row["source_name"], row["target_name"]]))
|
|
candidates.append(
|
|
_SearchCandidate(
|
|
kind="edge",
|
|
uuid=row["uuid"],
|
|
text=text,
|
|
item=edge,
|
|
embedding=_json_loads(row["embedding_json"], []),
|
|
episode_count=len(edge.episodes),
|
|
)
|
|
)
|
|
|
|
if scope in {"nodes", "both"}:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT n.*, ne.embedding_json
|
|
FROM nodes n
|
|
LEFT JOIN node_embeddings ne ON ne.node_uuid = n.uuid
|
|
WHERE n.graph_id = ?
|
|
""",
|
|
(graph_id,),
|
|
).fetchall()
|
|
episode_counts = self._node_episode_counts(conn, graph_id)
|
|
node_episode_ids = self._node_episode_ids(conn, graph_id)
|
|
for row in rows:
|
|
if not self._node_matches_filters(row, search_filters):
|
|
continue
|
|
if not self._episode_metadata_matches_any(conn, node_episode_ids.get(row["uuid"], []), search_filters):
|
|
continue
|
|
labels = _json_loads(row["labels_json"], [])
|
|
attributes = _json_loads(row["attributes_json"], {})
|
|
text = " ".join(
|
|
filter(
|
|
None,
|
|
[
|
|
row["name"],
|
|
row["summary"],
|
|
" ".join(labels),
|
|
json.dumps(attributes, ensure_ascii=False),
|
|
],
|
|
)
|
|
)
|
|
candidates.append(
|
|
_SearchCandidate(
|
|
kind="node",
|
|
uuid=row["uuid"],
|
|
text=text,
|
|
item=self._row_to_node(row),
|
|
embedding=_json_loads(row["embedding_json"], []),
|
|
episode_count=episode_counts.get(row["uuid"], 0),
|
|
)
|
|
)
|
|
|
|
if scope == "episodes":
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT ep.*, ee.embedding_json
|
|
FROM episodes ep
|
|
LEFT JOIN episode_embeddings ee ON ee.episode_uuid = ep.uuid
|
|
WHERE ep.graph_id = ?
|
|
""",
|
|
(graph_id,),
|
|
).fetchall()
|
|
for row in rows:
|
|
if not self._episode_metadata_matches(_json_loads(row["metadata_json"], {}), search_filters):
|
|
continue
|
|
candidates.append(
|
|
_SearchCandidate(
|
|
kind="episode",
|
|
uuid=row["uuid"],
|
|
text=row["data"] or "",
|
|
item=self._row_to_episode(row),
|
|
embedding=_json_loads(row["embedding_json"], []),
|
|
episode_count=1,
|
|
)
|
|
)
|
|
|
|
lexical_scores = _bm25_scores(query, [candidate.text for candidate in candidates])
|
|
for candidate, lexical_score in zip(candidates, lexical_scores):
|
|
candidate.lexical_score = lexical_score
|
|
if query_embedding and candidate.embedding:
|
|
candidate.semantic_score = cosine_similarity(query_embedding, candidate.embedding)
|
|
|
|
return candidates
|
|
|
|
def _rank_candidates(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
graph_id: str,
|
|
query: str,
|
|
query_embedding: list[float],
|
|
candidates: list[_SearchCandidate],
|
|
reranker: str,
|
|
mmr_lambda: float | None,
|
|
center_node_uuid: str | None,
|
|
) -> None:
|
|
reranker = (reranker or "rrf").lower()
|
|
if reranker == "cross_encoder":
|
|
self._rank_rrf(candidates)
|
|
pool = sorted(candidates, key=lambda candidate: candidate.score, reverse=True)[: settings.local_zep_rerank_top_k]
|
|
scores = self._get_reranker_client().rerank(query, [candidate.text for candidate in pool])
|
|
if scores is not None:
|
|
for candidate, score in zip(pool, scores):
|
|
candidate.score = float(score) + self._distance_boost(candidate)
|
|
candidate.relevance = max(0.0, min(1.0, float(score)))
|
|
pool_ids = {candidate.uuid for candidate in pool}
|
|
for candidate in candidates:
|
|
if candidate.uuid not in pool_ids:
|
|
candidate.score *= 0.01
|
|
return
|
|
|
|
logger.info("Cross-encoder reranker is not configured; using local RRF fallback")
|
|
return
|
|
|
|
if reranker == "mmr":
|
|
self._rank_mmr(candidates, query_embedding, mmr_lambda if mmr_lambda is not None else 0.5)
|
|
return
|
|
|
|
if reranker == "episode_mentions":
|
|
self._rank_rrf(candidates)
|
|
for candidate in candidates:
|
|
candidate.score += math.log1p(candidate.episode_count) * 0.1
|
|
return
|
|
|
|
if reranker == "node_distance" and center_node_uuid:
|
|
distances = self._graph_distances(conn, graph_id, [center_node_uuid])
|
|
self._apply_distances(conn, candidates, distances)
|
|
self._rank_node_distance(candidates)
|
|
return
|
|
|
|
self._rank_rrf(candidates)
|
|
|
|
def _rank_rrf(self, candidates: list[_SearchCandidate]) -> None:
|
|
semantic_ranks = _rank_positions(candidates, "semantic_score")
|
|
lexical_ranks = _rank_positions(candidates, "lexical_score")
|
|
for candidate in candidates:
|
|
score = 0.0
|
|
if candidate.uuid in semantic_ranks:
|
|
score += 1.0 / (60.0 + semantic_ranks[candidate.uuid])
|
|
if candidate.uuid in lexical_ranks:
|
|
score += 1.0 / (60.0 + lexical_ranks[candidate.uuid])
|
|
candidate.score = score + self._distance_boost(candidate)
|
|
|
|
def _rank_mmr(self, candidates: list[_SearchCandidate], query_embedding: list[float], lambda_value: float) -> None:
|
|
lambda_value = max(0.0, min(1.0, lambda_value))
|
|
remaining = candidates[:]
|
|
selected: list[_SearchCandidate] = []
|
|
|
|
while remaining:
|
|
best: _SearchCandidate | None = None
|
|
best_score = -float("inf")
|
|
for candidate in remaining:
|
|
relevance = candidate.semantic_score + (candidate.lexical_score * 0.05)
|
|
diversity_penalty = 0.0
|
|
if query_embedding and candidate.embedding and selected:
|
|
similarities = [
|
|
cosine_similarity(candidate.embedding, selected_candidate.embedding)
|
|
for selected_candidate in selected
|
|
if selected_candidate.embedding
|
|
]
|
|
diversity_penalty = max(similarities) if similarities else 0.0
|
|
mmr_score = lambda_value * relevance - (1.0 - lambda_value) * diversity_penalty
|
|
if mmr_score > best_score:
|
|
best_score = mmr_score
|
|
best = candidate
|
|
|
|
if best is None:
|
|
break
|
|
remaining.remove(best)
|
|
selected.append(best)
|
|
best.score = best_score + self._distance_boost(best)
|
|
|
|
rank_count = len(selected)
|
|
for rank, candidate in enumerate(selected):
|
|
candidate.score += (rank_count - rank) * 1e-6
|
|
|
|
def _rank_node_distance(self, candidates: list[_SearchCandidate]) -> None:
|
|
self._rank_rrf(candidates)
|
|
for candidate in candidates:
|
|
if candidate.distance is None:
|
|
candidate.score *= 0.01
|
|
else:
|
|
candidate.score += 1.0 / (1.0 + candidate.distance)
|
|
|
|
def _distance_boost(self, candidate: _SearchCandidate) -> float:
|
|
if candidate.distance is None:
|
|
return 0.0
|
|
return 0.15 / (1.0 + candidate.distance)
|
|
|
|
def _scored_item(self, candidate: _SearchCandidate):
|
|
candidate.item.score = candidate.score
|
|
candidate.item.relevance = candidate.relevance
|
|
return candidate.item
|
|
|
|
def _node_matches_filters(self, row: sqlite3.Row, search_filters: Any) -> bool:
|
|
if not search_filters:
|
|
return True
|
|
|
|
labels = _json_loads(row["labels_json"], [])
|
|
include_labels = [str(value) for value in _as_list(_get_value(search_filters, "node_labels"))]
|
|
exclude_labels = [str(value) for value in _as_list(_get_value(search_filters, "exclude_node_labels"))]
|
|
if not _matches_labels(labels, include_labels, exclude_labels):
|
|
return False
|
|
|
|
attributes = _json_loads(row["attributes_json"], {})
|
|
return self._properties_match(attributes, search_filters)
|
|
|
|
def _edge_matches_filters(self, row: sqlite3.Row, search_filters: Any) -> bool:
|
|
if not search_filters:
|
|
return True
|
|
|
|
include_edge_types = [str(value) for value in _as_list(_get_value(search_filters, "edge_types"))]
|
|
exclude_edge_types = [str(value) for value in _as_list(_get_value(search_filters, "exclude_edge_types"))]
|
|
if include_edge_types and row["name"] not in include_edge_types:
|
|
return False
|
|
if exclude_edge_types and row["name"] in exclude_edge_types:
|
|
return False
|
|
|
|
source_labels = _json_loads(row["source_labels_json"], [])
|
|
target_labels = _json_loads(row["target_labels_json"], [])
|
|
labels = sorted({*source_labels, *target_labels})
|
|
include_labels = [str(value) for value in _as_list(_get_value(search_filters, "node_labels"))]
|
|
exclude_labels = [str(value) for value in _as_list(_get_value(search_filters, "exclude_node_labels"))]
|
|
if not _matches_labels(labels, include_labels, exclude_labels):
|
|
return False
|
|
|
|
attributes = _json_loads(row["attributes_json"], {})
|
|
if not self._properties_match(attributes, search_filters):
|
|
return False
|
|
|
|
for field_name in ("created_at", "valid_at", "invalid_at", "expired_at"):
|
|
if not self._date_filters_match(row[field_name], _get_value(search_filters, field_name)):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _properties_match(self, attributes: dict[str, Any], search_filters: Any) -> bool:
|
|
for prop_filter in _as_list(_get_value(search_filters, "property_filters")):
|
|
property_name = _get_value(prop_filter, "property_name")
|
|
if not property_name:
|
|
continue
|
|
operator = _get_value(prop_filter, "comparison_operator", "=")
|
|
expected = _get_value(prop_filter, "property_value")
|
|
if not _compare_value(attributes.get(str(property_name)), str(operator), expected):
|
|
return False
|
|
return True
|
|
|
|
def _episode_metadata_matches_any(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
episode_ids: list[str],
|
|
search_filters: Any,
|
|
) -> bool:
|
|
metadata_filter = _get_value(search_filters, "episode_metadata_filters")
|
|
if not metadata_filter:
|
|
return True
|
|
if not episode_ids:
|
|
return False
|
|
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT metadata_json
|
|
FROM episodes
|
|
WHERE uuid IN ({",".join("?" for _ in episode_ids)})
|
|
""",
|
|
episode_ids,
|
|
).fetchall()
|
|
return any(
|
|
self._episode_metadata_matches(_json_loads(row["metadata_json"], {}), search_filters)
|
|
for row in rows
|
|
)
|
|
|
|
def _episode_metadata_matches(self, metadata: dict[str, Any], search_filters: Any) -> bool:
|
|
metadata_filter = _get_value(search_filters, "episode_metadata_filters")
|
|
if not metadata_filter:
|
|
return True
|
|
return self._metadata_group_matches(metadata, metadata_filter)
|
|
|
|
def _metadata_group_matches(self, metadata: dict[str, Any], group: Any) -> bool:
|
|
group_type = str(_get_value(group, "type", "and")).lower()
|
|
checks: list[bool] = []
|
|
|
|
for metadata_filter in _as_list(_get_value(group, "filters")):
|
|
property_name = _get_value(metadata_filter, "property_name")
|
|
if not property_name:
|
|
continue
|
|
operator = str(_get_value(metadata_filter, "comparison_operator", "="))
|
|
expected = _get_value(metadata_filter, "property_value")
|
|
checks.append(_compare_value(metadata.get(str(property_name)), operator, expected))
|
|
|
|
for nested_group in _as_list(_get_value(group, "groups")):
|
|
checks.append(self._metadata_group_matches(metadata, nested_group))
|
|
|
|
if not checks:
|
|
return True
|
|
if group_type == "or":
|
|
return any(checks)
|
|
return all(checks)
|
|
|
|
def _date_filters_match(self, value: str | None, filter_groups: Any) -> bool:
|
|
if not filter_groups:
|
|
return True
|
|
|
|
groups = _as_list(filter_groups)
|
|
if groups and not isinstance(groups[0], (list, tuple, set)):
|
|
groups = [groups]
|
|
|
|
for group in groups:
|
|
predicates = _as_list(group)
|
|
if all(
|
|
_compare_value(
|
|
value,
|
|
str(_get_value(predicate, "comparison_operator", "=")),
|
|
_get_value(predicate, "date"),
|
|
)
|
|
for predicate in predicates
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def _node_episode_counts(self, conn: sqlite3.Connection, graph_id: str) -> dict[str, int]:
|
|
episode_ids_by_node = self._node_episode_ids(conn, graph_id)
|
|
return {node_uuid: len(set(episode_ids)) for node_uuid, episode_ids in episode_ids_by_node.items()}
|
|
|
|
def _node_episode_ids(self, conn: sqlite3.Connection, graph_id: str) -> dict[str, list[str]]:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT e.source_node_uuid AS node_uuid, ee.episode_uuid
|
|
FROM edges e
|
|
JOIN edge_episodes ee ON ee.edge_uuid = e.uuid
|
|
WHERE e.graph_id = ?
|
|
UNION ALL
|
|
SELECT e.target_node_uuid AS node_uuid, ee.episode_uuid
|
|
FROM edges e
|
|
JOIN edge_episodes ee ON ee.edge_uuid = e.uuid
|
|
WHERE e.graph_id = ?
|
|
""",
|
|
(graph_id, graph_id),
|
|
).fetchall()
|
|
episodes_by_node: dict[str, list[str]] = {}
|
|
for row in rows:
|
|
episodes_by_node.setdefault(row["node_uuid"], []).append(row["episode_uuid"])
|
|
return episodes_by_node
|
|
|
|
def _graph_distances(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
graph_id: str,
|
|
origin_node_uuids: list[str] | None,
|
|
) -> dict[str, int]:
|
|
origins = [origin for origin in (origin_node_uuids or []) if origin]
|
|
if not origins:
|
|
return {}
|
|
|
|
rows = conn.execute(
|
|
"SELECT uuid, source_node_uuid, target_node_uuid FROM edges WHERE graph_id = ?",
|
|
(graph_id,),
|
|
).fetchall()
|
|
adjacency: dict[str, set[str]] = {}
|
|
for row in rows:
|
|
adjacency.setdefault(row["source_node_uuid"], set()).add(row["target_node_uuid"])
|
|
adjacency.setdefault(row["target_node_uuid"], set()).add(row["source_node_uuid"])
|
|
|
|
placeholders = ",".join("?" for _ in origins)
|
|
seed_rows = conn.execute(
|
|
f"SELECT uuid FROM nodes WHERE graph_id = ? AND uuid IN ({placeholders})",
|
|
[graph_id, *origins],
|
|
).fetchall()
|
|
seed_nodes = {row["uuid"] for row in seed_rows}
|
|
|
|
episode_rows = conn.execute(
|
|
f"""
|
|
SELECT e.source_node_uuid, e.target_node_uuid
|
|
FROM edge_episodes ee
|
|
JOIN edges e ON e.uuid = ee.edge_uuid
|
|
WHERE e.graph_id = ? AND ee.episode_uuid IN ({placeholders})
|
|
""",
|
|
[graph_id, *origins],
|
|
).fetchall()
|
|
for row in episode_rows:
|
|
seed_nodes.add(row["source_node_uuid"])
|
|
seed_nodes.add(row["target_node_uuid"])
|
|
|
|
distances: dict[str, int] = {node_uuid: 0 for node_uuid in seed_nodes}
|
|
queue = list(seed_nodes)
|
|
cursor = 0
|
|
while cursor < len(queue):
|
|
node_uuid = queue[cursor]
|
|
cursor += 1
|
|
for neighbor in adjacency.get(node_uuid, set()):
|
|
if neighbor in distances:
|
|
continue
|
|
distances[neighbor] = distances[node_uuid] + 1
|
|
queue.append(neighbor)
|
|
return distances
|
|
|
|
def _apply_distances(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
candidates: list[_SearchCandidate],
|
|
distances: dict[str, int],
|
|
) -> None:
|
|
if not distances:
|
|
return
|
|
|
|
for candidate in candidates:
|
|
if candidate.kind == "node":
|
|
candidate.distance = distances.get(candidate.uuid)
|
|
continue
|
|
|
|
if candidate.kind == "edge":
|
|
edge = candidate.item
|
|
candidate.distance = min(
|
|
(
|
|
distance
|
|
for distance in [
|
|
distances.get(edge.source_node_uuid),
|
|
distances.get(edge.target_node_uuid),
|
|
]
|
|
if distance is not None
|
|
),
|
|
default=None,
|
|
)
|
|
continue
|
|
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT e.source_node_uuid, e.target_node_uuid
|
|
FROM edge_episodes ee
|
|
JOIN edges e ON e.uuid = ee.edge_uuid
|
|
WHERE ee.episode_uuid = ?
|
|
""",
|
|
(candidate.uuid,),
|
|
).fetchall()
|
|
episode_distances = [
|
|
distance
|
|
for row in rows
|
|
for distance in (distances.get(row["source_node_uuid"]), distances.get(row["target_node_uuid"]))
|
|
if distance is not None
|
|
]
|
|
candidate.distance = min(episode_distances) if episode_distances else None
|
|
|
|
def _apply_extraction(
|
|
self,
|
|
graph_id: str,
|
|
episode_uuid: str,
|
|
extracted: dict[str, list[dict[str, Any]]],
|
|
ontology: dict[str, Any],
|
|
episode_created_at: str,
|
|
) -> tuple[set[str], set[str]]:
|
|
touched_nodes: set[str] = set()
|
|
touched_edges: set[str] = set()
|
|
entity_lookup: dict[tuple[str, str], GraphNode] = {}
|
|
|
|
with self._lock, self._connect() as conn:
|
|
for entity in extracted.get("entities", []):
|
|
node = self._upsert_node(conn, graph_id, entity)
|
|
entity_lookup[(_normalize_name(node.name), _primary_label(node.labels))] = node
|
|
entity_lookup[(_normalize_name(node.name), "Entity")] = node
|
|
touched_nodes.add(node.uuid_)
|
|
|
|
for edge in extracted.get("edges", []):
|
|
source_node = self._resolve_edge_node(conn, graph_id, edge.get("source", ""), ontology, edge.get("name", ""), True, entity_lookup)
|
|
target_node = self._resolve_edge_node(conn, graph_id, edge.get("target", ""), ontology, edge.get("name", ""), False, entity_lookup)
|
|
touched_nodes.update({source_node.uuid_, target_node.uuid_})
|
|
stored_edge = self._upsert_edge(
|
|
conn,
|
|
graph_id,
|
|
episode_uuid,
|
|
edge,
|
|
source_node,
|
|
target_node,
|
|
episode_created_at,
|
|
)
|
|
touched_edges.add(stored_edge.uuid_)
|
|
|
|
return touched_nodes, touched_edges
|
|
|
|
def _upsert_node(self, conn: sqlite3.Connection, graph_id: str, entity: dict[str, Any]) -> GraphNode:
|
|
name = (entity.get("name") or "").strip()
|
|
entity_type = (entity.get("type") or "Entity").strip() or "Entity"
|
|
summary = (entity.get("summary") or "").strip()
|
|
attributes = entity.get("attributes") or {}
|
|
labels = ["Entity"] if entity_type == "Entity" else ["Entity", entity_type]
|
|
normalized_name = _normalize_name(name)
|
|
label = _primary_label(labels)
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM nodes
|
|
WHERE graph_id = ? AND normalized_name = ? AND primary_label = ?
|
|
""",
|
|
(graph_id, normalized_name, label),
|
|
).fetchone()
|
|
timestamp = _now_iso()
|
|
|
|
if row:
|
|
existing_labels = _json_loads(row["labels_json"], [])
|
|
merged_labels = sorted({*existing_labels, *labels})
|
|
existing_attributes = _json_loads(row["attributes_json"], {})
|
|
merged_attributes = {**existing_attributes, **attributes}
|
|
merged_summary = self._merge_summary(row["summary"], summary, merged_attributes)
|
|
conn.execute(
|
|
"""
|
|
UPDATE nodes
|
|
SET labels_json = ?, summary = ?, attributes_json = ?, updated_at = ?
|
|
WHERE uuid = ?
|
|
""",
|
|
(_json_dumps(merged_labels), merged_summary, _json_dumps(merged_attributes), timestamp, row["uuid"]),
|
|
)
|
|
updated_row = conn.execute("SELECT * FROM nodes WHERE uuid = ?", (row["uuid"],)).fetchone()
|
|
return self._row_to_node(updated_row)
|
|
|
|
node_uuid = uuid.uuid4().hex
|
|
summary = summary or self._fallback_summary(name, attributes)
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO nodes(
|
|
uuid, graph_id, name, normalized_name, primary_label, labels_json,
|
|
summary, attributes_json, created_at, updated_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
node_uuid,
|
|
graph_id,
|
|
name,
|
|
normalized_name,
|
|
label,
|
|
_json_dumps(labels),
|
|
summary,
|
|
_json_dumps(attributes),
|
|
timestamp,
|
|
timestamp,
|
|
),
|
|
)
|
|
return GraphNode(
|
|
uuid_=node_uuid,
|
|
graph_id=graph_id,
|
|
name=name,
|
|
labels=labels,
|
|
summary=summary,
|
|
attributes=attributes,
|
|
created_at=timestamp,
|
|
)
|
|
|
|
def _resolve_edge_node(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
graph_id: str,
|
|
node_name: str,
|
|
ontology: dict[str, Any],
|
|
edge_name: str,
|
|
is_source: bool,
|
|
entity_lookup: dict[tuple[str, str], GraphNode],
|
|
) -> GraphNode:
|
|
normalized_name = _normalize_name(node_name)
|
|
preferred_labels = self._allowed_labels_for_edge(ontology, edge_name, is_source)
|
|
|
|
for preferred_label in preferred_labels + ["Entity"]:
|
|
existing = entity_lookup.get((normalized_name, preferred_label))
|
|
if existing:
|
|
return existing
|
|
|
|
row = None
|
|
if preferred_labels:
|
|
placeholders = ",".join("?" for _ in preferred_labels)
|
|
row = conn.execute(
|
|
f"""
|
|
SELECT * FROM nodes
|
|
WHERE graph_id = ? AND normalized_name = ? AND primary_label IN ({placeholders})
|
|
ORDER BY updated_at DESC
|
|
LIMIT 1
|
|
""",
|
|
[graph_id, normalized_name, *preferred_labels],
|
|
).fetchone()
|
|
if row is None:
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM nodes
|
|
WHERE graph_id = ? AND normalized_name = ?
|
|
ORDER BY updated_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(graph_id, normalized_name),
|
|
).fetchone()
|
|
if row:
|
|
node = self._row_to_node(row)
|
|
entity_lookup[(normalized_name, _primary_label(node.labels))] = node
|
|
entity_lookup[(normalized_name, "Entity")] = node
|
|
return node
|
|
|
|
fallback_type = preferred_labels[0] if preferred_labels else "Entity"
|
|
node = self._upsert_node(
|
|
conn,
|
|
graph_id,
|
|
{
|
|
"name": node_name,
|
|
"type": fallback_type,
|
|
"summary": node_name,
|
|
"attributes": {},
|
|
},
|
|
)
|
|
entity_lookup[(normalized_name, _primary_label(node.labels))] = node
|
|
entity_lookup[(normalized_name, "Entity")] = node
|
|
return node
|
|
|
|
def _allowed_labels_for_edge(self, ontology: dict[str, Any], edge_name: str, is_source: bool) -> list[str]:
|
|
for edge in ontology.get("edge_types", []):
|
|
if edge.get("name") != edge_name:
|
|
continue
|
|
labels = []
|
|
for pair in edge.get("source_targets", []):
|
|
label = pair.get("source") if is_source else pair.get("target")
|
|
if label and label != "Entity" and label not in labels:
|
|
labels.append(label)
|
|
return labels
|
|
return []
|
|
|
|
def _upsert_edge(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
graph_id: str,
|
|
episode_uuid: str,
|
|
edge: dict[str, Any],
|
|
source_node: GraphNode,
|
|
target_node: GraphNode,
|
|
episode_created_at: str,
|
|
) -> GraphEdge:
|
|
name = (edge.get("name") or "RELATED_TO").strip() or "RELATED_TO"
|
|
fact = (edge.get("fact") or f"{source_node.name} {name} {target_node.name}").strip()
|
|
attributes = edge.get("attributes") or {}
|
|
learned_at = _now_iso()
|
|
valid_at = _coerce_iso(edge.get("valid_at") or episode_created_at)
|
|
row = conn.execute(
|
|
"""
|
|
SELECT * FROM edges
|
|
WHERE graph_id = ? AND source_node_uuid = ? AND target_node_uuid = ? AND name = ? AND fact = ?
|
|
""",
|
|
(graph_id, source_node.uuid_, target_node.uuid_, name, fact),
|
|
).fetchone()
|
|
|
|
if row:
|
|
existing_attributes = _json_loads(row["attributes_json"], {})
|
|
merged_attributes = {**existing_attributes, **attributes}
|
|
conn.execute(
|
|
"""
|
|
UPDATE edges
|
|
SET attributes_json = ?, valid_at = COALESCE(valid_at, ?)
|
|
WHERE uuid = ?
|
|
""",
|
|
(_json_dumps(merged_attributes), valid_at, row["uuid"]),
|
|
)
|
|
edge_uuid = row["uuid"]
|
|
else:
|
|
self._invalidate_superseded_edges(
|
|
conn=conn,
|
|
graph_id=graph_id,
|
|
source_node_uuid=source_node.uuid_,
|
|
target_node_uuid=target_node.uuid_,
|
|
edge_name=name,
|
|
new_fact=fact,
|
|
invalid_at=valid_at,
|
|
expired_at=learned_at,
|
|
)
|
|
edge_uuid = uuid.uuid4().hex
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO edges(
|
|
uuid, graph_id, name, fact, source_node_uuid, target_node_uuid,
|
|
attributes_json, created_at, valid_at, invalid_at, expired_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, NULL)
|
|
""",
|
|
(
|
|
edge_uuid,
|
|
graph_id,
|
|
name,
|
|
fact,
|
|
source_node.uuid_,
|
|
target_node.uuid_,
|
|
_json_dumps(attributes),
|
|
learned_at,
|
|
valid_at,
|
|
),
|
|
)
|
|
|
|
conn.execute(
|
|
"""
|
|
INSERT OR IGNORE INTO edge_episodes(edge_uuid, episode_uuid)
|
|
VALUES (?, ?)
|
|
""",
|
|
(edge_uuid, episode_uuid),
|
|
)
|
|
|
|
row = conn.execute("SELECT * FROM edges WHERE uuid = ?", (edge_uuid,)).fetchone()
|
|
return self._row_to_edge(row, [episode_uuid])
|
|
|
|
def _invalidate_superseded_edges(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
graph_id: str,
|
|
source_node_uuid: str,
|
|
target_node_uuid: str,
|
|
edge_name: str,
|
|
new_fact: str,
|
|
invalid_at: str,
|
|
expired_at: str,
|
|
) -> None:
|
|
"""Approximate Zep/Graphiti temporal fact invalidation.
|
|
|
|
If a new fact uses the same source/target and either the same relation
|
|
name or an explicitly conflicting relation name, treat the old fact as
|
|
superseded unless it is the same normalized fact. This preserves history
|
|
while keeping active facts current for typical single-user workflows.
|
|
"""
|
|
conflicting_names = self._conflicting_names(edge_name)
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT uuid, fact
|
|
FROM edges
|
|
WHERE graph_id = ?
|
|
AND source_node_uuid = ?
|
|
AND target_node_uuid = ?
|
|
AND name IN ({",".join("?" for _ in conflicting_names)})
|
|
AND invalid_at IS NULL
|
|
AND expired_at IS NULL
|
|
""",
|
|
[graph_id, source_node_uuid, target_node_uuid, *conflicting_names],
|
|
).fetchall()
|
|
normalized_new_fact = _normalize_fact(new_fact)
|
|
superseded_ids = [
|
|
row["uuid"]
|
|
for row in rows
|
|
if _normalize_fact(row["fact"]) != normalized_new_fact
|
|
]
|
|
if not superseded_ids:
|
|
return
|
|
conn.execute(
|
|
f"""
|
|
UPDATE edges
|
|
SET invalid_at = ?, expired_at = ?
|
|
WHERE uuid IN ({",".join("?" for _ in superseded_ids)})
|
|
""",
|
|
[invalid_at, expired_at, *superseded_ids],
|
|
)
|
|
|
|
def _conflicting_names(self, edge_name: str) -> list[str]:
|
|
names = {edge_name}
|
|
names.update(_CONFLICTING_EDGE_NAMES.get(edge_name.upper(), set()))
|
|
return sorted(names)
|
|
|
|
def _refresh_node_embeddings(self, graph_id: str, node_ids: set[str]) -> None:
|
|
if not node_ids:
|
|
return
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT * FROM nodes
|
|
WHERE graph_id = ? AND uuid IN ({",".join("?" for _ in node_ids)})
|
|
""",
|
|
[graph_id, *node_ids],
|
|
).fetchall()
|
|
if not rows:
|
|
return
|
|
|
|
texts = []
|
|
ids = []
|
|
for row in rows:
|
|
text = " ".join(
|
|
filter(
|
|
None,
|
|
[
|
|
row["name"],
|
|
row["summary"],
|
|
" ".join(_json_loads(row["labels_json"], [])),
|
|
json.dumps(_json_loads(row["attributes_json"], {}), ensure_ascii=False),
|
|
],
|
|
)
|
|
)
|
|
ids.append(row["uuid"])
|
|
texts.append(text)
|
|
|
|
try:
|
|
embeddings = self._get_embedding_client().embed_texts(texts)
|
|
except Exception as exc:
|
|
logger.warning("Failed to refresh node embeddings: %s", exc)
|
|
return
|
|
|
|
now = _now_iso()
|
|
with self._lock, self._connect() as conn:
|
|
for node_id, embedding in zip(ids, embeddings):
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO node_embeddings(node_uuid, embedding_json, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(node_uuid) DO UPDATE SET
|
|
embedding_json = excluded.embedding_json,
|
|
updated_at = excluded.updated_at
|
|
""",
|
|
(node_id, _json_dumps(embedding), now),
|
|
)
|
|
|
|
def _refresh_edge_embeddings(self, graph_id: str, edge_ids: set[str]) -> None:
|
|
if not edge_ids:
|
|
return
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT e.*, src.name AS source_name, dst.name AS target_name
|
|
FROM edges e
|
|
JOIN nodes src ON src.uuid = e.source_node_uuid
|
|
JOIN nodes dst ON dst.uuid = e.target_node_uuid
|
|
WHERE e.graph_id = ? AND e.uuid IN ({",".join("?" for _ in edge_ids)})
|
|
""",
|
|
[graph_id, *edge_ids],
|
|
).fetchall()
|
|
if not rows:
|
|
return
|
|
|
|
ids = []
|
|
texts = []
|
|
for row in rows:
|
|
ids.append(row["uuid"])
|
|
texts.append(" ".join(filter(None, [row["name"], row["fact"], row["source_name"], row["target_name"]])))
|
|
|
|
try:
|
|
embeddings = self._get_embedding_client().embed_texts(texts)
|
|
except Exception as exc:
|
|
logger.warning("Failed to refresh edge embeddings: %s", exc)
|
|
return
|
|
|
|
now = _now_iso()
|
|
with self._lock, self._connect() as conn:
|
|
for edge_id, embedding in zip(ids, embeddings):
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO edge_embeddings(edge_uuid, embedding_json, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(edge_uuid) DO UPDATE SET
|
|
embedding_json = excluded.embedding_json,
|
|
updated_at = excluded.updated_at
|
|
""",
|
|
(edge_id, _json_dumps(embedding), now),
|
|
)
|
|
|
|
def _refresh_episode_embeddings(self, graph_id: str, episode_ids: set[str]) -> None:
|
|
if not episode_ids:
|
|
return
|
|
with self._connect() as conn:
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT * FROM episodes
|
|
WHERE graph_id = ? AND uuid IN ({",".join("?" for _ in episode_ids)})
|
|
""",
|
|
[graph_id, *episode_ids],
|
|
).fetchall()
|
|
if not rows:
|
|
return
|
|
|
|
ids = [row["uuid"] for row in rows]
|
|
texts = [row["data"] or "" for row in rows]
|
|
|
|
try:
|
|
embeddings = self._get_embedding_client().embed_texts(texts)
|
|
except Exception as exc:
|
|
logger.warning("Failed to refresh episode embeddings: %s", exc)
|
|
return
|
|
|
|
now = _now_iso()
|
|
with self._lock, self._connect() as conn:
|
|
for episode_id, embedding in zip(ids, embeddings):
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO episode_embeddings(episode_uuid, embedding_json, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(episode_uuid) DO UPDATE SET
|
|
embedding_json = excluded.embedding_json,
|
|
updated_at = excluded.updated_at
|
|
""",
|
|
(episode_id, _json_dumps(embedding), now),
|
|
)
|
|
|
|
def _load_edge_episode_map(self, conn: sqlite3.Connection, edge_ids: list[str]) -> dict[str, list[str]]:
|
|
if not edge_ids:
|
|
return {}
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT edge_uuid, episode_uuid
|
|
FROM edge_episodes
|
|
WHERE edge_uuid IN ({",".join("?" for _ in edge_ids)})
|
|
ORDER BY episode_uuid
|
|
""",
|
|
edge_ids,
|
|
).fetchall()
|
|
episode_map: dict[str, list[str]] = {}
|
|
for row in rows:
|
|
episode_map.setdefault(row["edge_uuid"], []).append(row["episode_uuid"])
|
|
return episode_map
|
|
|
|
def _load_edge_endpoint_names(self, conn: sqlite3.Connection, row: sqlite3.Row) -> tuple[str, str]:
|
|
source = conn.execute("SELECT name FROM nodes WHERE uuid = ?", (row["source_node_uuid"],)).fetchone()
|
|
target = conn.execute("SELECT name FROM nodes WHERE uuid = ?", (row["target_node_uuid"],)).fetchone()
|
|
return (source["name"] if source else "", target["name"] if target else "")
|
|
|
|
def _row_to_node(self, row: sqlite3.Row) -> GraphNode:
|
|
return GraphNode(
|
|
uuid_=row["uuid"],
|
|
graph_id=row["graph_id"],
|
|
name=row["name"],
|
|
labels=_json_loads(row["labels_json"], []),
|
|
summary=row["summary"] or "",
|
|
attributes=_json_loads(row["attributes_json"], {}),
|
|
created_at=row["created_at"],
|
|
)
|
|
|
|
def _row_to_edge(self, row: sqlite3.Row, episodes: list[str]) -> GraphEdge:
|
|
return GraphEdge(
|
|
uuid_=row["uuid"],
|
|
graph_id=row["graph_id"],
|
|
name=row["name"],
|
|
fact=row["fact"],
|
|
source_node_uuid=row["source_node_uuid"],
|
|
target_node_uuid=row["target_node_uuid"],
|
|
attributes=_json_loads(row["attributes_json"], {}),
|
|
created_at=row["created_at"],
|
|
valid_at=row["valid_at"],
|
|
invalid_at=row["invalid_at"],
|
|
expired_at=row["expired_at"],
|
|
episodes=episodes,
|
|
)
|
|
|
|
def _row_to_episode(self, row: sqlite3.Row) -> GraphEpisode:
|
|
return GraphEpisode(
|
|
uuid_=row["uuid"],
|
|
graph_id=row["graph_id"],
|
|
data=row["data"],
|
|
type=row["type"],
|
|
processed=bool(row["processed"]),
|
|
created_at=row["created_at"],
|
|
error=row["error"],
|
|
metadata=_json_loads(row["metadata_json"], {}),
|
|
source_description=row["source_description"],
|
|
role=row["role"],
|
|
role_type=row["role_type"],
|
|
thread_id=row["thread_id"],
|
|
task_id=row["task_id"],
|
|
)
|
|
|
|
def _merge_summary(self, existing: str, new_value: str, attributes: dict[str, Any]) -> str:
|
|
existing = (existing or "").strip()
|
|
new_value = (new_value or "").strip()
|
|
if existing and new_value:
|
|
if new_value in existing:
|
|
return existing
|
|
if existing in new_value:
|
|
return new_value
|
|
return f"{existing} {new_value}".strip()
|
|
if new_value:
|
|
return new_value
|
|
if existing:
|
|
return existing
|
|
return self._fallback_summary("", attributes)
|
|
|
|
def _fallback_summary(self, name: str, attributes: dict[str, Any]) -> str:
|
|
if attributes:
|
|
pairs = [f"{key}: {value}" for key, value in attributes.items() if value]
|
|
if pairs:
|
|
prefix = f"{name} - " if name else ""
|
|
return prefix + ", ".join(pairs[:4])
|
|
return name or ""
|