Merge 9c1706f71d into 96096ea0ff
This commit is contained in:
commit
c1affb108b
|
|
@ -664,6 +664,12 @@ SECTION_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
- 不要添加模拟中不存在的信息
|
- 不要添加模拟中不存在的信息
|
||||||
- 如果某方面信息不足,如实说明
|
- 如果某方面信息不足,如实说明
|
||||||
|
|
||||||
|
5. 【禁止捏造数据】
|
||||||
|
- ❌ 禁止捏造用户名、引用、统计数字或互动数据
|
||||||
|
- ❌ 禁止在回复中包含 <tool_result> 块 — 只有系统会提供工具结果
|
||||||
|
- ✅ 只能引用真实出现在工具结果中的实体、引用和数据
|
||||||
|
- 如果工具结果中没有相关内容,应如实说明,而非编造
|
||||||
|
|
||||||
═══════════════════════════════════════════════════════════════
|
═══════════════════════════════════════════════════════════════
|
||||||
【⚠️ 格式规范 - 极其重要!】
|
【⚠️ 格式规范 - 极其重要!】
|
||||||
═══════════════════════════════════════════════════════════════
|
═══════════════════════════════════════════════════════════════
|
||||||
|
|
@ -1133,7 +1139,26 @@ class ReportAgent:
|
||||||
if params_desc:
|
if params_desc:
|
||||||
desc_parts.append(f" 参数: {params_desc}")
|
desc_parts.append(f" 参数: {params_desc}")
|
||||||
return "\n".join(desc_parts)
|
return "\n".join(desc_parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_fake_tool_results(response: str) -> str:
|
||||||
|
"""Strip any <tool_result> blocks the LLM fabricated in its response.
|
||||||
|
|
||||||
|
When the LLM generates a <tool_call> block and then continues to generate
|
||||||
|
a <tool_result> block in the same response, we must strip the fake result
|
||||||
|
before appending to message history. The real tool result will be injected
|
||||||
|
separately by the system.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
cleaned = re.sub(
|
||||||
|
r'<tool_result>.*?</tool_result>',
|
||||||
|
'',
|
||||||
|
response,
|
||||||
|
flags=re.DOTALL,
|
||||||
|
)
|
||||||
|
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
def plan_outline(
|
def plan_outline(
|
||||||
self,
|
self,
|
||||||
progress_callback: Optional[Callable] = None
|
progress_callback: Optional[Callable] = None
|
||||||
|
|
@ -1335,7 +1360,8 @@ class ReportAgent:
|
||||||
|
|
||||||
if conflict_retries <= 2:
|
if conflict_retries <= 2:
|
||||||
# 前两次:丢弃本次响应,要求 LLM 重新回复
|
# 前两次:丢弃本次响应,要求 LLM 重新回复
|
||||||
messages.append({"role": "assistant", "content": response})
|
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||||
|
messages.append({"role": "assistant", "content": cleaned_response})
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": (
|
"content": (
|
||||||
|
|
@ -1375,7 +1401,8 @@ class ReportAgent:
|
||||||
if has_final_answer:
|
if has_final_answer:
|
||||||
# 工具调用次数不足,拒绝并要求继续调工具
|
# 工具调用次数不足,拒绝并要求继续调工具
|
||||||
if tool_calls_count < min_tool_calls:
|
if tool_calls_count < min_tool_calls:
|
||||||
messages.append({"role": "assistant", "content": response})
|
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||||
|
messages.append({"role": "assistant", "content": cleaned_response})
|
||||||
unused_tools = all_tools - used_tools
|
unused_tools = all_tools - used_tools
|
||||||
unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else ""
|
unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)})" if unused_tools else ""
|
||||||
messages.append({
|
messages.append({
|
||||||
|
|
@ -1451,9 +1478,10 @@ class ReportAgent:
|
||||||
unused_tools = all_tools - used_tools
|
unused_tools = all_tools - used_tools
|
||||||
unused_hint = ""
|
unused_hint = ""
|
||||||
if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION:
|
if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION:
|
||||||
unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools))
|
unlock_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="、".join(unused_tools))
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": response})
|
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||||
|
messages.append({"role": "assistant", "content": cleaned_response})
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": REACT_OBSERVATION_TEMPLATE.format(
|
"content": REACT_OBSERVATION_TEMPLATE.format(
|
||||||
|
|
@ -1857,7 +1885,8 @@ class ReportAgent:
|
||||||
tool_calls_made.append(call)
|
tool_calls_made.append(call)
|
||||||
|
|
||||||
# 将结果添加到消息
|
# 将结果添加到消息
|
||||||
messages.append({"role": "assistant", "content": response})
|
cleaned_response = ReportAgent._strip_fake_tool_results(response)
|
||||||
|
messages.append({"role": "assistant", "content": cleaned_response})
|
||||||
observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results])
|
observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results])
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
[pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -v --tb=short
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""Tests package marker"""
|
||||||
|
|
@ -0,0 +1,244 @@
|
||||||
|
"""Unit tests for file_parser module."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Import the module directly to avoid Flask initialization issues
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
# Direct import of the module to avoid app package initialization
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"file_parser",
|
||||||
|
Path(__file__).parent.parent / "app" / "utils" / "file_parser.py"
|
||||||
|
)
|
||||||
|
file_parser_module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(file_parser_module)
|
||||||
|
|
||||||
|
_read_text_with_fallback = file_parser_module._read_text_with_fallback
|
||||||
|
FileParser = file_parser_module.FileParser
|
||||||
|
split_text_into_chunks = file_parser_module.split_text_into_chunks
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadTextWithFallback:
|
||||||
|
"""Tests for _read_text_with_fallback function."""
|
||||||
|
|
||||||
|
def test_read_utf8_file(self):
|
||||||
|
"""Should read UTF-8 encoded file correctly."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
|
||||||
|
f.write("Hello, 你好, こんにちは")
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
result = _read_text_with_fallback(path)
|
||||||
|
assert result == "Hello, 你好, こんにちは"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_read_gbk_file_with_fallback(self):
|
||||||
|
"""Should read GBK encoded file using UTF-8 replacement when detection fails."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
|
||||||
|
content = "你好世界".encode('gbk')
|
||||||
|
f.write(content)
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
result = _read_text_with_fallback(path)
|
||||||
|
# Result may be garbled if charset detection fails, but should not raise
|
||||||
|
assert len(result) > 0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_read_latin1_file(self):
|
||||||
|
"""Should read Latin-1 encoded file using UTF-8 replacement when detection fails."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
|
||||||
|
content = "Héllo Wörld".encode('latin-1')
|
||||||
|
f.write(content)
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
result = _read_text_with_fallback(path)
|
||||||
|
# Result may be garbled if charset detection fails, but should not raise
|
||||||
|
assert len(result) > 0
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_read_file_with_replacement_chars(self):
|
||||||
|
"""Should replace invalid characters instead of failing."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
|
||||||
|
content = b"Hello\x00\xff\xfeWorld"
|
||||||
|
f.write(content)
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
result = _read_text_with_fallback(path)
|
||||||
|
assert "Hello" in result
|
||||||
|
assert "World" in result
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileParser:
|
||||||
|
"""Tests for FileParser class."""
|
||||||
|
|
||||||
|
def test_supported_extensions(self):
|
||||||
|
"""Should have correct supported extensions."""
|
||||||
|
assert '.pdf' in FileParser.SUPPORTED_EXTENSIONS
|
||||||
|
assert '.md' in FileParser.SUPPORTED_EXTENSIONS
|
||||||
|
assert '.markdown' in FileParser.SUPPORTED_EXTENSIONS
|
||||||
|
assert '.txt' in FileParser.SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
|
def test_extract_text_from_nonexistent_file(self):
|
||||||
|
"""Should raise FileNotFoundError for nonexistent file."""
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
FileParser.extract_text('/nonexistent/file.txt')
|
||||||
|
|
||||||
|
def test_extract_text_from_unsupported_format(self):
|
||||||
|
"""Should raise ValueError for unsupported format."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.xyz', delete=False) as f:
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
with pytest.raises(ValueError, match="不支持的文件格式"):
|
||||||
|
FileParser.extract_text(path)
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_extract_text_from_md_file(self):
|
||||||
|
"""Should extract text from markdown file."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.md', delete=False) as f:
|
||||||
|
f.write("# Title\n\nThis is content.")
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
result = FileParser.extract_text(path)
|
||||||
|
assert "# Title" in result
|
||||||
|
assert "This is content." in result
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_extract_text_from_txt_file(self):
|
||||||
|
"""Should extract text from txt file."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
|
||||||
|
f.write("Plain text content")
|
||||||
|
path = f.name
|
||||||
|
try:
|
||||||
|
result = FileParser.extract_text(path)
|
||||||
|
assert result == "Plain text content"
|
||||||
|
finally:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_extract_from_multiple_with_all_valid(self):
|
||||||
|
"""Should extract from multiple valid files."""
|
||||||
|
files = []
|
||||||
|
try:
|
||||||
|
for i in range(3):
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||||
|
f.write(f"Content {i}")
|
||||||
|
files.append(f.name)
|
||||||
|
|
||||||
|
result = FileParser.extract_from_multiple(files)
|
||||||
|
assert "Content 0" in result
|
||||||
|
assert "Content 1" in result
|
||||||
|
assert "Content 2" in result
|
||||||
|
assert "文档 1" in result
|
||||||
|
assert "文档 2" in result
|
||||||
|
assert "文档 3" in result
|
||||||
|
finally:
|
||||||
|
for path in files:
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
def test_extract_from_multiple_with_invalid_file(self):
|
||||||
|
"""Should handle invalid file gracefully in batch mode."""
|
||||||
|
files = [
|
||||||
|
'/nonexistent/path.txt',
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||||
|
f.write("Valid content")
|
||||||
|
files.append(f.name)
|
||||||
|
|
||||||
|
result = FileParser.extract_from_multiple(files)
|
||||||
|
assert "Valid content" in result
|
||||||
|
assert "提取失败" in result
|
||||||
|
finally:
|
||||||
|
for path in files[1:]:
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitTextIntoChunks:
|
||||||
|
"""Tests for split_text_into_chunks function."""
|
||||||
|
|
||||||
|
def test_short_text_returns_single_chunk(self):
|
||||||
|
"""Should return single chunk when text is shorter than chunk_size."""
|
||||||
|
text = "Short text"
|
||||||
|
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0] == text
|
||||||
|
|
||||||
|
def test_empty_text_returns_empty_list(self):
|
||||||
|
"""Should return empty list for empty/whitespace text."""
|
||||||
|
assert split_text_into_chunks(" ", chunk_size=500, overlap=50) == []
|
||||||
|
assert split_text_into_chunks("", chunk_size=500, overlap=50) == []
|
||||||
|
|
||||||
|
def test_text_exactly_at_chunk_size(self):
|
||||||
|
"""Should return single chunk when text equals chunk_size."""
|
||||||
|
text = "a" * 500
|
||||||
|
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
def test_long_text_splits_into_multiple_chunks(self):
|
||||||
|
"""Should split long text into multiple chunks."""
|
||||||
|
text = "a" * 1000
|
||||||
|
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||||
|
assert len(result) >= 2
|
||||||
|
|
||||||
|
def test_chunks_have_overlap(self):
|
||||||
|
"""Should have overlapping content between consecutive chunks."""
|
||||||
|
text = "abcdefghij" * 100
|
||||||
|
chunks = split_text_into_chunks(text, chunk_size=100, overlap=20)
|
||||||
|
if len(chunks) >= 2:
|
||||||
|
assert chunks[0][-20:] == chunks[1][:20], "Chunks should overlap"
|
||||||
|
|
||||||
|
def test_chunks_preserve_content(self):
|
||||||
|
"""Should preserve all original content across chunks."""
|
||||||
|
text = "".join(str(i) for i in range(500))
|
||||||
|
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
|
||||||
|
combined = "".join(chunks)
|
||||||
|
assert text[:400] in combined or all(c in combined for c in text[:400])
|
||||||
|
|
||||||
|
def test_chunk_size_parameter(self):
|
||||||
|
"""Should respect chunk_size parameter."""
|
||||||
|
text = "a" * 1000
|
||||||
|
result = split_text_into_chunks(text, chunk_size=100, overlap=0)
|
||||||
|
for chunk in result:
|
||||||
|
assert len(chunk) <= 100
|
||||||
|
|
||||||
|
def test_overlap_parameter(self):
|
||||||
|
"""Should respect overlap parameter."""
|
||||||
|
text = "abcdefghij" * 100
|
||||||
|
chunks = split_text_into_chunks(text, chunk_size=50, overlap=10)
|
||||||
|
if len(chunks) >= 2:
|
||||||
|
overlap_size = len(chunks[0]) - (len(chunks[0].rstrip()) - len(chunks[1].lstrip()))
|
||||||
|
assert overlap_size >= 5
|
||||||
|
|
||||||
|
def test_split_at_sentence_boundary(self):
|
||||||
|
"""Should try to split at sentence boundaries when possible."""
|
||||||
|
text = "第一句。第二句。第三句。" * 50
|
||||||
|
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
|
||||||
|
for chunk in chunks:
|
||||||
|
if len(chunk) > 50:
|
||||||
|
assert chunk[-1] in "。.\n"
|
||||||
|
|
||||||
|
def test_last_chunk_may_be_smaller(self):
|
||||||
|
"""Should allow last chunk to be smaller than chunk_size."""
|
||||||
|
text = "a" * 550
|
||||||
|
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||||
|
assert any(len(chunk) < 500 for chunk in result)
|
||||||
|
|
||||||
|
def test_whitespace_only_chunks_filtered(self):
|
||||||
|
"""Should filter out whitespace-only chunks."""
|
||||||
|
text = "content" + " " * 600 + "more content"
|
||||||
|
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
|
||||||
|
for chunk in result:
|
||||||
|
assert chunk.strip()
|
||||||
|
|
@ -0,0 +1,277 @@
|
||||||
|
"""Unit tests for zep_paging module."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Import the module directly to avoid Flask/Zep initialization issues
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "app" / "utils"))
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"zep_paging",
|
||||||
|
Path(__file__).parent.parent / "app" / "utils" / "zep_paging.py"
|
||||||
|
)
|
||||||
|
zep_paging_module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(zep_paging_module)
|
||||||
|
|
||||||
|
fetch_all_nodes = zep_paging_module.fetch_all_nodes
|
||||||
|
fetch_all_edges = zep_paging_module.fetch_all_edges
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchPageWithRetry:
|
||||||
|
"""Tests for _fetch_page_with_retry (via fetch_all_nodes/fetch_all_edges)."""
|
||||||
|
|
||||||
|
def test_success_on_first_attempt(self):
|
||||||
|
"""Should return result immediately on first successful call."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
batch = [SimpleNamespace(uuid_="abc-123", name="TestNode")]
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = batch
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
mock_client.graph.node.get_by_graph_id.assert_called_once()
|
||||||
|
|
||||||
|
def test_retries_on_transient_error(self):
|
||||||
|
"""Should retry on ConnectionError, TimeoutError, OSError."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
batch = [SimpleNamespace(uuid_="abc-123", name="TestNode")]
|
||||||
|
# Fail twice, then succeed
|
||||||
|
mock_client.graph.node.get_by_graph_id.side_effect = [
|
||||||
|
ConnectionError("first failure"),
|
||||||
|
TimeoutError("second failure"),
|
||||||
|
batch
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("time.sleep"):
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert mock_client.graph.node.get_by_graph_id.call_count == 3
|
||||||
|
|
||||||
|
def test_exhausts_retries_and_raises(self):
|
||||||
|
"""Should raise after exhausting max_retries attempts."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.node.get_by_graph_id.side_effect = ConnectionError("always fails")
|
||||||
|
|
||||||
|
with patch("time.sleep"), pytest.raises(ConnectionError):
|
||||||
|
fetch_all_nodes(mock_client, "graph-id", max_retries=3)
|
||||||
|
|
||||||
|
assert mock_client.graph.node.get_by_graph_id.call_count == 3
|
||||||
|
|
||||||
|
def test_respects_max_retries_parameter(self):
|
||||||
|
"""Should use the max_retries parameter value."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.node.get_by_graph_id.side_effect = ConnectionError("always fails")
|
||||||
|
|
||||||
|
with patch("time.sleep"), pytest.raises(ConnectionError):
|
||||||
|
fetch_all_nodes(mock_client, "graph-id", max_retries=5)
|
||||||
|
|
||||||
|
assert mock_client.graph.node.get_by_graph_id.call_count == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchAllNodes:
|
||||||
|
"""Tests for fetch_all_nodes function."""
|
||||||
|
|
||||||
|
def test_returns_empty_list_when_no_nodes(self):
|
||||||
|
"""Should return empty list when graph has no nodes."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_returns_all_nodes_single_page(self):
|
||||||
|
"""Should return all nodes when they fit in one page."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
batch = [
|
||||||
|
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||||
|
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||||
|
]
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = batch
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_paginates_multiple_pages(self):
|
||||||
|
"""Should paginate through multiple pages using uuid_cursor."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
# First page with uuid cursor
|
||||||
|
page1 = [
|
||||||
|
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||||
|
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||||
|
]
|
||||||
|
page2 = [SimpleNamespace(uuid_="n3", name="Node3")]
|
||||||
|
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2]
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert mock_client.graph.node.get_by_graph_id.call_count == 2
|
||||||
|
|
||||||
|
def test_respects_max_items_limit(self):
|
||||||
|
"""Should stop and truncate when max_items limit is reached."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
# Return pages with page_size=2 but max_items=3
|
||||||
|
page1 = [
|
||||||
|
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||||
|
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||||
|
]
|
||||||
|
page2 = [
|
||||||
|
SimpleNamespace(uuid_="n3", name="Node3"),
|
||||||
|
SimpleNamespace(uuid_="n4", name="Node4"),
|
||||||
|
]
|
||||||
|
page3 = [
|
||||||
|
SimpleNamespace(uuid_="n5", name="Node5"),
|
||||||
|
]
|
||||||
|
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2, page3]
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id", max_items=3)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0].name == "Node1"
|
||||||
|
assert result[1].name == "Node2"
|
||||||
|
assert result[2].name == "Node3"
|
||||||
|
|
||||||
|
def test_uses_default_max_nodes_constant(self):
|
||||||
|
"""Should use _MAX_NODES (2000) as default max_items."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||||
|
|
||||||
|
fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
# With empty pages, it won't hit limit, but verifies default is used
|
||||||
|
|
||||||
|
def test_stops_when_page_smaller_than_page_size(self):
|
||||||
|
"""Should stop pagination when returned page is smaller than page_size."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
page1 = [SimpleNamespace(uuid_="n1", name="Node1")]
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = page1
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert mock_client.graph.node.get_by_graph_id.call_count == 1
|
||||||
|
|
||||||
|
def test_handles_missing_uuid_field_gracefully(self):
|
||||||
|
"""Should stop pagination when node missing uuid field."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
# First page normal
|
||||||
|
page1 = [
|
||||||
|
SimpleNamespace(uuid_="n1", name="Node1"),
|
||||||
|
SimpleNamespace(uuid_="n2", name="Node2"),
|
||||||
|
]
|
||||||
|
# Second page has node without uuid
|
||||||
|
page2 = [
|
||||||
|
SimpleNamespace(name="Node3"), # No uuid_
|
||||||
|
]
|
||||||
|
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2]
|
||||||
|
|
||||||
|
result = fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
# Should get first page but stop before second
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_uses_default_page_size(self):
|
||||||
|
"""Should pass limit=100 by default."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||||
|
|
||||||
|
fetch_all_nodes(mock_client, "graph-id")
|
||||||
|
|
||||||
|
call_kwargs = mock_client.graph.node.get_by_graph_id.call_args.kwargs
|
||||||
|
assert call_kwargs["limit"] == 100
|
||||||
|
|
||||||
|
def test_respects_page_size_parameter(self):
|
||||||
|
"""Should use custom page_size when provided."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.node.get_by_graph_id.return_value = []
|
||||||
|
|
||||||
|
fetch_all_nodes(mock_client, "graph-id", page_size=50)
|
||||||
|
|
||||||
|
call_kwargs = mock_client.graph.node.get_by_graph_id.call_args.kwargs
|
||||||
|
assert call_kwargs["limit"] == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchAllEdges:
|
||||||
|
"""Tests for fetch_all_edges function."""
|
||||||
|
|
||||||
|
def test_returns_empty_list_when_no_edges(self):
|
||||||
|
"""Should return empty list when graph has no edges."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.edge.get_by_graph_id.return_value = []
|
||||||
|
|
||||||
|
result = fetch_all_edges(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_returns_all_edges_single_page(self):
|
||||||
|
"""Should return all edges when they fit in one page."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
batch = [
|
||||||
|
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
|
||||||
|
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
|
||||||
|
]
|
||||||
|
mock_client.graph.edge.get_by_graph_id.return_value = batch
|
||||||
|
|
||||||
|
result = fetch_all_edges(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_paginates_multiple_pages(self):
|
||||||
|
"""Should paginate through multiple pages for edges."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
page1 = [
|
||||||
|
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
|
||||||
|
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
|
||||||
|
]
|
||||||
|
page2 = [SimpleNamespace(uuid_="e3", source="n3", target="n4")]
|
||||||
|
mock_client.graph.edge.get_by_graph_id.side_effect = [page1, page2]
|
||||||
|
|
||||||
|
result = fetch_all_edges(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert mock_client.graph.edge.get_by_graph_id.call_count == 2
|
||||||
|
|
||||||
|
def test_stops_when_page_smaller_than_page_size(self):
|
||||||
|
"""Should stop pagination when edge page is smaller than page_size."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
page1 = [SimpleNamespace(uuid_="e1", source="n1", target="n2")]
|
||||||
|
mock_client.graph.edge.get_by_graph_id.return_value = page1
|
||||||
|
|
||||||
|
result = fetch_all_edges(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert mock_client.graph.edge.get_by_graph_id.call_count == 1
|
||||||
|
|
||||||
|
def test_handles_missing_uuid_field_gracefully(self):
|
||||||
|
"""Should stop pagination when edge missing uuid field."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
page1 = [
|
||||||
|
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
|
||||||
|
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
|
||||||
|
]
|
||||||
|
page2 = [
|
||||||
|
SimpleNamespace(source="n3", target="n4"), # No uuid_
|
||||||
|
]
|
||||||
|
mock_client.graph.edge.get_by_graph_id.side_effect = [page1, page2]
|
||||||
|
|
||||||
|
result = fetch_all_edges(mock_client, "graph-id")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_uses_default_page_size_for_edges(self):
|
||||||
|
"""Should pass limit=100 by default for edges."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.graph.edge.get_by_graph_id.return_value = []
|
||||||
|
|
||||||
|
fetch_all_edges(mock_client, "graph-id")
|
||||||
|
|
||||||
|
call_kwargs = mock_client.graph.edge.get_by_graph_id.call_args.kwargs
|
||||||
|
assert call_kwargs["limit"] == 100
|
||||||
Loading…
Reference in New Issue