MicroFish/backend/tests/test_file_parser.py

244 lines
9.4 KiB
Python

"""Unit tests for file_parser module."""
import os
import tempfile
from pathlib import Path
import pytest
# Import the module directly to avoid Flask initialization issues
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Direct import of the module to avoid app package initialization
import importlib.util
spec = importlib.util.spec_from_file_location(
"file_parser",
Path(__file__).parent.parent / "app" / "utils" / "file_parser.py"
)
file_parser_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(file_parser_module)
_read_text_with_fallback = file_parser_module._read_text_with_fallback
FileParser = file_parser_module.FileParser
split_text_into_chunks = file_parser_module.split_text_into_chunks
class TestReadTextWithFallback:
"""Tests for _read_text_with_fallback function."""
def test_read_utf8_file(self):
"""Should read UTF-8 encoded file correctly."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Hello, 你好, こんにちは")
path = f.name
try:
result = _read_text_with_fallback(path)
assert result == "Hello, 你好, こんにちは"
finally:
os.unlink(path)
def test_read_gbk_file_with_fallback(self):
"""Should read GBK encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "你好世界".encode('gbk')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)
def test_read_latin1_file(self):
"""Should read Latin-1 encoded file using UTF-8 replacement when detection fails."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = "Héllo Wörld".encode('latin-1')
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
# Result may be garbled if charset detection fails, but should not raise
assert len(result) > 0
finally:
os.unlink(path)
def test_read_file_with_replacement_chars(self):
"""Should replace invalid characters instead of failing."""
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as f:
content = b"Hello\x00\xff\xfeWorld"
f.write(content)
path = f.name
try:
result = _read_text_with_fallback(path)
assert "Hello" in result
assert "World" in result
finally:
os.unlink(path)
class TestFileParser:
"""Tests for FileParser class."""
def test_supported_extensions(self):
"""Should have correct supported extensions."""
assert '.pdf' in FileParser.SUPPORTED_EXTENSIONS
assert '.md' in FileParser.SUPPORTED_EXTENSIONS
assert '.markdown' in FileParser.SUPPORTED_EXTENSIONS
assert '.txt' in FileParser.SUPPORTED_EXTENSIONS
def test_extract_text_from_nonexistent_file(self):
"""Should raise FileNotFoundError for nonexistent file."""
with pytest.raises(FileNotFoundError):
FileParser.extract_text('/nonexistent/file.txt')
def test_extract_text_from_unsupported_format(self):
"""Should raise ValueError for unsupported format."""
with tempfile.NamedTemporaryFile(suffix='.xyz', delete=False) as f:
path = f.name
try:
with pytest.raises(ValueError, match="不支持的文件格式"):
FileParser.extract_text(path)
finally:
os.unlink(path)
def test_extract_text_from_md_file(self):
"""Should extract text from markdown file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.md', delete=False) as f:
f.write("# Title\n\nThis is content.")
path = f.name
try:
result = FileParser.extract_text(path)
assert "# Title" in result
assert "This is content." in result
finally:
os.unlink(path)
def test_extract_text_from_txt_file(self):
"""Should extract text from txt file."""
with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.txt', delete=False) as f:
f.write("Plain text content")
path = f.name
try:
result = FileParser.extract_text(path)
assert result == "Plain text content"
finally:
os.unlink(path)
def test_extract_from_multiple_with_all_valid(self):
"""Should extract from multiple valid files."""
files = []
try:
for i in range(3):
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write(f"Content {i}")
files.append(f.name)
result = FileParser.extract_from_multiple(files)
assert "Content 0" in result
assert "Content 1" in result
assert "Content 2" in result
assert "文档 1" in result
assert "文档 2" in result
assert "文档 3" in result
finally:
for path in files:
os.unlink(path)
def test_extract_from_multiple_with_invalid_file(self):
"""Should handle invalid file gracefully in batch mode."""
files = [
'/nonexistent/path.txt',
]
try:
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write("Valid content")
files.append(f.name)
result = FileParser.extract_from_multiple(files)
assert "Valid content" in result
assert "提取失败" in result
finally:
for path in files[1:]:
if os.path.exists(path):
os.unlink(path)
class TestSplitTextIntoChunks:
"""Tests for split_text_into_chunks function."""
def test_short_text_returns_single_chunk(self):
"""Should return single chunk when text is shorter than chunk_size."""
text = "Short text"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
assert result[0] == text
def test_empty_text_returns_empty_list(self):
"""Should return empty list for empty/whitespace text."""
assert split_text_into_chunks(" ", chunk_size=500, overlap=50) == []
assert split_text_into_chunks("", chunk_size=500, overlap=50) == []
def test_text_exactly_at_chunk_size(self):
"""Should return single chunk when text equals chunk_size."""
text = "a" * 500
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) == 1
def test_long_text_splits_into_multiple_chunks(self):
"""Should split long text into multiple chunks."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert len(result) >= 2
def test_chunks_have_overlap(self):
"""Should have overlapping content between consecutive chunks."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=100, overlap=20)
if len(chunks) >= 2:
assert chunks[0][-20:] == chunks[1][:20], "Chunks should overlap"
def test_chunks_preserve_content(self):
"""Should preserve all original content across chunks."""
text = "".join(str(i) for i in range(500))
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
combined = "".join(chunks)
assert text[:400] in combined or all(c in combined for c in text[:400])
def test_chunk_size_parameter(self):
"""Should respect chunk_size parameter."""
text = "a" * 1000
result = split_text_into_chunks(text, chunk_size=100, overlap=0)
for chunk in result:
assert len(chunk) <= 100
def test_overlap_parameter(self):
"""Should respect overlap parameter."""
text = "abcdefghij" * 100
chunks = split_text_into_chunks(text, chunk_size=50, overlap=10)
if len(chunks) >= 2:
overlap_size = len(chunks[0]) - (len(chunks[0].rstrip()) - len(chunks[1].lstrip()))
assert overlap_size >= 5
def test_split_at_sentence_boundary(self):
"""Should try to split at sentence boundaries when possible."""
text = "第一句。第二句。第三句。" * 50
chunks = split_text_into_chunks(text, chunk_size=100, overlap=10)
for chunk in chunks:
if len(chunk) > 50:
assert chunk[-1] in "。.\n"
def test_last_chunk_may_be_smaller(self):
"""Should allow last chunk to be smaller than chunk_size."""
text = "a" * 550
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
assert any(len(chunk) < 500 for chunk in result)
def test_whitespace_only_chunks_filtered(self):
"""Should filter out whitespace-only chunks."""
text = "content" + " " * 600 + "more content"
result = split_text_into_chunks(text, chunk_size=500, overlap=50)
for chunk in result:
assert chunk.strip()