This commit is contained in:
JiayuWang(王嘉宇) 2026-05-28 17:40:23 -04:00 committed by GitHub
commit c52e4efb17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 384 additions and 6 deletions

View File

@ -664,6 +664,12 @@ SECTION_SYSTEM_PROMPT_TEMPLATE = """\
- 不要添加模拟中不存在的信息 - 不要添加模拟中不存在的信息
- 如果某方面信息不足如实说明 - 如果某方面信息不足如实说明
5. 禁止捏造数据
- 禁止捏造用户名引用统计数字或互动数据
- 禁止在回复中包含 <tool_result> 只有系统会提供工具结果
- 只能引用真实出现在工具结果中的实体引用和数据
- 如果工具结果中没有相关内容应如实说明而非编造
格式规范 - 极其重要 格式规范 - 极其重要
@ -1133,7 +1139,26 @@ class ReportAgent:
if params_desc: if params_desc:
desc_parts.append(f" 参数: {params_desc}") desc_parts.append(f" 参数: {params_desc}")
return "\n".join(desc_parts) return "\n".join(desc_parts)
@staticmethod
def _strip_fake_tool_results(response: str) -> str:
"""Strip any <tool_result> blocks the LLM fabricated in its response.
When the LLM generates a <tool_call> block and then continues to generate
a <tool_result> block in the same response, we must strip the fake result
before appending to message history. The real tool result will be injected
separately by the system.
"""
import re
cleaned = re.sub(
r'<tool_result>.*?</tool_result>',
'',
response,
flags=re.DOTALL,
)
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
return cleaned.strip()
def plan_outline( def plan_outline(
self, self,
progress_callback: Optional[Callable] = None progress_callback: Optional[Callable] = None
@ -1335,7 +1360,8 @@ class ReportAgent:
if conflict_retries <= 2: if conflict_retries <= 2:
# 前两次:丢弃本次响应,要求 LLM 重新回复 # 前两次:丢弃本次响应,要求 LLM 重新回复
messages.append({"role": "assistant", "content": response}) cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
messages.append({ messages.append({
"role": "user", "role": "user",
"content": ( "content": (
@ -1375,7 +1401,8 @@ class ReportAgent:
if has_final_answer: if has_final_answer:
# 工具调用次数不足,拒绝并要求继续调工具 # 工具调用次数不足,拒绝并要求继续调工具
if tool_calls_count < min_tool_calls: if tool_calls_count < min_tool_calls:
messages.append({"role": "assistant", "content": response}) cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
unused_tools = all_tools - used_tools unused_tools = all_tools - used_tools
unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)}" if unused_tools else "" unused_hint = f"(这些工具还未使用,推荐用一下他们: {', '.join(unused_tools)}" if unused_tools else ""
messages.append({ messages.append({
@ -1451,9 +1478,10 @@ class ReportAgent:
unused_tools = all_tools - used_tools unused_tools = all_tools - used_tools
unused_hint = "" unused_hint = ""
if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION: if unused_tools and tool_calls_count < self.MAX_TOOL_CALLS_PER_SECTION:
unused_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="".join(unused_tools)) unlock_hint = REACT_UNUSED_TOOLS_HINT.format(unused_list="".join(unused_tools))
messages.append({"role": "assistant", "content": response}) cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
messages.append({ messages.append({
"role": "user", "role": "user",
"content": REACT_OBSERVATION_TEMPLATE.format( "content": REACT_OBSERVATION_TEMPLATE.format(
@ -1857,7 +1885,8 @@ class ReportAgent:
tool_calls_made.append(call) tool_calls_made.append(call)
# 将结果添加到消息 # 将结果添加到消息
messages.append({"role": "assistant", "content": response}) cleaned_response = ReportAgent._strip_fake_tool_results(response)
messages.append({"role": "assistant", "content": cleaned_response})
observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results]) observation = "\n".join([f"[{r['tool']}结果]\n{r['result']}" for r in tool_results])
messages.append({ messages.append({
"role": "user", "role": "user",

6
backend/pytest.ini Normal file
View File

@ -0,0 +1,6 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short

View File

@ -0,0 +1 @@
"""Tests package marker"""

View File

@ -0,0 +1,244 @@
"""Unit tests for file_parser module."""
import os
import tempfile
from pathlib import Path
import pytest
# Import the module directly to avoid Flask initialization issues
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Direct import of the module to avoid app package initialization
import importlib.util
spec = importlib.util.spec_from_file_location(
"file_parser",
Path(__file__).parent.parent / "app" / "utils" / "file_parser.py"
)
file_parser_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(file_parser_module)
_read_text_with_fallback = file_parser_module._read_text_with_fallback
FileParser = file_parser_module.FileParser
split_text_into_chunks = file_parser_module.split_text_into_chunks
class TestReadTextWithFallback:
"""Tests for _read_text_with_fallback function."""
def test_read_utf8_file(self):
"""Should read UTF-8 encoded file correctly."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Hello, 你好, こんにちは")
path = f.name
try:
result = _read_text_with_fallback(path)
assert result == "Hello, 你好, こんにちは"
finally:
os.unlink(path)
def test_read_gbk_file_with_fallback(self):
"""Should read GBK encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "你好世界".encode('gbk')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)
def test_read_latin1_file(self):
"""Should read Latin-1 encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "Héllo Wörld".encode('latin-1')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)
def test_read_file_with_replacement_chars(self):
"""Should replace invalid characters instead of failing."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = b"Hello\x00\xff\xfeWorld"
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
assert "Hello" in result
assert "World" in result
finally:
os.unlink(path)
class TestFileParser:
"""Tests for FileParser class."""
def test_supported_extensions(self):
"""Should have correct supported extensions."""
assert '.pdf' in FileParser.SUPPORTED_EXTENSIONS
assert '.md' in FileParser.SUPPORTED_EXTENSIONS
assert '.markdown' in FileParser.SUPPORTED_EXTENSIONS
assert '.txt' in FileParser.SUPPORTED_EXTENSIONS
def test_extract_text_from_nonexistent_file(self):
"""Should raise FileNotFoundError for nonexistent file."""
with pytest.raises(FileNotFoundError):
FileParser.extract_text('/nonexistent/file.txt')
def test_extract_text_from_unsupported_format(self):
"""Should raise ValueError for unsupported format."""
with tempfile.NamedTemporaryFile(suffix='.xyz', delete=False) as f:
path = f.name
try:
with pytest.raises(ValueError, match="不支持的文件格式"):
FileParser.extract_text(path)
finally:
os.unlink(path)
def test_extract_text_from_md_file(self):
"""Should extract text from markdown file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.md', delete=False) as f:
f.write("# Title\n\nThis is content.")
path = f.name
try:
result = FileParser.extract_text(path)
assert "# Title" in result
assert "This is content." in result
finally:
os.unlink(path)
def test_extract_text_from_txt_file(self):
"""Should extract text from txt file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Plain text content")
path = f.name
try:
result = FileParser.extract_text(path)
assert result == "Plain text content"
finally:
os.unlink(path)
def test_extract_from_multiple_with_all_valid(self):
"""Should extract from multiple valid files."""
files = []
try:
for i in range(3):
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(f"Content {i}")
files.append(f.name)
result = FileParser.extract_from_multiple(files)
assert "Content 0" in result
assert "Content 1" in result
assert "Content 2" in result
assert "文档 1" in result
assert "文档 2" in result
assert "文档 3" in result
finally:
for path in files:
os.unlink(path)
def test_extract_from_multiple_with_invalid_file(self):
"""Should handle invalid file gracefully in batch mode."""
files = [
'/nonexistent/path.txt',
]
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write("Valid content")
files.append(f.name)
result = FileParser.extract_from_multiple(files)
assert "Valid content" in result
assert "提取失败" in result
finally:
for path in files[1:]:
if os.path.exists(path):
os.unlink(path)
class TestSplitTextIntoChunks:
"""Tests for split_text_into_chunks function."""
def test_short_text_returns_single_chunk(self):
"""Should return single chunk when text is shorter than chunk_size."""
text = "Short text"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
assert result[0] == text
def test_empty_text_returns_empty_list(self):
"""Should return empty list for empty/whitespace text."""
assert split_text_into_chunks(" ", chunk_size=500, overlap=50) == []
assert split_text_into_chunks("", chunk_size=500, overlap=50) == []
def test_text_exactly_at_chunk_size(self):
"""Should return single chunk when text equals chunk_size."""
text = "a" * 500
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
def test_long_text_splits_into_multiple_chunks(self):
"""Should split long text into multiple chunks."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) >= 2
def test_chunks_have_overlap(self):
"""Should have overlapping content between consecutive chunks."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=100, overlap=20)
if len(chunks) >= 2:
assert chunks[0][-20:] == chunks[1][:20], "Chunks should overlap"
def test_chunks_preserve_content(self):
"""Should preserve all original content across chunks."""
text = "".join(str(i) for i in range(500))
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
combined = "".join(chunks)
assert text[:400] in combined or all(c in combined for c in text[:400])
def test_chunk_size_parameter(self):
"""Should respect chunk_size parameter."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=100, overlap=0)
for chunk in result:
assert len(chunk) <= 100
def test_overlap_parameter(self):
"""Should respect overlap parameter."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=50, overlap=10)
if len(chunks) >= 2:
overlap_size = len(chunks[0]) - (len(chunks[0].rstrip()) - len(chunks[1].lstrip()))
assert overlap_size >= 5
def test_split_at_sentence_boundary(self):
"""Should try to split at sentence boundaries when possible."""
text = "第一句。第二句。第三句。" * 50
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
for chunk in chunks:
if len(chunk) > 50:
assert chunk[-1] in "。.\n"
def test_last_chunk_may_be_smaller(self):
"""Should allow last chunk to be smaller than chunk_size."""
text = "a" * 550
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert any(len(chunk) < 500 for chunk in result)
def test_whitespace_only_chunks_filtered(self):
"""Should filter out whitespace-only chunks."""
text = "content" + " " * 600 + "more content"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
for chunk in result:
assert chunk.strip()

View File

@ -0,0 +1,98 @@
"""Unit tests for locale module."""
import json
from pathlib import Path
import pytest
def _load_locale_files():
"""Load locale files for testing."""
# Construct absolute path to source/locales from this test file's location
# This test file is at: source/backend/tests/test_locale.py
# We need: source/locales/
this_file = Path(__file__).resolve()
# tests/test_locale.py -> tests/ -> backend/ -> source/
source_root = this_file.parent.parent.parent
locales_dir = source_root / "locales"
# Load languages.json
with open(locales_dir / "languages.json", "r", encoding="utf-8") as f:
languages = json.load(f)
# Load translations
translations = {}
for filename in Path(locales_dir).iterdir():
if filename.suffix == ".json" and filename.name != "languages.json":
locale_name = filename.stem
with open(filename, "r", encoding="utf-8") as f:
translations[locale_name] = json.load(f)
return languages, translations
# Load locale data at module level for structure tests
_languages, _translations = _load_locale_files()
class TestLocaleStructure:
"""Tests for locale file structure and completeness."""
def test_languages_has_required_fields(self):
"""Should have required fields in languages.json."""
assert "zh" in _languages
assert "en" in _languages
assert _languages["zh"]["label"] == "中文"
assert _languages["en"]["label"] == "English"
def test_languages_have_llm_instruction(self):
"""Should have llmInstruction field for each language."""
for lang, config in _languages.items():
assert "llmInstruction" in config
assert len(config["llmInstruction"]) > 0
def test_zh_translation_has_common_keys(self):
"""zh translation should have common keys."""
zh = _translations.get("zh", {})
assert "common" in zh
assert "confirm" in zh["common"]
assert "cancel" in zh["common"]
assert "loading" in zh["common"]
def test_en_translation_has_common_keys(self):
"""en translation should have common keys."""
en = _translations.get("en", {})
assert "common" in en
assert "confirm" in en["common"]
assert "cancel" in en["common"]
assert "loading" in en["common"]
def test_zh_and_en_have_same_top_level_keys(self):
"""zh and en should have same top-level keys."""
zh_keys = set(_translations.get("zh", {}).keys())
en_keys = set(_translations.get("en", {}).keys())
# All en keys should be in zh, and vice versa
assert zh_keys == en_keys, "zh and en should have same top-level keys"
def test_translation_interpolation_format(self):
"""Translations with variables should use {var} format."""
zh = _translations.get("zh", {})
# Check a known key with interpolation
if "home" in zh and "heroDesc" in zh["home"]:
hero_desc = zh["home"]["heroDesc"]
# Should contain {brand}, {agentScale}, {optimalSolution} placeholders
assert "{" in hero_desc, "heroDesc should have interpolation placeholders"
class TestLanguagesCompleteness:
"""Tests for language completeness."""
def test_translation_files_have_same_keys(self):
"""All existing translation files should have the same structure."""
# Get keys from zh as reference
zh_keys = set(_translations.get("zh", {}).keys())
for lang, trans in _translations.items():
if lang != "zh":
trans_keys = set(trans.keys())
missing_in_trans = zh_keys - trans_keys
assert not missing_in_trans, f"{lang} missing keys: {missing_in_trans}"