From ef46ba174367a9b0d75b6ba6db95bdc884d164eb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 3 May 2026 21:54:34 +0000 Subject: [PATCH] feat(simulation): add max_agents selector via top-N connectivity in prepare Co-Authored-By: Claude Sonnet 4.6 --- backend/app/api/simulation.py | 2 ++ backend/app/services/simulation_manager.py | 29 ++++++++++++--- backend/app/services/zep_entity_reader.py | 21 +++++++++++ backend/tests/test_zep_entity_reader.py | 42 ++++++++++++++++++++++ 4 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 backend/tests/test_zep_entity_reader.py diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 348ccba5..aa72b61d 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -448,6 +448,7 @@ def prepare_simulation(): document_text = ProjectManager.get_extracted_text(state.project_id, get_storage()) or "" entity_types_list = data.get('entity_types') + max_agents = data.get('max_agents') # optional: limit to top-N most-connected entities use_llm_for_profiles = data.get('use_llm_for_profiles', True) parallel_profile_count = data.get('parallel_profile_count', 5) @@ -568,6 +569,7 @@ def prepare_simulation(): simulation_requirement=simulation_requirement, document_text=document_text, defined_entity_types=entity_types_list, + max_agents=max_agents, use_llm_for_profiles=use_llm_for_profiles, progress_callback=progress_callback, parallel_profile_count=parallel_profile_count diff --git a/backend/app/services/simulation_manager.py b/backend/app/services/simulation_manager.py index e14fed6f..f9857c3e 100644 --- a/backend/app/services/simulation_manager.py +++ b/backend/app/services/simulation_manager.py @@ -245,6 +245,7 @@ class SimulationManager: simulation_requirement: str, document_text: str, defined_entity_types: Optional[List[str]] = None, + max_agents: Optional[int] = None, use_llm_for_profiles: bool = True, progress_callback: Optional[callable] = None, parallel_profile_count: int = 3 @@ -290,11 +291,29 @@ class SimulationManager: if progress_callback: progress_callback("reading", 30, t('progress.readingNodeData')) - filtered = reader.filter_defined_entities( - graph_id=state.graph_id, - defined_entity_types=defined_entity_types, - enrich_with_edges=True - ) + if max_agents is not None and max_agents > 0: + top_entities = reader.get_entities_by_connectivity( + graph_id=state.graph_id, + max_n=max_agents, + defined_entity_types=defined_entity_types, + ) + entity_types_found = set() + for e in top_entities: + et = e.get_entity_type() + if et: + entity_types_found.add(et) + filtered = FilteredEntities( + entities=top_entities, + entity_types=entity_types_found, + total_count=len(top_entities), + filtered_count=len(top_entities), + ) + else: + filtered = reader.filter_defined_entities( + graph_id=state.graph_id, + defined_entity_types=defined_entity_types, + enrich_with_edges=True + ) state.entities_count = filtered.filtered_count state.entity_types = list(filtered.entity_types) diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 30a5fef2..162b78e0 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -360,6 +360,27 @@ class ZepEntityReader: logger.error(f"Failed to get entity {entity_uuid}: {str(e)}") return None + def get_entities_by_connectivity( + self, + graph_id: str, + max_n: Optional[int] = None, + defined_entity_types: Optional[List[str]] = None, + ) -> List[EntityNode]: + """Return entities sorted by edge degree (descending), optionally capped at max_n.""" + filtered = self.filter_defined_entities( + graph_id=graph_id, + defined_entity_types=defined_entity_types, + enrich_with_edges=True, + ) + entities = sorted( + filtered.entities, + key=lambda e: len(e.related_edges), + reverse=True, + ) + if max_n is not None and max_n > 0: + entities = entities[:max_n] + return entities + def get_entities_by_type( self, graph_id: str, diff --git a/backend/tests/test_zep_entity_reader.py b/backend/tests/test_zep_entity_reader.py new file mode 100644 index 00000000..63b7ff38 --- /dev/null +++ b/backend/tests/test_zep_entity_reader.py @@ -0,0 +1,42 @@ +# backend/tests/test_zep_entity_reader.py +import pytest +from unittest.mock import patch, MagicMock +from backend.app.services.zep_entity_reader import ZepEntityReader, EntityNode + +def _make_entity(uuid, name, edge_count): + e = EntityNode(uuid=uuid, name=name, labels=["Person", "Entity"], summary="", attributes={}) + e.related_edges = [{}] * edge_count # simulate edge_count edges + return e + +def test_get_entities_by_connectivity_returns_top_n(): + reader = ZepEntityReader.__new__(ZepEntityReader) + entities = [ + _make_entity("u1", "Alice", 10), + _make_entity("u2", "Bob", 3), + _make_entity("u3", "Carol", 7), + _make_entity("u4", "Dave", 1), + _make_entity("u5", "Eve", 5), + ] + + with patch.object(reader, 'filter_defined_entities') as mock_filter: + from backend.app.services.zep_entity_reader import FilteredEntities + mock_filter.return_value = FilteredEntities( + entities=entities, entity_types=set(), total_count=5, filtered_count=5 + ) + result = reader.get_entities_by_connectivity(graph_id="g1", max_n=3) + + assert len(result) == 3 + assert result[0].name == "Alice" # 10 edges — top + assert result[1].name == "Carol" # 7 edges + assert result[2].name == "Eve" # 5 edges + +def test_get_entities_by_connectivity_no_limit(): + reader = ZepEntityReader.__new__(ZepEntityReader) + entities = [_make_entity(f"u{i}", f"E{i}", i) for i in range(5)] + with patch.object(reader, 'filter_defined_entities') as mock_filter: + from backend.app.services.zep_entity_reader import FilteredEntities + mock_filter.return_value = FilteredEntities( + entities=entities, entity_types=set(), total_count=5, filtered_count=5 + ) + result = reader.get_entities_by_connectivity(graph_id="g1", max_n=None) + assert len(result) == 5