feat(simulation): add max_agents selector via top-N connectivity in prepare
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a74c0975fe
commit
ef46ba1743
|
|
@ -448,6 +448,7 @@ def prepare_simulation():
|
|||
document_text = ProjectManager.get_extracted_text(state.project_id, get_storage()) or ""
|
||||
|
||||
entity_types_list = data.get('entity_types')
|
||||
max_agents = data.get('max_agents') # optional: limit to top-N most-connected entities
|
||||
use_llm_for_profiles = data.get('use_llm_for_profiles', True)
|
||||
parallel_profile_count = data.get('parallel_profile_count', 5)
|
||||
|
||||
|
|
@ -568,6 +569,7 @@ def prepare_simulation():
|
|||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
defined_entity_types=entity_types_list,
|
||||
max_agents=max_agents,
|
||||
use_llm_for_profiles=use_llm_for_profiles,
|
||||
progress_callback=progress_callback,
|
||||
parallel_profile_count=parallel_profile_count
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ class SimulationManager:
|
|||
simulation_requirement: str,
|
||||
document_text: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
max_agents: Optional[int] = None,
|
||||
use_llm_for_profiles: bool = True,
|
||||
progress_callback: Optional[callable] = None,
|
||||
parallel_profile_count: int = 3
|
||||
|
|
@ -290,11 +291,29 @@ class SimulationManager:
|
|||
if progress_callback:
|
||||
progress_callback("reading", 30, t('progress.readingNodeData'))
|
||||
|
||||
filtered = reader.filter_defined_entities(
|
||||
graph_id=state.graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True
|
||||
)
|
||||
if max_agents is not None and max_agents > 0:
|
||||
top_entities = reader.get_entities_by_connectivity(
|
||||
graph_id=state.graph_id,
|
||||
max_n=max_agents,
|
||||
defined_entity_types=defined_entity_types,
|
||||
)
|
||||
entity_types_found = set()
|
||||
for e in top_entities:
|
||||
et = e.get_entity_type()
|
||||
if et:
|
||||
entity_types_found.add(et)
|
||||
filtered = FilteredEntities(
|
||||
entities=top_entities,
|
||||
entity_types=entity_types_found,
|
||||
total_count=len(top_entities),
|
||||
filtered_count=len(top_entities),
|
||||
)
|
||||
else:
|
||||
filtered = reader.filter_defined_entities(
|
||||
graph_id=state.graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True
|
||||
)
|
||||
|
||||
state.entities_count = filtered.filtered_count
|
||||
state.entity_types = list(filtered.entity_types)
|
||||
|
|
|
|||
|
|
@ -360,6 +360,27 @@ class ZepEntityReader:
|
|||
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_entities_by_connectivity(
|
||||
self,
|
||||
graph_id: str,
|
||||
max_n: Optional[int] = None,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
) -> List[EntityNode]:
|
||||
"""Return entities sorted by edge degree (descending), optionally capped at max_n."""
|
||||
filtered = self.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True,
|
||||
)
|
||||
entities = sorted(
|
||||
filtered.entities,
|
||||
key=lambda e: len(e.related_edges),
|
||||
reverse=True,
|
||||
)
|
||||
if max_n is not None and max_n > 0:
|
||||
entities = entities[:max_n]
|
||||
return entities
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
graph_id: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
# backend/tests/test_zep_entity_reader.py
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.app.services.zep_entity_reader import ZepEntityReader, EntityNode
|
||||
|
||||
def _make_entity(uuid, name, edge_count):
|
||||
e = EntityNode(uuid=uuid, name=name, labels=["Person", "Entity"], summary="", attributes={})
|
||||
e.related_edges = [{}] * edge_count # simulate edge_count edges
|
||||
return e
|
||||
|
||||
def test_get_entities_by_connectivity_returns_top_n():
|
||||
reader = ZepEntityReader.__new__(ZepEntityReader)
|
||||
entities = [
|
||||
_make_entity("u1", "Alice", 10),
|
||||
_make_entity("u2", "Bob", 3),
|
||||
_make_entity("u3", "Carol", 7),
|
||||
_make_entity("u4", "Dave", 1),
|
||||
_make_entity("u5", "Eve", 5),
|
||||
]
|
||||
|
||||
with patch.object(reader, 'filter_defined_entities') as mock_filter:
|
||||
from backend.app.services.zep_entity_reader import FilteredEntities
|
||||
mock_filter.return_value = FilteredEntities(
|
||||
entities=entities, entity_types=set(), total_count=5, filtered_count=5
|
||||
)
|
||||
result = reader.get_entities_by_connectivity(graph_id="g1", max_n=3)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "Alice" # 10 edges — top
|
||||
assert result[1].name == "Carol" # 7 edges
|
||||
assert result[2].name == "Eve" # 5 edges
|
||||
|
||||
def test_get_entities_by_connectivity_no_limit():
|
||||
reader = ZepEntityReader.__new__(ZepEntityReader)
|
||||
entities = [_make_entity(f"u{i}", f"E{i}", i) for i in range(5)]
|
||||
with patch.object(reader, 'filter_defined_entities') as mock_filter:
|
||||
from backend.app.services.zep_entity_reader import FilteredEntities
|
||||
mock_filter.return_value = FilteredEntities(
|
||||
entities=entities, entity_types=set(), total_count=5, filtered_count=5
|
||||
)
|
||||
result = reader.get_entities_by_connectivity(graph_id="g1", max_n=None)
|
||||
assert len(result) == 5
|
||||
Loading…
Reference in New Issue