feat(simulation): add max_agents selector via top-N connectivity in prepare
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a74c0975fe
commit
ef46ba1743
|
|
@ -448,6 +448,7 @@ def prepare_simulation():
|
||||||
document_text = ProjectManager.get_extracted_text(state.project_id, get_storage()) or ""
|
document_text = ProjectManager.get_extracted_text(state.project_id, get_storage()) or ""
|
||||||
|
|
||||||
entity_types_list = data.get('entity_types')
|
entity_types_list = data.get('entity_types')
|
||||||
|
max_agents = data.get('max_agents') # optional: limit to top-N most-connected entities
|
||||||
use_llm_for_profiles = data.get('use_llm_for_profiles', True)
|
use_llm_for_profiles = data.get('use_llm_for_profiles', True)
|
||||||
parallel_profile_count = data.get('parallel_profile_count', 5)
|
parallel_profile_count = data.get('parallel_profile_count', 5)
|
||||||
|
|
||||||
|
|
@ -568,6 +569,7 @@ def prepare_simulation():
|
||||||
simulation_requirement=simulation_requirement,
|
simulation_requirement=simulation_requirement,
|
||||||
document_text=document_text,
|
document_text=document_text,
|
||||||
defined_entity_types=entity_types_list,
|
defined_entity_types=entity_types_list,
|
||||||
|
max_agents=max_agents,
|
||||||
use_llm_for_profiles=use_llm_for_profiles,
|
use_llm_for_profiles=use_llm_for_profiles,
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
parallel_profile_count=parallel_profile_count
|
parallel_profile_count=parallel_profile_count
|
||||||
|
|
|
||||||
|
|
@ -245,6 +245,7 @@ class SimulationManager:
|
||||||
simulation_requirement: str,
|
simulation_requirement: str,
|
||||||
document_text: str,
|
document_text: str,
|
||||||
defined_entity_types: Optional[List[str]] = None,
|
defined_entity_types: Optional[List[str]] = None,
|
||||||
|
max_agents: Optional[int] = None,
|
||||||
use_llm_for_profiles: bool = True,
|
use_llm_for_profiles: bool = True,
|
||||||
progress_callback: Optional[callable] = None,
|
progress_callback: Optional[callable] = None,
|
||||||
parallel_profile_count: int = 3
|
parallel_profile_count: int = 3
|
||||||
|
|
@ -290,11 +291,29 @@ class SimulationManager:
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback("reading", 30, t('progress.readingNodeData'))
|
progress_callback("reading", 30, t('progress.readingNodeData'))
|
||||||
|
|
||||||
filtered = reader.filter_defined_entities(
|
if max_agents is not None and max_agents > 0:
|
||||||
graph_id=state.graph_id,
|
top_entities = reader.get_entities_by_connectivity(
|
||||||
defined_entity_types=defined_entity_types,
|
graph_id=state.graph_id,
|
||||||
enrich_with_edges=True
|
max_n=max_agents,
|
||||||
)
|
defined_entity_types=defined_entity_types,
|
||||||
|
)
|
||||||
|
entity_types_found = set()
|
||||||
|
for e in top_entities:
|
||||||
|
et = e.get_entity_type()
|
||||||
|
if et:
|
||||||
|
entity_types_found.add(et)
|
||||||
|
filtered = FilteredEntities(
|
||||||
|
entities=top_entities,
|
||||||
|
entity_types=entity_types_found,
|
||||||
|
total_count=len(top_entities),
|
||||||
|
filtered_count=len(top_entities),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filtered = reader.filter_defined_entities(
|
||||||
|
graph_id=state.graph_id,
|
||||||
|
defined_entity_types=defined_entity_types,
|
||||||
|
enrich_with_edges=True
|
||||||
|
)
|
||||||
|
|
||||||
state.entities_count = filtered.filtered_count
|
state.entities_count = filtered.filtered_count
|
||||||
state.entity_types = list(filtered.entity_types)
|
state.entity_types = list(filtered.entity_types)
|
||||||
|
|
|
||||||
|
|
@ -360,6 +360,27 @@ class ZepEntityReader:
|
||||||
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
|
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_entities_by_connectivity(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
max_n: Optional[int] = None,
|
||||||
|
defined_entity_types: Optional[List[str]] = None,
|
||||||
|
) -> List[EntityNode]:
|
||||||
|
"""Return entities sorted by edge degree (descending), optionally capped at max_n."""
|
||||||
|
filtered = self.filter_defined_entities(
|
||||||
|
graph_id=graph_id,
|
||||||
|
defined_entity_types=defined_entity_types,
|
||||||
|
enrich_with_edges=True,
|
||||||
|
)
|
||||||
|
entities = sorted(
|
||||||
|
filtered.entities,
|
||||||
|
key=lambda e: len(e.related_edges),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
if max_n is not None and max_n > 0:
|
||||||
|
entities = entities[:max_n]
|
||||||
|
return entities
|
||||||
|
|
||||||
def get_entities_by_type(
|
def get_entities_by_type(
|
||||||
self,
|
self,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
# backend/tests/test_zep_entity_reader.py
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from backend.app.services.zep_entity_reader import ZepEntityReader, EntityNode
|
||||||
|
|
||||||
|
def _make_entity(uuid, name, edge_count):
|
||||||
|
e = EntityNode(uuid=uuid, name=name, labels=["Person", "Entity"], summary="", attributes={})
|
||||||
|
e.related_edges = [{}] * edge_count # simulate edge_count edges
|
||||||
|
return e
|
||||||
|
|
||||||
|
def test_get_entities_by_connectivity_returns_top_n():
|
||||||
|
reader = ZepEntityReader.__new__(ZepEntityReader)
|
||||||
|
entities = [
|
||||||
|
_make_entity("u1", "Alice", 10),
|
||||||
|
_make_entity("u2", "Bob", 3),
|
||||||
|
_make_entity("u3", "Carol", 7),
|
||||||
|
_make_entity("u4", "Dave", 1),
|
||||||
|
_make_entity("u5", "Eve", 5),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(reader, 'filter_defined_entities') as mock_filter:
|
||||||
|
from backend.app.services.zep_entity_reader import FilteredEntities
|
||||||
|
mock_filter.return_value = FilteredEntities(
|
||||||
|
entities=entities, entity_types=set(), total_count=5, filtered_count=5
|
||||||
|
)
|
||||||
|
result = reader.get_entities_by_connectivity(graph_id="g1", max_n=3)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0].name == "Alice" # 10 edges — top
|
||||||
|
assert result[1].name == "Carol" # 7 edges
|
||||||
|
assert result[2].name == "Eve" # 5 edges
|
||||||
|
|
||||||
|
def test_get_entities_by_connectivity_no_limit():
|
||||||
|
reader = ZepEntityReader.__new__(ZepEntityReader)
|
||||||
|
entities = [_make_entity(f"u{i}", f"E{i}", i) for i in range(5)]
|
||||||
|
with patch.object(reader, 'filter_defined_entities') as mock_filter:
|
||||||
|
from backend.app.services.zep_entity_reader import FilteredEntities
|
||||||
|
mock_filter.return_value = FilteredEntities(
|
||||||
|
entities=entities, entity_types=set(), total_count=5, filtered_count=5
|
||||||
|
)
|
||||||
|
result = reader.get_entities_by_connectivity(graph_id="g1", max_n=None)
|
||||||
|
assert len(result) == 5
|
||||||
Loading…
Reference in New Issue