diff --git a/backend/app/graph/graphiti_backend.py b/backend/app/graph/graphiti_backend.py index f185e098..8ce10f97 100644 --- a/backend/app/graph/graphiti_backend.py +++ b/backend/app/graph/graphiti_backend.py @@ -520,6 +520,39 @@ class GraphitiBackend(GraphBackend): ) ) + async def _execute_neo4j_query(self, query: str, parameters: dict = None): + """Execute a raw Cypher query against Neo4j via the sync driver in the shared event loop.""" + result = await asyncio.get_event_loop().run_in_executor( + None, + lambda: self._client.driver.execute_query(query, **({"parameters_": parameters} if parameters else {})) + ) + return result + + async def clone_graph(self, src_group_id: str, dst_group_id: str) -> None: + """Clone all nodes and relationships from src_group_id to dst_group_id.""" + clone_nodes_query = """ + MATCH (n) WHERE n.group_id = $src + WITH n, properties(n) AS props + CREATE (m) SET m = props SET m.group_id = $dst + """ + await self._execute_neo4j_query(clone_nodes_query, {"src": src_group_id, "dst": dst_group_id}) + + clone_rels_query = """ + MATCH (n)-[r]->(m) + WHERE n.group_id = $src AND m.group_id = $src + MATCH (n2 {uuid: n.uuid, group_id: $dst}) + MATCH (m2 {uuid: m.uuid, group_id: $dst}) + CALL apoc.create.relationship(n2, type(r), properties(r), m2) YIELD rel + SET rel.group_id = $dst + RETURN rel + """ + await self._execute_neo4j_query(clone_rels_query, {"src": src_group_id, "dst": dst_group_id}) + + async def delete_graph_async(self, group_id: str) -> None: + """Delete all nodes and relationships for a given group_id (async version).""" + delete_query = "MATCH (n) WHERE n.group_id = $gid DETACH DELETE n" + await self._execute_neo4j_query(delete_query, {"gid": group_id}) + def delete_graph(self, graph_id: str) -> None: _run_async( self._client.driver.execute_query( diff --git a/backend/tests/test_graph_clone.py b/backend/tests/test_graph_clone.py new file mode 100644 index 00000000..82e3802f --- /dev/null +++ b/backend/tests/test_graph_clone.py @@ -0,0 +1,38 @@ +# backend/tests/test_graph_clone.py +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import asyncio + +def test_clone_graph_executes_two_queries(): + """clone_graph should run exactly 2 Cypher queries: one for nodes, one for relationships.""" + from backend.app.graph.graphiti_backend import GraphitiBackend + + backend = GraphitiBackend.__new__(GraphitiBackend) + + executed_queries = [] + async def fake_execute_query(query, parameters=None, **kwargs): + executed_queries.append(query) + return [] + + backend._execute_neo4j_query = fake_execute_query + asyncio.run(backend.clone_graph("src_group", "dst_group")) + + assert len(executed_queries) == 2 + combined = " ".join(executed_queries).lower() + assert "group_id" in combined + +def test_delete_graph_executes_detach_delete(): + """delete_graph should run a DETACH DELETE query.""" + from backend.app.graph.graphiti_backend import GraphitiBackend + + backend = GraphitiBackend.__new__(GraphitiBackend) + + executed_queries = [] + async def fake_execute_query(query, parameters=None, **kwargs): + executed_queries.append(query) + return [] + + backend._execute_neo4j_query = fake_execute_query + asyncio.run(backend.delete_graph_async("group_to_delete")) + + assert any("DETACH DELETE" in q for q in executed_queries)