Merge 9c1706f71d into 96096ea0ff
This commit is contained in:
commit
c1affb108b
|
|
@ -664,6 +664,12 @@ SECTION_SYSTEM_PROMPT_TEMPLATE = """\
|
|||
- 不要添加模拟中不存在的信息
|
||||
- 如果某方面信息不足,如实说明
|
||||
|
||||
5. 【禁止捏造数据】
|
||||
- ❌ 禁止捏造用户名、引用、统计数字或互动数据
|
||||
- ❌ 禁止在回复中包含 <tool_result> 块 — 只有系统会提供工具结果
|
||||
- ✅ 只能引用真实出现在工具结果中的实体、引用和数据
|
||||
- 如果工具结果中没有相关内容,应如实说明,而非编造
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
【⚠️ 格式规范 - 极其重要!】
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
|
@ -1134,6 +1140,25 @@ class ReportAgent:
|
|||
desc_parts.append(f" 参数: {params_desc}")
|
||||
return "\n".join(desc_parts)
|
||||
|
||||
@staticmethod
|
||||
def _strip_fake_tool_results(response: str) -> str:
|
||||
"""Strip any <tool_result> blocks the LLM fabricated in its response.
|
||||
|
||||
When the LLM generates a <tool_call> block and then continues to generate
|
||||
a <tool_result> block in the same response, we must strip the fake result
|
||||
before appending to message history. The real tool result will be injected
|
||||
separately by the system.
|
||||
"""
|
||||
import re
|
||||
cleaned = re.sub(
|
||||
r'<tool_result>.*?</tool_result>',
|
||||
'',
|
||||
response,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def plan_outline(
|
||||
self,
|
||||
progress_callback: Optional[Callable] = None
|
||||
|
|
@ -1335,7 +1360,8 @@ class ReportAgent:
|
|||
|
||||
if conflict_retries <= 2:
|
||||
# 前两次:丢弃本次响应,要求 LLM 重新回复
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||
messages.append({"role": "assistant", "content": cleaned_response})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": (
|
||||
|
|
@ -1375,7 +1401,8 @@ class ReportAgent:
|
|||
if has_final_answer:
|
||||
# 工具调用次数不足,拒绝并要求继续调工具
|
||||
if tool_calls_count < min_tool_calls:
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||
messages.append({"role": "assistant", "content": cleaned_response})
|
||||
unused_tools = all_tools - used_tools
|
||||
unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else ""
|
||||
messages.append({
|
||||
|
|
@ -1451,9 +1478,10 @@ class ReportAgent:
|
|||
unused_tools = all_tools - used_tools
|
||||
unused_hint = ""
|
||||
if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION:
|
||||
unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools))
|
||||
unlock_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools))
|
||||
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||
messages.append({"role": "assistant", "content": cleaned_response})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": REACT_OBSERVATION_TEMPLATE.format(
|
||||
|
|
@ -1857,7 +1885,8 @@ class ReportAgent:
|
|||
tool_calls_made.append(call)
|
||||
|
||||
# 将结果添加到消息
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||
messages.append({"role": "assistant", "content": cleaned_response})
|
||||
observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results])
|
||||
messages.append({
|
||||
"role": "user",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Tests package marker"""
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
"""Unit tests for file_parser module."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the module directly to avoid Flask initialization issues
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Direct import of the module to avoid app package initialization
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"file_parser",
|
||||
Path(__file__).parent.parent / "app" / "utils" / "file_parser.py"
|
||||
)
|
||||
file_parser_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(file_parser_module)
|
||||
|
||||
_read_text_with_fallback = file_parser_module._read_text_with_fallback
|
||||
FileParser = file_parser_module.FileParser
|
||||
split_text_into_chunks = file_parser_module.split_text_into_chunks
|
||||
|
||||
|
||||
class TestReadTextWithFallback:
|
||||
"""Tests for _read_text_with_fallback function."""
|
||||
|
||||
def test_read_utf8_file(self):
|
||||
"""Should read UTF-8 encoded file correctly."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
|
||||
f.write("Hello, 你好, こんにちは")
|
||||
path = f.name
|
||||
try:
|
||||
result = _read_text_with_fallback(path)
|
||||
assert result == "Hello, 你好, こんにちは"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_read_gbk_file_with_fallback(self):
|
||||
"""Should read GBK encoded file using UTF-8 replacement when detection fails."""
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
|
||||
content = "你好世界".encode('gbk')
|
||||
f.write(content)
|
||||
path = f.name
|
||||
try:
|
||||
result = _read_text_with_fallback(path)
|
||||
# Result may be garbled if charset detection fails, but should not raise
|
||||
assert len(result) > 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_read_latin1_file(self):
|
||||
"""Should read Latin-1 encoded file using UTF-8 replacement when detection fails."""
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
|
||||
content = "Héllo Wörld".encode('latin-1')
|
||||
f.write(content)
|
||||
path = f.name
|
||||
try:
|
||||
result = _read_text_with_fallback(path)
|
||||
# Result may be garbled if charset detection fails, but should not raise
|
||||
assert len(result) > 0
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_read_file_with_replacement_chars(self):
|
||||
"""Should replace invalid characters instead of failing."""
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
|
||||
content = b"Hello\x00\xff\xfeWorld"
|
||||
f.write(content)
|
||||
path = f.name
|
||||
try:
|
||||
result = _read_text_with_fallback(path)
|
||||
assert "Hello" in result
|
||||
assert "World" in result
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
class TestFileParser:
|
||||
"""Tests for FileParser class."""
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Should have correct supported extensions."""
|
||||
assert '.pdf' in FileParser.SUPPORTED_EXTENSIONS
|
||||
assert '.md' in FileParser.SUPPORTED_EXTENSIONS
|
||||
assert '.markdown' in FileParser.SUPPORTED_EXTENSIONS
|
||||
assert '.txt' in FileParser.SUPPORTED_EXTENSIONS
|
||||
|
||||
def test_extract_text_from_nonexistent_file(self):
|
||||
"""Should raise FileNotFoundError for nonexistent file."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
FileParser.extract_text('/nonexistent/file.txt')
|
||||
|
||||
def test_extract_text_from_unsupported_format(self):
|
||||
"""Should raise ValueError for unsupported format."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.xyz', delete=False) as f:
|
||||
path = f.name
|
||||
try:
|
||||
with pytest.raises(ValueError, match="不支持的文件格式"):
|
||||
FileParser.extract_text(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_extract_text_from_md_file(self):
|
||||
"""Should extract text from markdown file."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.md', delete=False) as f:
|
||||
f.write("# Title\n\nThis is content.")
|
||||
path = f.name
|
||||
try:
|
||||
result = FileParser.extract_text(path)
|
||||
assert "# Title" in result
|
||||
assert "This is content." in result
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_extract_text_from_txt_file(self):
|
||||
"""Should extract text from txt file."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
|
||||
f.write("Plain text content")
|
||||
path = f.name
|
||||
try:
|
||||
result = FileParser.extract_text(path)
|
||||
assert result == "Plain text content"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_extract_from_multiple_with_all_valid(self):
|
||||
"""Should extract from multiple valid files."""
|
||||
files = []
|
||||
try:
|
||||
for i in range(3):
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||
f.write(f"Content {i}")
|
||||
files.append(f.name)
|
||||
|
||||
result = FileParser.extract_from_multiple(files)
|
||||
assert "Content 0" in result
|
||||
assert "Content 1" in result
|
||||
assert "Content 2" in result
|
||||
assert "文档 1" in result
|
||||
assert "文档 2" in result
|
||||
assert "文档 3" in result
|
||||
finally:
|
||||
for path in files:
|
||||
os.unlink(path)
|
||||
|
||||
def test_extract_from_multiple_with_invalid_file(self):
|
||||
"""Should handle invalid file gracefully in batch mode."""
|
||||
files = [
|
||||
'/nonexistent/path.txt',
|
||||
]
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||
f.write("Valid content")
|
||||
files.append(f.name)
|
||||
|
||||
result = FileParser.extract_from_multiple(files)
|
||||
assert "Valid content" in result
|
||||
assert "提取失败" in result
|
||||
finally:
|
||||
for path in files[1:]:
|
||||
if os.path.exists(path):
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
class TestSplitTextIntoChunks:
|
||||
"""Tests for split_text_into_chunks function."""
|
||||
|
||||
def test_short_text_returns_single_chunk(self):
|
||||
"""Should return single chunk when text is shorter than chunk_size."""
|
||||
text = "Short text"
|
||||
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||
assert len(result) == 1
|
||||
assert result[0] == text
|
||||
|
||||
def test_empty_text_returns_empty_list(self):
|
||||
"""Should return empty list for empty/whitespace text."""
|
||||
assert split_text_into_chunks(" ", chunk_size=500, overlap=50) == []
|
||||
assert split_text_into_chunks("", chunk_size=500, overlap=50) == []
|
||||
|
||||
def test_text_exactly_at_chunk_size(self):
|
||||
"""Should return single chunk when text equals chunk_size."""
|
||||
text = "a" * 500
|
||||
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_long_text_splits_into_multiple_chunks(self):
|
||||
"""Should split long text into multiple chunks."""
|
||||
text = "a" * 1000
|
||||
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||
assert len(result) >= 2
|
||||
|
||||
def test_chunks_have_overlap(self):
|
||||
"""Should have overlapping content between consecutive chunks."""
|
||||
text = "abcdefghij" * 100
|
||||
chunks = split_text_into_chunks(text, chunk_size=100, overlap=20)
|
||||
if len(chunks) >= 2:
|
||||
assert chunks[0][-20:] == chunks[1][:20], "Chunks should overlap"
|
||||
|
||||
def test_chunks_preserve_content(self):
|
||||
"""Should preserve all original content across chunks."""
|
||||
text = "".join(str(i) for i in range(500))
|
||||
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
|
||||
combined = "".join(chunks)
|
||||
assert text[:400] in combined or all(c in combined for c in text[:400])
|
||||
|
||||
def test_chunk_size_parameter(self):
|
||||
"""Should respect chunk_size parameter."""
|
||||
text = "a" * 1000
|
||||
result = split_text_into_chunks(text, chunk_size=100, overlap=0)
|
||||
for chunk in result:
|
||||
assert len(chunk) <= 100
|
||||
|
||||
def test_overlap_parameter(self):
|
||||
"""Should respect overlap parameter."""
|
||||
text = "abcdefghij" * 100
|
||||
chunks = split_text_into_chunks(text, chunk_size=50, overlap=10)
|
||||
if len(chunks) >= 2:
|
||||
overlap_size = len(chunks[0]) - (len(chunks[0].rstrip()) - len(chunks[1].lstrip()))
|
||||
assert overlap_size >= 5
|
||||
|
||||
def test_split_at_sentence_boundary(self):
|
||||
"""Should try to split at sentence boundaries when possible."""
|
||||
text = "第一句。第二句。第三句。" * 50
|
||||
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
|
||||
for chunk in chunks:
|
||||
if len(chunk) > 50:
|
||||
assert chunk[-1] in "。.\n"
|
||||
|
||||
def test_last_chunk_may_be_smaller(self):
|
||||
"""Should allow last chunk to be smaller than chunk_size."""
|
||||
text = "a" * 550
|
||||
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||
assert any(len(chunk) < 500 for chunk in result)
|
||||
|
||||
def test_whitespace_only_chunks_filtered(self):
|
||||
"""Should filter out whitespace-only chunks."""
|
||||
text = "content" + " " * 600 + "more content"
|
||||
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||
for chunk in result:
|
||||
assert chunk.strip()
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
"""Unit tests for zep_paging module."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the module directly to avoid Flask/Zep initialization issues
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "app" / "utils"))
|
||||
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"zep_paging",
|
||||
Path(__file__).parent.parent / "app" / "utils" / "zep_paging.py"
|
||||
)
|
||||
zep_paging_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(zep_paging_module)
|
||||
|
||||
fetch_all_nodes = zep_paging_module.fetch_all_nodes
|
||||
fetch_all_edges = zep_paging_module.fetch_all_edges
|
||||
|
||||
|
||||
class TestFetchPageWithRetry:
|
||||
"""Tests for _fetch_page_with_retry (via fetch_all_nodes/fetch_all_edges)."""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
"""Should return result immediately on first successful call."""
|
||||
mock_client = MagicMock()
|
||||
batch = [SimpleNamespace(uuid_="abc-123", name="TestNode")]
|
||||
mock_client.graph.node.get_by_graph_id.return_value = batch
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 1
|
||||
mock_client.graph.node.get_by_graph_id.assert_called_once()
|
||||
|
||||
def test_retries_on_transient_error(self):
|
||||
"""Should retry on ConnectionError, TimeoutError, OSError."""
|
||||
mock_client = MagicMock()
|
||||
batch = [SimpleNamespace(uuid_="abc-123", name="TestNode")]
|
||||
# Fail twice, then succeed
|
||||
mock_client.graph.node.get_by_graph_id.side_effect = [
|
||||
ConnectionError("first failure"),
|
||||
TimeoutError("second failure"),
|
||||
batch
|
||||
]
|
||||
|
||||
with patch("time.sleep"):
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 1
|
||||
assert mock_client.graph.node.get_by_graph_id.call_count == 3
|
||||
|
||||
def test_exhausts_retries_and_raises(self):
|
||||
"""Should raise after exhausting max_retries attempts."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.node.get_by_graph_id.side_effect = ConnectionError("always fails")
|
||||
|
||||
with patch("time.sleep"), pytest.raises(ConnectionError):
|
||||
fetch_all_nodes(mock_client, "graph-id", max_retries=3)
|
||||
|
||||
assert mock_client.graph.node.get_by_graph_id.call_count == 3
|
||||
|
||||
def test_respects_max_retries_parameter(self):
|
||||
"""Should use the max_retries parameter value."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.node.get_by_graph_id.side_effect = ConnectionError("always fails")
|
||||
|
||||
with patch("time.sleep"), pytest.raises(ConnectionError):
|
||||
fetch_all_nodes(mock_client, "graph-id", max_retries=5)
|
||||
|
||||
assert mock_client.graph.node.get_by_graph_id.call_count == 5
|
||||
|
||||
|
||||
class TestFetchAllNodes:
|
||||
"""Tests for fetch_all_nodes function."""
|
||||
|
||||
def test_returns_empty_list_when_no_nodes(self):
|
||||
"""Should return empty list when graph has no nodes."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_returns_all_nodes_single_page(self):
|
||||
"""Should return all nodes when they fit in one page."""
|
||||
mock_client = MagicMock()
|
||||
batch = [
|
||||
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||
]
|
||||
mock_client.graph.node.get_by_graph_id.return_value = batch
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_paginates_multiple_pages(self):
|
||||
"""Should paginate through multiple pages using uuid_cursor."""
|
||||
mock_client = MagicMock()
|
||||
# First page with uuid cursor
|
||||
page1 = [
|
||||
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||
]
|
||||
page2 = [SimpleNamespace(uuid_="n3", name="Node3")]
|
||||
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2]
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 3
|
||||
assert mock_client.graph.node.get_by_graph_id.call_count == 2
|
||||
|
||||
def test_respects_max_items_limit(self):
|
||||
"""Should stop and truncate when max_items limit is reached."""
|
||||
mock_client = MagicMock()
|
||||
# Return pages with page_size=2 but max_items=3
|
||||
page1 = [
|
||||
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||
]
|
||||
page2 = [
|
||||
SimpleNamespace(uuid_="n3", name="Node3"),
|
||||
SimpleNamespace(uuid_="n4", name="Node4"),
|
||||
]
|
||||
page3 = [
|
||||
SimpleNamespace(uuid_="n5", name="Node5"),
|
||||
]
|
||||
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2, page3]
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id", max_items=3)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "Node1"
|
||||
assert result[1].name == "Node2"
|
||||
assert result[2].name == "Node3"
|
||||
|
||||
def test_uses_default_max_nodes_constant(self):
|
||||
"""Should use _MAX_NODES (2000) as default max_items."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||
|
||||
fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
# With empty pages, it won't hit limit, but verifies default is used
|
||||
|
||||
def test_stops_when_page_smaller_than_page_size(self):
|
||||
"""Should stop pagination when returned page is smaller than page_size."""
|
||||
mock_client = MagicMock()
|
||||
page1 = [SimpleNamespace(uuid_="n1", name="Node1")]
|
||||
mock_client.graph.node.get_by_graph_id.return_value = page1
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 1
|
||||
assert mock_client.graph.node.get_by_graph_id.call_count == 1
|
||||
|
||||
def test_handles_missing_uuid_field_gracefully(self):
|
||||
"""Should stop pagination when node missing uuid field."""
|
||||
mock_client = MagicMock()
|
||||
# First page normal
|
||||
page1 = [
|
||||
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||
]
|
||||
# Second page has node without uuid
|
||||
page2 = [
|
||||
SimpleNamespace(name="Node3"), # No uuid_
|
||||
]
|
||||
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2]
|
||||
|
||||
result = fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
# Should get first page but stop before second
|
||||
assert len(result) == 2
|
||||
|
||||
def test_uses_default_page_size(self):
|
||||
"""Should pass limit=100 by default."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||
|
||||
fetch_all_nodes(mock_client, "graph-id")
|
||||
|
||||
call_kwargs = mock_client.graph.node.get_by_graph_id.call_args.kwargs
|
||||
assert call_kwargs["limit"] == 100
|
||||
|
||||
def test_respects_page_size_parameter(self):
|
||||
"""Should use custom page_size when provided."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||
|
||||
fetch_all_nodes(mock_client, "graph-id", page_size=50)
|
||||
|
||||
call_kwargs = mock_client.graph.node.get_by_graph_id.call_args.kwargs
|
||||
assert call_kwargs["limit"] == 50
|
||||
|
||||
|
||||
class TestFetchAllEdges:
|
||||
"""Tests for fetch_all_edges function."""
|
||||
|
||||
def test_returns_empty_list_when_no_edges(self):
|
||||
"""Should return empty list when graph has no edges."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.edge.get_by_graph_id.return_value = []
|
||||
|
||||
result = fetch_all_edges(mock_client, "graph-id")
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_returns_all_edges_single_page(self):
|
||||
"""Should return all edges when they fit in one page."""
|
||||
mock_client = MagicMock()
|
||||
batch = [
|
||||
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
|
||||
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
|
||||
]
|
||||
mock_client.graph.edge.get_by_graph_id.return_value = batch
|
||||
|
||||
result = fetch_all_edges(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_paginates_multiple_pages(self):
|
||||
"""Should paginate through multiple pages for edges."""
|
||||
mock_client = MagicMock()
|
||||
page1 = [
|
||||
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
|
||||
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
|
||||
]
|
||||
page2 = [SimpleNamespace(uuid_="e3", source="n3", target="n4")]
|
||||
mock_client.graph.edge.get_by_graph_id.side_effect = [page1, page2]
|
||||
|
||||
result = fetch_all_edges(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 3
|
||||
assert mock_client.graph.edge.get_by_graph_id.call_count == 2
|
||||
|
||||
def test_stops_when_page_smaller_than_page_size(self):
|
||||
"""Should stop pagination when edge page is smaller than page_size."""
|
||||
mock_client = MagicMock()
|
||||
page1 = [SimpleNamespace(uuid_="e1", source="n1", target="n2")]
|
||||
mock_client.graph.edge.get_by_graph_id.return_value = page1
|
||||
|
||||
result = fetch_all_edges(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 1
|
||||
assert mock_client.graph.edge.get_by_graph_id.call_count == 1
|
||||
|
||||
def test_handles_missing_uuid_field_gracefully(self):
|
||||
"""Should stop pagination when edge missing uuid field."""
|
||||
mock_client = MagicMock()
|
||||
page1 = [
|
||||
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
|
||||
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
|
||||
]
|
||||
page2 = [
|
||||
SimpleNamespace(source="n3", target="n4"), # No uuid_
|
||||
]
|
||||
mock_client.graph.edge.get_by_graph_id.side_effect = [page1, page2]
|
||||
|
||||
result = fetch_all_edges(mock_client, "graph-id")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_uses_default_page_size_for_edges(self):
|
||||
"""Should pass limit=100 by default for edges."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.graph.edge.get_by_graph_id.return_value = []
|
||||
|
||||
fetch_all_edges(mock_client, "graph-id")
|
||||
|
||||
call_kwargs = mock_client.graph.edge.get_by_graph_id.call_args.kwargs
|
||||
assert call_kwargs["limit"] == 100
|
||||
Loading…
Reference in New Issue