This commit is contained in:
JiayuWang(王嘉宇) 2026-05-28 17:40:23 -04:00 committed by GitHub
commit c1affb108b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 563 additions and 6 deletions

View File

@ -664,6 +664,12 @@ SECTION_SYSTEM_PROMPT_TEMPLATE = """\
- 不要添加模拟中不存在的信息
- 如果某方面信息不足如实说明
5. 禁止捏造数据
- 禁止捏造用户名引用统计数字或互动数据
- 禁止在回复中包含 <tool_result> 只有系统会提供工具结果
- 只能引用真实出现在工具结果中的实体引用和数据
- 如果工具结果中没有相关内容应如实说明而非编造
格式规范 - 极其重要
@ -1134,6 +1140,25 @@ class ReportAgent:
desc_parts.append(f" 参数: {params_desc}")
return "\n".join(desc_parts)
@staticmethod
def _strip_fake_tool_results(response: str) -> str:
"""Strip any <tool_result> blocks the LLM fabricated in its response.
When the LLM generates a <tool_call> block and then continues to generate
a <tool_result> block in the same response, we must strip the fake result
before appending to message history. The real tool result will be injected
separately by the system.
"""
import re
cleaned = re.sub(
r'<tool_result>.*?</tool_result>',
'',
response,
flags=re.DOTALL,
)
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
return cleaned.strip()
def plan_outline(
self,
progress_callback: Optional[Callable] = None
@ -1335,7 +1360,8 @@ class ReportAgent:
if conflict_retries <= 2:
# 前两次:丢弃本次响应,要求 LLM 重新回复
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
messages.append({
"role": "user",
"content": (
@ -1375,7 +1401,8 @@ class ReportAgent:
if has_final_answer:
# 工具调用次数不足,拒绝并要求继续调工具
if tool_calls_count < min_tool_calls:
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
unused_tools = all_tools - used_tools
unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)}" if unused_tools else ""
messages.append({
@ -1451,9 +1478,10 @@ class ReportAgent:
unused_tools = all_tools - used_tools
unused_hint = ""
if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION:
unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="".join(unused_tools))
unlock_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="".join(unused_tools))
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
messages.append({
"role": "user",
"content": REACT_OBSERVATION_TEMPLATE.format(
@ -1857,7 +1885,8 @@ class ReportAgent:
tool_calls_made.append(call)
# 将结果添加到消息
messages.append({"role": "assistant", "content": response})
cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results])
messages.append({
"role": "user",

6
backend/pytest.ini Normal file
View File

@ -0,0 +1,6 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short

View File

@ -0,0 +1 @@
"""Tests package marker"""

View File

@ -0,0 +1,244 @@
"""Unit tests for file_parser module."""
import os
import tempfile
from pathlib import Path
import pytest
# Import the module directly to avoid Flask initialization issues
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Direct import of the module to avoid app package initialization
import importlib.util
spec = importlib.util.spec_from_file_location(
"file_parser",
Path(__file__).parent.parent / "app" / "utils" / "file_parser.py"
)
file_parser_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(file_parser_module)
_read_text_with_fallback = file_parser_module._read_text_with_fallback
FileParser = file_parser_module.FileParser
split_text_into_chunks = file_parser_module.split_text_into_chunks
class TestReadTextWithFallback:
"""Tests for _read_text_with_fallback function."""
def test_read_utf8_file(self):
"""Should read UTF-8 encoded file correctly."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Hello, 你好, こんにちは")
path = f.name
try:
result = _read_text_with_fallback(path)
assert result == "Hello, 你好, こんにちは"
finally:
os.unlink(path)
def test_read_gbk_file_with_fallback(self):
"""Should read GBK encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "你好世界".encode('gbk')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)
def test_read_latin1_file(self):
"""Should read Latin-1 encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "Héllo Wörld".encode('latin-1')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)
def test_read_file_with_replacement_chars(self):
"""Should replace invalid characters instead of failing."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = b"Hello\x00\xff\xfeWorld"
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
assert "Hello" in result
assert "World" in result
finally:
os.unlink(path)
class TestFileParser:
"""Tests for FileParser class."""
def test_supported_extensions(self):
"""Should have correct supported extensions."""
assert '.pdf' in FileParser.SUPPORTED_EXTENSIONS
assert '.md' in FileParser.SUPPORTED_EXTENSIONS
assert '.markdown' in FileParser.SUPPORTED_EXTENSIONS
assert '.txt' in FileParser.SUPPORTED_EXTENSIONS
def test_extract_text_from_nonexistent_file(self):
"""Should raise FileNotFoundError for nonexistent file."""
with pytest.raises(FileNotFoundError):
FileParser.extract_text('/nonexistent/file.txt')
def test_extract_text_from_unsupported_format(self):
"""Should raise ValueError for unsupported format."""
with tempfile.NamedTemporaryFile(suffix='.xyz', delete=False) as f:
path = f.name
try:
with pytest.raises(ValueError, match="不支持的文件格式"):
FileParser.extract_text(path)
finally:
os.unlink(path)
def test_extract_text_from_md_file(self):
"""Should extract text from markdown file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.md', delete=False) as f:
f.write("# Title\n\nThis is content.")
path = f.name
try:
result = FileParser.extract_text(path)
assert "# Title" in result
assert "This is content." in result
finally:
os.unlink(path)
def test_extract_text_from_txt_file(self):
"""Should extract text from txt file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Plain text content")
path = f.name
try:
result = FileParser.extract_text(path)
assert result == "Plain text content"
finally:
os.unlink(path)
def test_extract_from_multiple_with_all_valid(self):
"""Should extract from multiple valid files."""
files = []
try:
for i in range(3):
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(f"Content {i}")
files.append(f.name)
result = FileParser.extract_from_multiple(files)
assert "Content 0" in result
assert "Content 1" in result
assert "Content 2" in result
assert "文档 1" in result
assert "文档 2" in result
assert "文档 3" in result
finally:
for path in files:
os.unlink(path)
def test_extract_from_multiple_with_invalid_file(self):
"""Should handle invalid file gracefully in batch mode."""
files = [
'/nonexistent/path.txt',
]
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write("Valid content")
files.append(f.name)
result = FileParser.extract_from_multiple(files)
assert "Valid content" in result
assert "提取失败" in result
finally:
for path in files[1:]:
if os.path.exists(path):
os.unlink(path)
class TestSplitTextIntoChunks:
"""Tests for split_text_into_chunks function."""
def test_short_text_returns_single_chunk(self):
"""Should return single chunk when text is shorter than chunk_size."""
text = "Short text"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
assert result[0] == text
def test_empty_text_returns_empty_list(self):
"""Should return empty list for empty/whitespace text."""
assert split_text_into_chunks(" ", chunk_size=500, overlap=50) == []
assert split_text_into_chunks("", chunk_size=500, overlap=50) == []
def test_text_exactly_at_chunk_size(self):
"""Should return single chunk when text equals chunk_size."""
text = "a" * 500
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
def test_long_text_splits_into_multiple_chunks(self):
"""Should split long text into multiple chunks."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) >= 2
def test_chunks_have_overlap(self):
"""Should have overlapping content between consecutive chunks."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=100, overlap=20)
if len(chunks) >= 2:
assert chunks[0][-20:] == chunks[1][:20], "Chunks should overlap"
def test_chunks_preserve_content(self):
"""Should preserve all original content across chunks."""
text = "".join(str(i) for i in range(500))
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
combined = "".join(chunks)
assert text[:400] in combined or all(c in combined for c in text[:400])
def test_chunk_size_parameter(self):
"""Should respect chunk_size parameter."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=100, overlap=0)
for chunk in result:
assert len(chunk) <= 100
def test_overlap_parameter(self):
"""Should respect overlap parameter."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=50, overlap=10)
if len(chunks) >= 2:
overlap_size = len(chunks[0]) - (len(chunks[0].rstrip()) - len(chunks[1].lstrip()))
assert overlap_size >= 5
def test_split_at_sentence_boundary(self):
"""Should try to split at sentence boundaries when possible."""
text = "第一句。第二句。第三句。" * 50
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
for chunk in chunks:
if len(chunk) > 50:
assert chunk[-1] in "。.\n"
def test_last_chunk_may_be_smaller(self):
"""Should allow last chunk to be smaller than chunk_size."""
text = "a" * 550
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert any(len(chunk) < 500 for chunk in result)
def test_whitespace_only_chunks_filtered(self):
"""Should filter out whitespace-only chunks."""
text = "content" + " " * 600 + "more content"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
for chunk in result:
assert chunk.strip()

View File

@ -0,0 +1,277 @@
"""Unit tests for zep_paging module."""
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
from types import SimpleNamespace
import pytest
# Import the module directly to avoid Flask/Zep initialization issues
sys.path.insert(0, str(Path(__file__).parent.parent / "app" / "utils"))
import importlib.util
spec = importlib.util.spec_from_file_location(
"zep_paging",
Path(__file__).parent.parent / "app" / "utils" / "zep_paging.py"
)
zep_paging_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(zep_paging_module)
fetch_all_nodes = zep_paging_module.fetch_all_nodes
fetch_all_edges = zep_paging_module.fetch_all_edges
class TestFetchPageWithRetry:
"""Tests for _fetch_page_with_retry (via fetch_all_nodes/fetch_all_edges)."""
def test_success_on_first_attempt(self):
"""Should return result immediately on first successful call."""
mock_client = MagicMock()
batch = [SimpleNamespace(uuid_="abc-123", name="TestNode")]
mock_client.graph.node.get_by_graph_id.return_value = batch
result = fetch_all_nodes(mock_client, "graph-id")
assert len(result) == 1
mock_client.graph.node.get_by_graph_id.assert_called_once()
def test_retries_on_transient_error(self):
"""Should retry on ConnectionError, TimeoutError, OSError."""
mock_client = MagicMock()
batch = [SimpleNamespace(uuid_="abc-123", name="TestNode")]
# Fail twice, then succeed
mock_client.graph.node.get_by_graph_id.side_effect = [
ConnectionError("first failure"),
TimeoutError("second failure"),
batch
]
with patch("time.sleep"):
result = fetch_all_nodes(mock_client, "graph-id")
assert len(result) == 1
assert mock_client.graph.node.get_by_graph_id.call_count == 3
def test_exhausts_retries_and_raises(self):
"""Should raise after exhausting max_retries attempts."""
mock_client = MagicMock()
mock_client.graph.node.get_by_graph_id.side_effect = ConnectionError("always fails")
with patch("time.sleep"), pytest.raises(ConnectionError):
fetch_all_nodes(mock_client, "graph-id", max_retries=3)
assert mock_client.graph.node.get_by_graph_id.call_count == 3
def test_respects_max_retries_parameter(self):
"""Should use the max_retries parameter value."""
mock_client = MagicMock()
mock_client.graph.node.get_by_graph_id.side_effect = ConnectionError("always fails")
with patch("time.sleep"), pytest.raises(ConnectionError):
fetch_all_nodes(mock_client, "graph-id", max_retries=5)
assert mock_client.graph.node.get_by_graph_id.call_count == 5
class TestFetchAllNodes:
"""Tests for fetch_all_nodes function."""
def test_returns_empty_list_when_no_nodes(self):
"""Should return empty list when graph has no nodes."""
mock_client = MagicMock()
mock_client.graph.node.get_by_graph_id.return_value = []
result = fetch_all_nodes(mock_client, "graph-id")
assert result == []
def test_returns_all_nodes_single_page(self):
"""Should return all nodes when they fit in one page."""
mock_client = MagicMock()
batch = [
SimpleNamespace(uuid_="n1", name="Node1"),
SimpleNamespace(uuid_="n2", name="Node2"),
]
mock_client.graph.node.get_by_graph_id.return_value = batch
result = fetch_all_nodes(mock_client, "graph-id")
assert len(result) == 2
def test_paginates_multiple_pages(self):
"""Should paginate through multiple pages using uuid_cursor."""
mock_client = MagicMock()
# First page with uuid cursor
page1 = [
SimpleNamespace(uuid_="n1", name="Node1"),
SimpleNamespace(uuid_="n2", name="Node2"),
]
page2 = [SimpleNamespace(uuid_="n3", name="Node3")]
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2]
result = fetch_all_nodes(mock_client, "graph-id")
assert len(result) == 3
assert mock_client.graph.node.get_by_graph_id.call_count == 2
def test_respects_max_items_limit(self):
"""Should stop and truncate when max_items limit is reached."""
mock_client = MagicMock()
# Return pages with page_size=2 but max_items=3
page1 = [
SimpleNamespace(uuid_="n1", name="Node1"),
SimpleNamespace(uuid_="n2", name="Node2"),
]
page2 = [
SimpleNamespace(uuid_="n3", name="Node3"),
SimpleNamespace(uuid_="n4", name="Node4"),
]
page3 = [
SimpleNamespace(uuid_="n5", name="Node5"),
]
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2, page3]
result = fetch_all_nodes(mock_client, "graph-id", max_items=3)
assert len(result) == 3
assert result[0].name == "Node1"
assert result[1].name == "Node2"
assert result[2].name == "Node3"
def test_uses_default_max_nodes_constant(self):
"""Should use _MAX_NODES (2000) as default max_items."""
mock_client = MagicMock()
mock_client.graph.node.get_by_graph_id.return_value = []
fetch_all_nodes(mock_client, "graph-id")
# With empty pages, it won't hit limit, but verifies default is used
def test_stops_when_page_smaller_than_page_size(self):
"""Should stop pagination when returned page is smaller than page_size."""
mock_client = MagicMock()
page1 = [SimpleNamespace(uuid_="n1", name="Node1")]
mock_client.graph.node.get_by_graph_id.return_value = page1
result = fetch_all_nodes(mock_client, "graph-id")
assert len(result) == 1
assert mock_client.graph.node.get_by_graph_id.call_count == 1
def test_handles_missing_uuid_field_gracefully(self):
"""Should stop pagination when node missing uuid field."""
mock_client = MagicMock()
# First page normal
page1 = [
SimpleNamespace(uuid_="n1", name="Node1"),
SimpleNamespace(uuid_="n2", name="Node2"),
]
# Second page has node without uuid
page2 = [
SimpleNamespace(name="Node3"), # No uuid_
]
mock_client.graph.node.get_by_graph_id.side_effect = [page1, page2]
result = fetch_all_nodes(mock_client, "graph-id")
# Should get first page but stop before second
assert len(result) == 2
def test_uses_default_page_size(self):
"""Should pass limit=100 by default."""
mock_client = MagicMock()
mock_client.graph.node.get_by_graph_id.return_value = []
fetch_all_nodes(mock_client, "graph-id")
call_kwargs = mock_client.graph.node.get_by_graph_id.call_args.kwargs
assert call_kwargs["limit"] == 100
def test_respects_page_size_parameter(self):
"""Should use custom page_size when provided."""
mock_client = MagicMock()
mock_client.graph.node.get_by_graph_id.return_value = []
fetch_all_nodes(mock_client, "graph-id", page_size=50)
call_kwargs = mock_client.graph.node.get_by_graph_id.call_args.kwargs
assert call_kwargs["limit"] == 50
class TestFetchAllEdges:
"""Tests for fetch_all_edges function."""
def test_returns_empty_list_when_no_edges(self):
"""Should return empty list when graph has no edges."""
mock_client = MagicMock()
mock_client.graph.edge.get_by_graph_id.return_value = []
result = fetch_all_edges(mock_client, "graph-id")
assert result == []
def test_returns_all_edges_single_page(self):
"""Should return all edges when they fit in one page."""
mock_client = MagicMock()
batch = [
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
]
mock_client.graph.edge.get_by_graph_id.return_value = batch
result = fetch_all_edges(mock_client, "graph-id")
assert len(result) == 2
def test_paginates_multiple_pages(self):
"""Should paginate through multiple pages for edges."""
mock_client = MagicMock()
page1 = [
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
]
page2 = [SimpleNamespace(uuid_="e3", source="n3", target="n4")]
mock_client.graph.edge.get_by_graph_id.side_effect = [page1, page2]
result = fetch_all_edges(mock_client, "graph-id")
assert len(result) == 3
assert mock_client.graph.edge.get_by_graph_id.call_count == 2
def test_stops_when_page_smaller_than_page_size(self):
"""Should stop pagination when edge page is smaller than page_size."""
mock_client = MagicMock()
page1 = [SimpleNamespace(uuid_="e1", source="n1", target="n2")]
mock_client.graph.edge.get_by_graph_id.return_value = page1
result = fetch_all_edges(mock_client, "graph-id")
assert len(result) == 1
assert mock_client.graph.edge.get_by_graph_id.call_count == 1
def test_handles_missing_uuid_field_gracefully(self):
"""Should stop pagination when edge missing uuid field."""
mock_client = MagicMock()
page1 = [
SimpleNamespace(uuid_="e1", source="n1", target="n2"),
SimpleNamespace(uuid_="e2", source="n2", target="n3"),
]
page2 = [
SimpleNamespace(source="n3", target="n4"), # No uuid_
]
mock_client.graph.edge.get_by_graph_id.side_effect = [page1, page2]
result = fetch_all_edges(mock_client, "graph-id")
assert len(result) == 2
def test_uses_default_page_size_for_edges(self):
"""Should pass limit=100 by default for edges."""
mock_client = MagicMock()
mock_client.graph.edge.get_by_graph_id.return_value = []
fetch_all_edges(mock_client, "graph-id")
call_kwargs = mock_client.graph.edge.get_by_graph_id.call_args.kwargs
assert call_kwargs["limit"] == 100