Merge pull request #36 from salestech-group/docs/i18n-7-translate-backend-comments

docs(i18n): translate chinese docstrings/comments in backend (full ticket #7)
This commit is contained in:
Dominik Seemann 2026-05-11 11:03:30 +02:00 committed by GitHub
commit 056f3664be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 4869 additions and 4328 deletions

View File

@ -1,5 +1,5 @@
# Per-path CJK baseline for the i18n CI guard. # Per-path CJK baseline for the i18n CI guard.
# Format: <path>\t<count>. Sorted lexicographically. # Format: <path>\t<count>. Sorted lexicographically.
# Refresh via: python scripts/ci/i18n_cjk_guard.py --update-baseline # Refresh via: python scripts/ci/i18n_cjk_guard.py --update-baseline
backend/app 2792 backend/app 307
frontend/src 902 frontend/src 124

View File

@ -0,0 +1,78 @@
# Handoff — `i18n-translate-backend-comments` (Issue #7)
## Status
**Complete.** All in-scope Chinese docstrings and `#` comments under `backend/` have been translated to English.
This second installment of the ticket-#7 cleanup builds on the first installment (PR #20) and finishes the remaining 12 files. Together, the two installments cover the full 35-file in-scope set.
## Completed across both installments (35 files)
### First installment (PR #20 — landed on `feat/i18n-6-externalize-backend-logs`, then merged here via `merge main` into this branch)
- **Root**: `backend/app/__init__.py`, `backend/app/config.py`, `backend/run.py`
- **API package init**: `backend/app/api/__init__.py`
- **Models** (full package): `backend/app/models/__init__.py`, `project.py`, `task.py`
- **Utils** (full package): `backend/app/utils/__init__.py`, `file_parser.py`, `llm_client.py`, `locale.py`, `logger.py`, `retry.py`, `zep_paging.py`
- **Services** (partial): `backend/app/services/__init__.py`, `graph_builder.py`, `ontology_generator.py`, `simulation_ipc.py`, `simulation_manager.py`, `text_processor.py`, `zep_entity_reader.py`
- **Scripts** (partial): `backend/scripts/action_logger.py`, `backend/scripts/test_profile_format.py`
### Second installment (this PR — finishes the ticket)
| File | Starting in-scope hits | Comment-the-obvious deletions |
| --- | --- | --- |
| `backend/app/api/graph.py` | 70 | 25 |
| `backend/app/api/report.py` | 104 | 11 |
| `backend/app/api/simulation.py` | 351 | ~25 |
| `backend/app/services/oasis_profile_generator.py` | 185 | ~14 |
| `backend/app/services/report_agent.py` | 335 | 8 |
| `backend/app/services/simulation_config_generator.py` | 148 | 0 |
| `backend/app/services/simulation_runner.py` | 277 | ~31 |
| `backend/app/services/zep_graph_memory_updater.py` | 97 | 5 |
| `backend/app/services/zep_tools.py` | 269 | 6 |
| `backend/scripts/run_parallel_simulation.py` | 227 | ~7 |
| `backend/scripts/run_reddit_simulation.py` | 75 | 12 |
| `backend/scripts/run_twitter_simulation.py` | 97 | 21 |
| **Total** | **2,235** | **~165** |
After the pass, every file in the table reports zero in-scope hits from the AST scanner.
## Remaining residuals (out of scope — owned by sibling tickets)
After this PR, the only files under `backend/` that still contain CJK characters do so exclusively inside string literals. These are owned by sibling tickets and are intentional residuals for this spec:
- LLM prompt template strings: `oasis_profile_generator.py`, `ontology_generator.py`, `simulation_config_generator.py`, `report_agent.py` — owned by tickets #2 / #3 / #4 / #5.
- Runtime log strings, API response messages, exception arguments, CLI prints: distributed across `api/`, `services/`, `scripts/`, `utils/retry.py`, `utils/locale.py`, `run.py`, `app/config.py` — owned by ticket #6 (with follow-up tickets #18, #24 for residuals).
- Sample-data values returned to clients: `services/zep_tools.py`, `services/zep_graph_memory_updater.py`, `services/zep_entity_reader.py`, etc.
The CJK CI guard (`scripts/ci/i18n_cjk_guard.py`) enforces that this set never grows; the per-path baseline at `.kiro/specs/i18n-ci-guard/baseline.txt` is updated as part of this PR to reflect the new (lower) count.
## Verification methodology
The AST-aware scanner at `.kiro/specs/i18n-translate-backend-comments/scan_chinese.py` (committed in this branch) classifies every CJK-bearing line into one of three buckets:
- `DOCSTRING` — line lies inside a module/class/function docstring (in scope).
- `COMMENT` — line contains a `#` and is not inside a docstring or string-literal span (in scope).
- `STRING` — line is part of a string-literal value (out of scope, owned by sibling tickets).
For every translated file in this installment:
1. `python3 -m py_compile <file>` succeeds.
2. The scanner reports `0` in-scope hits.
3. `git diff <file>` shows only docstring lines and `#` comment lines changed; no signature, import, decorator, expression, or string-literal byte changes.
For two of the largest files (`api/simulation.py`, `report_agent.py`), the implementing agent additionally ran an AST-equivalence check (parsing both before and after, stripping docstrings, and confirming structural equality) to validate that no executable surface changed.
## Test environment caveat
The repo's `uv sync` builds `tiktoken` from source, which requires a Rust toolchain. The sandbox running this implementation pass does not have Rust, so `cd backend && uv run python -m pytest scripts/test_profile_format.py` cannot be executed end-to-end here. Because the change set is comments-and-docstrings-only, runtime behavior cannot be affected; the syntactic-validity check (`py_compile` across all 12 files) stands in for the test run in this environment.
A developer with the project's normal dev environment (Rust toolchain installed, full `uv sync` succeeded) should re-run `cd backend && uv run python -m pytest scripts/test_profile_format.py` against this branch before merging to confirm.
## What is NOT changed
- No string literal anywhere in the touched files (verified by AST classification).
- No executable Python statement.
- No symbol renamed; `zep_*` legacy filenames preserved per steering rule.
- No file added or removed (other than the AST scanner inside `.kiro/specs/i18n-translate-backend-comments/`).
- No dependency added or version-bumped.
## Branch & PR
- Branch: `docs/i18n-7-translate-backend-comments` (re-used from PR #20; that PR was merged into `feat/i18n-6-externalize-backend-logs` after `feat/i18n-6` had already merged into `main`, which orphaned PR #20's content from `main`).
- This PR re-targets the branch at `main`, including: the four prior commits from PR #20, a `Merge branch 'main'` commit (one conflict resolved in `services/ontology_generator.py` to combine PR #20's translated comment with main's English prompt-string), and the new commits for the 12 files completed here.
- Commits follow Conventional Commits in the form `docs(i18n): translate chinese docstrings/comments in backend/<area>`.
- The PR description references issue #7 with `Closes #7`.
- No `Co-Authored-By:` watermarks.

View File

@ -0,0 +1,316 @@
# Design Document — `i18n-translate-backend-comments`
## Overview
**Purpose**: Translate Chinese-language docstrings and `#` comments across `backend/` Python files into English, so that English-speaking maintainers can read and review the codebase without translation overhead.
**Users**: Backend maintainers and code reviewers who do not read Chinese.
**Impact**: Improves developer ergonomics and review throughput. No runtime, behavior, or interface change. Adjacent i18n tickets (#2/#3/#4/#5/#6), which own the string-literal Chinese, remain unaffected.
### Goals
- Eliminate Chinese characters from docstrings and `#` comments under the in-scope paths.
- Preserve Google-style docstring shape and project formatting rules (4-space indent, ≤120 chars/line, double-quoted strings).
- Keep the diff comments-and-docstrings-only — no executable, string-literal, or symbol changes.
### Non-Goals
- Translating Chinese inside string literals (prompt templates, `logger.{info,warning,error}` arguments, API responses, error messages). These are owned by issues #2/#3/#4/#5/#6.
- Refactoring code, reformatting style, or renaming symbols.
- Introducing new tooling, linters, or CI rules.
- Translating `backend/tests/test_locale*.py` (Chinese there is intentional test data inside string literals; outside ticket scope).
## Boundary Commitments
### This Spec Owns
- Comment and docstring text under: `backend/app/__init__.py`, `backend/app/config.py`, `backend/app/api/`, `backend/app/models/`, `backend/app/services/`, `backend/app/utils/`, `backend/run.py`, `backend/scripts/`.
- The decision rule for distinguishing docstrings from value strings (first-statement rule).
- The Chinese→English Google-style docstring key map.
- The verification workflow (residual `grep`, `pytest`, diff sanity check).
### Out of Boundary
- All string-literal content, including triple-quoted strings used as values.
- Files under `backend/tests/`, `backend/.venv/`, and any non-Python file.
- Refactors, renames, formatting changes, or new dependencies.
- Front-end localization, locale JSON files, or i18n runtime behavior.
### Allowed Dependencies
- The repository's Python source (read + write for in-scope files only).
- The existing test suite (`backend/scripts/test_profile_format.py`) for verification.
- The existing `grep`-based residual scan for verification.
### Revalidation Triggers
- A new in-scope file added under the listed paths (would expand the file list).
- A change to `dev-guidelines.md` regarding docstring style (would change the key map or quote/indent rule).
- A merge of any adjacent i18n ticket (#2/#3/#4/#5/#6) that turns a string literal into a docstring or vice versa.
## Architecture
### Existing Architecture Analysis
This change touches only commentary; no architectural element of the backend is modified. The work spans the following packages:
- `backend/app/__init__.py`, `backend/app/config.py` (Flask app and configuration entrypoint).
- `backend/app/api/` (Flask blueprints).
- `backend/app/models/` (`Project`, `Task` models).
- `backend/app/services/` (graph builder, simulation runner, report agent, etc.).
- `backend/app/utils/` (LLM client, file parser, retry, logger, locale, paging).
- `backend/run.py` (process entrypoint).
- `backend/scripts/` (simulation runners, profile-format test).
### Architecture Pattern & Boundary Map
```mermaid
graph TB
Discovery[Residual Grep Scan]
Plan[Per-Package Plan]
Translator[Translation Pass]
Verify[Verification Gate]
Commit[Per-Package Commit]
PR[Single PR to main]
Discovery --> Plan
Plan --> Translator
Translator --> Verify
Verify -->|all checks pass| Commit
Verify -->|any check fails| Translator
Commit --> Plan
Commit -->|all packages done| PR
```
**Architecture Integration**:
- Selected pattern: **Iterative pass per package** with a verification gate after each pass. Linear, deterministic, low-coordination.
- Domain/feature boundaries: One pass per backend package; commits are package-scoped to keep review chunks small.
- Existing patterns preserved: 4-space indent, double-quoted strings, Google-style docstrings, `snake_case`, project file layout.
- New components rationale: None — no new code, no new files.
- Steering compliance: Conforms to repo-level coding rules and the commits ruleset.
### Technology Stack
| Layer | Choice / Version | Role in Feature | Notes |
|-------|------------------|-----------------|-------|
| Backend / Services | Python ≥3.11 | Source language whose docstrings/comments are being translated | No version change; no dependency change |
| Tooling | `git`, `grep`, `pytest` (existing) | Discovery, verification, regression check | No new tools |
No frontend, data, messaging, or infrastructure layer is touched.
## File Structure Plan
### Directory Structure (no additions, no deletions)
```
backend/
├── app/
│ ├── __init__.py # docstrings/comments only
│ ├── config.py # docstrings/comments only
│ ├── api/ # all *.py: docstrings/comments only
│ ├── models/ # all *.py: docstrings/comments only
│ ├── services/ # all *.py: docstrings/comments only
│ └── utils/ # all *.py: docstrings/comments only
├── run.py # docstrings/comments only
└── scripts/ # all *.py: docstrings/comments only
```
### Modified Files
The 37 in-scope files identified in `gap-analysis.md` are modified — comment and docstring lines only. No other paths are touched.
## Translation Rules
These rules drive the translation pass and the verification gate. They are normative; the implementation must follow them exactly.
### Rule 1 — Docstring vs Value String Disambiguation
A triple-quoted string is treated as a **docstring** (in scope) iff it is the first statement of a module, class, or function body. All other triple-quoted strings are **values** (out of scope) and must not be modified.
### Rule 2 — Translate Docstrings to English Google-style
- Translate Chinese narrative text to faithful English.
- Convert the following Chinese section keys to canonical English Google-style keys when present:
| Chinese key | English key |
| --- | --- |
| `参数:` | `Args:` |
| `返回:` | `Returns:` |
| `异常:` | `Raises:` |
| `产生:` / `生成:` | `Yields:` |
| `示例:` | `Examples:` |
| `注意:` / `备注:` | `Note:` |
- Preserve double-quoted triple-quoted form (`"""..."""`).
- Preserve indentation matching the surrounding scope.
### Rule 3 — Translate Inline `#` Comments to English
- Translate the comment text to English.
- If the translated comment would merely restate the immediately following executable line (a redundant verb-phrase paraphrase), delete the comment.
- Preserve `TODO:` / `FIXME:` markers and any embedded ticket reference verbatim.
- Preserve trailing in-line comments on the same line as code (e.g. `PENDING = "pending" # waiting`).
### Rule 4 — Style Compliance
- Keep every translated line ≤120 characters.
- Do not introduce trailing whitespace.
- Preserve the original indentation of each comment/docstring.
- Use double quotes for any docstring rewritten.
### Rule 5 — Preservation
- Do not modify any executable Python statement.
- Do not modify any string literal (single-, double-, triple-quoted, f-string, raw, byte) that is not a docstring under Rule 1. The single exception is the docstring being rewritten under Rule 2: quote-style normalization to triple double-quoted form (`"""..."""`) is permitted on the docstring only, since it is the artifact under translation.
- Do not rename any symbol.
## System Flows
### Per-package iteration
```mermaid
sequenceDiagram
participant Dev as Translator
participant Repo as Repo
participant Tests as Test Suite
Dev->>Repo: git checkout docs/i18n-7-translate-backend-comments
loop For each package in [models, utils, services, api, scripts, root]
Dev->>Repo: Translate docstrings/comments
Dev->>Repo: git diff --stat (sanity check)
Dev->>Tests: cd backend then uv run python -m pytest scripts/test_profile_format.py
Tests-->>Dev: pass / fail
Dev->>Repo: Re-run residual grep
Repo-->>Dev: residual hits (string-literal only)
Dev->>Repo: git commit -m "docs(i18n): translate chinese docstrings/comments in backend/<area>"
end
Dev->>Repo: gh pr create -> single PR closing #7
```
## Requirements Traceability
| Requirement | Summary | Components | Interfaces | Flows |
|-------------|---------|------------|------------|-------|
| 1.1 | No Chinese in docstrings under in-scope paths | Translation Pass | Rule 1, Rule 2 | Per-package iteration |
| 1.2 | No Chinese in `#` comments under in-scope paths | Translation Pass | Rule 3 | Per-package iteration |
| 1.3 | Residual grep returns only string-literal Chinese | Verification Gate | Residual grep workflow | Per-package iteration |
| 1.4 | Google-style docstring shape preserved | Translation Pass | Rule 2 (key map) | — |
| 2.1 | No executable statement modified | Verification Gate | Rule 5 | Per-package iteration |
| 2.2 | No string literal modified | Verification Gate | Rule 1 (first-statement rule), Rule 5 | Per-package iteration |
| 2.3 | No symbol renamed | Verification Gate | Rule 5 | Per-package iteration |
| 2.4 | `pytest` passes | Verification Gate | Test suite invocation | Per-package iteration |
| 2.5 | Hunks touching code rejected | Verification Gate | `git diff --stat` review | Per-package iteration |
| 3.1 | Drop redundant comments | Translation Pass | Rule 3 | — |
| 3.2 | Translate the *why* faithfully | Translation Pass | Rule 3 | — |
| 3.3 | Preserve `TODO:`/`FIXME:` and ticket refs | Translation Pass | Rule 3 | — |
| 3.4 | No new comments introduced | Translation Pass | Rule 3 | — |
| 4.1 | ≤120 chars/line | Verification Gate | Rule 4 | — |
| 4.2 | No trailing whitespace | Verification Gate | Rule 4 | — |
| 4.3 | Preserve indentation | Translation Pass | Rule 4 | — |
| 4.4 | Double quotes on rewritten docstrings | Translation Pass | Rule 4 | — |
| 4.5 | Preserve 4-space indentation | Translation Pass | Rule 4 | — |
| 5.1 | Use grep for discovery | Verification Gate | Discovery scan | — |
| 5.2 | Re-run grep after each batch | Verification Gate | Residual grep workflow | Per-package iteration |
| 5.3 | Continue until non-string-literal residual cleared | Verification Gate | Rule 1 disambiguation | Per-package iteration |
| 5.4 | `git diff --stat` only in-scope paths | Verification Gate | Diff sanity check | Per-package iteration |
| 6.1 | Branch `docs/i18n-7-translate-backend-comments` | Tracking & Branching | `/done` skill | — |
| 6.2 | Reference issue #7 | Tracking & Branching | Commit/PR template | — |
| 6.3 | Conventional Commits `docs(i18n)` | Tracking & Branching | `.claude/rules/commits.md` | — |
| 6.4 | No unrelated changes | Verification Gate | Diff sanity check | — |
## Components and Interfaces
| Component | Domain/Layer | Intent | Req Coverage | Key Dependencies (P0/P1) | Contracts |
|-----------|--------------|--------|--------------|--------------------------|-----------|
| Translation Pass | Process | Apply Rules 15 to one package's `*.py` | 1.1, 1.2, 1.4, 3.1, 3.2, 3.3, 3.4, 4.3, 4.4, 4.5 | None (manual + AI-assisted) | Process |
| Verification Gate | Process | Run residual grep, `pytest`, and diff sanity check after each package | 1.3, 2.1, 2.2, 2.3, 2.4, 2.5, 4.1, 4.2, 5.1, 5.2, 5.3, 5.4, 6.4 | `git`, `grep`, `pytest` (P0) | Process |
| Tracking & Branching | Process | Branching, commit messages, PR | 6.1, 6.2, 6.3 | `/done` skill, `gh` CLI (P0) | Process |
### Process
#### Translation Pass
| Field | Detail |
|-------|--------|
| Intent | Translate docstrings and `#` comments in one package without touching code or string literals |
| Requirements | 1.1, 1.2, 1.4, 3.1, 3.2, 3.3, 3.4, 4.3, 4.4, 4.5 |
**Responsibilities & Constraints**
- Apply Rule 1 (first-statement disambiguation) before editing any triple-quoted string.
- Apply Rule 2 (key map) for any Chinese Google-style key encountered.
- Apply Rule 3 to inline comments; delete redundant ones.
- Operate on one package at a time; do not interleave packages.
**Dependencies**
- Inbound: Verification Gate (provides feedback if a previous batch failed).
- Outbound: Verification Gate (hands off post-pass).
- External: None.
**Contracts**: Process [x] / Service [ ] / API [ ] / Event [ ] / Batch [ ] / State [ ]
**Implementation Notes**
- Integration: Operates directly on the working tree on branch `docs/i18n-7-translate-backend-comments`.
- Validation: After each file is rewritten, sanity-check that the diff for that file shows changes only on comment/docstring lines.
- Risks: Accidental edit to a string-literal triple-quoted value — mitigated by Rule 1 + diff review.
#### Verification Gate
| Field | Detail |
|-------|--------|
| Intent | Confirm a package's translation pass left runtime behavior intact |
| Requirements | 1.3, 2.1, 2.2, 2.3, 2.4, 2.5, 4.1, 4.2, 5.1, 5.2, 5.3, 5.4, 6.4 |
**Responsibilities & Constraints**
- Re-run `grep -rln '[一-鿿]' backend/ --include='*.py'` after each package and confirm residual hits are limited to string-literal Chinese owned by adjacent tickets.
- Run `uv run python -m pytest backend/scripts/test_profile_format.py` and confirm exit 0.
- Run `git diff --stat` and confirm only in-scope file paths are listed.
- Spot-check a sample of changed files to confirm only comment/docstring lines changed.
**Dependencies**
- Inbound: Translation Pass.
- Outbound: Tracking & Branching (commits) when all checks pass; loops back to Translation Pass otherwise.
- External: `git`, `grep`, `pytest` (P0 — required for verification).
**Contracts**: Process [x] / Service [ ] / API [ ] / Event [ ] / Batch [ ] / State [ ]
**Implementation Notes**
- Integration: Run from the repo root; no environment variables required beyond what `uv run` already provides.
- Validation: All four checks (grep / pytest / diff scope / spot diff) must pass before committing.
- Risks: A flaky `pytest` run unrelated to this change would block progress — mitigated by reading the failure and re-running once.
#### Tracking & Branching
| Field | Detail |
|-------|--------|
| Intent | Branch, commit, push, and open PR per project conventions |
| Requirements | 6.1, 6.2, 6.3 |
**Responsibilities & Constraints**
- Branch name: `docs/i18n-7-translate-backend-comments`.
- Commit messages follow Conventional Commits with `docs(i18n)` scope (e.g. `docs(i18n): translate chinese docstrings/comments in backend/services`).
- PR closes #7 and references the spec.
**Dependencies**
- Inbound: Verification Gate (only commits when all checks pass).
- External: `gh` CLI (P0), `/done` skill (P0).
**Contracts**: Process [x] / Service [ ] / API [ ] / Event [ ] / Batch [ ] / State [ ]
**Implementation Notes**
- Integration: Use `/done` skill at the end to handle branch/push/PR uniformly.
- Validation: Confirm PR body references issue #7 with `Closes #7` and lists each commit.
- Risks: None.
## Error Handling
### Error Strategy
This is a build-time / source-edit task — there is no runtime error path. Errors are caught by the Verification Gate.
### Error Categories and Responses
- **Translation slipped into a string literal**: caught by `git diff --stat` + spot diff. Response: revert that hunk, re-apply translation against the docstring/comment only.
- **Test suite fails after a pass**: caught by `pytest`. Response: read failure, identify which line was incorrectly modified (likely a string the translator misclassified as a docstring), revert that hunk, re-apply.
- **Residual grep returns non-string-literal Chinese**: caught by post-pass grep. Response: classify those hits as in-scope and translate them in the next sub-pass.
- **Line exceeds 120 chars after translation**: caught by spot diff. Response: reflow the comment/docstring without changing executable code.
### Monitoring
None — this is a one-shot change. No production observability required.
## Testing Strategy
The repository's existing tests are the safety net. No new tests are added.
### Default sections
- **Unit Tests**: Not applicable; nothing executable changes.
- **Integration Tests**: `uv run python -m pytest backend/scripts/test_profile_format.py` must continue to pass after each commit.
- **E2E/UI Tests**: Not applicable.
- **Verification checks (per package commit)**:
1. Residual `grep -rln '[一-鿿]' backend/ --include='*.py'` (run from repo root) returns only files whose remaining Chinese is in string literals owned by adjacent tickets.
2. `cd backend && uv run python -m pytest scripts/test_profile_format.py` exits 0.
3. `git diff --stat HEAD~..HEAD` shows only in-scope file paths.
4. Spot diff on three random changed files confirms only comment/docstring lines changed.
## Supporting References (Optional)
- `gap-analysis.md` — full file enumeration and pattern survey.
- `research.md` — discovery log, alternatives, and decisions.

View File

@ -0,0 +1,92 @@
# Gap Analysis — `i18n-translate-backend-comments`
## Scope Recap
- **Ticket**: salestech-group/MiroFish#7
- **Goal**: Translate Chinese docstrings and `#` comments in `backend/` to English without behavior changes.
- **Blast radius**: Comments and docstrings only; runtime semantics preserved.
## Current State Investigation
### Discovered files
A scan with the regex `[一-鿿]` across `backend/**/*.py` (excluding `.venv`) returns **37 in-app files** plus 2 test files:
| Area | Count | Files |
| --- | --- | --- |
| `backend/app/__init__.py` | 1 | `__init__.py` |
| `backend/app/config.py` | 1 | `config.py` |
| `backend/app/api/` | 4 | `__init__.py`, `graph.py`, `report.py`, `simulation.py` |
| `backend/app/models/` | 3 | `__init__.py`, `project.py`, `task.py` |
| `backend/app/services/` | 12 | `__init__.py`, `graph_builder.py`, `oasis_profile_generator.py`, `ontology_generator.py`, `report_agent.py`, `simulation_config_generator.py`, `simulation_ipc.py`, `simulation_manager.py`, `simulation_runner.py`, `text_processor.py`, `zep_entity_reader.py`, `zep_graph_memory_updater.py`, `zep_tools.py` |
| `backend/app/utils/` | 7 | `__init__.py`, `file_parser.py`, `llm_client.py`, `locale.py`, `logger.py`, `retry.py`, `zep_paging.py` |
| `backend/run.py` | 1 | `run.py` |
| `backend/scripts/` | 5 | `action_logger.py`, `run_parallel_simulation.py`, `run_reddit_simulation.py`, `run_twitter_simulation.py`, `test_profile_format.py` |
| `backend/tests/` (extra, not in ticket file list) | 2 | `test_locale.py`, `test_locale_request_resolution.py` |
Spot checks (`models/task.py`, `models/project.py`, `services/text_processor.py`, `utils/locale.py`):
- Module-level docstrings in Chinese (e.g. `"""任务状态管理"""`).
- Class/method docstrings in Chinese, often Google-shaped (`Args:` translated as `参数:`).
- Inline `#` comments tagging fields, sections, or restating obvious code (e.g. `# 标准化换行` above an `\n` normalization call).
- Status-enum trailing comments (e.g. `PENDING = "pending" # 等待中`).
### Conventions to preserve
- Project guideline: 4-space indent, max 120 char/line, double-quoted strings (Python).
- Docstring style: Google-style per `dev-guidelines.md`. Existing files mix English-shape `Args:`/`Returns:` keys with Chinese descriptions, or use Chinese keys (`参数:`, `返回:`). Translate both to canonical Google-style English.
- File-level convention: `snake_case` filenames, Python `__init__.py` modules typically have a one-line module docstring.
### Integration surfaces
None. This work touches only commentary; no API contracts, schemas, or imports change.
## Requirements Feasibility
| Requirement | Status | Notes |
| --- | --- | --- |
| R1 (coverage) | Feasible — straightforward | Files identified by `grep` rule. |
| R2 (behavior preservation) | Feasible | Achieved by limiting diffs to comment/docstring lines. Need to be careful with multi-line triple-quoted docstrings vs string literals (they are syntactically identical to strings — disambiguation: docstring is the *first* statement of a module/class/function body). |
| R3 (comment hygiene) | Feasible | Some judgment required; will adopt heuristic: drop comments whose translated form would be a single verb-phrase paraphrase of the next executable line. |
| R4 (style compliance) | Feasible | Watch line-length when translating dense Chinese to English (English is typically longer); rewrap as needed without changing executable code. |
| R5 (verification) | Feasible | The `grep -rln '[一-鿿]'` rule is reliable. Residual hits should land only in: prompt template strings (#2/#3/#4/#5), logger/API string literals (#6), and the `tests/test_locale*` files (intentional Chinese test data). |
| R6 (tracking/branching) | Feasible | Branch + commit conventions are standard for this repo; `/done` skill enforces them. |
### Gaps and constraints
- **Constraint**: Triple-quoted strings used as values (not as docstrings) must NOT be edited if their content is in scope of issues #2#6 (prompts/log messages/error messages). Disambiguation matters.
- **Constraint**: Chinese characters appearing inside f-string literal segments must remain. They are out of scope.
- **Unknown / Research Needed**: None — task is mechanical and well-bounded.
### Adjacent specs / overlap with other tickets
- `i18n-externalize-backend-logs` (#6) owns translating `logger.{info,warning,error}` Chinese arguments and API response strings.
- `i18n-report-agent-prompts` (#5), and tickets #2/#3/#4 own prompt template strings.
- We must NOT touch any string literal that those tickets own. After this PR, residual `grep` hits should reduce by exactly the count of comments and docstrings translated and nothing else.
- The two `backend/tests/test_locale*.py` files are **not in the ticket's listed file scope**, and inspection shows their Chinese is exclusively in string literals (test data and a Unicode range check). They are out of scope by R1's enumerated paths and remain untouched.
## Implementation Approach Options
### Option A — Single-pass file-by-file translation (recommended)
- Walk the 37 in-scope files in a deterministic order (alphabetical), translating docstrings/comments per file, running the residual grep after each batch.
- Group commit by area (models, utils, services, api, scripts, root) to keep PR diff readable.
- ✅ Simple, low risk, easy to revert per-area.
- ✅ Maps directly to the requirements; easy to verify.
- ❌ Larger PR than option B, but ticket explicitly allows a single PR.
### Option B — Multi-PR per package
- Split into one PR per package (`models/`, `utils/`, …). The ticket allows this.
- ✅ Smaller diffs to review.
- ❌ More overhead (multiple branches/PRs); not necessary for a mechanical change of this size.
### Option C — Tooling-assisted bulk script
- Build a one-shot translation script (LLM-driven) that rewrites docstrings/comments.
- ✅ Could scale to other repos.
- ❌ Out of proportion for a single-ticket task; risk of errant edits to string literals; tooling itself becomes a deliverable to test and maintain.
## Effort and Risk
- **Effort**: **M (37 days of focused work)** — 37 files, hundreds of comments. In an interactive AI-assisted run, this collapses to a few hours.
- **Risk**: **Low** — comments-only diff; covered by mechanical verification (grep + pytest); easy to rollback per file/area.
## Recommendations for Design Phase
- **Preferred approach**: Option A (single-pass file-by-file, package-grouped commits, single PR).
- **Key decisions to capture in design**:
- Order of traversal (proposed: `models/``utils/``services/``api/``scripts/` → root files `__init__.py`, `config.py`, `run.py`).
- Heuristic for "drops the obvious comment" (one-line rule).
- How to handle Google-style docstring keys: always translate `参数:``Args:`, `返回:``Returns:`, `异常:``Raises:`.
- Verification cadence: re-run the grep after each package batch.
- **Research items to carry forward**: None.

View File

@ -0,0 +1,67 @@
# Requirements Document
## Introduction
This specification covers the developer-facing internationalization of `backend/` Python source: translating Chinese docstrings and inline comments to English so that English-speaking maintainers can read and review the code without translation overhead. The change is mechanical — no behavior, no public strings, no symbol names are modified. It is one of several i18n tickets (#2, #3, #4, #5, #6, #7); this spec covers ticket #7 only.
## Boundary Context
- **In scope**: Translation of Chinese-language characters that appear in Python docstrings (module/class/function) and inline `#` comments under `backend/`. Removal of comments that merely restate the code. Preservation of `TODO:` / `FIXME:` markers and embedded ticket references.
- **Out of scope**: Chinese characters inside string literals (prompt templates, `logger.{info,warning,error}` arguments, API response bodies, error messages returned to clients) — these are tracked separately by issues #2/#3/#4/#5/#6. No refactoring, reformatting, renaming, or behavior changes.
- **Adjacent expectations**: Spec `i18n-externalize-backend-logs` (issue #6) and the prompt-translation specs handle string-literal Chinese; this spec must leave those untouched so the other tickets remain mergeable.
## Requirements
### Requirement 1: Translation Coverage of In-Scope Files
**Objective:** As a maintainer, I want every Chinese docstring and inline comment in the in-scope backend files translated to English, so that I can read and review the code without translation tools.
#### Acceptance Criteria
1. The Backend Codebase shall contain no Chinese characters (Unicode range U+4E00U+9FFF) inside Python docstrings under `backend/app/__init__.py`, `backend/app/config.py`, `backend/app/models/`, `backend/app/services/`, `backend/app/api/`, `backend/app/utils/`, `backend/run.py`, and `backend/scripts/`.
2. The Backend Codebase shall contain no Chinese characters inside Python `#` inline comments under the same paths.
3. When `grep -rln '[一-鿿]' backend/ --include='*.py'` is run after this change, the Backend Codebase shall return only files whose remaining Chinese is contained within string literals owned by issues #2/#3/#4/#5/#6.
4. When a docstring is translated, the Translator shall preserve Google-style docstring shape (`Args:`, `Returns:`, `Raises:`, `Yields:` sections) per `dev-guidelines.md`.
### Requirement 2: Preservation of Code Behavior
**Objective:** As a maintainer, I want the translation to be comments-and-docstrings-only, so that runtime behavior is provably unchanged.
#### Acceptance Criteria
1. The Translator shall not modify any executable Python statement (assignments, function calls, control flow, decorators, imports).
2. The Translator shall not modify any Python string literal (single-, double-, triple-quoted, f-string, raw, byte) regardless of whether it contains Chinese characters.
3. The Translator shall not rename any symbol (variable, function, class, module, parameter).
4. When `uv run python -m pytest backend/scripts/test_profile_format.py` is run after the change, the Backend Codebase shall exit with status 0.
5. If a diff line touches any non-comment, non-docstring code, the Translator shall reject that diff hunk and revise.
### Requirement 3: Comment Quality Hygiene
**Objective:** As a maintainer, I want translated comments to add value, so that the codebase remains easy to read after the migration.
#### Acceptance Criteria
1. When a Chinese comment merely restates the immediately following code (e.g. `# 初始化客户端` above `client = Client()`), the Translator shall delete the comment rather than translate it.
2. When a Chinese comment captures non-obvious *why* (constraints, workarounds, invariants), the Translator shall translate it to a faithful English equivalent.
3. The Translator shall preserve any `TODO:` / `FIXME:` marker and any embedded ticket reference (e.g. `#1234`, `PROJ-456`) verbatim within the translated comment.
4. The Translator shall not introduce new comments that did not exist (or had no Chinese equivalent) in the original source.
### Requirement 4: Style and Format Compliance
**Objective:** As a maintainer, I want the translated output to comply with project style rules, so that no follow-up cleanup PR is needed.
#### Acceptance Criteria
1. The Translator shall keep all translated docstrings and comments at or below 120 characters per line.
2. The Translator shall not introduce trailing whitespace on any line.
3. The Translator shall preserve the original indentation (tabs/spaces) of every comment and docstring.
4. The Translator shall use double quotes for any docstring it rewrites, matching the existing Python convention in the file.
5. Where a file already uses 4-space indentation, the Translator shall preserve that indentation.
### Requirement 5: Discovery and Verification Workflow
**Objective:** As a reviewer, I want a reproducible discovery and verification workflow, so that I can confirm coverage and absence of regressions in CI or locally.
#### Acceptance Criteria
1. The Translator shall enumerate candidate files using `grep -rln '[一-鿿]' backend/ --include='*.py'` before beginning work.
2. The Translator shall re-run the same `grep` after each batch and confirm the residual hits are limited to string-literal Chinese owned by adjacent tickets (#2/#3/#4/#5/#6).
3. When the residual `grep` hits include any non-string-literal Chinese, the Translator shall classify those hits as in-scope and continue translation until they are gone.
4. The Translator shall verify that `git diff --stat` only reports changes inside the in-scope file paths listed in Requirement 1.
### Requirement 6: Tracking and Branching
**Objective:** As a release manager, I want the work tracked against ticket #7 on a dedicated branch, so that the PR remains scoped and traceable.
#### Acceptance Criteria
1. The Translator shall produce changes on a branch named `docs/i18n-7-translate-backend-comments`.
2. The Translator shall reference issue `salestech-group/MiroFish#7` in commit messages or PR description.
3. When committing, the Translator shall use Conventional Commits with type `docs` and scope `i18n` (e.g. `docs(i18n): translate chinese docstrings/comments in backend/<area>`).
4. The Translator shall not include unrelated changes (e.g. dependency bumps, config changes, refactors) in the resulting PR.

View File

@ -0,0 +1,80 @@
# Research & Design Decisions — `i18n-translate-backend-comments`
## Summary
- **Feature**: `i18n-translate-backend-comments`
- **Discovery Scope**: Simple Addition (mechanical translation, no architectural change)
- **Key Findings**:
- 37 in-scope `backend/` Python files contain Chinese characters in docstrings or `#` comments. The full list is in `gap-analysis.md`.
- Existing docstrings mix English-shape Google-style keys (`Args:`/`Returns:`) with Chinese descriptions, and a smaller subset uses Chinese keys (`参数:`/`返回:`/`异常:`). Both patterns must converge to canonical English Google-style.
- Several `tests/test_locale*.py` files contain Chinese only inside string literals (intentional test data) and are out of scope by the ticket's enumerated paths.
## Research Log
### Discovery scan: where is Chinese in `backend/`?
- **Context**: Need a deterministic enumeration of files to translate.
- **Sources Consulted**: `grep`/Python-driven scan against `backend/**/*.py`.
- **Findings**:
- 37 in-app files (under `backend/app/`, `backend/run.py`, `backend/scripts/`).
- 2 additional test files in `backend/tests/` whose Chinese is only in string literals; not in ticket scope.
- `.venv/` matches are noise and excluded.
- **Implications**: The ticket-listed paths are exhaustive; no unexpected location. Order of traversal can be alphabetical within package groups.
### Disambiguation: docstring vs string literal
- **Context**: A triple-quoted string is a docstring iff it is the first statement of a module, class, or function body. Otherwise it is a value (e.g. a prompt template) owned by adjacent tickets.
- **Sources Consulted**: Python language reference; spot inspection of `services/ontology_generator.py`, `services/report_agent.py`.
- **Findings**:
- In-scope files contain both kinds of triple-quoted strings.
- Translating only the *first-statement* triple-quoted string per scope keeps the change comments-and-docstrings-only.
- **Implications**: Translation pass must visually verify each triple-quoted string is the first statement before rewriting; otherwise leave it alone.
### Google-style docstring conversions
- **Context**: `dev-guidelines.md` requires Google-style docstrings; existing Chinese docstrings sometimes use Chinese keys.
- **Findings**: The following key map applies:
- `参数:``Args:`
- `返回:``Returns:`
- `异常:``Raises:`
- `产生:` / `生成:``Yields:`
- `示例:``Example:` (or `Examples:`)
- `注意:` / `备注:``Note:` (or `Notes:`)
- **Implications**: Document this mapping in design.md so the implementation pass is mechanical.
## Architecture Pattern Evaluation
| Option | Description | Strengths | Risks / Limitations | Notes |
|--------|-------------|-----------|---------------------|-------|
| Manual file-by-file pass | Walk in alphabetical order, package-grouped commits | Predictable, easy to review per package | Human time required | Selected approach |
| Multi-PR per package | One PR per backend package | Smaller diffs to review | Higher overhead, more PR churn | Allowed by ticket but not required |
| Tooling-assisted bulk script | LLM-driven find-and-replace tool | Reusable | Risk of touching string literals; tool itself becomes a deliverable | Out of proportion |
## Design Decisions
### Decision: Single-pass, package-grouped commits, single PR
- **Context**: 37 files, mechanical change, ticket allows either single or split PRs.
- **Alternatives Considered**:
1. Multi-PR per package — more granular review but higher overhead.
2. Tooling-assisted bulk script — overkill for one ticket.
- **Selected Approach**: Single PR with one or more commits, grouped by package (`models/`, `utils/`, `services/`, `api/`, `scripts/`, root) so reviewers can read the diff one package at a time.
- **Rationale**: Mechanical change with low risk; ticket explicitly allows it; reduces PR overhead; `/done` produces one PR per branch by default.
- **Trade-offs**: One large PR, but partitioned by commit. Reviewer can use commit history to navigate.
- **Follow-up**: After each package commit, re-run residual `grep` and `pytest` to maintain the invariant.
### Decision: First-statement disambiguation rule
- **Context**: Distinguish docstrings (in scope) from value strings (out of scope).
- **Selected Approach**: A triple-quoted string is treated as a docstring (in scope) only if it is the first statement of a module / class / function body. All other triple-quoted strings are values (out of scope).
- **Rationale**: Matches Python's own definition; keeps boundary with adjacent tickets unambiguous.
### Decision: Drop comments that restate code
- **Context**: R3 requires deletion of comments whose translated form would merely paraphrase the next line.
- **Selected Approach**: Apply a one-line heuristic: if the translated comment would be a verb phrase that mirrors the immediately following executable line, delete the comment instead of writing it.
- **Rationale**: Aligns with project rule "comment the why, not the what".
## Risks & Mitigations
- **Risk**: Accidental edit to a string literal (would belong to ticket #2/#3/#4/#5/#6) — **Mitigation**: After each package commit, run `git diff --stat` and a per-file diff sanity check; verify only `#` lines and docstring lines change.
- **Risk**: Tests failing because a string-shape changed — **Mitigation**: Run `uv run python -m pytest backend/scripts/test_profile_format.py` after each commit.
- **Risk**: Line length violations after English expansion — **Mitigation**: Reflow long English at <= 120 chars within the docstring/comment only; never reflow code.
## References
- `dev-guidelines.md` — repo-level coding standards, Google-style docstring requirement.
- `.claude/rules/commits.md` — Conventional Commits standard for the commit message.
- Issue #7 — salestech-group/MiroFish: source ticket.
- Issues #2/#3/#4/#5/#6 — adjacent i18n tickets that own the string-literal Chinese.

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""AST-aware classifier of Chinese characters in a Python source file.
Usage::
python3 .kiro/specs/i18n-translate-backend-comments/scan_chinese.py <path>
Classifies every line containing CJK Unified Ideographs (U+4E00..U+9FFF)
into one of three buckets:
* ``DOCSTRING`` line lies within a module/class/function docstring (in
scope for ticket #7).
* ``COMMENT`` line contains a ``#`` and is not inside a docstring or
a string literal span (in scope for ticket #7).
* ``STRING`` line is part of a string literal value (out of scope
owned by sibling tickets #2/#3/#4/#5/#6).
Exit code is the count of in-scope hits (DOCSTRING + COMMENT). Stdout
lists each in-scope hit as ``<line> <bucket>: <content>`` so callers can
inspect them.
"""
from __future__ import annotations
import ast
import pathlib
import re
import sys
CJK_RE = re.compile(r"[一-鿿]")
def classify(path: pathlib.Path) -> int:
text = path.read_text(encoding="utf-8")
lines = text.split("\n")
tree = ast.parse(text)
docstring_lines: set[int] = set()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
ds = ast.get_docstring(node, clean=False)
if ds is None:
continue
body = node.body
if not body or not isinstance(body[0], ast.Expr):
continue
const = body[0].value
if isinstance(const, ast.Constant) and isinstance(const.value, str):
start = const.lineno
end = getattr(const, "end_lineno", start)
for ln in range(start, end + 1):
docstring_lines.add(ln)
string_value_lines: set[int] = set()
for node in ast.walk(tree):
if isinstance(node, ast.Constant) and isinstance(node.value, str):
start = node.lineno
end = getattr(node, "end_lineno", start)
for ln in range(start, end + 1):
string_value_lines.add(ln)
in_scope_count = 0
for i, line in enumerate(lines, start=1):
if not CJK_RE.search(line):
continue
if i in docstring_lines:
print(f"{i:5d} DOCSTRING: {line.rstrip()[:120]}")
in_scope_count += 1
elif i in string_value_lines:
# Out of scope: owned by sibling tickets.
pass
elif "#" in line:
print(f"{i:5d} COMMENT : {line.rstrip()[:120]}")
in_scope_count += 1
# else: unclassified — treat as out of scope (STRING value spanning).
return in_scope_count
def main(argv: list[str]) -> int:
if len(argv) < 2:
print("usage: scan_chinese.py <path>", file=sys.stderr)
return 2
path = pathlib.Path(argv[1])
in_scope = classify(path)
print(f"---", file=sys.stderr)
print(f"in-scope CJK hits in {path}: {in_scope}", file=sys.stderr)
return 0 if in_scope == 0 else 1
if __name__ == "__main__":
raise SystemExit(main(sys.argv))

View File

@ -0,0 +1,24 @@
{
"feature_name": "i18n-translate-backend-comments",
"created_at": "2026-05-07T14:24:17Z",
"updated_at": "2026-05-07T14:26:00Z",
"language": "en",
"phase": "tasks-generated",
"ticket": 7,
"ticket_url": "https://github.com/salestech-group/MiroFish/issues/7",
"approvals": {
"requirements": {
"generated": true,
"approved": true
},
"design": {
"generated": true,
"approved": true
},
"tasks": {
"generated": true,
"approved": true
}
},
"ready_for_implementation": true
}

View File

@ -0,0 +1,97 @@
# Implementation Plan
## Foundation
- [x] 1. Establish baseline and working branch
- [x] 1.1 Create translation working branch and capture baseline state
- Create branch `docs/i18n-7-translate-backend-comments` from `main`.
- Capture the baseline residual hits by running the discovery scan (the regex `[一-鿿]` against `backend/**/*.py`, excluding `.venv`); record the file list as the work queue.
- Run `cd backend && uv run python -m pytest scripts/test_profile_format.py` and confirm a green baseline before any edits.
- Observable: a fresh branch exists, the baseline file list of 37 in-scope files is captured, and the baseline pytest run passes.
- _Requirements: 5.1, 6.1_
## Core — Per-Package Translation
- [x] 2. Translate Chinese docstrings and inline comments per package
- [x] 2.1 (P) Translate `backend/app/models/`
- Translate Chinese module/class/function docstrings and `#` comments in `backend/app/models/__init__.py`, `backend/app/models/project.py`, and `backend/app/models/task.py`.
- Apply the docstring-vs-value disambiguation rule (first-statement only) so that no string literal is touched.
- Apply the Google-style key map (`参数:` → `Args:`, `返回:``Returns:`, `异常:``Raises:`, `产生:`/`生成:` → `Yields:`, `示例:``Examples:`, `注意:`/`备注:` → `Note:`).
- Drop comments that merely restate the next executable line; preserve `TODO:`/`FIXME:` and any embedded ticket reference verbatim.
- Re-run the residual scan and confirm `backend/app/models/` no longer has Chinese in non-string-literal positions.
- Re-run `cd backend && uv run python -m pytest scripts/test_profile_format.py` and confirm exit 0.
- Observable: zero non-string-literal Chinese remains in `backend/app/models/*.py`, and the test command exits 0.
- _Requirements: 1.1, 1.2, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, 4.1, 4.2, 4.3, 4.4, 4.5_
- _Boundary: backend/app/models/_
- [x] 2.2 (P) Translate `backend/app/utils/`
- Translate Chinese docstrings and `#` comments in `backend/app/utils/__init__.py`, `file_parser.py`, `llm_client.py`, `locale.py`, `logger.py`, `retry.py`, and `zep_paging.py`.
- Be especially careful with `locale.py` and `logger.py`: they intentionally route Chinese strings through their value paths; only docstrings and `#` comments are in scope.
- Apply Rules 15 from `design.md` (disambiguation, key map, comment hygiene, style, preservation).
- Re-run the residual scan and confirm `backend/app/utils/` no longer has Chinese in non-string-literal positions.
- Re-run the pytest command and confirm exit 0.
- Observable: zero non-string-literal Chinese remains in `backend/app/utils/*.py`, and the test command exits 0.
- _Requirements: 1.1, 1.2, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, 4.1, 4.2, 4.3, 4.4, 4.5_
- _Boundary: backend/app/utils/_
- [x] 2.3 (P) Translate `backend/app/services/` — complete (all 12 files; finished in this installment)
- Translate Chinese docstrings and `#` comments across all 12 service files: `__init__.py`, `graph_builder.py`, `ontology_generator.py`, `oasis_profile_generator.py`, `report_agent.py`, `simulation_config_generator.py`, `simulation_ipc.py`, `simulation_manager.py`, `simulation_runner.py`, `text_processor.py`, `zep_entity_reader.py`, `zep_graph_memory_updater.py`, `zep_tools.py`.
- Treat all triple-quoted prompt templates and value strings as out of scope (owned by issues #2/#3/#4/#5/#6) — only the first-statement docstrings of modules/classes/functions are in scope.
- Apply Rules 15 from `design.md`.
- Re-run the residual scan and confirm `backend/app/services/` no longer has Chinese in non-string-literal positions.
- Re-run the pytest command and confirm exit 0.
- Observable: zero non-string-literal Chinese remains in `backend/app/services/*.py`, and the test command exits 0.
- _Requirements: 1.1, 1.2, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, 4.1, 4.2, 4.3, 4.4, 4.5_
- _Boundary: backend/app/services/_
- [x] 2.4 (P) Translate `backend/app/api/` — complete (all 4 files; finished in this installment)
- Translate Chinese docstrings and `#` comments in `__init__.py`, `graph.py`, `report.py`, `simulation.py`.
- Treat any user-facing string-literal Chinese in API responses as out of scope (owned by issue #6).
- Apply Rules 15 from `design.md`.
- Re-run the residual scan and confirm `backend/app/api/` no longer has Chinese in non-string-literal positions.
- Re-run the pytest command and confirm exit 0.
- Observable: zero non-string-literal Chinese remains in `backend/app/api/*.py`, and the test command exits 0.
- _Requirements: 1.1, 1.2, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, 4.1, 4.2, 4.3, 4.4, 4.5_
- _Boundary: backend/app/api/_
- [x] 2.5 (P) Translate `backend/scripts/` — complete (all 5 files; finished in this installment)
- Translate Chinese docstrings and `#` comments in `action_logger.py`, `run_parallel_simulation.py`, `run_reddit_simulation.py`, `run_twitter_simulation.py`, `test_profile_format.py`.
- Apply Rules 15 from `design.md`.
- Be especially careful with `test_profile_format.py`: any Chinese in test data string literals is out of scope; only docstrings and `#` comments are in scope.
- Re-run the residual scan and confirm `backend/scripts/` no longer has Chinese in non-string-literal positions.
- Re-run the pytest command and confirm exit 0.
- Observable: zero non-string-literal Chinese remains in `backend/scripts/*.py`, and the test command exits 0.
- _Requirements: 1.1, 1.2, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, 4.1, 4.2, 4.3, 4.4, 4.5_
- _Boundary: backend/scripts/_
- [x] 2.6 (P) Translate root backend files
- Translate Chinese docstrings and `#` comments in `backend/app/__init__.py`, `backend/app/config.py`, and `backend/run.py`.
- Apply Rules 15 from `design.md`.
- Be especially careful with `backend/app/config.py`: any Chinese in default-value string literals is out of scope; only docstrings and `#` comments are in scope.
- Re-run the residual scan and confirm these three files no longer have Chinese in non-string-literal positions.
- Re-run the pytest command and confirm exit 0.
- Observable: zero non-string-literal Chinese remains in `backend/app/__init__.py`, `backend/app/config.py`, and `backend/run.py`, and the test command exits 0.
- _Requirements: 1.1, 1.2, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, 4.1, 4.2, 4.3, 4.4, 4.5_
- _Boundary: backend/app (root), backend/run.py_
## Validation
- [x] 3. Final verification and PR preparation
- [x] 3.1 Run the final verification gate — scanner + py_compile pass on all 12 newly-translated files; CJK guard baseline updated (backend/app: 2792 → 307); pytest blocked by pre-existing env issues, see HANDOFF.md
- Run the residual scan one more time and confirm the only remaining hits are files where the Chinese is in string literals owned by issues #2/#3/#4/#5/#6, plus the intentional Chinese in `backend/tests/test_locale*.py`.
- Run `cd backend && uv run python -m pytest scripts/test_profile_format.py` and confirm exit 0.
- Run `git diff --stat origin/main...HEAD` and confirm only in-scope file paths under `backend/app/`, `backend/run.py`, and `backend/scripts/` are listed.
- Spot-check three random changed files with `git diff <path>` and confirm only `#` lines and docstring lines changed (no executable lines, no string-literal lines).
- Observable: residual scan, pytest, diff scope, and spot diff all pass.
- _Depends: 2.1, 2.2, 2.3, 2.4, 2.5, 2.6_
- _Requirements: 1.3, 2.5, 5.1, 5.2, 5.3, 5.4, 6.4_
- [ ] 3.2 Open PR and reference ticket #7
- Use `/done` to commit any remaining changes per Conventional Commits with type `docs` and scope `i18n` (e.g. `docs(i18n): translate chinese docstrings/comments in backend/<area>`), push the branch, and open a PR.
- The PR body must include `Closes #7` and reference the spec at `.kiro/specs/i18n-translate-backend-comments/`.
- Verify the PR contains no unrelated changes (no dependency bumps, no config changes, no refactors).
- Observable: a PR exists on GitHub from `docs/i18n-7-translate-backend-comments` to `main` that closes #7 and contains only docstring/comment translation diffs.
- _Depends: 3.1_
- _Requirements: 6.1, 6.2, 6.3, 6.4_

View File

@ -1,12 +1,10 @@
""" """MiroFish backend Flask application factory."""
MiroFish Backend - Flask应用工厂
"""
import os import os
import warnings import warnings
# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers # Silence multiprocessing.resource_tracker warnings emitted by some third-party
# 需要在所有其他导入之前设置 # libraries (e.g. transformers); must run before those modules are imported.
warnings.filterwarnings("ignore", message=".*resource_tracker.*") warnings.filterwarnings("ignore", message=".*resource_tracker.*")
from flask import Flask, request from flask import Flask, request
@ -18,62 +16,65 @@ from .utils.locale import t
def create_app(config_class=Config): def create_app(config_class=Config):
"""Flask应用工厂函数""" """Flask application factory."""
app = Flask(__name__) app = Flask(__name__)
app.config.from_object(config_class) app.config.from_object(config_class)
# 设置JSON编码确保中文直接显示而不是 \uXXXX 格式) # Configure JSON encoding so non-ASCII characters render literally
# Flask >= 2.3 使用 app.json.ensure_ascii旧版本使用 JSON_AS_ASCII 配置 # rather than as \uXXXX escape sequences. Flask >= 2.3 exposes
# ``app.json.ensure_ascii``; older versions use ``JSON_AS_ASCII``.
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'): if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
app.json.ensure_ascii = False app.json.ensure_ascii = False
# 设置日志 # Configure logging.
logger = setup_logger('mirofish') logger = setup_logger('mirofish')
# 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次) # Only print startup banners in the reloader child process to avoid
# double-printing in debug mode.
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
debug_mode = app.config.get('DEBUG', False) debug_mode = app.config.get('DEBUG', False)
should_log_startup = not debug_mode or is_reloader_process should_log_startup = not debug_mode or is_reloader_process
if should_log_startup: if should_log_startup:
logger.info("=" * 50) logger.info("=" * 50)
logger.info(t("log.bootstrap.m001")) logger.info(t("log.bootstrap.m001"))
logger.info("=" * 50) logger.info("=" * 50)
# 启用CORS # Enable CORS.
CORS(app, resources={r"/api/*": {"origins": "*"}}) CORS(app, resources={r"/api/*": {"origins": "*"}})
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程) # Register simulation-process cleanup so all child processes are torn down
# when the Flask server shuts down.
from .services.simulation_runner import SimulationRunner from .services.simulation_runner import SimulationRunner
SimulationRunner.register_cleanup() SimulationRunner.register_cleanup()
if should_log_startup: if should_log_startup:
logger.info(t("log.bootstrap.m002")) logger.info(t("log.bootstrap.m002"))
# 请求日志中间件 # Request-logging middleware.
@app.before_request @app.before_request
def log_request(): def log_request():
logger = get_logger('mirofish.request') logger = get_logger('mirofish.request')
logger.debug(t("log.bootstrap.m003", request=request.method, request_2=request.path)) logger.debug(t("log.bootstrap.m003", request=request.method, request_2=request.path))
if request.content_type and 'json' in request.content_type: if request.content_type and 'json' in request.content_type:
logger.debug(t("log.bootstrap.m004", request=request.get_json(silent=True))) logger.debug(t("log.bootstrap.m004", request=request.get_json(silent=True)))
@app.after_request @app.after_request
def log_response(response): def log_response(response):
logger = get_logger('mirofish.request') logger = get_logger('mirofish.request')
logger.debug(t("log.bootstrap.m005", response=response.status_code)) logger.debug(t("log.bootstrap.m005", response=response.status_code))
return response return response
# 注册蓝图 # Register API blueprints.
from .api import graph_bp, simulation_bp, report_bp from .api import graph_bp, simulation_bp, report_bp
app.register_blueprint(graph_bp, url_prefix='/api/graph') app.register_blueprint(graph_bp, url_prefix='/api/graph')
app.register_blueprint(simulation_bp, url_prefix='/api/simulation') app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
app.register_blueprint(report_bp, url_prefix='/api/report') app.register_blueprint(report_bp, url_prefix='/api/report')
# 健康检查 # Health-check endpoint.
@app.route('/health') @app.route('/health')
def health(): def health():
return {'status': 'ok', 'service': 'MiroFish Backend'} return {'status': 'ok', 'service': 'MiroFish Backend'}
# On startup: recover any projects stuck in graph_building (task was killed by restart) # On startup: recover any projects stuck in graph_building (task was killed by restart)
if should_log_startup: if should_log_startup:
_recover_stuck_projects() _recover_stuck_projects()

View File

@ -1,6 +1,4 @@
""" """API blueprints package."""
API路由模块
"""
from flask import Blueprint from flask import Blueprint

View File

@ -1,6 +1,7 @@
""" """
图谱相关API路由 Graph-related API routes.
采用项目上下文机制服务端持久化状态
Uses a project context mechanism with server-side state persistence.
""" """
import os import os
@ -26,25 +27,22 @@ _graph_data_cache: dict = {} # graph_id -> {"data": ..., "ts": float}
_graph_refresh_locks: dict = {} # graph_id -> threading.Lock (one refresh at a time) _graph_refresh_locks: dict = {} # graph_id -> threading.Lock (one refresh at a time)
_GRAPH_CACHE_TTL = 300 # seconds before triggering a background refresh _GRAPH_CACHE_TTL = 300 # seconds before triggering a background refresh
# 获取日志器
logger = get_logger('mirofish.api') logger = get_logger('mirofish.api')
def allowed_file(filename: str) -> bool: def allowed_file(filename: str) -> bool:
"""检查文件扩展名是否允许""" """Return True if the file extension is in the allowed list."""
if not filename or '.' not in filename: if not filename or '.' not in filename:
return False return False
ext = os.path.splitext(filename)[1].lower().lstrip('.') ext = os.path.splitext(filename)[1].lower().lstrip('.')
return ext in Config.ALLOWED_EXTENSIONS return ext in Config.ALLOWED_EXTENSIONS
# ============== 项目管理接口 ============== # ============== Project management endpoints ==============
@graph_bp.route('/project/<project_id>', methods=['GET']) @graph_bp.route('/project/<project_id>', methods=['GET'])
def get_project(project_id: str): def get_project(project_id: str):
""" """Get project details."""
获取项目详情
"""
project = ProjectManager.get_project(project_id) project = ProjectManager.get_project(project_id)
if not project: if not project:
@ -61,9 +59,7 @@ def get_project(project_id: str):
@graph_bp.route('/project/list', methods=['GET']) @graph_bp.route('/project/list', methods=['GET'])
def list_projects(): def list_projects():
""" """List all projects."""
列出所有项目
"""
limit = request.args.get('limit', 50, type=int) limit = request.args.get('limit', 50, type=int)
projects = ProjectManager.list_projects(limit=limit) projects = ProjectManager.list_projects(limit=limit)
@ -76,9 +72,7 @@ def list_projects():
@graph_bp.route('/project/<project_id>', methods=['DELETE']) @graph_bp.route('/project/<project_id>', methods=['DELETE'])
def delete_project(project_id: str): def delete_project(project_id: str):
""" """Delete a project."""
删除项目
"""
success = ProjectManager.delete_project(project_id) success = ProjectManager.delete_project(project_id)
if not success: if not success:
@ -95,9 +89,7 @@ def delete_project(project_id: str):
@graph_bp.route('/project/<project_id>/reset', methods=['POST']) @graph_bp.route('/project/<project_id>/reset', methods=['POST'])
def reset_project(project_id: str): def reset_project(project_id: str):
""" """Reset project state (used to rebuild the graph from scratch)."""
重置项目状态用于重新构建图谱
"""
project = ProjectManager.get_project(project_id) project = ProjectManager.get_project(project_id)
if not project: if not project:
@ -106,7 +98,8 @@ def reset_project(project_id: str):
"error": t("api.error.graph.m004", project_id=project_id) "error": t("api.error.graph.m004", project_id=project_id)
}), 404 }), 404
# 重置到本体已生成状态 # Roll back to the "ontology generated" state so the next build can resume
# from the existing ontology rather than re-running ontology generation.
if project.ontology: if project.ontology:
project.status = ProjectStatus.ONTOLOGY_GENERATED project.status = ProjectStatus.ONTOLOGY_GENERATED
else: else:
@ -124,22 +117,21 @@ def reset_project(project_id: str):
}) })
# ============== 接口1上传文件并生成本体 ============== # ============== Endpoint 1: upload files and generate ontology ==============
@graph_bp.route('/ontology/generate', methods=['POST']) @graph_bp.route('/ontology/generate', methods=['POST'])
def generate_ontology(): def generate_ontology():
""" """Endpoint 1: upload files, analyze them, and generate an ontology definition.
接口1上传文件分析生成本体定义
Request format: multipart/form-data.
请求方式multipart/form-data
Args:
参数 files: Uploaded files (PDF/MD/TXT); one or more.
files: 上传的文件PDF/MD/TXT可多个 simulation_requirement: Description of the simulation requirement (required).
simulation_requirement: 模拟需求描述必填 project_name: Project name (optional).
project_name: 项目名称可选 additional_context: Additional context (optional).
additional_context: 额外说明可选
Returns:
返回
{ {
"success": true, "success": true,
"data": { "data": {
@ -156,8 +148,7 @@ def generate_ontology():
""" """
try: try:
logger.info(t("log.graph_api.m006")) logger.info(t("log.graph_api.m006"))
# 获取参数
simulation_requirement = request.form.get('simulation_requirement', '') simulation_requirement = request.form.get('simulation_requirement', '')
project_name = request.form.get('project_name', 'Unnamed Project') project_name = request.form.get('project_name', 'Unnamed Project')
additional_context = request.form.get('additional_context', '') additional_context = request.form.get('additional_context', '')
@ -171,7 +162,6 @@ def generate_ontology():
"error": t("api.error.graph.m009") "error": t("api.error.graph.m009")
}), 400 }), 400
# 获取上传的文件
uploaded_files = request.files.getlist('files') uploaded_files = request.files.getlist('files')
if not uploaded_files or all(not f.filename for f in uploaded_files): if not uploaded_files or all(not f.filename for f in uploaded_files):
return jsonify({ return jsonify({
@ -179,18 +169,17 @@ def generate_ontology():
"error": t("api.error.graph.m010") "error": t("api.error.graph.m010")
}), 400 }), 400
# 创建项目
project = ProjectManager.create_project(name=project_name) project = ProjectManager.create_project(name=project_name)
project.simulation_requirement = simulation_requirement project.simulation_requirement = simulation_requirement
logger.info(t("log.graph_api.m011", project=project.project_id)) logger.info(t("log.graph_api.m011", project=project.project_id))
# 保存文件并提取文本 # Persist each uploaded file under the project's directory and pull its
# text out so the ontology generator has plain text to work with.
document_texts = [] document_texts = []
all_text = "" all_text = ""
for file in uploaded_files: for file in uploaded_files:
if file and file.filename and allowed_file(file.filename): if file and file.filename and allowed_file(file.filename):
# 保存文件到项目目录
file_info = ProjectManager.save_file_to_project( file_info = ProjectManager.save_file_to_project(
project.project_id, project.project_id,
file, file,
@ -201,7 +190,6 @@ def generate_ontology():
"size": file_info["size"] "size": file_info["size"]
}) })
# 提取文本
text = FileParser.extract_text(file_info["path"]) text = FileParser.extract_text(file_info["path"])
text = TextProcessor.preprocess_text(text) text = TextProcessor.preprocess_text(text)
document_texts.append(text) document_texts.append(text)
@ -214,12 +202,10 @@ def generate_ontology():
"error": t("api.error.graph.m012") "error": t("api.error.graph.m012")
}), 400 }), 400
# 保存提取的文本
project.total_text_length = len(all_text) project.total_text_length = len(all_text)
ProjectManager.save_extracted_text(project.project_id, all_text) ProjectManager.save_extracted_text(project.project_id, all_text)
logger.info(t("log.graph_api.m013", len=len(all_text))) logger.info(t("log.graph_api.m013", len=len(all_text)))
# 生成本体
logger.info(t("log.graph_api.m014")) logger.info(t("log.graph_api.m014"))
generator = OntologyGenerator() generator = OntologyGenerator()
ontology = generator.generate( ontology = generator.generate(
@ -228,7 +214,6 @@ def generate_ontology():
additional_context=additional_context if additional_context else None additional_context=additional_context if additional_context else None
) )
# 保存本体到项目
entity_count = len(ontology.get("entity_types", [])) entity_count = len(ontology.get("entity_types", []))
edge_count = len(ontology.get("edge_types", [])) edge_count = len(ontology.get("edge_types", []))
logger.info(t("log.graph_api.m015", entity_count=entity_count, edge_count=edge_count)) logger.info(t("log.graph_api.m015", entity_count=entity_count, edge_count=edge_count))
@ -262,35 +247,33 @@ def generate_ontology():
}), 500 }), 500
# ============== 接口2构建图谱 ============== # ============== Endpoint 2: build graph ==============
@graph_bp.route('/build', methods=['POST']) @graph_bp.route('/build', methods=['POST'])
def build_graph(): def build_graph():
""" """Endpoint 2: build the graph for the given project_id.
接口2根据project_id构建图谱
Request (JSON):
请求JSON
{ {
"project_id": "proj_xxxx", // 必填来自接口1 "project_id": "proj_xxxx", // required, from endpoint 1
"graph_name": "图谱名称", // 可选 "graph_name": "Graph name", // optional
"chunk_size": 500, // 可选默认500 "chunk_size": 500, // optional, default 500
"chunk_overlap": 50 // 可选默认50 "chunk_overlap": 50 // optional, default 50
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"project_id": "proj_xxxx", "project_id": "proj_xxxx",
"task_id": "task_xxxx", "task_id": "task_xxxx",
"message": "图谱构建任务已启动" "message": "Graph build task started"
} }
} }
""" """
try: try:
logger.info(t("log.graph_api.m017")) logger.info(t("log.graph_api.m017"))
# 检查配置
errors = [] errors = []
if not Config.NEO4J_PASSWORD: if not Config.NEO4J_PASSWORD:
errors.append("NEO4J未配置") errors.append("NEO4J未配置")
@ -301,7 +284,6 @@ def build_graph():
"error": "配置错误: " + "; ".join(errors) "error": "配置错误: " + "; ".join(errors)
}), 500 }), 500
# 解析请求
data = request.get_json() or {} data = request.get_json() or {}
project_id = data.get('project_id') project_id = data.get('project_id')
logger.debug(t("log.graph_api.m019", project_id=project_id)) logger.debug(t("log.graph_api.m019", project_id=project_id))
@ -312,7 +294,6 @@ def build_graph():
"error": t("api.error.graph.m020") "error": t("api.error.graph.m020")
}), 400 }), 400
# 获取项目
project = ProjectManager.get_project(project_id) project = ProjectManager.get_project(project_id)
if not project: if not project:
return jsonify({ return jsonify({
@ -320,8 +301,8 @@ def build_graph():
"error": t("api.error.graph.m021", project_id=project_id) "error": t("api.error.graph.m021", project_id=project_id)
}), 404 }), 404
# 检查项目状态 # If True, abandon any existing build progress and rebuild from scratch.
force = data.get('force', False) # 强制重新构建 force = data.get('force', False)
if project.status == ProjectStatus.CREATED: if project.status == ProjectStatus.CREATED:
return jsonify({ return jsonify({
@ -336,23 +317,20 @@ def build_graph():
"task_id": project.graph_build_task_id "task_id": project.graph_build_task_id
}), 400 }), 400
# 如果强制重建,重置状态 # On a forced rebuild, drop any prior build artifacts so we restart cleanly.
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]: if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
project.status = ProjectStatus.ONTOLOGY_GENERATED project.status = ProjectStatus.ONTOLOGY_GENERATED
project.graph_id = None project.graph_id = None
project.graph_build_task_id = None project.graph_build_task_id = None
project.error = None project.error = None
# 获取配置
graph_name = data.get('graph_name', project.name or 'MiroFish Graph') graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE) chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP) chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
# 更新项目配置
project.chunk_size = chunk_size project.chunk_size = chunk_size
project.chunk_overlap = chunk_overlap project.chunk_overlap = chunk_overlap
# 获取提取的文本
text = ProjectManager.get_extracted_text(project_id) text = ProjectManager.get_extracted_text(project_id)
if not text: if not text:
return jsonify({ return jsonify({
@ -360,7 +338,6 @@ def build_graph():
"error": t("api.error.graph.m024") "error": t("api.error.graph.m024")
}), 400 }), 400
# 获取本体
ontology = project.ontology ontology = project.ontology
if not ontology: if not ontology:
return jsonify({ return jsonify({
@ -368,17 +345,14 @@ def build_graph():
"error": t("api.error.graph.m025") "error": t("api.error.graph.m025")
}), 400 }), 400
# 创建异步任务
task_manager = TaskManager() task_manager = TaskManager()
task_id = task_manager.create_task(f"构建图谱: {graph_name}") task_id = task_manager.create_task(f"构建图谱: {graph_name}")
logger.info(t("log.graph_api.m026", task_id=task_id, project_id=project_id)) logger.info(t("log.graph_api.m026", task_id=task_id, project_id=project_id))
# 更新项目状态
project.status = ProjectStatus.GRAPH_BUILDING project.status = ProjectStatus.GRAPH_BUILDING
project.graph_build_task_id = task_id project.graph_build_task_id = task_id
ProjectManager.save_project(project) ProjectManager.save_project(project)
# 启动后台任务
def build_task(): def build_task():
build_logger = get_logger('mirofish.build') build_logger = get_logger('mirofish.build')
try: try:
@ -389,10 +363,8 @@ def build_graph():
message="初始化图谱构建服务..." message="初始化图谱构建服务..."
) )
# 创建图谱构建服务
builder = GraphBuilderService() builder = GraphBuilderService()
# 分块
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message="文本分块中...", message="文本分块中...",
@ -404,30 +376,27 @@ def build_graph():
overlap=chunk_overlap overlap=chunk_overlap
) )
total_chunks = len(chunks) total_chunks = len(chunks)
# 创建图谱
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message="创建Zep图谱...", message="创建Zep图谱...",
progress=10 progress=10
) )
graph_id = builder.create_graph(name=graph_name) graph_id = builder.create_graph(name=graph_name)
# 更新项目的graph_id
project.graph_id = graph_id project.graph_id = graph_id
ProjectManager.save_project(project) ProjectManager.save_project(project)
# 设置本体
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message="设置本体定义...", message="设置本体定义...",
progress=15 progress=15
) )
builder.set_ontology(graph_id, ontology) builder.set_ontology(graph_id, ontology)
# 添加文本progress_callback 签名是 (msg, progress_ratio) # Add text. The progress_callback signature is (msg, progress_ratio).
def add_progress_callback(msg, progress_ratio): def add_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 40) # 15% - 55% progress = 15 + int(progress_ratio * 40) # maps ratio onto 15%-55%
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=msg, message=msg,
@ -460,7 +429,7 @@ def build_graph():
skip_chunks=skip_chunks, skip_chunks=skip_chunks,
) )
# 等待Zep处理完成查询每个episode的processed状态 # Wait for Zep to finish processing (poll each episode's processed flag).
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message="等待Zep处理数据...", message="等待Zep处理数据...",
@ -468,7 +437,7 @@ def build_graph():
) )
def wait_progress_callback(msg, progress_ratio): def wait_progress_callback(msg, progress_ratio):
progress = 55 + int(progress_ratio * 35) # 55% - 90% progress = 55 + int(progress_ratio * 35) # maps ratio onto 55%-90%
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=msg, message=msg,
@ -476,16 +445,14 @@ def build_graph():
) )
builder._wait_for_episodes(episode_uuids, wait_progress_callback) builder._wait_for_episodes(episode_uuids, wait_progress_callback)
# 获取图谱数据
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message="获取图谱数据...", message="获取图谱数据...",
progress=95 progress=95
) )
graph_data = builder.get_graph_data(graph_id) graph_data = builder.get_graph_data(graph_id)
# 更新项目状态
project.status = ProjectStatus.GRAPH_COMPLETED project.status = ProjectStatus.GRAPH_COMPLETED
ProjectManager.save_project(project) ProjectManager.save_project(project)
@ -498,8 +465,7 @@ def build_graph():
node_count=node_count, node_count=node_count,
edge_count=edge_count, edge_count=edge_count,
)) ))
# 完成
task_manager.update_task( task_manager.update_task(
task_id, task_id,
status=TaskStatus.COMPLETED, status=TaskStatus.COMPLETED,
@ -515,7 +481,7 @@ def build_graph():
) )
except Exception as e: except Exception as e:
# 更新项目状态为失败 # Mark the project as FAILED so the UI can surface the error.
build_logger.error(t("log.graph_api.m029", task_id=task_id, e=str(e))) build_logger.error(t("log.graph_api.m029", task_id=task_id, e=str(e)))
build_logger.debug(traceback.format_exc()) build_logger.debug(traceback.format_exc())
@ -530,7 +496,6 @@ def build_graph():
error=traceback.format_exc() error=traceback.format_exc()
) )
# 启动后台线程
thread = threading.Thread(target=build_task, daemon=True) thread = threading.Thread(target=build_task, daemon=True)
thread.start() thread.start()
@ -551,13 +516,11 @@ def build_graph():
}), 500 }), 500
# ============== 任务查询接口 ============== # ============== Task query endpoints ==============
@graph_bp.route('/task/<task_id>', methods=['GET']) @graph_bp.route('/task/<task_id>', methods=['GET'])
def get_task(task_id: str): def get_task(task_id: str):
""" """Query the status of a task."""
查询任务状态
"""
task = TaskManager().get_task(task_id) task = TaskManager().get_task(task_id)
if not task: if not task:
@ -574,9 +537,7 @@ def get_task(task_id: str):
@graph_bp.route('/tasks', methods=['GET']) @graph_bp.route('/tasks', methods=['GET'])
def list_tasks(): def list_tasks():
""" """List all tasks."""
列出所有任务
"""
tasks = TaskManager().list_tasks() tasks = TaskManager().list_tasks()
return jsonify({ return jsonify({
@ -586,7 +547,7 @@ def list_tasks():
}) })
# ============== 图谱数据接口 ============== # ============== Graph data endpoints ==============
def _refresh_graph_cache(graph_id: str): def _refresh_graph_cache(graph_id: str):
"""Background thread: fetch graph data from Neo4j and update cache.""" """Background thread: fetch graph data from Neo4j and update cache."""
@ -613,11 +574,11 @@ def _refresh_graph_cache(graph_id: str):
@graph_bp.route('/data/<graph_id>', methods=['GET']) @graph_bp.route('/data/<graph_id>', methods=['GET'])
def get_graph_data(graph_id: str): def get_graph_data(graph_id: str):
""" """Return graph data (nodes and edges).
获取图谱数据节点和边
- 有缓存且未过期直接返回缓存不调用 Zep - Fresh cache: serve from cache without hitting Zep.
- 有缓存但已过期立即返回旧缓存后台异步刷新 - Stale cache: return the old cache immediately and refresh in the background.
- 无缓存后台线程拉取返回 202 让前端稍后重试 - No cache: kick off a background fetch and return 202 so the frontend retries.
""" """
if not Config.NEO4J_PASSWORD: if not Config.NEO4J_PASSWORD:
return jsonify({"success": False, "error": t("api.error.graph.m028")}), 500 return jsonify({"success": False, "error": t("api.error.graph.m028")}), 500
@ -645,9 +606,7 @@ def get_graph_data(graph_id: str):
@graph_bp.route('/delete/<graph_id>', methods=['DELETE']) @graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
def delete_graph(graph_id: str): def delete_graph(graph_id: str):
""" """Delete a Zep graph."""
删除Zep图谱
"""
try: try:
if not Config.NEO4J_PASSWORD: if not Config.NEO4J_PASSWORD:
return jsonify({ return jsonify({

View File

@ -1,6 +1,7 @@
""" """
Report API路由 Report API routes.
提供模拟报告生成获取对话等接口
Provides endpoints for generating, retrieving, and chatting about simulation reports.
""" """
import os import os
@ -20,30 +21,30 @@ from ..utils.locale import t, get_locale, set_locale
logger = get_logger('mirofish.api.report') logger = get_logger('mirofish.api.report')
# ============== 报告生成接口 ============== # ============== Report generation endpoints ==============
@report_bp.route('/generate', methods=['POST']) @report_bp.route('/generate', methods=['POST'])
def generate_report(): def generate_report():
""" """
生成模拟分析报告异步任务 Generate a simulation analysis report (asynchronous task).
这是一个耗时操作接口会立即返回task_id This is a long-running operation. The endpoint returns a task_id immediately;
使用 GET /api/report/generate/status 查询进度 use GET /api/report/generate/status to poll progress.
请求JSON Request (JSON):
{ {
"simulation_id": "sim_xxxx", // 必填模拟ID "simulation_id": "sim_xxxx", // required, simulation ID
"force_regenerate": false // 可选强制重新生成 "force_regenerate": false // optional, force regeneration
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"simulation_id": "sim_xxxx", "simulation_id": "sim_xxxx",
"task_id": "task_xxxx", "task_id": "task_xxxx",
"status": "generating", "status": "generating",
"message": "报告生成任务已启动" "message": "Report generation task started"
} }
} }
""" """
@ -58,8 +59,7 @@ def generate_report():
}), 400 }), 400
force_regenerate = data.get('force_regenerate', False) force_regenerate = data.get('force_regenerate', False)
# 获取模拟信息
manager = SimulationManager() manager = SimulationManager()
state = manager.get_simulation(simulation_id) state = manager.get_simulation(simulation_id)
@ -69,7 +69,7 @@ def generate_report():
"error": t('api.simulationNotFound', id=simulation_id) "error": t('api.simulationNotFound', id=simulation_id)
}), 404 }), 404
# 检查是否已有报告 # Skip regeneration if a completed report already exists for this simulation.
if not force_regenerate: if not force_regenerate:
existing_report = ReportManager.get_report_by_simulation(simulation_id) existing_report = ReportManager.get_report_by_simulation(simulation_id)
if existing_report and existing_report.status == ReportStatus.COMPLETED: if existing_report and existing_report.status == ReportStatus.COMPLETED:
@ -84,7 +84,6 @@ def generate_report():
} }
}) })
# 获取项目信息
project = ProjectManager.get_project(state.project_id) project = ProjectManager.get_project(state.project_id)
if not project: if not project:
return jsonify({ return jsonify({
@ -106,11 +105,11 @@ def generate_report():
"error": t('api.missingSimRequirement') "error": t('api.missingSimRequirement')
}), 400 }), 400
# 提前生成 report_id以便立即返回给前端 # Generate report_id eagerly so the frontend can use it immediately
# (before the background task has actually persisted anything).
import uuid import uuid
report_id = f"report_{uuid.uuid4().hex[:12]}" report_id = f"report_{uuid.uuid4().hex[:12]}"
# 创建异步任务
task_manager = TaskManager() task_manager = TaskManager()
task_id = task_manager.create_task( task_id = task_manager.create_task(
task_type="report_generate", task_type="report_generate",
@ -124,7 +123,6 @@ def generate_report():
# Capture locale before spawning background thread # Capture locale before spawning background thread
current_locale = get_locale() current_locale = get_locale()
# 定义后台任务
def run_generate(): def run_generate():
set_locale(current_locale) set_locale(current_locale)
try: try:
@ -134,15 +132,13 @@ def generate_report():
progress=0, progress=0,
message=t('api.initReportAgent') message=t('api.initReportAgent')
) )
# 创建Report Agent
agent = ReportAgent( agent = ReportAgent(
graph_id=graph_id, graph_id=graph_id,
simulation_id=simulation_id, simulation_id=simulation_id,
simulation_requirement=simulation_requirement simulation_requirement=simulation_requirement
) )
# 进度回调
def progress_callback(stage, progress, message): def progress_callback(stage, progress, message):
task_manager.update_task( task_manager.update_task(
task_id, task_id,
@ -150,13 +146,13 @@ def generate_report():
message=f"[{stage}] {message}" message=f"[{stage}] {message}"
) )
# 生成报告(传入预先生成的 report_id # Pass in the pre-generated report_id so the persisted report matches
# the id we already returned to the frontend.
report = agent.generate_report( report = agent.generate_report(
progress_callback=progress_callback, progress_callback=progress_callback,
report_id=report_id report_id=report_id
) )
# 保存报告
ReportManager.save_report(report) ReportManager.save_report(report)
if report.status == ReportStatus.COMPLETED: if report.status == ReportStatus.COMPLETED:
@ -174,8 +170,7 @@ def generate_report():
except Exception as e: except Exception as e:
logger.error(t("log.report_api.m001", str=str(e))) logger.error(t("log.report_api.m001", str=str(e)))
task_manager.fail_task(task_id, str(e)) task_manager.fail_task(task_id, str(e))
# 启动后台线程
thread = threading.Thread(target=run_generate, daemon=True) thread = threading.Thread(target=run_generate, daemon=True)
thread.start() thread.start()
@ -203,15 +198,15 @@ def generate_report():
@report_bp.route('/generate/status', methods=['POST']) @report_bp.route('/generate/status', methods=['POST'])
def get_generate_status(): def get_generate_status():
""" """
查询报告生成任务进度 Query the progress of a report generation task.
请求JSON Request (JSON):
{ {
"task_id": "task_xxxx", // 可选generate返回的task_id "task_id": "task_xxxx", // optional, task_id returned by generate
"simulation_id": "sim_xxxx" // 可选模拟ID "simulation_id": "sim_xxxx" // optional, simulation ID
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -228,7 +223,8 @@ def get_generate_status():
task_id = data.get('task_id') task_id = data.get('task_id')
simulation_id = data.get('simulation_id') simulation_id = data.get('simulation_id')
# 如果提供了simulation_id先检查是否已有完成的报告 # If simulation_id is provided, short-circuit when a completed report already exists
# so callers don't have to track a stale task_id after a successful run.
if simulation_id: if simulation_id:
existing_report = ReportManager.get_report_by_simulation(simulation_id) existing_report = ReportManager.get_report_by_simulation(simulation_id)
if existing_report and existing_report.status == ReportStatus.COMPLETED: if existing_report and existing_report.status == ReportStatus.COMPLETED:
@ -272,14 +268,14 @@ def get_generate_status():
}), 500 }), 500
# ============== 报告获取接口 ============== # ============== Report retrieval endpoints ==============
@report_bp.route('/<report_id>', methods=['GET']) @report_bp.route('/<report_id>', methods=['GET'])
def get_report(report_id: str): def get_report(report_id: str):
""" """
获取报告详情 Get report details.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -319,9 +315,9 @@ def get_report(report_id: str):
@report_bp.route('/by-simulation/<simulation_id>', methods=['GET']) @report_bp.route('/by-simulation/<simulation_id>', methods=['GET'])
def get_report_by_simulation(simulation_id: str): def get_report_by_simulation(simulation_id: str):
""" """
根据模拟ID获取报告 Get the report for a given simulation ID.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -358,13 +354,13 @@ def get_report_by_simulation(simulation_id: str):
@report_bp.route('/list', methods=['GET']) @report_bp.route('/list', methods=['GET'])
def list_reports(): def list_reports():
""" """
列出所有报告 List all reports.
Query参数 Query parameters:
simulation_id: 按模拟ID过滤可选 simulation_id: optional filter by simulation ID.
limit: 返回数量限制默认50 limit: maximum number of reports to return (default 50).
返回 Returns:
{ {
"success": true, "success": true,
"data": [...], "data": [...],
@ -398,9 +394,9 @@ def list_reports():
@report_bp.route('/<report_id>/download', methods=['GET']) @report_bp.route('/<report_id>/download', methods=['GET'])
def download_report(report_id: str): def download_report(report_id: str):
""" """
下载报告Markdown格式 Download a report as a Markdown file.
返回Markdown文件 Returns the Markdown file as an attachment.
""" """
try: try:
report = ReportManager.get_report(report_id) report = ReportManager.get_report(report_id)
@ -414,7 +410,8 @@ def download_report(report_id: str):
md_path = ReportManager._get_report_markdown_path(report_id) md_path = ReportManager._get_report_markdown_path(report_id)
if not os.path.exists(md_path): if not os.path.exists(md_path):
# 如果MD文件不存在生成一个临时文件 # MD file is missing on disk; materialize a temp file from the in-memory content
# so the download still succeeds for older reports that were never persisted.
import tempfile import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f: with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
f.write(report.markdown_content) f.write(report.markdown_content)
@ -443,7 +440,7 @@ def download_report(report_id: str):
@report_bp.route('/<report_id>', methods=['DELETE']) @report_bp.route('/<report_id>', methods=['DELETE'])
def delete_report(report_id: str): def delete_report(report_id: str):
"""删除报告""" """Delete a report."""
try: try:
success = ReportManager.delete_report(report_id) success = ReportManager.delete_report(report_id)
@ -467,32 +464,33 @@ def delete_report(report_id: str):
}), 500 }), 500
# ============== Report Agent对话接口 ============== # ============== Report Agent chat endpoints ==============
@report_bp.route('/chat', methods=['POST']) @report_bp.route('/chat', methods=['POST'])
def chat_with_report_agent(): def chat_with_report_agent():
""" """
与Report Agent对话 Chat with the Report Agent.
Report Agent可以在对话中自主调用检索工具来回答问题 The Report Agent can autonomously invoke retrieval tools during the conversation
to answer the user's question.
请求JSON
Request (JSON):
{ {
"simulation_id": "sim_xxxx", // 必填模拟ID "simulation_id": "sim_xxxx", // required, simulation ID
"message": "请解释一下舆情走向", // 必填用户消息 "message": "Explain the sentiment trend", // required, user message
"chat_history": [ // 可选对话历史 "chat_history": [ // optional, prior turns
{"role": "user", "content": "..."}, {"role": "user", "content": "..."},
{"role": "assistant", "content": "..."} {"role": "assistant", "content": "..."}
] ]
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"response": "Agent回复...", "response": "Agent reply...",
"tool_calls": [调用的工具列表], "tool_calls": [list of tools invoked],
"sources": [信息来源] "sources": [information sources]
} }
} }
""" """
@ -515,7 +513,6 @@ def chat_with_report_agent():
"error": t('api.requireMessage') "error": t('api.requireMessage')
}), 400 }), 400
# 获取模拟和项目信息
manager = SimulationManager() manager = SimulationManager()
state = manager.get_simulation(simulation_id) state = manager.get_simulation(simulation_id)
@ -540,8 +537,7 @@ def chat_with_report_agent():
}), 400 }), 400
simulation_requirement = project.simulation_requirement or "" simulation_requirement = project.simulation_requirement or ""
# 创建Agent并进行对话
agent = ReportAgent( agent = ReportAgent(
graph_id=graph_id, graph_id=graph_id,
simulation_id=simulation_id, simulation_id=simulation_id,
@ -564,22 +560,22 @@ def chat_with_report_agent():
}), 500 }), 500
# ============== 报告进度与分章节接口 ============== # ============== Report progress and section endpoints ==============
@report_bp.route('/<report_id>/progress', methods=['GET']) @report_bp.route('/<report_id>/progress', methods=['GET'])
def get_report_progress(report_id: str): def get_report_progress(report_id: str):
""" """
获取报告生成进度实时 Get real-time report generation progress.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"status": "generating", "status": "generating",
"progress": 45, "progress": 45,
"message": "正在生成章节: 关键发现", "message": "Generating section: Key Findings",
"current_section": "关键发现", "current_section": "Key Findings",
"completed_sections": ["执行摘要", "模拟背景"], "completed_sections": ["Executive Summary", "Simulation Background"],
"updated_at": "2025-12-09T..." "updated_at": "2025-12-09T..."
} }
} }
@ -610,11 +606,12 @@ def get_report_progress(report_id: str):
@report_bp.route('/<report_id>/sections', methods=['GET']) @report_bp.route('/<report_id>/sections', methods=['GET'])
def get_report_sections(report_id: str): def get_report_sections(report_id: str):
""" """
获取已生成的章节列表分章节输出 Get the list of sections generated so far (per-section streaming output).
前端可以轮询此接口获取已生成的章节内容无需等待整个报告完成 The frontend can poll this endpoint to render sections incrementally,
without waiting for the entire report to finish.
返回
Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -623,7 +620,7 @@ def get_report_sections(report_id: str):
{ {
"filename": "section_01.md", "filename": "section_01.md",
"section_index": 1, "section_index": 1,
"content": "## 执行摘要\\n\\n..." "content": "## Executive Summary\\n\\n..."
}, },
... ...
], ],
@ -634,8 +631,7 @@ def get_report_sections(report_id: str):
""" """
try: try:
sections = ReportManager.get_generated_sections(report_id) sections = ReportManager.get_generated_sections(report_id)
# 获取报告状态
report = ReportManager.get_report(report_id) report = ReportManager.get_report(report_id)
is_complete = report is not None and report.status == ReportStatus.COMPLETED is_complete = report is not None and report.status == ReportStatus.COMPLETED
@ -661,14 +657,14 @@ def get_report_sections(report_id: str):
@report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET']) @report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET'])
def get_single_section(report_id: str, section_index: int): def get_single_section(report_id: str, section_index: int):
""" """
获取单个章节内容 Get the content of a single section.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"filename": "section_01.md", "filename": "section_01.md",
"content": "## 执行摘要\\n\\n..." "content": "## Executive Summary\\n\\n..."
} }
} }
""" """
@ -702,16 +698,16 @@ def get_single_section(report_id: str, section_index: int):
}), 500 }), 500
# ============== 报告状态检查接口 ============== # ============== Report status check endpoints ==============
@report_bp.route('/check/<simulation_id>', methods=['GET']) @report_bp.route('/check/<simulation_id>', methods=['GET'])
def check_report_status(simulation_id: str): def check_report_status(simulation_id: str):
""" """
检查模拟是否有报告以及报告状态 Check whether a simulation has a report, and report its status.
用于前端判断是否解锁Interview功能 Used by the frontend to decide whether to unlock the Interview feature.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -730,7 +726,7 @@ def check_report_status(simulation_id: str):
report_status = report.status.value if report else None report_status = report.status.value if report else None
report_id = report.report_id if report else None report_id = report.report_id if report else None
# 只有报告完成后才解锁interview # Interview feature is only unlocked once a report has finished generating.
interview_unlocked = has_report and report.status == ReportStatus.COMPLETED interview_unlocked = has_report and report.status == ReportStatus.COMPLETED
return jsonify({ return jsonify({
@ -753,22 +749,22 @@ def check_report_status(simulation_id: str):
}), 500 }), 500
# ============== Agent 日志接口 ============== # ============== Agent log endpoints ==============
@report_bp.route('/<report_id>/agent-log', methods=['GET']) @report_bp.route('/<report_id>/agent-log', methods=['GET'])
def get_agent_log(report_id: str): def get_agent_log(report_id: str):
""" """
获取 Report Agent 的详细执行日志 Get the detailed execution log of the Report Agent.
实时获取报告生成过程中的每一步动作包括 Streams every step the agent took while generating the report, including:
- 报告开始规划开始/完成 - Report start, planning start/complete.
- 每个章节的开始工具调用LLM响应完成 - Per-section start, tool calls, LLM responses, and completion.
- 报告完成或失败 - Final report completion or failure.
Query参数 Query parameters:
from_line: 从第几行开始读取可选默认0用于增量获取 from_line: line offset to start reading from (optional, default 0, for incremental polling).
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -779,7 +775,7 @@ def get_agent_log(report_id: str):
"report_id": "report_xxxx", "report_id": "report_xxxx",
"action": "tool_call", "action": "tool_call",
"stage": "generating", "stage": "generating",
"section_title": "执行摘要", "section_title": "Executive Summary",
"section_index": 1, "section_index": 1,
"details": { "details": {
"tool_name": "insight_forge", "tool_name": "insight_forge",
@ -817,9 +813,9 @@ def get_agent_log(report_id: str):
@report_bp.route('/<report_id>/agent-log/stream', methods=['GET']) @report_bp.route('/<report_id>/agent-log/stream', methods=['GET'])
def stream_agent_log(report_id: str): def stream_agent_log(report_id: str):
""" """
获取完整的 Agent 日志一次性获取全部 Get the full Agent log in one shot (no pagination).
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -848,27 +844,27 @@ def stream_agent_log(report_id: str):
}), 500 }), 500
# ============== 控制台日志接口 ============== # ============== Console log endpoints ==============
@report_bp.route('/<report_id>/console-log', methods=['GET']) @report_bp.route('/<report_id>/console-log', methods=['GET'])
def get_console_log(report_id: str): def get_console_log(report_id: str):
""" """
获取 Report Agent 的控制台输出日志 Get the Report Agent's console output log.
实时获取报告生成过程中的控制台输出INFOWARNING等 Streams the console output produced during report generation (INFO, WARNING, etc.).
这与 agent-log 接口返回的结构化 JSON 日志不同 Unlike the structured JSON returned by the agent-log endpoint, this is plain-text
是纯文本格式的控制台风格日志 console-style output.
Query参数 Query parameters:
from_line: 从第几行开始读取可选默认0用于增量获取 from_line: line offset to start reading from (optional, default 0, for incremental polling).
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"logs": [ "logs": [
"[19:46:14] INFO: 搜索完成: 找到 15 条相关事实", "[19:46:14] INFO: Search complete: found 15 relevant facts",
"[19:46:14] INFO: 图谱搜索: graph_id=xxx, query=...", "[19:46:14] INFO: Graph search: graph_id=xxx, query=...",
... ...
], ],
"total_lines": 100, "total_lines": 100,
@ -899,9 +895,9 @@ def get_console_log(report_id: str):
@report_bp.route('/<report_id>/console-log/stream', methods=['GET']) @report_bp.route('/<report_id>/console-log/stream', methods=['GET'])
def stream_console_log(report_id: str): def stream_console_log(report_id: str):
""" """
获取完整的控制台日志一次性获取全部 Get the full console log in one shot (no pagination).
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -930,17 +926,17 @@ def stream_console_log(report_id: str):
}), 500 }), 500
# ============== 工具调用接口(供调试使用)============== # ============== Tool invocation endpoints (for debugging) ==============
@report_bp.route('/tools/search', methods=['POST']) @report_bp.route('/tools/search', methods=['POST'])
def search_graph_tool(): def search_graph_tool():
""" """
图谱搜索工具接口供调试使用 Graph search tool endpoint (for debugging).
请求JSON Request (JSON):
{ {
"graph_id": "mirofish_xxxx", "graph_id": "mirofish_xxxx",
"query": "搜索查询", "query": "search query",
"limit": 10 "limit": 10
} }
""" """
@ -983,9 +979,9 @@ def search_graph_tool():
@report_bp.route('/tools/statistics', methods=['POST']) @report_bp.route('/tools/statistics', methods=['POST'])
def get_graph_statistics_tool(): def get_graph_statistics_tool():
""" """
图谱统计工具接口供调试使用 Graph statistics tool endpoint (for debugging).
请求JSON Request (JSON):
{ {
"graph_id": "mirofish_xxxx" "graph_id": "mirofish_xxxx"
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,38 +1,40 @@
""" """Configuration management.
配置管理
统一从项目根目录的 .env 文件加载配置 Loads configuration values from the project-root ``.env`` file.
""" """
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
# 加载项目根目录的 .env 文件 # Load the project-root .env file.
# 路径: MiroFish/.env (相对于 backend/app/config.py) # Path: MiroFish/.env (relative to backend/app/config.py).
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env') project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
if os.path.exists(project_root_env): if os.path.exists(project_root_env):
load_dotenv(project_root_env, override=True) load_dotenv(project_root_env, override=True)
else: else:
# 如果根目录没有 .env尝试加载环境变量用于生产环境 # If the project root has no .env, fall back to the process environment
# (used in production deployments).
load_dotenv(override=True) load_dotenv(override=True)
class Config: class Config:
"""Flask配置类""" """Flask configuration class."""
# Flask配置 # Flask settings.
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
# JSON配置 - 禁用ASCII转义让中文直接显示而不是 \uXXXX 格式) # JSON settings: disable ASCII escaping so non-ASCII output renders literally
# rather than as \uXXXX escape sequences.
JSON_AS_ASCII = False JSON_AS_ASCII = False
# LLM配置统一使用OpenAI格式 # LLM settings (called via the OpenAI-compatible API surface).
LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_API_KEY = os.environ.get('LLM_API_KEY')
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
# Neo4j + Graphiti配置(替代 Zep Cloud # Neo4j + Graphiti settings (replacement for Zep Cloud).
NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j') NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD', 'mirofish123') NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD', 'mirofish123')
@ -50,23 +52,23 @@ class Config:
EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY') EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY')
EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL') EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL')
# Zep配置(保留兼容性,已废弃) # Zep settings (kept for backwards compatibility; deprecated).
ZEP_API_KEY = os.environ.get('ZEP_API_KEY', '') ZEP_API_KEY = os.environ.get('ZEP_API_KEY', '')
# 文件上传配置 # File upload settings.
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'} ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
# 文本处理配置 # Text processing settings.
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 DEFAULT_CHUNK_SIZE = 500 # default chunk size in characters
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 DEFAULT_CHUNK_OVERLAP = 50 # default overlap in characters
# OASIS模拟配置 # OASIS simulation settings.
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
# OASIS平台可用动作配置 # OASIS per-platform allowed action lists.
OASIS_TWITTER_ACTIONS = [ OASIS_TWITTER_ACTIONS = [
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST' 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
] ]
@ -76,14 +78,14 @@ class Config:
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE' 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
] ]
# Report Agent配置 # Report agent settings.
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5')) REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
@classmethod @classmethod
def validate(cls): def validate(cls):
"""验证必要配置""" """Validate that required configuration values are present."""
errors = [] errors = []
if not cls.LLM_API_KEY: if not cls.LLM_API_KEY:
errors.append("LLM_API_KEY 未配置") errors.append("LLM_API_KEY 未配置")

View File

@ -1,6 +1,4 @@
""" """Data model package."""
数据模型模块
"""
from .task import TaskManager, TaskStatus from .task import TaskManager, TaskStatus
from .project import Project, ProjectStatus, ProjectManager from .project import Project, ProjectStatus, ProjectManager

View File

@ -1,6 +1,7 @@
""" """Project context management.
项目上下文管理
用于在服务端持久化项目状态避免前端在接口间传递大量数据 Persists project state on the server so the frontend does not have to round-trip
large blobs of context between API calls.
""" """
import os import os
@ -15,45 +16,45 @@ from ..config import Config
class ProjectStatus(str, Enum): class ProjectStatus(str, Enum):
"""项目状态""" """Project lifecycle status."""
CREATED = "created" # 刚创建,文件已上传 CREATED = "created" # just created, files uploaded
ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成 ONTOLOGY_GENERATED = "ontology_generated" # ontology has been generated
GRAPH_BUILDING = "graph_building" # 图谱构建中 GRAPH_BUILDING = "graph_building" # graph build in progress
GRAPH_COMPLETED = "graph_completed" # 图谱构建完成 GRAPH_COMPLETED = "graph_completed" # graph build finished
FAILED = "failed" # 失败 FAILED = "failed" # build failed
@dataclass @dataclass
class Project: class Project:
"""项目数据模型""" """Project data model."""
project_id: str project_id: str
name: str name: str
status: ProjectStatus status: ProjectStatus
created_at: str created_at: str
updated_at: str updated_at: str
# 文件信息 # File information
files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}] files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}]
total_text_length: int = 0 total_text_length: int = 0
# 本体信息接口1生成后填充 # Ontology information (filled in after step 1 generates it)
ontology: Optional[Dict[str, Any]] = None ontology: Optional[Dict[str, Any]] = None
analysis_summary: Optional[str] = None analysis_summary: Optional[str] = None
# 图谱信息接口2完成后填充 # Graph information (filled in after step 2 finishes)
graph_id: Optional[str] = None graph_id: Optional[str] = None
graph_build_task_id: Optional[str] = None graph_build_task_id: Optional[str] = None
# 配置 # Configuration
simulation_requirement: Optional[str] = None simulation_requirement: Optional[str] = None
chunk_size: int = 500 chunk_size: int = 500
chunk_overlap: int = 50 chunk_overlap: int = 50
# 错误信息 # Error message when status == FAILED
error: Optional[str] = None error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Serialize the project to a JSON-friendly dict."""
return { return {
"project_id": self.project_id, "project_id": self.project_id,
"name": self.name, "name": self.name,
@ -71,14 +72,14 @@ class Project:
"chunk_overlap": self.chunk_overlap, "chunk_overlap": self.chunk_overlap,
"error": self.error "error": self.error
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Project': def from_dict(cls, data: Dict[str, Any]) -> 'Project':
"""从字典创建""" """Reconstruct a project from its serialized dict."""
status = data.get('status', 'created') status = data.get('status', 'created')
if isinstance(status, str): if isinstance(status, str):
status = ProjectStatus(status) status = ProjectStatus(status)
return cls( return cls(
project_id=data['project_id'], project_id=data['project_id'],
name=data.get('name', 'Unnamed Project'), name=data.get('name', 'Unnamed Project'),
@ -99,52 +100,51 @@ class Project:
class ProjectManager: class ProjectManager:
"""项目管理器 - 负责项目的持久化存储和检索""" """Project manager: handles persistence and retrieval of projects on disk."""
# 项目存储根目录 # Root directory for project storage
PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects') PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects')
@classmethod @classmethod
def _ensure_projects_dir(cls): def _ensure_projects_dir(cls):
"""确保项目目录存在""" """Ensure the projects root directory exists."""
os.makedirs(cls.PROJECTS_DIR, exist_ok=True) os.makedirs(cls.PROJECTS_DIR, exist_ok=True)
@classmethod @classmethod
def _get_project_dir(cls, project_id: str) -> str: def _get_project_dir(cls, project_id: str) -> str:
"""获取项目目录路径""" """Return the on-disk directory for a project."""
return os.path.join(cls.PROJECTS_DIR, project_id) return os.path.join(cls.PROJECTS_DIR, project_id)
@classmethod @classmethod
def _get_project_meta_path(cls, project_id: str) -> str: def _get_project_meta_path(cls, project_id: str) -> str:
"""获取项目元数据文件路径""" """Return the path to a project's metadata JSON file."""
return os.path.join(cls._get_project_dir(project_id), 'project.json') return os.path.join(cls._get_project_dir(project_id), 'project.json')
@classmethod @classmethod
def _get_project_files_dir(cls, project_id: str) -> str: def _get_project_files_dir(cls, project_id: str) -> str:
"""获取项目文件存储目录""" """Return the directory where project source files are stored."""
return os.path.join(cls._get_project_dir(project_id), 'files') return os.path.join(cls._get_project_dir(project_id), 'files')
@classmethod @classmethod
def _get_project_text_path(cls, project_id: str) -> str: def _get_project_text_path(cls, project_id: str) -> str:
"""获取项目提取文本存储路径""" """Return the path to a project's extracted text file."""
return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt') return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt')
@classmethod @classmethod
def create_project(cls, name: str = "Unnamed Project") -> Project: def create_project(cls, name: str = "Unnamed Project") -> Project:
""" """Create a new project.
创建新项目
Args: Args:
name: 项目名称 name: Display name for the project.
Returns: Returns:
新创建的Project对象 The newly created ``Project`` instance.
""" """
cls._ensure_projects_dir() cls._ensure_projects_dir()
project_id = f"proj_{uuid.uuid4().hex[:12]}" project_id = f"proj_{uuid.uuid4().hex[:12]}"
now = datetime.now().isoformat() now = datetime.now().isoformat()
project = Project( project = Project(
project_id=project_id, project_id=project_id,
name=name, name=name,
@ -152,154 +152,147 @@ class ProjectManager:
created_at=now, created_at=now,
updated_at=now updated_at=now
) )
# 创建项目目录结构 # Create the on-disk project directory layout
project_dir = cls._get_project_dir(project_id) project_dir = cls._get_project_dir(project_id)
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
os.makedirs(project_dir, exist_ok=True) os.makedirs(project_dir, exist_ok=True)
os.makedirs(files_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True)
# 保存项目元数据 # Persist project metadata
cls.save_project(project) cls.save_project(project)
return project return project
@classmethod @classmethod
def save_project(cls, project: Project) -> None: def save_project(cls, project: Project) -> None:
"""保存项目元数据""" """Persist project metadata to disk."""
project.updated_at = datetime.now().isoformat() project.updated_at = datetime.now().isoformat()
meta_path = cls._get_project_meta_path(project.project_id) meta_path = cls._get_project_meta_path(project.project_id)
with open(meta_path, 'w', encoding='utf-8') as f: with open(meta_path, 'w', encoding='utf-8') as f:
json.dump(project.to_dict(), f, ensure_ascii=False, indent=2) json.dump(project.to_dict(), f, ensure_ascii=False, indent=2)
@classmethod @classmethod
def get_project(cls, project_id: str) -> Optional[Project]: def get_project(cls, project_id: str) -> Optional[Project]:
""" """Load a project by id.
获取项目
Args: Args:
project_id: 项目ID project_id: Project identifier.
Returns: Returns:
Project对象如果不存在返回None The ``Project`` if it exists, otherwise ``None``.
""" """
meta_path = cls._get_project_meta_path(project_id) meta_path = cls._get_project_meta_path(project_id)
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
return None return None
with open(meta_path, 'r', encoding='utf-8') as f: with open(meta_path, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
return Project.from_dict(data) return Project.from_dict(data)
@classmethod @classmethod
def list_projects(cls, limit: int = 50) -> List[Project]: def list_projects(cls, limit: int = 50) -> List[Project]:
""" """List existing projects, newest first.
列出所有项目
Args: Args:
limit: 返回数量限制 limit: Maximum number of projects to return.
Returns: Returns:
项目列表按创建时间倒序 Projects ordered by ``created_at`` descending.
""" """
cls._ensure_projects_dir() cls._ensure_projects_dir()
projects = [] projects = []
for project_id in os.listdir(cls.PROJECTS_DIR): for project_id in os.listdir(cls.PROJECTS_DIR):
project = cls.get_project(project_id) project = cls.get_project(project_id)
if project: if project:
projects.append(project) projects.append(project)
# 按创建时间倒序排序
projects.sort(key=lambda p: p.created_at, reverse=True) projects.sort(key=lambda p: p.created_at, reverse=True)
return projects[:limit] return projects[:limit]
@classmethod @classmethod
def delete_project(cls, project_id: str) -> bool: def delete_project(cls, project_id: str) -> bool:
""" """Delete a project and all of its files.
删除项目及其所有文件
Args: Args:
project_id: 项目ID project_id: Project identifier.
Returns: Returns:
是否删除成功 ``True`` if the project existed and was removed, ``False`` otherwise.
""" """
project_dir = cls._get_project_dir(project_id) project_dir = cls._get_project_dir(project_id)
if not os.path.exists(project_dir): if not os.path.exists(project_dir):
return False return False
shutil.rmtree(project_dir) shutil.rmtree(project_dir)
return True return True
@classmethod @classmethod
def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]: def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]:
""" """Save an uploaded file under the project's files directory.
保存上传的文件到项目目录
Args: Args:
project_id: 项目ID project_id: Project identifier.
file_storage: Flask的FileStorage对象 file_storage: Flask ``FileStorage`` object from the request.
original_filename: 原始文件名 original_filename: The user-supplied filename.
Returns: Returns:
文件信息字典 {filename, path, size} Dict describing the saved file: ``{original_filename, saved_filename, path, size}``.
""" """
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
os.makedirs(files_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True)
# 生成安全的文件名 # Generate a safe randomized filename to avoid collisions
ext = os.path.splitext(original_filename)[1].lower() ext = os.path.splitext(original_filename)[1].lower()
safe_filename = f"{uuid.uuid4().hex[:8]}{ext}" safe_filename = f"{uuid.uuid4().hex[:8]}{ext}"
file_path = os.path.join(files_dir, safe_filename) file_path = os.path.join(files_dir, safe_filename)
# 保存文件
file_storage.save(file_path) file_storage.save(file_path)
# 获取文件大小
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
return { return {
"original_filename": original_filename, "original_filename": original_filename,
"saved_filename": safe_filename, "saved_filename": safe_filename,
"path": file_path, "path": file_path,
"size": file_size "size": file_size
} }
@classmethod @classmethod
def save_extracted_text(cls, project_id: str, text: str) -> None: def save_extracted_text(cls, project_id: str, text: str) -> None:
"""保存提取的文本""" """Persist the project's extracted full text to disk."""
text_path = cls._get_project_text_path(project_id) text_path = cls._get_project_text_path(project_id)
with open(text_path, 'w', encoding='utf-8') as f: with open(text_path, 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
@classmethod @classmethod
def get_extracted_text(cls, project_id: str) -> Optional[str]: def get_extracted_text(cls, project_id: str) -> Optional[str]:
"""获取提取的文本""" """Read back the project's extracted full text, or ``None`` if absent."""
text_path = cls._get_project_text_path(project_id) text_path = cls._get_project_text_path(project_id)
if not os.path.exists(text_path): if not os.path.exists(text_path):
return None return None
with open(text_path, 'r', encoding='utf-8') as f: with open(text_path, 'r', encoding='utf-8') as f:
return f.read() return f.read()
@classmethod @classmethod
def get_project_files(cls, project_id: str) -> List[str]: def get_project_files(cls, project_id: str) -> List[str]:
"""获取项目的所有文件路径""" """Return the on-disk paths of all files in the project."""
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
if not os.path.exists(files_dir): if not os.path.exists(files_dir):
return [] return []
return [ return [
os.path.join(files_dir, f) os.path.join(files_dir, f)
for f in os.listdir(files_dir) for f in os.listdir(files_dir)
if os.path.isfile(os.path.join(files_dir, f)) if os.path.isfile(os.path.join(files_dir, f))
] ]

View File

@ -1,6 +1,6 @@
""" """Task state management.
任务状态管理
用于跟踪长时间运行的任务如图谱构建 Tracks long-running tasks (e.g. graph build) so callers can poll progress.
""" """
import uuid import uuid
@ -14,30 +14,30 @@ from ..utils.locale import t
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
"""任务状态枚举""" """Task status enum."""
PENDING = "pending" # 等待中 PENDING = "pending" # waiting
PROCESSING = "processing" # 处理中 PROCESSING = "processing" # in progress
COMPLETED = "completed" # 已完成 COMPLETED = "completed" # finished successfully
FAILED = "failed" # 失败 FAILED = "failed" # finished with error
@dataclass @dataclass
class Task: class Task:
"""任务数据类""" """Task data class."""
task_id: str task_id: str
task_type: str task_type: str
status: TaskStatus status: TaskStatus
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
progress: int = 0 # 总进度百分比 0-100 progress: int = 0 # overall progress percentage 0-100
message: str = "" # 状态消息 message: str = "" # human-readable status message
result: Optional[Dict] = None # 任务结果 result: Optional[Dict] = None # task result payload
error: Optional[str] = None # 错误信息 error: Optional[str] = None # error message when failed
metadata: Dict = field(default_factory=dict) # 额外元数据 metadata: Dict = field(default_factory=dict) # arbitrary caller metadata
progress_detail: Dict = field(default_factory=dict) # 详细进度信息 progress_detail: Dict = field(default_factory=dict) # fine-grained progress info
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Serialize the task to a JSON-friendly dict."""
return { return {
"task_id": self.task_id, "task_id": self.task_id,
"task_type": self.task_type, "task_type": self.task_type,
@ -54,16 +54,12 @@ class Task:
class TaskManager: class TaskManager:
""" """Thread-safe singleton task registry."""
任务管理器
线程安全的任务状态管理
"""
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls): def __new__(cls):
"""单例模式"""
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
@ -71,21 +67,20 @@ class TaskManager:
cls._instance._tasks: Dict[str, Task] = {} cls._instance._tasks: Dict[str, Task] = {}
cls._instance._task_lock = threading.Lock() cls._instance._task_lock = threading.Lock()
return cls._instance return cls._instance
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
""" """Create a new task.
创建新任务
Args: Args:
task_type: 任务类型 task_type: Task type identifier.
metadata: 额外元数据 metadata: Optional caller-supplied metadata.
Returns: Returns:
任务ID The newly created task id.
""" """
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
now = datetime.now() now = datetime.now()
task = Task( task = Task(
task_id=task_id, task_id=task_id,
task_type=task_type, task_type=task_type,
@ -94,17 +89,17 @@ class TaskManager:
updated_at=now, updated_at=now,
metadata=metadata or {} metadata=metadata or {}
) )
with self._task_lock: with self._task_lock:
self._tasks[task_id] = task self._tasks[task_id] = task
return task_id return task_id
def get_task(self, task_id: str) -> Optional[Task]: def get_task(self, task_id: str) -> Optional[Task]:
"""获取任务""" """Return the task for ``task_id`` or ``None`` if unknown."""
with self._task_lock: with self._task_lock:
return self._tasks.get(task_id) return self._tasks.get(task_id)
def update_task( def update_task(
self, self,
task_id: str, task_id: str,
@ -115,17 +110,16 @@ class TaskManager:
error: Optional[str] = None, error: Optional[str] = None,
progress_detail: Optional[Dict] = None progress_detail: Optional[Dict] = None
): ):
""" """Update mutable fields on an existing task.
更新任务状态
Args: Args:
task_id: 任务ID task_id: Task id to update.
status: 新状态 status: New status, if changing.
progress: 进度 progress: New overall progress (0-100), if changing.
message: 消息 message: New status message, if changing.
result: 结果 result: New result payload, if changing.
error: 错误信息 error: New error message, if changing.
progress_detail: 详细进度信息 progress_detail: New fine-grained progress info, if changing.
""" """
with self._task_lock: with self._task_lock:
task = self._tasks.get(task_id) task = self._tasks.get(task_id)
@ -143,9 +137,9 @@ class TaskManager:
task.error = error task.error = error
if progress_detail is not None: if progress_detail is not None:
task.progress_detail = progress_detail task.progress_detail = progress_detail
def complete_task(self, task_id: str, result: Dict): def complete_task(self, task_id: str, result: Dict):
"""标记任务完成""" """Mark a task as completed and attach the result."""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.COMPLETED, status=TaskStatus.COMPLETED,
@ -153,29 +147,29 @@ class TaskManager:
message=t('progress.taskComplete'), message=t('progress.taskComplete'),
result=result result=result
) )
def fail_task(self, task_id: str, error: str): def fail_task(self, task_id: str, error: str):
"""标记任务失败""" """Mark a task as failed and attach the error message."""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.FAILED, status=TaskStatus.FAILED,
message=t('progress.taskFailed'), message=t('progress.taskFailed'),
error=error error=error
) )
def list_tasks(self, task_type: Optional[str] = None) -> list: def list_tasks(self, task_type: Optional[str] = None) -> list:
"""列出任务""" """List tasks, optionally filtered by ``task_type``, newest first."""
with self._task_lock: with self._task_lock:
tasks = list(self._tasks.values()) tasks = list(self._tasks.values())
if task_type: if task_type:
tasks = [t for t in tasks if t.task_type == task_type] tasks = [t for t in tasks if t.task_type == task_type]
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
def cleanup_old_tasks(self, max_age_hours: int = 24): def cleanup_old_tasks(self, max_age_hours: int = 24):
"""清理旧任务""" """Drop completed/failed tasks older than ``max_age_hours``."""
from datetime import timedelta from datetime import timedelta
cutoff = datetime.now() - timedelta(hours=max_age_hours) cutoff = datetime.now() - timedelta(hours=max_age_hours)
with self._task_lock: with self._task_lock:
old_ids = [ old_ids = [
tid for tid, task in self._tasks.items() tid for tid, task in self._tasks.items()

View File

@ -1,6 +1,4 @@
""" """Business services package."""
业务服务模块
"""
from .ontology_generator import OntologyGenerator from .ontology_generator import OntologyGenerator
from .graph_builder import GraphBuilderService from .graph_builder import GraphBuilderService

View File

@ -1,6 +1,7 @@
""" """Graph build service.
图谱构建服务
接口2使用Zep API构建Standalone Graph Pipeline step 2: build the project's standalone knowledge graph through the
Zep/Graphiti API.
""" """
import os import os
@ -69,7 +70,7 @@ def _classify_entity_type(name: str, summary: str, ontology: Optional[Dict]) ->
@dataclass @dataclass
class GraphInfo: class GraphInfo:
"""图谱信息""" """Summary information about a built graph."""
graph_id: str graph_id: str
node_count: int node_count: int
edge_count: int edge_count: int
@ -85,10 +86,7 @@ class GraphInfo:
class GraphBuilderService: class GraphBuilderService:
""" """Drives knowledge-graph construction via the Zep/Graphiti API."""
图谱构建服务
负责调用Zep API构建知识图谱
"""
def __init__(self, api_key: Optional[str] = None): def __init__(self, api_key: Optional[str] = None):
self.client = GraphitiAdapter() self.client = GraphitiAdapter()
@ -103,21 +101,20 @@ class GraphBuilderService:
chunk_overlap: int = 50, chunk_overlap: int = 50,
batch_size: int = 3 batch_size: int = 3
) -> str: ) -> str:
""" """Kick off a graph build asynchronously.
异步构建图谱
Args: Args:
text: 输入文本 text: Source text to ingest.
ontology: 本体定义来自接口1的输出 ontology: Ontology definition (the output of pipeline step 1).
graph_name: 图谱名称 graph_name: Display name for the graph.
chunk_size: 文本块大小 chunk_size: Characters per text chunk.
chunk_overlap: 块重叠大小 chunk_overlap: Overlap (in characters) between consecutive chunks.
batch_size: 每批发送的块数量 batch_size: Number of chunks pushed to Zep per batch.
Returns: Returns:
任务ID The id of the task tracking the build.
""" """
# 创建任务 # Register a task to track build progress.
task_id = self.task_manager.create_task( task_id = self.task_manager.create_task(
task_type="graph_build", task_type="graph_build",
metadata={ metadata={
@ -130,7 +127,7 @@ class GraphBuilderService:
# Capture locale before spawning background thread # Capture locale before spawning background thread
current_locale = get_locale() current_locale = get_locale()
# 在后台线程中执行构建 # Run the build on a background thread so the request returns immediately.
thread = threading.Thread( thread = threading.Thread(
target=self._build_graph_worker, target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
@ -151,7 +148,7 @@ class GraphBuilderService:
batch_size: int, batch_size: int,
locale: str = 'zh' locale: str = 'zh'
): ):
"""图谱构建工作线程""" """Background worker that performs the graph build."""
set_locale(locale) set_locale(locale)
try: try:
self.task_manager.update_task( self.task_manager.update_task(
@ -161,7 +158,7 @@ class GraphBuilderService:
message=t('progress.startBuildingGraph') message=t('progress.startBuildingGraph')
) )
# 1. 创建图谱 # 1. Create the graph.
graph_id = self.create_graph(graph_name) graph_id = self.create_graph(graph_name)
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
@ -169,7 +166,7 @@ class GraphBuilderService:
message=t('progress.graphCreated', graphId=graph_id) message=t('progress.graphCreated', graphId=graph_id)
) )
# 2. 设置本体 # 2. Set the ontology.
self.set_ontology(graph_id, ontology) self.set_ontology(graph_id, ontology)
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
@ -177,7 +174,7 @@ class GraphBuilderService:
message=t('progress.ontologySet') message=t('progress.ontologySet')
) )
# 3. 文本分块 # 3. Split source text into chunks.
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks) total_chunks = len(chunks)
self.task_manager.update_task( self.task_manager.update_task(
@ -186,7 +183,7 @@ class GraphBuilderService:
message=t('progress.textSplit', count=total_chunks) message=t('progress.textSplit', count=total_chunks)
) )
# 4. 分批发送数据 # 4. Push chunks to the graph in batches.
episode_uuids = self.add_text_batches( episode_uuids = self.add_text_batches(
graph_id, chunks, batch_size, graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task( lambda msg, prog: self.task_manager.update_task(
@ -196,7 +193,7 @@ class GraphBuilderService:
) )
) )
# 5. 等待Zep处理完成 # 5. Wait for Zep to finish processing the episodes.
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
progress=60, progress=60,
@ -212,7 +209,7 @@ class GraphBuilderService:
) )
) )
# 6. 获取图谱信息 # 6. Fetch the final graph metadata.
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
progress=90, progress=90,
@ -220,8 +217,7 @@ class GraphBuilderService:
) )
graph_info = self._get_graph_info(graph_id) graph_info = self._get_graph_info(graph_id)
# 完成
self.task_manager.complete_task(task_id, { self.task_manager.complete_task(task_id, {
"graph_id": graph_id, "graph_id": graph_id,
"graph_info": graph_info.to_dict(), "graph_info": graph_info.to_dict(),
@ -234,7 +230,7 @@ class GraphBuilderService:
self.task_manager.fail_task(task_id, error_msg) self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str: def create_graph(self, name: str) -> str:
"""创建Zep图谱公开方法""" """Create a new Zep graph and return its id (public API)."""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create( self.client.graph.create(
@ -246,7 +242,7 @@ class GraphBuilderService:
return graph_id return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""设置图谱本体提示Graphiti自动提取实体本体作为提示存储""" """Register the ontology with the graph (Graphiti uses it as an extraction prompt)."""
self.client.graph.set_ontology( self.client.graph.set_ontology(
graph_ids=[graph_id], graph_ids=[graph_id],
entities=ontology.get("entity_types"), entities=ontology.get("entity_types"),
@ -261,8 +257,11 @@ class GraphBuilderService:
progress_callback: Optional[Callable] = None, progress_callback: Optional[Callable] = None,
skip_chunks: int = 0, skip_chunks: int = 0,
) -> List[str]: ) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表。 """Push chunks to the graph in batches; returns the uuids of all episodes added.
skip_chunks: 跳过已处理的块数用于断点续传"""
Args:
skip_chunks: Number of chunks to skip (used for resume-after-restart).
"""
episode_uuids = [] episode_uuids = []
total_chunks = len(chunks) total_chunks = len(chunks)
@ -279,27 +278,26 @@ class GraphBuilderService:
) )
# 构建episode数据 # Build the per-episode payload structures expected by the client.
episodes = [ episodes = [
type('Episode', (), {'data': chunk, 'type': 'text'})() type('Episode', (), {'data': chunk, 'type': 'text'})()
for chunk in batch_chunks for chunk in batch_chunks
] ]
# 发送到Zep
try: try:
batch_result = self.client.graph.add_batch( batch_result = self.client.graph.add_batch(
graph_id=graph_id, graph_id=graph_id,
episodes=episodes episodes=episodes
) )
# 收集返回的 episode uuid # Collect the uuids returned for each episode.
if batch_result and isinstance(batch_result, list): if batch_result and isinstance(batch_result, list):
for ep in batch_result: for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid: if ep_uuid:
episode_uuids.append(ep_uuid) episode_uuids.append(ep_uuid)
# 避免请求过快 # Throttle to avoid overwhelming the upstream API.
time.sleep(1) time.sleep(1)
except Exception as e: except Exception as e:
@ -315,7 +313,7 @@ class GraphBuilderService:
progress_callback: Optional[Callable] = None, progress_callback: Optional[Callable] = None,
timeout: int = 600 timeout: int = 600
): ):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" """Poll each episode until Zep marks it processed, or the timeout expires."""
if not episode_uuids: if not episode_uuids:
if progress_callback: if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0) progress_callback(t('progress.noEpisodesWait'), 1.0)
@ -338,18 +336,18 @@ class GraphBuilderService:
) )
break break
# 检查每个 episode 的处理状态 # Check the processing state of each pending episode.
for ep_uuid in list(pending_episodes): for ep_uuid in list(pending_episodes):
try: try:
episode = self.client.graph.episode.get(uuid_=ep_uuid) episode = self.client.graph.episode.get(uuid_=ep_uuid)
is_processed = getattr(episode, 'processed', False) is_processed = getattr(episode, 'processed', False)
if is_processed: if is_processed:
pending_episodes.remove(ep_uuid) pending_episodes.remove(ep_uuid)
completed_count += 1 completed_count += 1
except Exception as e: except Exception as e:
# 忽略单个查询错误,继续 # Tolerate a single failed query; the next loop iteration retries.
pass pass
elapsed = int(time.time() - start_time) elapsed = int(time.time() - start_time)
@ -360,20 +358,17 @@ class GraphBuilderService:
) )
if pending_episodes: if pending_episodes:
time.sleep(3) # 每3秒检查一次 time.sleep(3) # poll every 3 seconds
if progress_callback: if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo: def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息""" """Fetch summary info (counts and entity types) for a graph."""
# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
# 统计实体类型 # Tally distinct entity types across all nodes.
entity_types = set() entity_types = set()
for node in nodes: for node in nodes:
if node.labels: if node.labels:
@ -389,26 +384,24 @@ class GraphBuilderService:
) )
def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]: def get_graph_data(self, graph_id: str, ontology: Optional[Dict] = None) -> Dict[str, Any]:
""" """Return the full graph payload including timestamps, attributes, and edges.
获取完整图谱数据包含详细信息
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
Returns: Returns:
包含nodes和edges的字典包括时间信息属性等详细数据 Dict with ``nodes``, ``edges``, and aggregate counts.
""" """
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
# 创建节点映射用于获取节点名称 # Build a uuid->name map so edge endpoints can be labeled.
node_map = {} node_map = {}
for node in nodes: for node in nodes:
node_map[node.uuid_] = node.name or "" node_map[node.uuid_] = node.name or ""
nodes_data = [] nodes_data = []
for node in nodes: for node in nodes:
# 获取创建时间
created_at = getattr(node, 'created_at', None) created_at = getattr(node, 'created_at', None)
if created_at: if created_at:
created_at = str(created_at) created_at = str(created_at)
@ -429,20 +422,18 @@ class GraphBuilderService:
edges_data = [] edges_data = []
for edge in edges: for edge in edges:
# 获取时间信息
created_at = getattr(edge, 'created_at', None) created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None) valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None) invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None) expired_at = getattr(edge, 'expired_at', None)
# 获取 episodes # Normalize the episode list (the field may be missing or a single id).
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list): if episodes and not isinstance(episodes, list):
episodes = [str(episodes)] episodes = [str(episodes)]
elif episodes: elif episodes:
episodes = [str(e) for e in episodes] episodes = [str(e) for e in episodes]
# 获取 fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or "" fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({ edges_data.append({
@ -471,6 +462,6 @@ class GraphBuilderService:
} }
def delete_graph(self, graph_id: str): def delete_graph(self, graph_id: str):
"""删除图谱""" """Delete a graph by id."""
self.client.graph.delete(graph_id=graph_id) self.client.graph.delete(graph_id=graph_id)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
""" """Ontology generation service.
本体生成服务
接口1分析文本内容生成适合社会模拟的实体和关系类型定义 Pipeline step 1: analyze the source text and propose entity and relationship
types that fit a social-media opinion simulation.
""" """
import json import json
@ -14,19 +15,19 @@ logger = logging.getLogger(__name__)
def _to_pascal_case(name: str) -> str: def _to_pascal_case(name: str) -> str:
"""将任意格式的名称转换为 PascalCase'works_for' -> 'WorksFor', 'person' -> 'Person'""" """Convert an arbitrary identifier to PascalCase (e.g. ``works_for`` -> ``WorksFor``)."""
# 按非字母数字字符分割 # Split on non-alphanumeric separators first.
parts = re.split(r'[^a-zA-Z0-9]+', name) parts = re.split(r'[^a-zA-Z0-9]+', name)
# 再按 camelCase 边界分割(如 'camelCase' -> ['camel', 'Case'] # Then split on camelCase boundaries (e.g. ``camelCase`` -> ``['camel', 'Case']``).
words = [] words = []
for part in parts: for part in parts:
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_')) words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
# 每个词首字母大写,过滤空串 # Title-case each non-empty word and concatenate.
result = ''.join(word.capitalize() for word in words if word) result = ''.join(word.capitalize() for word in words if word)
return result if result else 'Unknown' return result if result else 'Unknown'
# 本体生成的系统提示词 # System prompt template for ontology generation.
ONTOLOGY_SYSTEM_PROMPT = """You are a professional knowledge-graph ontology designer. Your task is to analyze the supplied text and simulation requirement and design entity types and relationship types suitable for a **social-media public-opinion simulation**. ONTOLOGY_SYSTEM_PROMPT = """You are a professional knowledge-graph ontology designer. Your task is to analyze the supplied text and simulation requirement and design entity types and relationship types suitable for a **social-media public-opinion simulation**.
**Important: you must output valid JSON data and nothing else.** **Important: you must output valid JSON data and nothing else.**
@ -174,10 +175,7 @@ B. **Concrete types (8 entries, designed from the text content)**:
class OntologyGenerator: class OntologyGenerator:
""" """Generate an entity- and edge-type ontology from arbitrary input text."""
本体生成器
分析文本内容生成实体和关系类型定义
"""
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm_client = llm_client or LLMClient() self.llm_client = llm_client or LLMClient()
@ -188,18 +186,17 @@ class OntologyGenerator:
simulation_requirement: str, simulation_requirement: str,
additional_context: Optional[str] = None additional_context: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """Generate an ontology definition.
生成本体定义
Args: Args:
document_texts: 文档文本列表 document_texts: Source document text segments.
simulation_requirement: 模拟需求描述 simulation_requirement: Description of the simulation goal.
additional_context: 额外上下文 additional_context: Optional supplemental context.
Returns: Returns:
本体定义entity_types, edge_types等 The ontology dict with ``entity_types``, ``edge_types``, and a summary.
""" """
# 构建用户消息 # Compose the user message that frames the LLM request.
user_message = self._build_user_message( user_message = self._build_user_message(
document_texts, document_texts,
simulation_requirement, simulation_requirement,
@ -213,19 +210,19 @@ class OntologyGenerator:
{"role": "user", "content": user_message} {"role": "user", "content": user_message}
] ]
# 调用LLM # Invoke the LLM.
result = self.llm_client.chat_json( result = self.llm_client.chat_json(
messages=messages, messages=messages,
temperature=0.3, temperature=0.3,
max_tokens=4096 max_tokens=4096
) )
# 验证和后处理 # Validate the LLM response and post-process it.
result = self._validate_and_process(result) result = self._validate_and_process(result)
return result return result
# 传给 LLM 的文本最大长度5万字 # Maximum length of source text passed to the LLM (50k characters).
MAX_TEXT_LENGTH_FOR_LLM = 50000 MAX_TEXT_LENGTH_FOR_LLM = 50000
def _build_user_message( def _build_user_message(
@ -234,13 +231,14 @@ class OntologyGenerator:
simulation_requirement: str, simulation_requirement: str,
additional_context: Optional[str] additional_context: Optional[str]
) -> str: ) -> str:
"""构建用户消息""" """Build the user-message string for the ontology LLM call."""
# 合并文本 # Concatenate the source documents into a single string.
combined_text = "\n\n---\n\n".join(document_texts) combined_text = "\n\n---\n\n".join(document_texts)
original_length = len(combined_text) original_length = len(combined_text)
# 如果文本超过5万字截断仅影响传给LLM的内容不影响图谱构建 # If the combined text exceeds the LLM input cap, truncate it for the
# LLM call only. The full text is still used for graph construction.
if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM: if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM:
combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM] combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM]
combined_text += f"\n\n...(original text is {original_length} characters; only the first {self.MAX_TEXT_LENGTH_FOR_LLM} characters were used for ontology analysis)..." combined_text += f"\n\n...(original text is {original_length} characters; only the first {self.MAX_TEXT_LENGTH_FOR_LLM} characters were used for ontology analysis)..."
@ -275,9 +273,9 @@ Based on the content above, design entity types and relationship types suitable
return message return message
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""验证和后处理结果""" """Validate and post-process the LLM-generated ontology dict."""
# 确保必要字段存在 # Ensure required top-level fields exist.
if "entity_types" not in result: if "entity_types" not in result:
result["entity_types"] = [] result["entity_types"] = []
if "edge_types" not in result: if "edge_types" not in result:
@ -285,11 +283,12 @@ Based on the content above, design entity types and relationship types suitable
if "analysis_summary" not in result: if "analysis_summary" not in result:
result["analysis_summary"] = "" result["analysis_summary"] = ""
# 验证实体类型 # Validate entity types.
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用 # Track original-name -> PascalCase mapping so edge source_targets
# references can be fixed up consistently below.
entity_name_map = {} entity_name_map = {}
for entity in result["entity_types"]: for entity in result["entity_types"]:
# 强制将 entity name 转为 PascalCaseZep API 要求) # Force entity names to PascalCase (required by the Zep API).
if "name" in entity: if "name" in entity:
original_name = entity["name"] original_name = entity["name"]
entity["name"] = _to_pascal_case(original_name) entity["name"] = _to_pascal_case(original_name)
@ -300,19 +299,20 @@ Based on the content above, design entity types and relationship types suitable
entity["attributes"] = [] entity["attributes"] = []
if "examples" not in entity: if "examples" not in entity:
entity["examples"] = [] entity["examples"] = []
# 确保description不超过100字符 # Truncate descriptions longer than 100 characters.
if len(entity.get("description", "")) > 100: if len(entity.get("description", "")) > 100:
entity["description"] = entity["description"][:97] + "..." entity["description"] = entity["description"][:97] + "..."
# 验证关系类型 # Validate edge types.
for edge in result["edge_types"]: for edge in result["edge_types"]:
# 强制将 edge name 转为 SCREAMING_SNAKE_CASEZep API 要求) # Force edge names to SCREAMING_SNAKE_CASE (required by the Zep API).
if "name" in edge: if "name" in edge:
original_name = edge["name"] original_name = edge["name"]
edge["name"] = original_name.upper() edge["name"] = original_name.upper()
if edge["name"] != original_name: if edge["name"] != original_name:
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'") logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致 # Rewrite source_targets entity-name references to match the
# PascalCase-normalized entity names.
for st in edge.get("source_targets", []): for st in edge.get("source_targets", []):
if st.get("source") in entity_name_map: if st.get("source") in entity_name_map:
st["source"] = entity_name_map[st["source"]] st["source"] = entity_name_map[st["source"]]
@ -325,11 +325,11 @@ Based on the content above, design entity types and relationship types suitable
if len(edge.get("description", "")) > 100: if len(edge.get("description", "")) > 100:
edge["description"] = edge["description"][:97] + "..." edge["description"] = edge["description"][:97] + "..."
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 # Zep API caps: at most 10 custom entity types and 10 custom edge types.
MAX_ENTITY_TYPES = 10 MAX_ENTITY_TYPES = 10
MAX_EDGE_TYPES = 10 MAX_EDGE_TYPES = 10
# 去重:按 name 去重,保留首次出现的 # Deduplicate by name, keeping the first occurrence.
seen_names = set() seen_names = set()
deduped = [] deduped = []
for entity in result["entity_types"]: for entity in result["entity_types"]:
@ -341,7 +341,7 @@ Based on the content above, design entity types and relationship types suitable
logger.warning(f"Duplicate entity type '{name}' removed during validation") logger.warning(f"Duplicate entity type '{name}' removed during validation")
result["entity_types"] = deduped result["entity_types"] = deduped
# 兜底类型定义 # Fallback entity-type definitions used when the LLM omits them.
person_fallback = { person_fallback = {
"name": "Person", "name": "Person",
"description": "Any individual person not fitting other specific person types.", "description": "Any individual person not fitting other specific person types.",
@ -362,33 +362,31 @@ Based on the content above, design entity types and relationship types suitable
"examples": ["small business", "community group"] "examples": ["small business", "community group"]
} }
# 检查是否已有兜底类型 # Check whether the fallback types are already present.
entity_names = {e["name"] for e in result["entity_types"]} entity_names = {e["name"] for e in result["entity_types"]}
has_person = "Person" in entity_names has_person = "Person" in entity_names
has_organization = "Organization" in entity_names has_organization = "Organization" in entity_names
# 需要添加的兜底类型 # Collect missing fallback types to add below.
fallbacks_to_add = [] fallbacks_to_add = []
if not has_person: if not has_person:
fallbacks_to_add.append(person_fallback) fallbacks_to_add.append(person_fallback)
if not has_organization: if not has_organization:
fallbacks_to_add.append(organization_fallback) fallbacks_to_add.append(organization_fallback)
if fallbacks_to_add: if fallbacks_to_add:
current_count = len(result["entity_types"]) current_count = len(result["entity_types"])
needed_slots = len(fallbacks_to_add) needed_slots = len(fallbacks_to_add)
# 如果添加后会超过 10 个,需要移除一些现有类型 # If adding the fallbacks would exceed the cap, drop some existing types.
if current_count + needed_slots > MAX_ENTITY_TYPES: if current_count + needed_slots > MAX_ENTITY_TYPES:
# 计算需要移除多少个
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
# 从末尾移除(保留前面更重要的具体类型) # Drop trailing types first; the more specific types come earlier.
result["entity_types"] = result["entity_types"][:-to_remove] result["entity_types"] = result["entity_types"][:-to_remove]
# 添加兜底类型
result["entity_types"].extend(fallbacks_to_add) result["entity_types"].extend(fallbacks_to_add)
# 最终确保不超过限制(防御性编程) # Defensive cap enforcement: hard-trim if anything slipped through.
if len(result["entity_types"]) > MAX_ENTITY_TYPES: if len(result["entity_types"]) > MAX_ENTITY_TYPES:
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES] result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
@ -398,14 +396,13 @@ Based on the content above, design entity types and relationship types suitable
return result return result
def generate_python_code(self, ontology: Dict[str, Any]) -> str: def generate_python_code(self, ontology: Dict[str, Any]) -> str:
""" """Render the ontology definition as Python source code.
将本体定义转换为Python代码类似ontology.py
Args: Args:
ontology: 本体定义 ontology: Ontology definition dict.
Returns: Returns:
Python代码字符串 Python source code as a single string.
""" """
code_lines = [ code_lines = [
'"""', '"""',
@ -421,7 +418,7 @@ Based on the content above, design entity types and relationship types suitable
'', '',
] ]
# 生成实体类型 # Emit each entity type as a Python class.
for entity in ontology.get("entity_types", []): for entity in ontology.get("entity_types", []):
name = entity["name"] name = entity["name"]
desc = entity.get("description", f"A {name} entity.") desc = entity.get("description", f"A {name} entity.")
@ -447,10 +444,10 @@ Based on the content above, design entity types and relationship types suitable
code_lines.append('# ============== 关系类型定义 ==============') code_lines.append('# ============== 关系类型定义 ==============')
code_lines.append('') code_lines.append('')
# 生成关系类型 # Emit each edge type as a Python class.
for edge in ontology.get("edge_types", []): for edge in ontology.get("edge_types", []):
name = edge["name"] name = edge["name"]
# 转换为PascalCase类名 # Convert SCREAMING_SNAKE_CASE -> PascalCase for the class name.
class_name = ''.join(word.capitalize() for word in name.split('_')) class_name = ''.join(word.capitalize() for word in name.split('_'))
desc = edge.get("description", f"A {name} relationship.") desc = edge.get("description", f"A {name} relationship.")
@ -472,7 +469,7 @@ Based on the content above, design entity types and relationship types suitable
code_lines.append('') code_lines.append('')
code_lines.append('') code_lines.append('')
# 生成类型字典 # Emit the type registries.
code_lines.append('# ============== 类型配置 ==============') code_lines.append('# ============== 类型配置 ==============')
code_lines.append('') code_lines.append('')
code_lines.append('ENTITY_TYPES = {') code_lines.append('ENTITY_TYPES = {')
@ -489,7 +486,7 @@ Based on the content above, design entity types and relationship types suitable
code_lines.append('}') code_lines.append('}')
code_lines.append('') code_lines.append('')
# 生成边的source_targets映射 # Emit the edge source_targets map.
code_lines.append('EDGE_SOURCE_TARGETS = {') code_lines.append('EDGE_SOURCE_TARGETS = {')
for edge in ontology.get("edge_types", []): for edge in ontology.get("edge_types", []):
name = edge["name"] name = edge["name"]

File diff suppressed because it is too large Load Diff

View File

@ -1,13 +1,16 @@
""" """
模拟配置智能生成器 Intelligent simulation-configuration generator.
使用LLM根据模拟需求文档内容图谱信息自动生成细致的模拟参数
实现全程自动化无需人工设置参数
采用分步生成策略避免一次性生成过长内容导致失败 Uses an LLM to derive detailed simulation parameters from the simulation
1. 生成时间配置 requirement, document content, and knowledge-graph information, fully
2. 生成事件配置 automating parameter setup without manual intervention.
3. 分批生成Agent配置
4. 生成平台配置 Employs a step-wise generation strategy to avoid failures caused by
producing too much content in a single call:
1. Generate time configuration
2. Generate event configuration
3. Generate agent configurations in batches
4. Generate platform configuration
""" """
import json import json
@ -25,156 +28,156 @@ from .zep_entity_reader import EntityNode, ZepEntityReader
logger = get_logger('mirofish.simulation_config') logger = get_logger('mirofish.simulation_config')
# 中国作息时间配置(北京时间) # Daily-rhythm config for China (Beijing time, UTC+8).
CHINA_TIMEZONE_CONFIG = { CHINA_TIMEZONE_CONFIG = {
# 深夜时段(几乎无人活动) # Late-night hours: almost no activity.
"dead_hours": [0, 1, 2, 3, 4, 5], "dead_hours": [0, 1, 2, 3, 4, 5],
# 早间时段(逐渐醒来) # Morning hours: gradually waking up.
"morning_hours": [6, 7, 8], "morning_hours": [6, 7, 8],
# 工作时段 # Working hours.
"work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18], "work_hours": [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
# 晚间高峰(最活跃) # Evening peak: most active.
"peak_hours": [19, 20, 21, 22], "peak_hours": [19, 20, 21, 22],
# 夜间时段(活跃度下降) # Late-evening hours: activity declining.
"night_hours": [23], "night_hours": [23],
# 活跃度系数 # Activity multipliers.
"activity_multipliers": { "activity_multipliers": {
"dead": 0.05, # 凌晨几乎无人 "dead": 0.05, # Overnight: almost no one online.
"morning": 0.4, # 早间逐渐活跃 "morning": 0.4, # Morning ramp-up.
"work": 0.7, # 工作时段中等 "work": 0.7, # Working hours: moderate activity.
"peak": 1.5, # 晚间高峰 "peak": 1.5, # Evening peak.
"night": 0.5 # 深夜下降 "night": 0.5 # Late-night decline.
} }
} }
@dataclass @dataclass
class AgentActivityConfig: class AgentActivityConfig:
"""单个Agent的活动配置""" """Activity configuration for a single agent."""
agent_id: int agent_id: int
entity_uuid: str entity_uuid: str
entity_name: str entity_name: str
entity_type: str entity_type: str
# 活跃度配置 (0.0-1.0) # Activity configuration (0.0-1.0).
activity_level: float = 0.5 # 整体活跃度 activity_level: float = 0.5 # Overall activity level.
# 发言频率(每小时预期发言次数) # Posting frequency (expected posts per hour).
posts_per_hour: float = 1.0 posts_per_hour: float = 1.0
comments_per_hour: float = 2.0 comments_per_hour: float = 2.0
# 活跃时间段24小时制0-23 # Active hours (24-hour clock, 0-23).
active_hours: List[int] = field(default_factory=lambda: list(range(8, 23))) active_hours: List[int] = field(default_factory=lambda: list(range(8, 23)))
# 响应速度(对热点事件的反应延迟,单位:模拟分钟) # Response speed: latency to react to hot events, in simulated minutes.
response_delay_min: int = 5 response_delay_min: int = 5
response_delay_max: int = 60 response_delay_max: int = 60
# 情感倾向 (-1.0到1.0,负面到正面) # Sentiment bias (-1.0 to 1.0, negative to positive).
sentiment_bias: float = 0.0 sentiment_bias: float = 0.0
# 立场(对特定话题的态度) # Stance: attitude toward a given topic.
stance: str = "neutral" # supportive, opposing, neutral, observer stance: str = "neutral" # supportive, opposing, neutral, observer
# 影响力权重决定其发言被其他Agent看到的概率 # Influence weight: probability of an agent's post being seen by others.
influence_weight: float = 1.0 influence_weight: float = 1.0
@dataclass @dataclass
class TimeSimulationConfig: class TimeSimulationConfig:
"""时间模拟配置(基于中国人作息习惯)""" """Time-simulation configuration (modelled on a Chinese daily rhythm)."""
# 模拟总时长(模拟小时数) # Total simulated duration (simulated hours).
total_simulation_hours: int = 72 # 默认模拟72小时3天 total_simulation_hours: int = 72 # Default: 72 simulated hours (3 days).
# 每轮代表的时间(模拟分钟)- 默认60分钟1小时加快时间流速 # Time represented by each round (simulated minutes); default 60 (1 hour) to speed up the simulated clock.
minutes_per_round: int = 60 minutes_per_round: int = 60
# 每小时激活的Agent数量范围 # Range of agents activated per hour.
agents_per_hour_min: int = 5 agents_per_hour_min: int = 5
agents_per_hour_max: int = 20 agents_per_hour_max: int = 20
# 高峰时段晚间19-22点中国人最活跃的时间 # Peak hours (evenings 19:00-22:00, most active for the modelled audience).
peak_hours: List[int] = field(default_factory=lambda: [19, 20, 21, 22]) peak_hours: List[int] = field(default_factory=lambda: [19, 20, 21, 22])
peak_activity_multiplier: float = 1.5 peak_activity_multiplier: float = 1.5
# 低谷时段凌晨0-5点几乎无人活动 # Off-peak hours (00:00-05:00, almost no activity).
off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) off_peak_hours: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5])
off_peak_activity_multiplier: float = 0.05 # 凌晨活跃度极低 off_peak_activity_multiplier: float = 0.05 # Overnight activity is very low.
# 早间时段 # Morning hours.
morning_hours: List[int] = field(default_factory=lambda: [6, 7, 8]) morning_hours: List[int] = field(default_factory=lambda: [6, 7, 8])
morning_activity_multiplier: float = 0.4 morning_activity_multiplier: float = 0.4
# 工作时段 # Working hours.
work_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18]) work_hours: List[int] = field(default_factory=lambda: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18])
work_activity_multiplier: float = 0.7 work_activity_multiplier: float = 0.7
@dataclass @dataclass
class EventConfig: class EventConfig:
"""事件配置""" """Event configuration."""
# 初始事件(模拟开始时的触发事件) # Initial events: triggers fired when the simulation begins.
initial_posts: List[Dict[str, Any]] = field(default_factory=list) initial_posts: List[Dict[str, Any]] = field(default_factory=list)
# 定时事件(在特定时间触发的事件) # Scheduled events: events fired at specific times.
scheduled_events: List[Dict[str, Any]] = field(default_factory=list) scheduled_events: List[Dict[str, Any]] = field(default_factory=list)
# 热点话题关键词 # Hot-topic keywords.
hot_topics: List[str] = field(default_factory=list) hot_topics: List[str] = field(default_factory=list)
# 舆论引导方向 # Narrative direction for public-opinion guidance.
narrative_direction: str = "" narrative_direction: str = ""
@dataclass @dataclass
class PlatformConfig: class PlatformConfig:
"""平台特定配置""" """Platform-specific configuration."""
platform: str # twitter or reddit platform: str # twitter or reddit
# 推荐算法权重 # Recommendation-algorithm weights.
recency_weight: float = 0.4 # 时间新鲜度 recency_weight: float = 0.4 # Recency.
popularity_weight: float = 0.3 # 热度 popularity_weight: float = 0.3 # Popularity.
relevance_weight: float = 0.3 # 相关性 relevance_weight: float = 0.3 # Relevance.
# 病毒传播阈值(达到多少互动后触发扩散) # Viral-spread threshold: number of interactions required to trigger spreading.
viral_threshold: int = 10 viral_threshold: int = 10
# 回声室效应强度(相似观点聚集程度) # Echo-chamber strength: how strongly similar viewpoints cluster together.
echo_chamber_strength: float = 0.5 echo_chamber_strength: float = 0.5
@dataclass @dataclass
class SimulationParameters: class SimulationParameters:
"""完整的模拟参数配置""" """Complete simulation-parameter configuration."""
# 基础信息 # Basic identifiers.
simulation_id: str simulation_id: str
project_id: str project_id: str
graph_id: str graph_id: str
simulation_requirement: str simulation_requirement: str
# 时间配置 # Time configuration.
time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig) time_config: TimeSimulationConfig = field(default_factory=TimeSimulationConfig)
# Agent配置列表 # Agent configuration list.
agent_configs: List[AgentActivityConfig] = field(default_factory=list) agent_configs: List[AgentActivityConfig] = field(default_factory=list)
# 事件配置 # Event configuration.
event_config: EventConfig = field(default_factory=EventConfig) event_config: EventConfig = field(default_factory=EventConfig)
# 平台配置 # Platform configurations.
twitter_config: Optional[PlatformConfig] = None twitter_config: Optional[PlatformConfig] = None
reddit_config: Optional[PlatformConfig] = None reddit_config: Optional[PlatformConfig] = None
# LLM配置 # LLM configuration.
llm_model: str = "" llm_model: str = ""
llm_base_url: str = "" llm_base_url: str = ""
# 生成元数据 # Generation metadata.
generated_at: str = field(default_factory=lambda: datetime.now().isoformat()) generated_at: str = field(default_factory=lambda: datetime.now().isoformat())
generation_reasoning: str = "" # LLM的推理说明 generation_reasoning: str = "" # LLM-provided rationale.
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Return the parameters as a dictionary."""
time_dict = asdict(self.time_config) time_dict = asdict(self.time_config)
return { return {
"simulation_id": self.simulation_id, "simulation_id": self.simulation_id,
@ -193,34 +196,35 @@ class SimulationParameters:
} }
def to_json(self, indent: int = 2) -> str: def to_json(self, indent: int = 2) -> str:
"""转换为JSON字符串""" """Return the parameters as a JSON string."""
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent) return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
class SimulationConfigGenerator: class SimulationConfigGenerator:
""" """
模拟配置智能生成器 Intelligent simulation-configuration generator.
使用LLM分析模拟需求文档内容图谱实体信息 Uses an LLM to analyse the simulation requirement, document content,
自动生成最佳的模拟参数配置 and graph entity information to automatically derive the best
simulation parameter configuration.
采用分步生成策略
1. 生成时间配置和事件配置轻量级 Step-wise generation strategy:
2. 分批生成Agent配置每批10-20 1. Generate time and event configurations (lightweight).
3. 生成平台配置 2. Generate agent configurations in batches (10-20 per batch).
3. Generate platform configuration.
""" """
# 上下文最大字符数 # Maximum context length (characters).
MAX_CONTEXT_LENGTH = 50000 MAX_CONTEXT_LENGTH = 50000
# 每批生成的Agent数量 # Number of agents generated per batch.
AGENTS_PER_BATCH = 15 AGENTS_PER_BATCH = 15
# 各步骤的上下文截断长度(字符数) # Per-step context truncation lengths (characters).
TIME_CONFIG_CONTEXT_LENGTH = 10000 # 时间配置 TIME_CONFIG_CONTEXT_LENGTH = 10000 # Time configuration.
EVENT_CONFIG_CONTEXT_LENGTH = 8000 # 事件配置 EVENT_CONFIG_CONTEXT_LENGTH = 8000 # Event configuration.
ENTITY_SUMMARY_LENGTH = 300 # 实体摘要 ENTITY_SUMMARY_LENGTH = 300 # Entity summary.
AGENT_SUMMARY_LENGTH = 300 # Agent配置中的实体摘要 AGENT_SUMMARY_LENGTH = 300 # Entity summary used in agent configs.
ENTITIES_PER_TYPE_DISPLAY = 20 # 每类实体显示数量 ENTITIES_PER_TYPE_DISPLAY = 20 # Number of entities displayed per type.
def __init__( def __init__(
self, self,
@ -252,28 +256,27 @@ class SimulationConfigGenerator:
enable_reddit: bool = True, enable_reddit: bool = True,
progress_callback: Optional[Callable[[int, int, str], None]] = None, progress_callback: Optional[Callable[[int, int, str], None]] = None,
) -> SimulationParameters: ) -> SimulationParameters:
""" """Intelligently generate a complete simulation configuration (step-wise).
智能生成完整的模拟配置分步生成
Args: Args:
simulation_id: 模拟ID simulation_id: Simulation ID.
project_id: 项目ID project_id: Project ID.
graph_id: 图谱ID graph_id: Graph ID.
simulation_requirement: 模拟需求描述 simulation_requirement: Description of the simulation requirement.
document_text: 原始文档内容 document_text: Original document content.
entities: 过滤后的实体列表 entities: Filtered list of entities.
enable_twitter: 是否启用Twitter enable_twitter: Whether to enable Twitter.
enable_reddit: 是否启用Reddit enable_reddit: Whether to enable Reddit.
progress_callback: 进度回调函数(current_step, total_steps, message) progress_callback: Progress callback (current_step, total_steps, message).
Returns: Returns:
SimulationParameters: 完整的模拟参数 SimulationParameters: The complete simulation parameters.
""" """
logger.info(t("log.simulation_config.m001", simulation_id=simulation_id, len=len(entities))) logger.info(t("log.simulation_config.m001", simulation_id=simulation_id, len=len(entities)))
# 计算总步骤数 # Compute total step count.
num_batches = math.ceil(len(entities) / self.AGENTS_PER_BATCH) num_batches = math.ceil(len(entities) / self.AGENTS_PER_BATCH)
total_steps = 3 + num_batches # 时间配置 + 事件配置 + N批Agent + 平台配置 total_steps = 3 + num_batches # Time config + event config + N agent batches + platform config.
current_step = 0 current_step = 0
def report_progress(step: int, message: str): def report_progress(step: int, message: str):
@ -283,7 +286,7 @@ class SimulationConfigGenerator:
progress_callback(step, total_steps, message) progress_callback(step, total_steps, message)
logger.info(f"[{step}/{total_steps}] {message}") logger.info(f"[{step}/{total_steps}] {message}")
# 1. 构建基础上下文信息 # 1. Build base context information.
context = self._build_context( context = self._build_context(
simulation_requirement=simulation_requirement, simulation_requirement=simulation_requirement,
document_text=document_text, document_text=document_text,
@ -292,20 +295,20 @@ class SimulationConfigGenerator:
reasoning_parts = [] reasoning_parts = []
# ========== 步骤1: 生成时间配置 ========== # ========== Step 1: generate time configuration ==========
report_progress(1, t('progress.generatingTimeConfig')) report_progress(1, t('progress.generatingTimeConfig'))
num_entities = len(entities) num_entities = len(entities)
time_config_result = self._generate_time_config(context, num_entities) time_config_result = self._generate_time_config(context, num_entities)
time_config = self._parse_time_config(time_config_result, num_entities) time_config = self._parse_time_config(time_config_result, num_entities)
reasoning_parts.append(f"{t('progress.timeConfigLabel')}: {time_config_result.get('reasoning', t('common.success'))}") reasoning_parts.append(f"{t('progress.timeConfigLabel')}: {time_config_result.get('reasoning', t('common.success'))}")
# ========== 步骤2: 生成事件配置 ========== # ========== Step 2: generate event configuration ==========
report_progress(2, t('progress.generatingEventConfig')) report_progress(2, t('progress.generatingEventConfig'))
event_config_result = self._generate_event_config(context, simulation_requirement, entities) event_config_result = self._generate_event_config(context, simulation_requirement, entities)
event_config = self._parse_event_config(event_config_result) event_config = self._parse_event_config(event_config_result)
reasoning_parts.append(f"{t('progress.eventConfigLabel')}: {event_config_result.get('reasoning', t('common.success'))}") reasoning_parts.append(f"{t('progress.eventConfigLabel')}: {event_config_result.get('reasoning', t('common.success'))}")
# ========== 步骤3-N: 分批生成Agent配置 ========== # ========== Steps 3-N: generate agent configurations in batches ==========
all_agent_configs = [] all_agent_configs = []
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
start_idx = batch_idx * self.AGENTS_PER_BATCH start_idx = batch_idx * self.AGENTS_PER_BATCH
@ -327,13 +330,13 @@ class SimulationConfigGenerator:
reasoning_parts.append(t('progress.agentConfigResult', count=len(all_agent_configs))) reasoning_parts.append(t('progress.agentConfigResult', count=len(all_agent_configs)))
# ========== 为初始帖子分配发布者 Agent ========== # ========== Assign poster agents to initial posts ==========
logger.info(t("log.simulation_config.m002")) logger.info(t("log.simulation_config.m002"))
event_config = self._assign_initial_post_agents(event_config, all_agent_configs) event_config = self._assign_initial_post_agents(event_config, all_agent_configs)
assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None]) assigned_count = len([p for p in event_config.initial_posts if p.get("poster_agent_id") is not None])
reasoning_parts.append(t('progress.postAssignResult', count=assigned_count)) reasoning_parts.append(t('progress.postAssignResult', count=assigned_count))
# ========== 最后一步: 生成平台配置 ========== # ========== Final step: generate platform configuration ==========
report_progress(total_steps, t('progress.generatingPlatformConfig')) report_progress(total_steps, t('progress.generatingPlatformConfig'))
twitter_config = None twitter_config = None
reddit_config = None reddit_config = None
@ -358,7 +361,7 @@ class SimulationConfigGenerator:
echo_chamber_strength=0.6 echo_chamber_strength=0.6
) )
# 构建最终参数 # Build final parameters.
params = SimulationParameters( params = SimulationParameters(
simulation_id=simulation_id, simulation_id=simulation_id,
project_id=project_id, project_id=project_id,
@ -384,19 +387,19 @@ class SimulationConfigGenerator:
document_text: str, document_text: str,
entities: List[EntityNode] entities: List[EntityNode]
) -> str: ) -> str:
"""构建LLM上下文截断到最大长度""" """Build the LLM context, truncated to the maximum length."""
# 实体摘要 # Entity summary.
entity_summary = self._summarize_entities(entities) entity_summary = self._summarize_entities(entities)
# 构建上下文 # Build the context.
context_parts = [ context_parts = [
f"## Simulation Requirement\n{simulation_requirement}", f"## Simulation Requirement\n{simulation_requirement}",
f"\n## Entities ({len(entities)})\n{entity_summary}", f"\n## Entities ({len(entities)})\n{entity_summary}",
] ]
current_length = sum(len(p) for p in context_parts) current_length = sum(len(p) for p in context_parts)
remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # 留500字符余量 remaining_length = self.MAX_CONTEXT_LENGTH - current_length - 500 # Reserve 500-char headroom.
if remaining_length > 0 and document_text: if remaining_length > 0 and document_text:
doc_text = document_text[:remaining_length] doc_text = document_text[:remaining_length]
@ -407,10 +410,10 @@ class SimulationConfigGenerator:
return "\n".join(context_parts) return "\n".join(context_parts)
def _summarize_entities(self, entities: List[EntityNode]) -> str: def _summarize_entities(self, entities: List[EntityNode]) -> str:
"""生成实体摘要""" """Generate an entity summary."""
lines = [] lines = []
# 按类型分组 # Group by type.
by_type: Dict[str, List[EntityNode]] = {} by_type: Dict[str, List[EntityNode]] = {}
for e in entities: for e in entities:
t = e.get_entity_type() or "Unknown" t = e.get_entity_type() or "Unknown"
@ -420,7 +423,7 @@ class SimulationConfigGenerator:
for entity_type, type_entities in by_type.items(): for entity_type, type_entities in by_type.items():
lines.append(f"\n### {entity_type} ({len(type_entities)})") lines.append(f"\n### {entity_type} ({len(type_entities)})")
# 使用配置的显示数量和摘要长度 # Use configured display count and summary length.
display_count = self.ENTITIES_PER_TYPE_DISPLAY display_count = self.ENTITIES_PER_TYPE_DISPLAY
summary_len = self.ENTITY_SUMMARY_LENGTH summary_len = self.ENTITY_SUMMARY_LENGTH
for e in type_entities[:display_count]: for e in type_entities[:display_count]:
@ -432,7 +435,7 @@ class SimulationConfigGenerator:
return "\n".join(lines) return "\n".join(lines)
def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]: def _call_llm_with_retry(self, prompt: str, system_prompt: str) -> Dict[str, Any]:
"""带重试的LLM调用包含JSON修复逻辑""" """LLM call with retries, including JSON repair logic."""
import re import re
max_attempts = 3 max_attempts = 3
@ -447,25 +450,25 @@ class SimulationConfigGenerator:
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
], ],
response_format={"type": "json_object"}, response_format={"type": "json_object"},
temperature=0.7 - (attempt * 0.1) # 每次重试降低温度 temperature=0.7 - (attempt * 0.1) # Lower temperature on each retry.
# 不设置max_tokens让LLM自由发挥 # max_tokens is intentionally unset so the LLM can use its full budget.
) )
content = response.choices[0].message.content content = response.choices[0].message.content
finish_reason = response.choices[0].finish_reason finish_reason = response.choices[0].finish_reason
# 检查是否被截断 # Detect truncation.
if finish_reason == 'length': if finish_reason == 'length':
logger.warning(t("log.simulation_config.m004", attempt=attempt + 1)) logger.warning(t("log.simulation_config.m004", attempt=attempt + 1))
content = self._fix_truncated_json(content) content = self._fix_truncated_json(content)
# 尝试解析JSON # Attempt to parse JSON.
try: try:
return json.loads(content) return json.loads(content)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(t("log.simulation_config.m005", attempt=attempt + 1, str=str(e)[:80])) logger.warning(t("log.simulation_config.m005", attempt=attempt + 1, str=str(e)[:80]))
# 尝试修复JSON # Attempt to repair the JSON.
fixed = self._try_fix_config_json(content) fixed = self._try_fix_config_json(content)
if fixed: if fixed:
return fixed return fixed
@ -481,36 +484,36 @@ class SimulationConfigGenerator:
raise last_error or Exception("LLM调用失败") raise last_error or Exception("LLM调用失败")
def _fix_truncated_json(self, content: str) -> str: def _fix_truncated_json(self, content: str) -> str:
"""修复被截断的JSON""" """Repair truncated JSON."""
content = content.strip() content = content.strip()
# 计算未闭合的括号 # Count unclosed brackets.
open_braces = content.count('{') - content.count('}') open_braces = content.count('{') - content.count('}')
open_brackets = content.count('[') - content.count(']') open_brackets = content.count('[') - content.count(']')
# 检查是否有未闭合的字符串 # Check for an unclosed string.
if content and content[-1] not in '",}]': if content and content[-1] not in '",}]':
content += '"' content += '"'
# 闭合括号 # Close brackets.
content += ']' * open_brackets content += ']' * open_brackets
content += '}' * open_braces content += '}' * open_braces
return content return content
def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]: def _try_fix_config_json(self, content: str) -> Optional[Dict[str, Any]]:
"""尝试修复配置JSON""" """Attempt to repair a configuration JSON payload."""
import re import re
# 修复被截断的情况 # Repair truncation first.
content = self._fix_truncated_json(content) content = self._fix_truncated_json(content)
# 提取JSON部分 # Extract the JSON portion.
json_match = re.search(r'\{[\s\S]*\}', content) json_match = re.search(r'\{[\s\S]*\}', content)
if json_match: if json_match:
json_str = json_match.group() json_str = json_match.group()
# 移除字符串中的换行符 # Remove line breaks from inside strings.
def fix_string(match): def fix_string(match):
s = match.group(0) s = match.group(0)
s = s.replace('\n', ' ').replace('\r', ' ') s = s.replace('\n', ' ').replace('\r', ' ')
@ -522,7 +525,7 @@ class SimulationConfigGenerator:
try: try:
return json.loads(json_str) return json.loads(json_str)
except: except:
# 尝试移除所有控制字符 # Strip all control characters and try again.
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
json_str = re.sub(r'\s+', ' ', json_str) json_str = re.sub(r'\s+', ' ', json_str)
try: try:
@ -533,11 +536,11 @@ class SimulationConfigGenerator:
return None return None
def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]: def _generate_time_config(self, context: str, num_entities: int) -> Dict[str, Any]:
"""生成时间配置""" """Generate the time configuration."""
# 使用配置的上下文截断长度 # Use the configured context truncation length.
context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH] context_truncated = context[:self.TIME_CONFIG_CONTEXT_LENGTH]
# 计算最大允许值80%的agent数 # Compute the upper bound (90% of the agent count).
max_agents_allowed = max(1, int(num_entities * 0.9)) max_agents_allowed = max(1, int(num_entities * 0.9))
prompt = f"""Based on the simulation requirement below, generate a time-simulation configuration. prompt = f"""Based on the simulation requirement below, generate a time-simulation configuration.
@ -595,10 +598,10 @@ Field guide:
return self._get_default_time_config(num_entities) return self._get_default_time_config(num_entities)
def _get_default_time_config(self, num_entities: int) -> Dict[str, Any]: def _get_default_time_config(self, num_entities: int) -> Dict[str, Any]:
"""获取默认时间配置(中国人作息)""" """Return the default time configuration (Chinese daily rhythm)."""
return { return {
"total_simulation_hours": 72, "total_simulation_hours": 72,
"minutes_per_round": 60, # 每轮1小时加快时间流速 "minutes_per_round": 60, # 1 hour per round to speed up the simulated clock.
"agents_per_hour_min": max(1, num_entities // 15), "agents_per_hour_min": max(1, num_entities // 15),
"agents_per_hour_max": max(5, num_entities // 5), "agents_per_hour_max": max(5, num_entities // 5),
"peak_hours": [19, 20, 21, 22], "peak_hours": [19, 20, 21, 22],
@ -609,12 +612,12 @@ Field guide:
} }
def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig: def _parse_time_config(self, result: Dict[str, Any], num_entities: int) -> TimeSimulationConfig:
"""解析时间配置结果并验证agents_per_hour值不超过总agent数""" """Parse the time-configuration result and ensure agents_per_hour values do not exceed the total agent count."""
# 获取原始值 # Pull raw values.
agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15)) agents_per_hour_min = result.get("agents_per_hour_min", max(1, num_entities // 15))
agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5)) agents_per_hour_max = result.get("agents_per_hour_max", max(5, num_entities // 5))
# 验证并修正确保不超过总agent数 # Validate and correct: ensure values do not exceed the total agent count.
if agents_per_hour_min > num_entities: if agents_per_hour_min > num_entities:
logger.warning(t("log.simulation_config.m008", agents_per_hour_min=agents_per_hour_min, num_entities=num_entities)) logger.warning(t("log.simulation_config.m008", agents_per_hour_min=agents_per_hour_min, num_entities=num_entities))
agents_per_hour_min = max(1, num_entities // 10) agents_per_hour_min = max(1, num_entities // 10)
@ -623,19 +626,19 @@ Field guide:
logger.warning(t("log.simulation_config.m009", agents_per_hour_max=agents_per_hour_max, num_entities=num_entities)) logger.warning(t("log.simulation_config.m009", agents_per_hour_max=agents_per_hour_max, num_entities=num_entities))
agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2) agents_per_hour_max = max(agents_per_hour_min + 1, num_entities // 2)
# 确保 min < max # Ensure min < max.
if agents_per_hour_min >= agents_per_hour_max: if agents_per_hour_min >= agents_per_hour_max:
agents_per_hour_min = max(1, agents_per_hour_max // 2) agents_per_hour_min = max(1, agents_per_hour_max // 2)
logger.warning(t("log.simulation_config.m010", agents_per_hour_min=agents_per_hour_min)) logger.warning(t("log.simulation_config.m010", agents_per_hour_min=agents_per_hour_min))
return TimeSimulationConfig( return TimeSimulationConfig(
total_simulation_hours=result.get("total_simulation_hours", 72), total_simulation_hours=result.get("total_simulation_hours", 72),
minutes_per_round=result.get("minutes_per_round", 60), # 默认每轮1小时 minutes_per_round=result.get("minutes_per_round", 60), # Default: 1 simulated hour per round.
agents_per_hour_min=agents_per_hour_min, agents_per_hour_min=agents_per_hour_min,
agents_per_hour_max=agents_per_hour_max, agents_per_hour_max=agents_per_hour_max,
peak_hours=result.get("peak_hours", [19, 20, 21, 22]), peak_hours=result.get("peak_hours", [19, 20, 21, 22]),
off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]), off_peak_hours=result.get("off_peak_hours", [0, 1, 2, 3, 4, 5]),
off_peak_activity_multiplier=0.05, # 凌晨几乎无人 off_peak_activity_multiplier=0.05, # Overnight: almost no one online.
morning_hours=result.get("morning_hours", [6, 7, 8]), morning_hours=result.get("morning_hours", [6, 7, 8]),
morning_activity_multiplier=0.4, morning_activity_multiplier=0.4,
work_hours=result.get("work_hours", list(range(9, 19))), work_hours=result.get("work_hours", list(range(9, 19))),
@ -649,14 +652,14 @@ Field guide:
simulation_requirement: str, simulation_requirement: str,
entities: List[EntityNode] entities: List[EntityNode]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""生成事件配置""" """Generate the event configuration."""
# 获取可用的实体类型列表,供 LLM 参考 # Build the list of available entity types for the LLM to reference.
entity_types_available = list(set( entity_types_available = list(set(
e.get_entity_type() or "Unknown" for e in entities e.get_entity_type() or "Unknown" for e in entities
)) ))
# 为每种类型列出代表性实体名称 # Collect representative entity names per type.
type_examples = {} type_examples = {}
for e in entities: for e in entities:
etype = e.get_entity_type() or "Unknown" etype = e.get_entity_type() or "Unknown"
@ -670,7 +673,7 @@ Field guide:
for t, examples in type_examples.items() for t, examples in type_examples.items()
]) ])
# 使用配置的上下文截断长度 # Use the configured context truncation length.
context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH] context_truncated = context[:self.EVENT_CONFIG_CONTEXT_LENGTH]
prompt = f"""Based on the simulation requirement below, generate an event configuration. prompt = f"""Based on the simulation requirement below, generate an event configuration.
@ -717,7 +720,7 @@ Return strict JSON (no markdown):
} }
def _parse_event_config(self, result: Dict[str, Any]) -> EventConfig: def _parse_event_config(self, result: Dict[str, Any]) -> EventConfig:
"""解析事件配置结果""" """Parse the event-configuration result."""
return EventConfig( return EventConfig(
initial_posts=result.get("initial_posts", []), initial_posts=result.get("initial_posts", []),
scheduled_events=[], scheduled_events=[],
@ -730,15 +733,15 @@ Return strict JSON (no markdown):
event_config: EventConfig, event_config: EventConfig,
agent_configs: List[AgentActivityConfig] agent_configs: List[AgentActivityConfig]
) -> EventConfig: ) -> EventConfig:
""" """Assign a suitable poster agent to each initial post.
为初始帖子分配合适的发布者 Agent
Matches the most appropriate agent_id for each post based on its
根据每个帖子的 poster_type 匹配最合适的 agent_id poster_type.
""" """
if not event_config.initial_posts: if not event_config.initial_posts:
return event_config return event_config
# 按实体类型建立 agent 索引 # Build an agent index keyed by entity type.
agents_by_type: Dict[str, List[AgentActivityConfig]] = {} agents_by_type: Dict[str, List[AgentActivityConfig]] = {}
for agent in agent_configs: for agent in agent_configs:
etype = agent.entity_type.lower() etype = agent.entity_type.lower()
@ -746,7 +749,7 @@ Return strict JSON (no markdown):
agents_by_type[etype] = [] agents_by_type[etype] = []
agents_by_type[etype].append(agent) agents_by_type[etype].append(agent)
# 类型映射表(处理 LLM 可能输出的不同格式) # Type alias map (handles the different formats the LLM might emit).
type_aliases = { type_aliases = {
"official": ["official", "university", "governmentagency", "government"], "official": ["official", "university", "governmentagency", "government"],
"university": ["university", "official"], "university": ["university", "official"],
@ -758,7 +761,7 @@ Return strict JSON (no markdown):
"person": ["person", "student", "alumni"], "person": ["person", "student", "alumni"],
} }
# 记录每种类型已使用的 agent 索引,避免重复使用同一个 agent # Track the next agent index used per type to avoid reusing the same agent twice.
used_indices: Dict[str, int] = {} used_indices: Dict[str, int] = {}
updated_posts = [] updated_posts = []
@ -766,17 +769,17 @@ Return strict JSON (no markdown):
poster_type = post.get("poster_type", "").lower() poster_type = post.get("poster_type", "").lower()
content = post.get("content", "") content = post.get("content", "")
# 尝试找到匹配的 agent # Try to find a matching agent.
matched_agent_id = None matched_agent_id = None
# 1. 直接匹配 # 1. Direct match.
if poster_type in agents_by_type: if poster_type in agents_by_type:
agents = agents_by_type[poster_type] agents = agents_by_type[poster_type]
idx = used_indices.get(poster_type, 0) % len(agents) idx = used_indices.get(poster_type, 0) % len(agents)
matched_agent_id = agents[idx].agent_id matched_agent_id = agents[idx].agent_id
used_indices[poster_type] = idx + 1 used_indices[poster_type] = idx + 1
else: else:
# 2. 使用别名匹配 # 2. Match via aliases.
for alias_key, aliases in type_aliases.items(): for alias_key, aliases in type_aliases.items():
if poster_type in aliases or alias_key == poster_type: if poster_type in aliases or alias_key == poster_type:
for alias in aliases: for alias in aliases:
@ -789,11 +792,11 @@ Return strict JSON (no markdown):
if matched_agent_id is not None: if matched_agent_id is not None:
break break
# 3. 如果仍未找到,使用影响力最高的 agent # 3. If still unresolved, fall back to the most influential agent.
if matched_agent_id is None: if matched_agent_id is None:
logger.warning(t("log.simulation_config.m012", poster_type=poster_type)) logger.warning(t("log.simulation_config.m012", poster_type=poster_type))
if agent_configs: if agent_configs:
# 按影响力排序,选择影响力最高的 # Sort by influence and pick the highest.
sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True) sorted_agents = sorted(agent_configs, key=lambda a: a.influence_weight, reverse=True)
matched_agent_id = sorted_agents[0].agent_id matched_agent_id = sorted_agents[0].agent_id
else: else:
@ -817,9 +820,9 @@ Return strict JSON (no markdown):
start_idx: int, start_idx: int,
simulation_requirement: str simulation_requirement: str
) -> List[AgentActivityConfig]: ) -> List[AgentActivityConfig]:
"""分批生成Agent配置""" """Generate agent configurations in batches."""
# 构建实体信息(使用配置的摘要长度) # Build entity information (using the configured summary length).
entity_list = [] entity_list = []
summary_len = self.AGENT_SUMMARY_LENGTH summary_len = self.AGENT_SUMMARY_LENGTH
for i, e in enumerate(entities): for i, e in enumerate(entities):
@ -876,13 +879,13 @@ Return strict JSON (no markdown):
logger.warning(t("log.simulation_config.m014", e=e)) logger.warning(t("log.simulation_config.m014", e=e))
llm_configs = {} llm_configs = {}
# 构建AgentActivityConfig对象 # Build AgentActivityConfig objects.
configs = [] configs = []
for i, entity in enumerate(entities): for i, entity in enumerate(entities):
agent_id = start_idx + i agent_id = start_idx + i
cfg = llm_configs.get(agent_id, {}) cfg = llm_configs.get(agent_id, {})
# 如果LLM没有生成使用规则生成 # If the LLM did not produce a config, fall back to rule-based generation.
if not cfg: if not cfg:
cfg = self._generate_agent_config_by_rule(entity) cfg = self._generate_agent_config_by_rule(entity)
@ -906,16 +909,16 @@ Return strict JSON (no markdown):
return configs return configs
def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]: def _generate_agent_config_by_rule(self, entity: EntityNode) -> Dict[str, Any]:
"""基于规则生成单个Agent配置中国人作息""" """Rule-based generation for a single agent's configuration (Chinese daily rhythm)."""
entity_type = (entity.get_entity_type() or "Unknown").lower() entity_type = (entity.get_entity_type() or "Unknown").lower()
if entity_type in ["university", "governmentagency", "ngo"]: if entity_type in ["university", "governmentagency", "ngo"]:
# 官方机构:工作时间活动,低频率,高影响力 # Official institutions: active during working hours, low frequency, high influence.
return { return {
"activity_level": 0.2, "activity_level": 0.2,
"posts_per_hour": 0.1, "posts_per_hour": 0.1,
"comments_per_hour": 0.05, "comments_per_hour": 0.05,
"active_hours": list(range(9, 18)), # 9:00-17:59 "active_hours": list(range(9, 18)), # 09:00-17:59
"response_delay_min": 60, "response_delay_min": 60,
"response_delay_max": 240, "response_delay_max": 240,
"sentiment_bias": 0.0, "sentiment_bias": 0.0,
@ -923,12 +926,12 @@ Return strict JSON (no markdown):
"influence_weight": 3.0 "influence_weight": 3.0
} }
elif entity_type in ["mediaoutlet"]: elif entity_type in ["mediaoutlet"]:
# 媒体:全天活动,中等频率,高影响力 # Media: active throughout the day, medium frequency, high influence.
return { return {
"activity_level": 0.5, "activity_level": 0.5,
"posts_per_hour": 0.8, "posts_per_hour": 0.8,
"comments_per_hour": 0.3, "comments_per_hour": 0.3,
"active_hours": list(range(7, 24)), # 7:00-23:59 "active_hours": list(range(7, 24)), # 07:00-23:59
"response_delay_min": 5, "response_delay_min": 5,
"response_delay_max": 30, "response_delay_max": 30,
"sentiment_bias": 0.0, "sentiment_bias": 0.0,
@ -936,12 +939,12 @@ Return strict JSON (no markdown):
"influence_weight": 2.5 "influence_weight": 2.5
} }
elif entity_type in ["professor", "expert", "official"]: elif entity_type in ["professor", "expert", "official"]:
# 专家/教授:工作+晚间活动,中等频率 # Experts / professors: active during work and evening, medium frequency.
return { return {
"activity_level": 0.4, "activity_level": 0.4,
"posts_per_hour": 0.3, "posts_per_hour": 0.3,
"comments_per_hour": 0.5, "comments_per_hour": 0.5,
"active_hours": list(range(8, 22)), # 8:00-21:59 "active_hours": list(range(8, 22)), # 08:00-21:59
"response_delay_min": 15, "response_delay_min": 15,
"response_delay_max": 90, "response_delay_max": 90,
"sentiment_bias": 0.0, "sentiment_bias": 0.0,
@ -949,12 +952,12 @@ Return strict JSON (no markdown):
"influence_weight": 2.0 "influence_weight": 2.0
} }
elif entity_type in ["student"]: elif entity_type in ["student"]:
# 学生:晚间为主,高频率 # Students: mostly evening, high frequency.
return { return {
"activity_level": 0.8, "activity_level": 0.8,
"posts_per_hour": 0.6, "posts_per_hour": 0.6,
"comments_per_hour": 1.5, "comments_per_hour": 1.5,
"active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 上午+晚间 "active_hours": [8, 9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # Morning + evening.
"response_delay_min": 1, "response_delay_min": 1,
"response_delay_max": 15, "response_delay_max": 15,
"sentiment_bias": 0.0, "sentiment_bias": 0.0,
@ -962,12 +965,12 @@ Return strict JSON (no markdown):
"influence_weight": 0.8 "influence_weight": 0.8
} }
elif entity_type in ["alumni"]: elif entity_type in ["alumni"]:
# 校友:晚间为主 # Alumni: mostly evening.
return { return {
"activity_level": 0.6, "activity_level": 0.6,
"posts_per_hour": 0.4, "posts_per_hour": 0.4,
"comments_per_hour": 0.8, "comments_per_hour": 0.8,
"active_hours": [12, 13, 19, 20, 21, 22, 23], # 午休+晚间 "active_hours": [12, 13, 19, 20, 21, 22, 23], # Lunch break + evening.
"response_delay_min": 5, "response_delay_min": 5,
"response_delay_max": 30, "response_delay_max": 30,
"sentiment_bias": 0.0, "sentiment_bias": 0.0,
@ -975,12 +978,12 @@ Return strict JSON (no markdown):
"influence_weight": 1.0 "influence_weight": 1.0
} }
else: else:
# 普通人:晚间高峰 # General public: evening peak.
return { return {
"activity_level": 0.7, "activity_level": 0.7,
"posts_per_hour": 0.5, "posts_per_hour": 0.5,
"comments_per_hour": 1.2, "comments_per_hour": 1.2,
"active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # 白天+晚间 "active_hours": [9, 10, 11, 12, 13, 18, 19, 20, 21, 22, 23], # Daytime + evening.
"response_delay_min": 2, "response_delay_min": 2,
"response_delay_max": 20, "response_delay_max": 20,
"sentiment_bias": 0.0, "sentiment_bias": 0.0,

View File

@ -1,11 +1,12 @@
""" """Simulation IPC module.
模拟IPC通信模块
用于Flask后端和模拟脚本之间的进程间通信
通过文件系统实现简单的命令/响应模式 Inter-process communication between the Flask backend and the simulation
1. Flask写入命令到 commands/ 目录 subprocess. Implements a simple file-system command/response pattern:
2. 模拟脚本轮询命令目录执行命令并写入响应到 responses/ 目录
3. Flask轮询响应目录获取结果 1. Flask writes commands into ``commands/``.
2. The simulation script polls for commands, executes them, and writes
responses into ``responses/``.
3. Flask polls the responses directory for results.
""" """
import os import os
@ -24,14 +25,14 @@ logger = get_logger('mirofish.simulation_ipc')
class CommandType(str, Enum): class CommandType(str, Enum):
"""命令类型""" """IPC command types."""
INTERVIEW = "interview" # 单个Agent采访 INTERVIEW = "interview" # interview a single agent
BATCH_INTERVIEW = "batch_interview" # 批量采访 BATCH_INTERVIEW = "batch_interview" # interview multiple agents at once
CLOSE_ENV = "close_env" # 关闭环境 CLOSE_ENV = "close_env" # tear down the environment
class CommandStatus(str, Enum): class CommandStatus(str, Enum):
"""命令状态""" """IPC command status."""
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"
COMPLETED = "completed" COMPLETED = "completed"
@ -40,12 +41,12 @@ class CommandStatus(str, Enum):
@dataclass @dataclass
class IPCCommand: class IPCCommand:
"""IPC命令""" """A command sent over the IPC channel."""
command_id: str command_id: str
command_type: CommandType command_type: CommandType
args: Dict[str, Any] args: Dict[str, Any]
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"command_id": self.command_id, "command_id": self.command_id,
@ -53,7 +54,7 @@ class IPCCommand:
"args": self.args, "args": self.args,
"timestamp": self.timestamp "timestamp": self.timestamp
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand': def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand':
return cls( return cls(
@ -66,13 +67,13 @@ class IPCCommand:
@dataclass @dataclass
class IPCResponse: class IPCResponse:
"""IPC响应""" """A response returned over the IPC channel."""
command_id: str command_id: str
status: CommandStatus status: CommandStatus
result: Optional[Dict[str, Any]] = None result: Optional[Dict[str, Any]] = None
error: Optional[str] = None error: Optional[str] = None
timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"command_id": self.command_id, "command_id": self.command_id,
@ -81,7 +82,7 @@ class IPCResponse:
"error": self.error, "error": self.error,
"timestamp": self.timestamp "timestamp": self.timestamp
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse': def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse':
return cls( return cls(
@ -94,27 +95,25 @@ class IPCResponse:
class SimulationIPCClient: class SimulationIPCClient:
"""IPC client used by the Flask side.
Sends commands to the simulation process and waits for responses.
""" """
模拟IPC客户端Flask端使用
用于向模拟进程发送命令并等待响应
"""
def __init__(self, simulation_dir: str): def __init__(self, simulation_dir: str):
""" """Initialize the IPC client.
初始化IPC客户端
Args: Args:
simulation_dir: 模拟数据目录 simulation_dir: Directory holding the simulation's IPC files.
""" """
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
self.responses_dir = os.path.join(simulation_dir, "ipc_responses") self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
# 确保目录存在 # Ensure both directories exist before use.
os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True)
def send_command( def send_command(
self, self,
command_type: CommandType, command_type: CommandType,
@ -122,20 +121,19 @@ class SimulationIPCClient:
timeout: float = 60.0, timeout: float = 60.0,
poll_interval: float = 0.5 poll_interval: float = 0.5
) -> IPCResponse: ) -> IPCResponse:
""" """Send a command and wait for the response.
发送命令并等待响应
Args: Args:
command_type: 命令类型 command_type: Command type to send.
args: 命令参数 args: Command arguments.
timeout: 超时时间 timeout: Timeout in seconds.
poll_interval: 轮询间隔 poll_interval: Polling interval in seconds.
Returns: Returns:
IPCResponse The ``IPCResponse``.
Raises: Raises:
TimeoutError: 等待响应超时 TimeoutError: When no response arrives before ``timeout``.
""" """
command_id = str(uuid.uuid4()) command_id = str(uuid.uuid4())
command = IPCCommand( command = IPCCommand(
@ -143,50 +141,50 @@ class SimulationIPCClient:
command_type=command_type, command_type=command_type,
args=args args=args
) )
# 写入命令文件 # Write the command file.
command_file = os.path.join(self.commands_dir, f"{command_id}.json") command_file = os.path.join(self.commands_dir, f"{command_id}.json")
with open(command_file, 'w', encoding='utf-8') as f: with open(command_file, 'w', encoding='utf-8') as f:
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2) json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
logger.info(t("log.simulation_ipc.m001", command_type=command_type.value, command_id=command_id)) logger.info(t("log.simulation_ipc.m001", command_type=command_type.value, command_id=command_id))
# 等待响应 # Poll for the response file.
response_file = os.path.join(self.responses_dir, f"{command_id}.json") response_file = os.path.join(self.responses_dir, f"{command_id}.json")
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
if os.path.exists(response_file): if os.path.exists(response_file):
try: try:
with open(response_file, 'r', encoding='utf-8') as f: with open(response_file, 'r', encoding='utf-8') as f:
response_data = json.load(f) response_data = json.load(f)
response = IPCResponse.from_dict(response_data) response = IPCResponse.from_dict(response_data)
# 清理命令和响应文件 # Clean up command and response files after successful read.
try: try:
os.remove(command_file) os.remove(command_file)
os.remove(response_file) os.remove(response_file)
except OSError: except OSError:
pass pass
logger.info(t("log.simulation_ipc.m002", command_id=command_id, response=response.status.value)) logger.info(t("log.simulation_ipc.m002", command_id=command_id, response=response.status.value))
return response return response
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.warning(t("log.simulation_ipc.m003", e=e)) logger.warning(t("log.simulation_ipc.m003", e=e))
time.sleep(poll_interval) time.sleep(poll_interval)
# 超时 # Timed out waiting for the response.
logger.error(t("log.simulation_ipc.m004", command_id=command_id)) logger.error(t("log.simulation_ipc.m004", command_id=command_id))
# 清理命令文件 # Clean up the unanswered command file.
try: try:
os.remove(command_file) os.remove(command_file)
except OSError: except OSError:
pass pass
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)") raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
def send_interview( def send_interview(
self, self,
agent_id: int, agent_id: int,
@ -194,20 +192,19 @@ class SimulationIPCClient:
platform: str = None, platform: str = None,
timeout: float = 60.0 timeout: float = 60.0
) -> IPCResponse: ) -> IPCResponse:
""" """Send a single-agent interview command.
发送单个Agent采访命令
Args: Args:
agent_id: Agent ID agent_id: Agent id to interview.
prompt: 采访问题 prompt: Interview question.
platform: 指定平台可选 platform: Optional platform selector.
- "twitter": 只采访Twitter平台 - ``"twitter"``: interview only on Twitter.
- "reddit": 只采访Reddit平台 - ``"reddit"``: interview only on Reddit.
- None: 双平台模拟时同时采访两个平台单平台模拟时采访该平台 - ``None``: dual-platform if applicable, else the single active platform.
timeout: 超时时间 timeout: Timeout in seconds.
Returns: Returns:
IPCResponseresult字段包含采访结果 ``IPCResponse`` whose ``result`` carries the interview response.
""" """
args = { args = {
"agent_id": agent_id, "agent_id": agent_id,
@ -215,69 +212,66 @@ class SimulationIPCClient:
} }
if platform: if platform:
args["platform"] = platform args["platform"] = platform
return self.send_command( return self.send_command(
command_type=CommandType.INTERVIEW, command_type=CommandType.INTERVIEW,
args=args, args=args,
timeout=timeout timeout=timeout
) )
def send_batch_interview( def send_batch_interview(
self, self,
interviews: List[Dict[str, Any]], interviews: List[Dict[str, Any]],
platform: str = None, platform: str = None,
timeout: float = 120.0 timeout: float = 120.0
) -> IPCResponse: ) -> IPCResponse:
""" """Send a batched interview command.
发送批量采访命令
Args: Args:
interviews: 采访列表每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} interviews: List of items shaped ``{"agent_id": int, "prompt": str, "platform": str?}``.
platform: 默认平台可选会被每个采访项的platform覆盖 platform: Default platform; per-item ``platform`` overrides this.
- "twitter": 默认只采访Twitter平台 - ``"twitter"``: default to Twitter.
- "reddit": 默认只采访Reddit平台 - ``"reddit"``: default to Reddit.
- None: 双平台模拟时每个Agent同时采访两个平台 - ``None``: dual-platform interview when applicable.
timeout: 超时时间 timeout: Timeout in seconds.
Returns: Returns:
IPCResponseresult字段包含所有采访结果 ``IPCResponse`` whose ``result`` carries every interview response.
""" """
args = {"interviews": interviews} args = {"interviews": interviews}
if platform: if platform:
args["platform"] = platform args["platform"] = platform
return self.send_command( return self.send_command(
command_type=CommandType.BATCH_INTERVIEW, command_type=CommandType.BATCH_INTERVIEW,
args=args, args=args,
timeout=timeout timeout=timeout
) )
def send_close_env(self, timeout: float = 30.0) -> IPCResponse: def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
""" """Send a tear-down-environment command.
发送关闭环境命令
Args: Args:
timeout: 超时时间 timeout: Timeout in seconds.
Returns: Returns:
IPCResponse ``IPCResponse``.
""" """
return self.send_command( return self.send_command(
command_type=CommandType.CLOSE_ENV, command_type=CommandType.CLOSE_ENV,
args={}, args={},
timeout=timeout timeout=timeout
) )
def check_env_alive(self) -> bool: def check_env_alive(self) -> bool:
""" """Return ``True`` if the simulation environment reports as alive.
检查模拟环境是否存活
Reads ``env_status.json`` written by the IPC server side.
通过检查 env_status.json 文件来判断
""" """
status_file = os.path.join(self.simulation_dir, "env_status.json") status_file = os.path.join(self.simulation_dir, "env_status.json")
if not os.path.exists(status_file): if not os.path.exists(status_file):
return False return False
try: try:
with open(status_file, 'r', encoding='utf-8') as f: with open(status_file, 'r', encoding='utf-8') as f:
status = json.load(f) status = json.load(f)
@ -287,68 +281,65 @@ class SimulationIPCClient:
class SimulationIPCServer: class SimulationIPCServer:
"""IPC server used by the simulation script.
Polls the commands directory, executes commands, and writes responses.
""" """
模拟IPC服务器模拟脚本端使用
轮询命令目录执行命令并返回响应
"""
def __init__(self, simulation_dir: str): def __init__(self, simulation_dir: str):
""" """Initialize the IPC server.
初始化IPC服务器
Args: Args:
simulation_dir: 模拟数据目录 simulation_dir: Directory holding the simulation's IPC files.
""" """
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
self.responses_dir = os.path.join(simulation_dir, "ipc_responses") self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
# 确保目录存在 # Ensure both directories exist before use.
os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True)
# 环境状态 # Server-running flag.
self._running = False self._running = False
def start(self): def start(self):
"""标记服务器为运行状态""" """Mark the server as alive and persist the state."""
self._running = True self._running = True
self._update_env_status("alive") self._update_env_status("alive")
def stop(self): def stop(self):
"""标记服务器为停止状态""" """Mark the server as stopped and persist the state."""
self._running = False self._running = False
self._update_env_status("stopped") self._update_env_status("stopped")
def _update_env_status(self, status: str): def _update_env_status(self, status: str):
"""更新环境状态文件""" """Update the persistent environment-status file."""
status_file = os.path.join(self.simulation_dir, "env_status.json") status_file = os.path.join(self.simulation_dir, "env_status.json")
with open(status_file, 'w', encoding='utf-8') as f: with open(status_file, 'w', encoding='utf-8') as f:
json.dump({ json.dump({
"status": status, "status": status,
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
}, f, ensure_ascii=False, indent=2) }, f, ensure_ascii=False, indent=2)
def poll_commands(self) -> Optional[IPCCommand]: def poll_commands(self) -> Optional[IPCCommand]:
""" """Poll the commands directory and return the next pending command.
轮询命令目录返回第一个待处理的命令
Returns: Returns:
IPCCommand None ``IPCCommand`` or ``None`` if no pending commands remain.
""" """
if not os.path.exists(self.commands_dir): if not os.path.exists(self.commands_dir):
return None return None
# 按时间排序获取命令文件 # Sort by mtime so we process commands in arrival order.
command_files = [] command_files = []
for filename in os.listdir(self.commands_dir): for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'): if filename.endswith('.json'):
filepath = os.path.join(self.commands_dir, filename) filepath = os.path.join(self.commands_dir, filename)
command_files.append((filepath, os.path.getmtime(filepath))) command_files.append((filepath, os.path.getmtime(filepath)))
command_files.sort(key=lambda x: x[1]) command_files.sort(key=lambda x: x[1])
for filepath, _ in command_files: for filepath, _ in command_files:
try: try:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, 'r', encoding='utf-8') as f:
@ -357,37 +348,36 @@ class SimulationIPCServer:
except (json.JSONDecodeError, KeyError, OSError) as e: except (json.JSONDecodeError, KeyError, OSError) as e:
logger.warning(t("log.simulation_ipc.m005", filepath=filepath, e=e)) logger.warning(t("log.simulation_ipc.m005", filepath=filepath, e=e))
continue continue
return None return None
def send_response(self, response: IPCResponse): def send_response(self, response: IPCResponse):
""" """Write a response file.
发送响应
Args: Args:
response: IPC响应 response: The response to send.
""" """
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json") response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f: with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2) json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
# 删除命令文件 # Delete the matching command file.
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json") command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
try: try:
os.remove(command_file) os.remove(command_file)
except OSError: except OSError:
pass pass
def send_success(self, command_id: str, result: Dict[str, Any]): def send_success(self, command_id: str, result: Dict[str, Any]):
"""发送成功响应""" """Send a success response."""
self.send_response(IPCResponse( self.send_response(IPCResponse(
command_id=command_id, command_id=command_id,
status=CommandStatus.COMPLETED, status=CommandStatus.COMPLETED,
result=result result=result
)) ))
def send_error(self, command_id: str, error: str): def send_error(self, command_id: str, error: str):
"""发送错误响应""" """Send a failure response."""
self.send_response(IPCResponse( self.send_response(IPCResponse(
command_id=command_id, command_id=command_id,
status=CommandStatus.FAILED, status=CommandStatus.FAILED,

View File

@ -1,7 +1,7 @@
""" """OASIS simulation manager.
OASIS模拟管理器
管理Twitter和Reddit双平台并行模拟 Drives parallel Twitter + Reddit simulations using preset scripts plus
使用预设脚本 + LLM智能生成配置参数 LLM-generated configuration parameters.
""" """
import os import os
@ -23,60 +23,60 @@ logger = get_logger('mirofish.simulation')
class SimulationStatus(str, Enum): class SimulationStatus(str, Enum):
"""模拟状态""" """Simulation lifecycle status."""
CREATED = "created" CREATED = "created"
PREPARING = "preparing" PREPARING = "preparing"
READY = "ready" READY = "ready"
RUNNING = "running" RUNNING = "running"
PAUSED = "paused" PAUSED = "paused"
STOPPED = "stopped" # 模拟被手动停止 STOPPED = "stopped" # manually stopped
COMPLETED = "completed" # 模拟自然完成 COMPLETED = "completed" # finished naturally
FAILED = "failed" FAILED = "failed"
class PlatformType(str, Enum): class PlatformType(str, Enum):
"""平台类型""" """Simulated platform types."""
TWITTER = "twitter" TWITTER = "twitter"
REDDIT = "reddit" REDDIT = "reddit"
@dataclass @dataclass
class SimulationState: class SimulationState:
"""模拟状态""" """In-memory + persisted state for a single simulation."""
simulation_id: str simulation_id: str
project_id: str project_id: str
graph_id: str graph_id: str
# 平台启用状态 # Per-platform enable flags.
enable_twitter: bool = True enable_twitter: bool = True
enable_reddit: bool = True enable_reddit: bool = True
# 状态 # Lifecycle status.
status: SimulationStatus = SimulationStatus.CREATED status: SimulationStatus = SimulationStatus.CREATED
# 准备阶段数据 # Counters captured during the prepare phase.
entities_count: int = 0 entities_count: int = 0
profiles_count: int = 0 profiles_count: int = 0
entity_types: List[str] = field(default_factory=list) entity_types: List[str] = field(default_factory=list)
# 配置生成信息 # Information about the auto-generated config.
config_generated: bool = False config_generated: bool = False
config_reasoning: str = "" config_reasoning: str = ""
# 运行时数据 # Runtime data.
current_round: int = 0 current_round: int = 0
twitter_status: str = "not_started" twitter_status: str = "not_started"
reddit_status: str = "not_started" reddit_status: str = "not_started"
# 时间戳 # Timestamps.
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
# 错误信息 # Error message when status == FAILED.
error: Optional[str] = None error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""完整状态字典(内部使用)""" """Full state dict (used for persistence and internal callers)."""
return { return {
"simulation_id": self.simulation_id, "simulation_id": self.simulation_id,
"project_id": self.project_id, "project_id": self.project_id,
@ -96,9 +96,9 @@ class SimulationState:
"updated_at": self.updated_at, "updated_at": self.updated_at,
"error": self.error, "error": self.error,
} }
def to_simple_dict(self) -> Dict[str, Any]: def to_simple_dict(self) -> Dict[str, Any]:
"""简化状态字典API返回使用""" """Simplified state dict (used for API responses)."""
return { return {
"simulation_id": self.simulation_id, "simulation_id": self.simulation_id,
"project_id": self.project_id, "project_id": self.project_id,
@ -113,61 +113,60 @@ class SimulationState:
class SimulationManager: class SimulationManager:
"""Simulation manager.
Core responsibilities:
1. Read entities from the Zep graph and filter to the configured types.
2. Generate OASIS agent profiles per entity.
3. Use the LLM to generate simulation configuration parameters.
4. Materialize the files the preset scripts expect.
""" """
模拟管理器
# Root directory for persisted simulation data.
核心功能
1. 从Zep图谱读取实体并过滤
2. 生成OASIS Agent Profile
3. 使用LLM智能生成模拟配置参数
4. 准备预设脚本所需的所有文件
"""
# 模拟数据存储目录
SIMULATION_DATA_DIR = os.path.join( SIMULATION_DATA_DIR = os.path.join(
os.path.dirname(__file__), os.path.dirname(__file__),
'../../uploads/simulations' '../../uploads/simulations'
) )
def __init__(self): def __init__(self):
# 确保目录存在 # Ensure the simulation data directory exists.
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
# 内存中的模拟状态缓存 # In-memory cache of simulation state objects.
self._simulations: Dict[str, SimulationState] = {} self._simulations: Dict[str, SimulationState] = {}
def _get_simulation_dir(self, simulation_id: str) -> str: def _get_simulation_dir(self, simulation_id: str) -> str:
"""获取模拟数据目录""" """Return the on-disk directory for a simulation, creating if missing."""
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id) sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
os.makedirs(sim_dir, exist_ok=True) os.makedirs(sim_dir, exist_ok=True)
return sim_dir return sim_dir
def _save_simulation_state(self, state: SimulationState): def _save_simulation_state(self, state: SimulationState):
"""保存模拟状态到文件""" """Persist a simulation state to disk and update the cache."""
sim_dir = self._get_simulation_dir(state.simulation_id) sim_dir = self._get_simulation_dir(state.simulation_id)
state_file = os.path.join(sim_dir, "state.json") state_file = os.path.join(sim_dir, "state.json")
state.updated_at = datetime.now().isoformat() state.updated_at = datetime.now().isoformat()
with open(state_file, 'w', encoding='utf-8') as f: with open(state_file, 'w', encoding='utf-8') as f:
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2) json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
self._simulations[state.simulation_id] = state self._simulations[state.simulation_id] = state
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]: def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
"""从文件加载模拟状态""" """Load a simulation state from disk (or cache) by id."""
if simulation_id in self._simulations: if simulation_id in self._simulations:
return self._simulations[simulation_id] return self._simulations[simulation_id]
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
state_file = os.path.join(sim_dir, "state.json") state_file = os.path.join(sim_dir, "state.json")
if not os.path.exists(state_file): if not os.path.exists(state_file):
return None return None
with open(state_file, 'r', encoding='utf-8') as f: with open(state_file, 'r', encoding='utf-8') as f:
data = json.load(f) data = json.load(f)
state = SimulationState( state = SimulationState(
simulation_id=simulation_id, simulation_id=simulation_id,
project_id=data.get("project_id", ""), project_id=data.get("project_id", ""),
@ -187,10 +186,10 @@ class SimulationManager:
updated_at=data.get("updated_at", datetime.now().isoformat()), updated_at=data.get("updated_at", datetime.now().isoformat()),
error=data.get("error"), error=data.get("error"),
) )
self._simulations[simulation_id] = state self._simulations[simulation_id] = state
return state return state
def create_simulation( def create_simulation(
self, self,
project_id: str, project_id: str,
@ -198,21 +197,20 @@ class SimulationManager:
enable_twitter: bool = True, enable_twitter: bool = True,
enable_reddit: bool = True, enable_reddit: bool = True,
) -> SimulationState: ) -> SimulationState:
""" """Create a new simulation in the ``CREATED`` state.
创建新的模拟
Args: Args:
project_id: 项目ID project_id: Owning project id.
graph_id: Zep图谱ID graph_id: Source Zep graph id.
enable_twitter: 是否启用Twitter模拟 enable_twitter: When ``True``, the Twitter simulation runs.
enable_reddit: 是否启用Reddit模拟 enable_reddit: When ``True``, the Reddit simulation runs.
Returns: Returns:
SimulationState The created ``SimulationState``.
""" """
import uuid import uuid
simulation_id = f"sim_{uuid.uuid4().hex[:12]}" simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
state = SimulationState( state = SimulationState(
simulation_id=simulation_id, simulation_id=simulation_id,
project_id=project_id, project_id=project_id,
@ -221,12 +219,12 @@ class SimulationManager:
enable_reddit=enable_reddit, enable_reddit=enable_reddit,
status=SimulationStatus.CREATED, status=SimulationStatus.CREATED,
) )
self._save_simulation_state(state) self._save_simulation_state(state)
logger.info(t("log.simulation_manager.m001", simulation_id=simulation_id, project_id=project_id, graph_id=graph_id)) logger.info(t("log.simulation_manager.m001", simulation_id=simulation_id, project_id=project_id, graph_id=graph_id))
return state return state
def prepare_simulation( def prepare_simulation(
self, self,
simulation_id: str, simulation_id: str,
@ -237,56 +235,55 @@ class SimulationManager:
progress_callback: Optional[callable] = None, progress_callback: Optional[callable] = None,
parallel_profile_count: int = 3 parallel_profile_count: int = 3
) -> SimulationState: ) -> SimulationState:
""" """Prepare the simulation environment end-to-end.
准备模拟环境全程自动化
Steps:
步骤 1. Read and filter entities from the graph.
1. 从Zep图谱读取并过滤实体 2. Generate OASIS agent profiles (optional LLM enrichment, parallel-capable).
2. 为每个实体生成OASIS Agent Profile可选LLM增强支持并行 3. Use the LLM to produce simulation parameters (timing, activity, posting frequency).
3. 使用LLM智能生成模拟配置参数时间活跃度发言频率等 4. Save the configuration and profile files.
4. 保存配置文件和Profile文件 5. Copy preset scripts into the simulation directory.
5. 复制预设脚本到模拟目录
Args: Args:
simulation_id: 模拟ID simulation_id: Simulation id.
simulation_requirement: 模拟需求描述用于LLM生成配置 simulation_requirement: Free-text description of the simulation goal.
document_text: 原始文档内容用于LLM理解背景 document_text: Raw source document text passed to the LLM for context.
defined_entity_types: 预定义的实体类型可选 defined_entity_types: Optional list of allowed entity types.
use_llm_for_profiles: 是否使用LLM生成详细人设 use_llm_for_profiles: When ``True``, enrich profiles via the LLM.
progress_callback: 进度回调函数 (stage, progress, message) progress_callback: Optional callback ``(stage, progress, message, **extras)``.
parallel_profile_count: 并行生成人设的数量默认3 parallel_profile_count: Number of profile generations to run in parallel.
Returns: Returns:
SimulationState The updated ``SimulationState``.
""" """
state = self._load_simulation_state(simulation_id) state = self._load_simulation_state(simulation_id)
if not state: if not state:
raise ValueError(f"模拟不存在: {simulation_id}") raise ValueError(f"模拟不存在: {simulation_id}")
try: try:
state.status = SimulationStatus.PREPARING state.status = SimulationStatus.PREPARING
self._save_simulation_state(state) self._save_simulation_state(state)
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
# ========== 阶段1: 读取并过滤实体 ========== # ========== Stage 1: read and filter entities ==========
if progress_callback: if progress_callback:
progress_callback("reading", 0, t('progress.connectingZepGraph')) progress_callback("reading", 0, t('progress.connectingZepGraph'))
reader = ZepEntityReader() reader = ZepEntityReader()
if progress_callback: if progress_callback:
progress_callback("reading", 30, t('progress.readingNodeData')) progress_callback("reading", 30, t('progress.readingNodeData'))
filtered = reader.filter_defined_entities( filtered = reader.filter_defined_entities(
graph_id=state.graph_id, graph_id=state.graph_id,
defined_entity_types=defined_entity_types, defined_entity_types=defined_entity_types,
enrich_with_edges=True enrich_with_edges=True
) )
state.entities_count = filtered.filtered_count state.entities_count = filtered.filtered_count
state.entity_types = list(filtered.entity_types) state.entity_types = list(filtered.entity_types)
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"reading", 100, "reading", 100,
@ -294,16 +291,16 @@ class SimulationManager:
current=filtered.filtered_count, current=filtered.filtered_count,
total=filtered.filtered_count total=filtered.filtered_count
) )
if filtered.filtered_count == 0: if filtered.filtered_count == 0:
state.status = SimulationStatus.FAILED state.status = SimulationStatus.FAILED
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建" state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
self._save_simulation_state(state) self._save_simulation_state(state)
return state return state
# ========== 阶段2: 生成Agent Profile ========== # ========== Stage 2: generate agent profiles ==========
total_entities = len(filtered.entities) total_entities = len(filtered.entities)
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_profiles", 0, "generating_profiles", 0,
@ -311,22 +308,22 @@ class SimulationManager:
current=0, current=0,
total=total_entities total=total_entities
) )
# 传入graph_id以启用Zep检索功能获取更丰富的上下文 # Pass the graph_id so the generator can use Zep retrieval for richer context.
generator = OasisProfileGenerator(graph_id=state.graph_id) generator = OasisProfileGenerator(graph_id=state.graph_id)
def profile_progress(current, total, msg): def profile_progress(current, total, msg):
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_profiles", "generating_profiles",
int(current / total * 100), int(current / total * 100),
msg, msg,
current=current, current=current,
total=total, total=total,
item_name=msg item_name=msg
) )
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式) # Configure the realtime save target (prefer Reddit JSON if Reddit is enabled).
realtime_output_path = None realtime_output_path = None
realtime_platform = "reddit" realtime_platform = "reddit"
if state.enable_reddit: if state.enable_reddit:
@ -335,21 +332,21 @@ class SimulationManager:
elif state.enable_twitter: elif state.enable_twitter:
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv") realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
realtime_platform = "twitter" realtime_platform = "twitter"
profiles = generator.generate_profiles_from_entities( profiles = generator.generate_profiles_from_entities(
entities=filtered.entities, entities=filtered.entities,
use_llm=use_llm_for_profiles, use_llm=use_llm_for_profiles,
progress_callback=profile_progress, progress_callback=profile_progress,
graph_id=state.graph_id, # 传入graph_id用于Zep检索 graph_id=state.graph_id, # used for Zep retrieval enrichment
parallel_count=parallel_profile_count, # 并行生成数量 parallel_count=parallel_profile_count,
realtime_output_path=realtime_output_path, # 实时保存路径 realtime_output_path=realtime_output_path,
output_platform=realtime_platform # 输出格式 output_platform=realtime_platform
) )
state.profiles_count = len(profiles) state.profiles_count = len(profiles)
# 保存Profile文件注意Twitter使用CSV格式Reddit使用JSON格式 # Save profile files. Reddit also writes JSON during generation; this is
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性 # a final consistency write. Twitter requires CSV per OASIS conventions.
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_profiles", 95, "generating_profiles", 95,
@ -357,22 +354,22 @@ class SimulationManager:
current=total_entities, current=total_entities,
total=total_entities total=total_entities
) )
if state.enable_reddit: if state.enable_reddit:
generator.save_profiles( generator.save_profiles(
profiles=profiles, profiles=profiles,
file_path=os.path.join(sim_dir, "reddit_profiles.json"), file_path=os.path.join(sim_dir, "reddit_profiles.json"),
platform="reddit" platform="reddit"
) )
if state.enable_twitter: if state.enable_twitter:
# Twitter使用CSV格式这是OASIS的要求 # Twitter uses CSV format — required by OASIS.
generator.save_profiles( generator.save_profiles(
profiles=profiles, profiles=profiles,
file_path=os.path.join(sim_dir, "twitter_profiles.csv"), file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
platform="twitter" platform="twitter"
) )
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_profiles", 100, "generating_profiles", 100,
@ -380,8 +377,8 @@ class SimulationManager:
current=len(profiles), current=len(profiles),
total=len(profiles) total=len(profiles)
) )
# ========== 阶段3: LLM智能生成模拟配置 ========== # ========== Stage 3: LLM-driven simulation config ==========
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_config", 0, "generating_config", 0,
@ -389,9 +386,9 @@ class SimulationManager:
current=0, current=0,
total=3 total=3
) )
config_generator = SimulationConfigGenerator() config_generator = SimulationConfigGenerator()
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_config", 30, "generating_config", 30,
@ -399,7 +396,7 @@ class SimulationManager:
current=1, current=1,
total=3 total=3
) )
sim_params = config_generator.generate_config( sim_params = config_generator.generate_config(
simulation_id=simulation_id, simulation_id=simulation_id,
project_id=state.project_id, project_id=state.project_id,
@ -410,7 +407,7 @@ class SimulationManager:
enable_twitter=state.enable_twitter, enable_twitter=state.enable_twitter,
enable_reddit=state.enable_reddit enable_reddit=state.enable_reddit
) )
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_config", 70, "generating_config", 70,
@ -418,15 +415,15 @@ class SimulationManager:
current=2, current=2,
total=3 total=3
) )
# 保存配置文件 # Save the configuration file.
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
with open(config_path, 'w', encoding='utf-8') as f: with open(config_path, 'w', encoding='utf-8') as f:
f.write(sim_params.to_json()) f.write(sim_params.to_json())
state.config_generated = True state.config_generated = True
state.config_reasoning = sim_params.generation_reasoning state.config_reasoning = sim_params.generation_reasoning
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_config", 100, "generating_config", 100,
@ -434,18 +431,17 @@ class SimulationManager:
current=3, current=3,
total=3 total=3
) )
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录 # The runtime scripts now live under backend/scripts/; we no longer copy
# 启动模拟时simulation_runner 会从 scripts/ 目录运行脚本 # them per-simulation. simulation_runner invokes them in place.
# 更新状态
state.status = SimulationStatus.READY state.status = SimulationStatus.READY
self._save_simulation_state(state) self._save_simulation_state(state)
logger.info(t("log.simulation_manager.m002", simulation_id=simulation_id, state=state.entities_count, state_2=state.profiles_count)) logger.info(t("log.simulation_manager.m002", simulation_id=simulation_id, state=state.entities_count, state_2=state.profiles_count))
return state return state
except Exception as e: except Exception as e:
logger.error(t("log.simulation_manager.m003", simulation_id=simulation_id, str=str(e))) logger.error(t("log.simulation_manager.m003", simulation_id=simulation_id, str=str(e)))
import traceback import traceback
@ -454,61 +450,61 @@ class SimulationManager:
state.error = str(e) state.error = str(e)
self._save_simulation_state(state) self._save_simulation_state(state)
raise raise
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
"""获取模拟状态""" """Return the simulation's state, or ``None`` if unknown."""
return self._load_simulation_state(simulation_id) return self._load_simulation_state(simulation_id)
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]: def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
"""列出所有模拟""" """List all simulations, optionally filtered by ``project_id``."""
simulations = [] simulations = []
if os.path.exists(self.SIMULATION_DATA_DIR): if os.path.exists(self.SIMULATION_DATA_DIR):
for sim_id in os.listdir(self.SIMULATION_DATA_DIR): for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
# 跳过隐藏文件(如 .DS_Store和非目录文件 # Skip dotfiles (e.g. .DS_Store) and non-directories.
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id) sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
if sim_id.startswith('.') or not os.path.isdir(sim_path): if sim_id.startswith('.') or not os.path.isdir(sim_path):
continue continue
state = self._load_simulation_state(sim_id) state = self._load_simulation_state(sim_id)
if state: if state:
if project_id is None or state.project_id == project_id: if project_id is None or state.project_id == project_id:
simulations.append(state) simulations.append(state)
return simulations return simulations
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]: def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
"""获取模拟的Agent Profile""" """Return the persisted agent profiles for a platform."""
state = self._load_simulation_state(simulation_id) state = self._load_simulation_state(simulation_id)
if not state: if not state:
raise ValueError(f"模拟不存在: {simulation_id}") raise ValueError(f"模拟不存在: {simulation_id}")
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json") profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
if not os.path.exists(profile_path): if not os.path.exists(profile_path):
return [] return []
with open(profile_path, 'r', encoding='utf-8') as f: with open(profile_path, 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
"""获取模拟配置""" """Return the persisted simulation config dict, or ``None`` if absent."""
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path): if not os.path.exists(config_path):
return None return None
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
"""获取运行说明""" """Return shell commands and instructions to launch the simulation manually."""
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
return { return {
"simulation_dir": sim_dir, "simulation_dir": sim_dir,
"scripts_dir": scripts_dir, "scripts_dir": scripts_dir,

File diff suppressed because it is too large Load Diff

View File

@ -1,68 +1,64 @@
""" """Text processing service."""
文本处理服务
"""
from typing import List, Optional from typing import List, Optional
from ..utils.file_parser import FileParser, split_text_into_chunks from ..utils.file_parser import FileParser, split_text_into_chunks
class TextProcessor: class TextProcessor:
"""文本处理器""" """Facade for the text-extraction and chunking pipeline."""
@staticmethod @staticmethod
def extract_from_files(file_paths: List[str]) -> str: def extract_from_files(file_paths: List[str]) -> str:
"""从多个文件提取文本""" """Extract and concatenate text from multiple files."""
return FileParser.extract_from_multiple(file_paths) return FileParser.extract_from_multiple(file_paths)
@staticmethod @staticmethod
def split_text( def split_text(
text: str, text: str,
chunk_size: int = 500, chunk_size: int = 500,
overlap: int = 50 overlap: int = 50
) -> List[str]: ) -> List[str]:
""" """Split text into chunks.
分割文本
Args: Args:
text: 原始文本 text: The source text.
chunk_size: 块大小 chunk_size: Target characters per chunk.
overlap: 重叠大小 overlap: Overlap between consecutive chunks.
Returns: Returns:
文本块列表 A list of chunk strings.
""" """
return split_text_into_chunks(text, chunk_size, overlap) return split_text_into_chunks(text, chunk_size, overlap)
@staticmethod @staticmethod
def preprocess_text(text: str) -> str: def preprocess_text(text: str) -> str:
""" """Pre-process text by normalizing whitespace and line endings.
预处理文本
- 移除多余空白 - Collapse runs of blank lines to at most two newlines.
- 标准化换行 - Normalize line endings to ``\\n``.
- Strip leading/trailing whitespace from each line.
Args: Args:
text: 原始文本 text: The source text.
Returns: Returns:
处理后的文本 The cleaned text.
""" """
import re import re
# 标准化换行
text = text.replace('\r\n', '\n').replace('\r', '\n') text = text.replace('\r\n', '\n').replace('\r', '\n')
# 移除连续空行(保留最多两个换行) # Collapse 3+ consecutive newlines down to a blank-line separator.
text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r'\n{3,}', '\n\n', text)
# 移除行首行尾空白
lines = [line.strip() for line in text.split('\n')] lines = [line.strip() for line in text.split('\n')]
text = '\n'.join(lines) text = '\n'.join(lines)
return text.strip() return text.strip()
@staticmethod @staticmethod
def get_text_stats(text: str) -> dict: def get_text_stats(text: str) -> dict:
"""获取文本统计信息""" """Return basic text statistics: total chars, lines, and words."""
return { return {
"total_chars": len(text), "total_chars": len(text),
"total_lines": text.count('\n') + 1, "total_lines": text.count('\n') + 1,

View File

@ -1,6 +1,7 @@
""" """Zep entity reader and filter service.
Zep实体读取与过滤服务
从Zep图谱中读取节点筛选出符合预定义实体类型的节点 Reads nodes from a Zep graph and filters down to those that match a
predefined ontology of entity types.
""" """
import time import time
@ -16,23 +17,23 @@ from ..utils.locale import t
logger = get_logger('mirofish.zep_entity_reader') logger = get_logger('mirofish.zep_entity_reader')
# 用于泛型返回类型 # Generic return-type variable.
T = TypeVar('T') T = TypeVar('T')
@dataclass @dataclass
class EntityNode: class EntityNode:
"""实体节点数据结构""" """In-memory representation of an entity node from the graph."""
uuid: str uuid: str
name: str name: str
labels: List[str] labels: List[str]
summary: str summary: str
attributes: Dict[str, Any] attributes: Dict[str, Any]
# 相关的边信息 # Edges connected to this entity.
related_edges: List[Dict[str, Any]] = field(default_factory=list) related_edges: List[Dict[str, Any]] = field(default_factory=list)
# 相关的其他节点信息 # Other nodes connected through related edges.
related_nodes: List[Dict[str, Any]] = field(default_factory=list) related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"uuid": self.uuid, "uuid": self.uuid,
@ -43,9 +44,9 @@ class EntityNode:
"related_edges": self.related_edges, "related_edges": self.related_edges,
"related_nodes": self.related_nodes, "related_nodes": self.related_nodes,
} }
def get_entity_type(self) -> Optional[str]: def get_entity_type(self) -> Optional[str]:
"""获取实体类型排除默认的Entity标签""" """Return the first non-default label, or ``None`` if only defaults are present."""
for label in self.labels: for label in self.labels:
if label not in ["Entity", "Node"]: if label not in ["Entity", "Node"]:
return label return label
@ -54,12 +55,12 @@ class EntityNode:
@dataclass @dataclass
class FilteredEntities: class FilteredEntities:
"""过滤后的实体集合""" """Result of a filter pass over the graph: matching entities + counts."""
entities: List[EntityNode] entities: List[EntityNode]
entity_types: Set[str] entity_types: Set[str]
total_count: int total_count: int
filtered_count: int filtered_count: int
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
"entities": [e.to_dict() for e in self.entities], "entities": [e.to_dict() for e in self.entities],
@ -70,40 +71,38 @@ class FilteredEntities:
class ZepEntityReader: class ZepEntityReader:
"""Read entities from a Zep graph and filter to ontology-defined types.
Capabilities:
1. Read all nodes from the graph.
2. Keep nodes whose labels include something other than the default ``Entity``.
3. Optionally enrich each entity with its connected edges and neighboring nodes.
""" """
Zep实体读取与过滤服务
主要功能
1. 从Zep图谱读取所有节点
2. 筛选出符合预定义实体类型的节点Labels不只是Entity的节点
3. 获取每个实体的相关边和关联节点信息
"""
def __init__(self, api_key: Optional[str] = None): def __init__(self, api_key: Optional[str] = None):
self.client = GraphitiAdapter() self.client = GraphitiAdapter()
def _call_with_retry( def _call_with_retry(
self, self,
func: Callable[[], T], func: Callable[[], T],
operation_name: str, operation_name: str,
max_retries: int = 3, max_retries: int = 3,
initial_delay: float = 2.0 initial_delay: float = 2.0
) -> T: ) -> T:
""" """Call a Zep API function with retry on failure.
带重试机制的Zep API调用
Args: Args:
func: 要执行的函数无参数的lambda或callable func: A zero-argument callable performing the request.
operation_name: 操作名称用于日志 operation_name: Operation label used in log output.
max_retries: 最大重试次数默认3次即最多尝试3次 max_retries: Maximum number of attempts (default 3 i.e. up to 3 tries total).
initial_delay: 初始延迟秒数 initial_delay: Initial delay between retries in seconds.
Returns: Returns:
API调用结果 The return value of ``func``.
""" """
last_exception = None last_exception = None
delay = initial_delay delay = initial_delay
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
return func() return func()
@ -114,21 +113,20 @@ class ZepEntityReader:
t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay) t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay)
) )
time.sleep(delay) time.sleep(delay)
delay *= 2 # 指数退避 delay *= 2 # exponential backoff
else: else:
logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e))) logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e)))
raise last_exception raise last_exception
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
""" """Return every node in the graph (paginated under the hood).
获取图谱的所有节点分页获取
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
Returns: Returns:
节点列表 A list of node dicts.
""" """
logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id)) logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id))
@ -148,14 +146,13 @@ class ZepEntityReader:
return nodes_data return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
""" """Return every edge in the graph (paginated under the hood).
获取图谱的所有边分页获取
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
Returns: Returns:
边列表 A list of edge dicts.
""" """
logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id)) logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id))
@ -174,24 +171,23 @@ class ZepEntityReader:
logger.info(t("log.zep_entity_reader.m006", len=len(edges_data))) logger.info(t("log.zep_entity_reader.m006", len=len(edges_data)))
return edges_data return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
""" """Return every edge connected to the given node (with retry).
获取指定节点的所有相关边带重试机制
Args: Args:
node_uuid: 节点UUID node_uuid: Node UUID.
Returns: Returns:
边列表 A list of edge dicts.
""" """
try: try:
# 使用重试机制调用Zep API # Wrap the API call in retry logic.
edges = self._call_with_retry( edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)" operation_name=f"获取节点边(node={node_uuid[:8]}...)"
) )
edges_data = [] edges_data = []
for edge in edges: for edge in edges:
edges_data.append({ edges_data.append({
@ -202,32 +198,31 @@ class ZepEntityReader:
"target_node_uuid": edge.target_node_uuid, "target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {}, "attributes": edge.attributes or {},
}) })
return edges_data return edges_data
except Exception as e: except Exception as e:
logger.warning(t("log.zep_entity_reader.m007", node_uuid=node_uuid, str=str(e))) logger.warning(t("log.zep_entity_reader.m007", node_uuid=node_uuid, str=str(e)))
return [] return []
def filter_defined_entities( def filter_defined_entities(
self, self,
graph_id: str, graph_id: str,
defined_entity_types: Optional[List[str]] = None, defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True enrich_with_edges: bool = True
) -> FilteredEntities: ) -> FilteredEntities:
""" """Filter nodes down to entities matching the predefined ontology types.
筛选出符合预定义实体类型的节点
Filtering rules:
筛选逻辑 - Skip nodes whose only label is ``Entity`` (uncategorized).
- 如果节点的Labels只有一个"Entity"说明这个实体不符合我们预定义的类型跳过 - Keep nodes whose labels include anything other than ``Entity`` and ``Node``.
- 如果节点的Labels包含除"Entity""Node"之外的标签说明符合预定义类型保留
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
defined_entity_types: 预定义的实体类型列表可选如果提供则只保留这些类型 defined_entity_types: Optional allow-list; when provided, only matching types are kept.
enrich_with_edges: 是否获取每个实体的相关边信息 enrich_with_edges: When ``True``, populate related_edges and related_nodes.
Returns: Returns:
FilteredEntities: 过滤后的实体集合 A ``FilteredEntities`` summary.
""" """
logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id)) logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id))
@ -243,7 +238,7 @@ class ZepEntityReader:
except Exception: except Exception:
pass pass
# 获取所有节点 # Read every node from the graph.
all_nodes = self.get_all_nodes(graph_id) all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes) total_count = len(all_nodes)
@ -259,27 +254,27 @@ class ZepEntityReader:
if entity_type != "Entity": if entity_type != "Entity":
node["labels"] = [entity_type] + labels node["labels"] = [entity_type] + labels
# 获取所有边(用于后续关联查找) # Read every edge so we can enrich entities later.
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# 构建节点UUID到节点数据的映射 # uuid -> node-data map for fast lookup.
node_map = {n["uuid"]: n for n in all_nodes} node_map = {n["uuid"]: n for n in all_nodes}
# 筛选符合条件的实体 # Filter to entities that match the criteria.
filtered_entities = [] filtered_entities = []
entity_types_found = set() entity_types_found = set()
for node in all_nodes: for node in all_nodes:
labels = node.get("labels", []) labels = node.get("labels", [])
# 筛选逻辑Labels必须包含除"Entity"和"Node"之外的标签 # Filtering rule: labels must contain something other than the defaults.
custom_labels = [l for l in labels if l not in ["Entity", "Node"]] custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
if not custom_labels: if not custom_labels:
# 只有默认标签,跳过 # Only default labels — skip.
continue continue
# 如果指定了预定义类型,检查是否匹配 # When a predefined-type list is supplied, require a match against it.
if defined_entity_types: if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types] matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels: if not matching_labels:
@ -287,10 +282,9 @@ class ZepEntityReader:
entity_type = matching_labels[0] entity_type = matching_labels[0]
else: else:
entity_type = custom_labels[0] entity_type = custom_labels[0]
entity_types_found.add(entity_type) entity_types_found.add(entity_type)
# 创建实体节点对象
entity = EntityNode( entity = EntityNode(
uuid=node["uuid"], uuid=node["uuid"],
name=node["name"], name=node["name"],
@ -298,12 +292,12 @@ class ZepEntityReader:
summary=node["summary"], summary=node["summary"],
attributes=node["attributes"], attributes=node["attributes"],
) )
# 获取相关边和节点 # Enrich with related edges and neighboring nodes.
if enrich_with_edges: if enrich_with_edges:
related_edges = [] related_edges = []
related_node_uuids = set() related_node_uuids = set()
for edge in all_edges: for edge in all_edges:
if edge["source_node_uuid"] == node["uuid"]: if edge["source_node_uuid"] == node["uuid"]:
related_edges.append({ related_edges.append({
@ -321,10 +315,10 @@ class ZepEntityReader:
"source_node_uuid": edge["source_node_uuid"], "source_node_uuid": edge["source_node_uuid"],
}) })
related_node_uuids.add(edge["source_node_uuid"]) related_node_uuids.add(edge["source_node_uuid"])
entity.related_edges = related_edges entity.related_edges = related_edges
# 获取关联节点的基本信息 # Populate basic info for each neighboring node.
related_nodes = [] related_nodes = []
for related_uuid in related_node_uuids: for related_uuid in related_node_uuids:
if related_uuid in node_map: if related_uuid in node_map:
@ -335,56 +329,55 @@ class ZepEntityReader:
"labels": related_node["labels"], "labels": related_node["labels"],
"summary": related_node.get("summary", ""), "summary": related_node.get("summary", ""),
}) })
entity.related_nodes = related_nodes entity.related_nodes = related_nodes
filtered_entities.append(entity) filtered_entities.append(entity)
logger.info(t("log.zep_entity_reader.m009", total_count=total_count, len=len(filtered_entities), entity_types_found=entity_types_found)) logger.info(t("log.zep_entity_reader.m009", total_count=total_count, len=len(filtered_entities), entity_types_found=entity_types_found))
return FilteredEntities( return FilteredEntities(
entities=filtered_entities, entities=filtered_entities,
entity_types=entity_types_found, entity_types=entity_types_found,
total_count=total_count, total_count=total_count,
filtered_count=len(filtered_entities), filtered_count=len(filtered_entities),
) )
def get_entity_with_context( def get_entity_with_context(
self, self,
graph_id: str, graph_id: str,
entity_uuid: str entity_uuid: str
) -> Optional[EntityNode]: ) -> Optional[EntityNode]:
""" """Fetch a single entity with its full context (edges + neighbors), with retry.
获取单个实体及其完整上下文边和关联节点带重试机制
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
entity_uuid: 实体UUID entity_uuid: Entity UUID.
Returns: Returns:
EntityNode或None ``EntityNode`` or ``None`` if not found.
""" """
try: try:
# 使用重试机制获取节点 # Fetch the node with retry.
node = self._call_with_retry( node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid), func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
) )
if not node: if not node:
return None return None
# 获取节点的边 # Edges connected to this node.
edges = self.get_node_edges(entity_uuid) edges = self.get_node_edges(entity_uuid)
# 获取所有节点用于关联查找 # All graph nodes, used for neighbor lookup.
all_nodes = self.get_all_nodes(graph_id) all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes} node_map = {n["uuid"]: n for n in all_nodes}
# 处理相关边和节点 # Collect related edges and neighboring uuids.
related_edges = [] related_edges = []
related_node_uuids = set() related_node_uuids = set()
for edge in edges: for edge in edges:
if edge["source_node_uuid"] == entity_uuid: if edge["source_node_uuid"] == entity_uuid:
related_edges.append({ related_edges.append({
@ -402,8 +395,8 @@ class ZepEntityReader:
"source_node_uuid": edge["source_node_uuid"], "source_node_uuid": edge["source_node_uuid"],
}) })
related_node_uuids.add(edge["source_node_uuid"]) related_node_uuids.add(edge["source_node_uuid"])
# 获取关联节点信息 # Populate basic info for each neighboring node.
related_nodes = [] related_nodes = []
for related_uuid in related_node_uuids: for related_uuid in related_node_uuids:
if related_uuid in node_map: if related_uuid in node_map:
@ -414,7 +407,7 @@ class ZepEntityReader:
"labels": related_node["labels"], "labels": related_node["labels"],
"summary": related_node.get("summary", ""), "summary": related_node.get("summary", ""),
}) })
return EntityNode( return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "", name=node.name or "",
@ -424,27 +417,26 @@ class ZepEntityReader:
related_edges=related_edges, related_edges=related_edges,
related_nodes=related_nodes, related_nodes=related_nodes,
) )
except Exception as e: except Exception as e:
logger.error(t("log.zep_entity_reader.m010", entity_uuid=entity_uuid, str=str(e))) logger.error(t("log.zep_entity_reader.m010", entity_uuid=entity_uuid, str=str(e)))
return None return None
def get_entities_by_type( def get_entities_by_type(
self, self,
graph_id: str, graph_id: str,
entity_type: str, entity_type: str,
enrich_with_edges: bool = True enrich_with_edges: bool = True
) -> List[EntityNode]: ) -> List[EntityNode]:
""" """Return every entity matching the given type.
获取指定类型的所有实体
Args: Args:
graph_id: 图谱ID graph_id: Graph identifier.
entity_type: 实体类型 "Student", "PublicFigure" entity_type: Entity type label (e.g. ``Student``, ``PublicFigure``).
enrich_with_edges: 是否获取相关边信息 enrich_with_edges: When ``True``, populate related edges/nodes.
Returns: Returns:
实体列表 A list of matching ``EntityNode`` instances.
""" """
result = self.filter_defined_entities( result = self.filter_defined_entities(
graph_id=graph_id, graph_id=graph_id,

View File

@ -1,6 +1,7 @@
""" """
Zep图谱记忆更新服务 Zep graph memory update service.
将模拟中的Agent活动动态更新到Zep图谱中
Streams agent activity from running simulations into the Zep knowledge graph.
""" """
import os import os
@ -23,7 +24,7 @@ logger = get_logger('mirofish.zep_graph_memory_updater')
@dataclass @dataclass
class AgentActivity: class AgentActivity:
"""Agent活动记录""" """Record of a single agent activity."""
platform: str # twitter / reddit platform: str # twitter / reddit
agent_id: int agent_id: int
agent_name: str agent_name: str
@ -33,13 +34,12 @@ class AgentActivity:
timestamp: str timestamp: str
def to_episode_text(self) -> str: def to_episode_text(self) -> str:
"""Render the activity as a natural-language episode for Zep.
The text uses plain narrative phrasing so Zep can extract entities and
relationships from it. No simulation-specific prefix is prepended, so
the graph update is not biased by framing words.
""" """
将活动转换为可以发送给Zep的文本描述
采用自然语言描述格式让Zep能够从中提取实体和关系
不添加模拟相关的前缀避免误导图谱更新
"""
# 根据不同的动作类型生成不同的描述
action_descriptions = { action_descriptions = {
"CREATE_POST": self._describe_create_post, "CREATE_POST": self._describe_create_post,
"LIKE_POST": self._describe_like_post, "LIKE_POST": self._describe_like_post,
@ -57,8 +57,8 @@ class AgentActivity:
describe_func = action_descriptions.get(self.action_type, self._describe_generic) describe_func = action_descriptions.get(self.action_type, self._describe_generic)
description = describe_func() description = describe_func()
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀 # Return "<agent name>: <activity>" with no simulation prefix.
return f"{self.agent_name}: {description}" return f"{self.agent_name}: {description}"
def _describe_create_post(self) -> str: def _describe_create_post(self) -> str:
@ -68,7 +68,7 @@ class AgentActivity:
return "发布了一条帖子" return "发布了一条帖子"
def _describe_like_post(self) -> str: def _describe_like_post(self) -> str:
"""点赞帖子 - 包含帖子原文和作者信息""" """Like a post — includes the post text and author when available."""
post_content = self.action_args.get("post_content", "") post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "") post_author = self.action_args.get("post_author_name", "")
@ -81,7 +81,7 @@ class AgentActivity:
return "点赞了一条帖子" return "点赞了一条帖子"
def _describe_dislike_post(self) -> str: def _describe_dislike_post(self) -> str:
"""踩帖子 - 包含帖子原文和作者信息""" """Dislike a post — includes the post text and author when available."""
post_content = self.action_args.get("post_content", "") post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "") post_author = self.action_args.get("post_author_name", "")
@ -94,7 +94,7 @@ class AgentActivity:
return "踩了一条帖子" return "踩了一条帖子"
def _describe_repost(self) -> str: def _describe_repost(self) -> str:
"""转发帖子 - 包含原帖内容和作者信息""" """Repost — includes the original post text and author when available."""
original_content = self.action_args.get("original_content", "") original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "") original_author = self.action_args.get("original_author_name", "")
@ -107,7 +107,7 @@ class AgentActivity:
return "转发了一条帖子" return "转发了一条帖子"
def _describe_quote_post(self) -> str: def _describe_quote_post(self) -> str:
"""引用帖子 - 包含原帖内容、作者信息和引用评论""" """Quote-post — includes the original post, author, and the quote comment."""
original_content = self.action_args.get("original_content", "") original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "") original_author = self.action_args.get("original_author_name", "")
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "") quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
@ -127,7 +127,7 @@ class AgentActivity:
return base return base
def _describe_follow(self) -> str: def _describe_follow(self) -> str:
"""关注用户 - 包含被关注用户的名称""" """Follow a user — includes the followed user's name."""
target_user_name = self.action_args.get("target_user_name", "") target_user_name = self.action_args.get("target_user_name", "")
if target_user_name: if target_user_name:
@ -135,7 +135,7 @@ class AgentActivity:
return "关注了一个用户" return "关注了一个用户"
def _describe_create_comment(self) -> str: def _describe_create_comment(self) -> str:
"""发表评论 - 包含评论内容和所评论的帖子信息""" """Create a comment — includes the comment text and the parent post."""
content = self.action_args.get("content", "") content = self.action_args.get("content", "")
post_content = self.action_args.get("post_content", "") post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "") post_author = self.action_args.get("post_author_name", "")
@ -151,7 +151,7 @@ class AgentActivity:
return "发表了评论" return "发表了评论"
def _describe_like_comment(self) -> str: def _describe_like_comment(self) -> str:
"""点赞评论 - 包含评论内容和作者信息""" """Like a comment — includes the comment text and author when available."""
comment_content = self.action_args.get("comment_content", "") comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "") comment_author = self.action_args.get("comment_author_name", "")
@ -164,7 +164,7 @@ class AgentActivity:
return "点赞了一条评论" return "点赞了一条评论"
def _describe_dislike_comment(self) -> str: def _describe_dislike_comment(self) -> str:
"""踩评论 - 包含评论内容和作者信息""" """Dislike a comment — includes the comment text and author when available."""
comment_content = self.action_args.get("comment_content", "") comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "") comment_author = self.action_args.get("comment_author_name", "")
@ -177,17 +177,17 @@ class AgentActivity:
return "踩了一条评论" return "踩了一条评论"
def _describe_search(self) -> str: def _describe_search(self) -> str:
"""搜索帖子 - 包含搜索关键词""" """Search posts — includes the search query."""
query = self.action_args.get("query", "") or self.action_args.get("keyword", "") query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
return f"搜索了「{query}" if query else "进行了搜索" return f"搜索了「{query}" if query else "进行了搜索"
def _describe_search_user(self) -> str: def _describe_search_user(self) -> str:
"""搜索用户 - 包含搜索关键词""" """Search users — includes the search query."""
query = self.action_args.get("query", "") or self.action_args.get("username", "") query = self.action_args.get("query", "") or self.action_args.get("username", "")
return f"搜索了用户「{query}" if query else "搜索了用户" return f"搜索了用户「{query}" if query else "搜索了用户"
def _describe_mute(self) -> str: def _describe_mute(self) -> str:
"""屏蔽用户 - 包含被屏蔽用户的名称""" """Mute a user — includes the muted user's name."""
target_user_name = self.action_args.get("target_user_name", "") target_user_name = self.action_args.get("target_user_name", "")
if target_user_name: if target_user_name:
@ -195,80 +195,79 @@ class AgentActivity:
return "屏蔽了一个用户" return "屏蔽了一个用户"
def _describe_generic(self) -> str: def _describe_generic(self) -> str:
# 对于未知的动作类型,生成通用描述 # Fallback narration for action types not handled explicitly above.
return f"执行了{self.action_type}操作" return f"执行了{self.action_type}操作"
class ZepGraphMemoryUpdater: class ZepGraphMemoryUpdater:
""" """Zep graph memory updater.
Zep图谱记忆更新器
Watches a simulation's actions log file and streams new agent activity
监控模拟的actions日志文件将新的agent活动实时更新到Zep图谱中 into the Zep knowledge graph in near real time. Activities are grouped
按平台分组每累积BATCH_SIZE条活动后批量发送到Zep by platform; each platform sends a batch once it has accumulated
``BATCH_SIZE`` items.
所有有意义的行为都会被更新到Zepaction_args中会包含完整的上下文信息
- 点赞/踩的帖子原文 Every meaningful action is forwarded to Zep, with full context preserved
- 转发/引用的帖子原文 in ``action_args``:
- 关注/屏蔽的用户名
- 点赞/踩的评论原文 - Original text of liked / disliked posts
- Original text of reposted / quoted posts
- Names of followed / muted users
- Original text of liked / disliked comments
""" """
# 批量发送大小(每个平台累积多少条后发送) # Number of activities to accumulate per platform before sending a batch.
BATCH_SIZE = 5 BATCH_SIZE = 5
# 平台名称映射(用于控制台显示) # Platform display names used for console / log output.
PLATFORM_DISPLAY_NAMES = { PLATFORM_DISPLAY_NAMES = {
'twitter': '世界1', 'twitter': '世界1',
'reddit': '世界2', 'reddit': '世界2',
} }
# 发送间隔(秒),避免请求过快 # Pause between sends (seconds) to avoid hammering the Zep API.
SEND_INTERVAL = 0.5 SEND_INTERVAL = 0.5
# 重试配置
MAX_RETRIES = 3 MAX_RETRIES = 3
RETRY_DELAY = 2 # RETRY_DELAY = 2 # seconds
def __init__(self, graph_id: str, api_key: Optional[str] = None): def __init__(self, graph_id: str, api_key: Optional[str] = None):
""" """Initialize the updater.
初始化更新器
Args: Args:
graph_id: Zep图谱ID graph_id: Zep graph ID.
api_key: Zep API Key可选默认从配置读取 api_key: Optional Zep API key; defaults to the value from config.
""" """
self.graph_id = graph_id self.graph_id = graph_id
self.client = GraphitiAdapter() self.client = GraphitiAdapter()
# 活动队列
self._activity_queue: Queue = Queue() self._activity_queue: Queue = Queue()
# 按平台分组的活动缓冲区每个平台各自累积到BATCH_SIZE后批量发送 # Per-platform buffer; each platform flushes once it reaches BATCH_SIZE.
self._platform_buffers: Dict[str, List[AgentActivity]] = { self._platform_buffers: Dict[str, List[AgentActivity]] = {
'twitter': [], 'twitter': [],
'reddit': [], 'reddit': [],
} }
self._buffer_lock = threading.Lock() self._buffer_lock = threading.Lock()
# 控制标志
self._running = False self._running = False
self._worker_thread: Optional[threading.Thread] = None self._worker_thread: Optional[threading.Thread] = None
# 统计 # Counters
self._total_activities = 0 # 实际添加到队列的活动数 self._total_activities = 0 # activities accepted into the queue
self._total_sent = 0 # 成功发送到Zep的批次数 self._total_sent = 0 # batches successfully sent to Zep
self._total_items_sent = 0 # 成功发送到Zep的活动条数 self._total_items_sent = 0 # individual activities successfully sent to Zep
self._failed_count = 0 # 发送失败的批次数 self._failed_count = 0 # batches that failed to send
self._skipped_count = 0 # 被过滤跳过的活动数DO_NOTHING self._skipped_count = 0 # activities filtered out (e.g. DO_NOTHING)
logger.info(t("log.zep_graph_memory_updater.m001", graph_id=graph_id, self=self.BATCH_SIZE)) logger.info(t("log.zep_graph_memory_updater.m001", graph_id=graph_id, self=self.BATCH_SIZE))
def _get_platform_display_name(self, platform: str) -> str: def _get_platform_display_name(self, platform: str) -> str:
"""获取平台的显示名称""" """Return the human-friendly display name for a platform."""
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform) return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
def start(self): def start(self):
"""启动后台工作线程""" """Start the background worker thread."""
if self._running: if self._running:
return return
@ -286,10 +285,9 @@ class ZepGraphMemoryUpdater:
logger.info(t("log.zep_graph_memory_updater.m002", self=self.graph_id)) logger.info(t("log.zep_graph_memory_updater.m002", self=self.graph_id))
def stop(self): def stop(self):
"""停止后台工作线程""" """Stop the background worker thread and flush pending activity."""
self._running = False self._running = False
# 发送剩余的活动
self._flush_remaining() self._flush_remaining()
if self._worker_thread and self._worker_thread.is_alive(): if self._worker_thread and self._worker_thread.is_alive():
@ -298,27 +296,28 @@ class ZepGraphMemoryUpdater:
logger.info(t("log.zep_graph_memory_updater.m003", self=self.graph_id, self_2=self._total_activities, self_3=self._total_sent, self_4=self._total_items_sent, self_5=self._failed_count, self_6=self._skipped_count)) logger.info(t("log.zep_graph_memory_updater.m003", self=self.graph_id, self_2=self._total_activities, self_3=self._total_sent, self_4=self._total_items_sent, self_5=self._failed_count, self_6=self._skipped_count))
def add_activity(self, activity: AgentActivity): def add_activity(self, activity: AgentActivity):
""" """Enqueue a single agent activity for delivery to Zep.
添加一个agent活动到队列
Every meaningful action is queued, including:
所有有意义的行为都会被添加到队列包括
- CREATE_POST发帖 - CREATE_POST (post)
- CREATE_COMMENT评论 - CREATE_COMMENT (comment)
- QUOTE_POST引用帖子 - QUOTE_POST (quote a post)
- SEARCH_POSTS搜索帖子 - SEARCH_POSTS (search posts)
- SEARCH_USER搜索用户 - SEARCH_USER (search users)
- LIKE_POST/DISLIKE_POST点赞/踩帖子 - LIKE_POST / DISLIKE_POST (like / dislike a post)
- REPOST转发 - REPOST (repost)
- FOLLOW关注 - FOLLOW (follow)
- MUTE屏蔽 - MUTE (mute)
- LIKE_COMMENT/DISLIKE_COMMENT点赞/踩评论 - LIKE_COMMENT / DISLIKE_COMMENT (like / dislike a comment)
action_args中会包含完整的上下文信息如帖子原文用户名等 ``action_args`` carries the full context (e.g. original post text,
user names) so the graph episode is self-contained.
Args: Args:
activity: Agent活动记录 activity: The agent activity record to enqueue.
""" """
# 跳过DO_NOTHING类型的活动 # DO_NOTHING actions carry no information worth indexing.
if activity.action_type == "DO_NOTHING": if activity.action_type == "DO_NOTHING":
self._skipped_count += 1 self._skipped_count += 1
return return
@ -328,14 +327,13 @@ class ZepGraphMemoryUpdater:
logger.debug(t("log.zep_graph_memory_updater.m004", activity=activity.agent_name, activity_2=activity.action_type)) logger.debug(t("log.zep_graph_memory_updater.m004", activity=activity.agent_name, activity_2=activity.action_type))
def add_activity_from_dict(self, data: Dict[str, Any], platform: str): def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
""" """Build an ``AgentActivity`` from a parsed JSON record and enqueue it.
从字典数据添加活动
Args: Args:
data: 从actions.jsonl解析的字典数据 data: A dict parsed from a single ``actions.jsonl`` line.
platform: 平台名称 (twitter/reddit) platform: Source platform name (``twitter`` or ``reddit``).
""" """
# 跳过事件类型的条目 # Event-type rows describe simulation lifecycle, not agent activity.
if "event_type" in data: if "event_type" in data:
return return
@ -352,28 +350,26 @@ class ZepGraphMemoryUpdater:
self.add_activity(activity) self.add_activity(activity)
def _worker_loop(self, locale: str = 'zh'): def _worker_loop(self, locale: str = 'zh'):
"""后台工作循环 - 按平台批量发送活动到Zep""" """Background loop that drains the queue and flushes per-platform batches."""
set_locale(locale) set_locale(locale)
while self._running or not self._activity_queue.empty(): while self._running or not self._activity_queue.empty():
try: try:
# 尝试从队列获取活动超时1秒 # Block briefly so the loop can also notice shutdown requests.
try: try:
activity = self._activity_queue.get(timeout=1) activity = self._activity_queue.get(timeout=1)
# 将活动添加到对应平台的缓冲区
platform = activity.platform.lower() platform = activity.platform.lower()
with self._buffer_lock: with self._buffer_lock:
if platform not in self._platform_buffers: if platform not in self._platform_buffers:
self._platform_buffers[platform] = [] self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity) self._platform_buffers[platform].append(activity)
# 检查该平台是否达到批量大小
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE: if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
batch = self._platform_buffers[platform][:self.BATCH_SIZE] batch = self._platform_buffers[platform][:self.BATCH_SIZE]
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:] self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
# 释放锁后再发送 # Release the lock before issuing the network call.
self._send_batch_activities(batch, platform) self._send_batch_activities(batch, platform)
# 发送间隔,避免请求过快 # Throttle so we don't hammer the Zep API.
time.sleep(self.SEND_INTERVAL) time.sleep(self.SEND_INTERVAL)
except Empty: except Empty:
@ -384,21 +380,20 @@ class ZepGraphMemoryUpdater:
time.sleep(1) time.sleep(1)
def _send_batch_activities(self, activities: List[AgentActivity], platform: str): def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
""" """Send a batch of activities to the Zep graph as one combined episode.
批量发送活动到Zep图谱合并为一条文本
Args: Args:
activities: Agent活动列表 activities: Agent activity records to send.
platform: 平台名称 platform: Source platform name.
""" """
if not activities: if not activities:
return return
# 将多条活动合并为一条文本,用换行分隔 # Concatenate the per-activity narrations into a single newline-separated episode.
episode_texts = [activity.to_episode_text() for activity in activities] episode_texts = [activity.to_episode_text() for activity in activities]
combined_text = "\n".join(episode_texts) combined_text = "\n".join(episode_texts)
# 带重试的发送 # Retry on failure with linear backoff.
for attempt in range(self.MAX_RETRIES): for attempt in range(self.MAX_RETRIES):
try: try:
self.client.graph.add( self.client.graph.add(
@ -423,8 +418,8 @@ class ZepGraphMemoryUpdater:
self._failed_count += 1 self._failed_count += 1
def _flush_remaining(self): def _flush_remaining(self):
"""发送队列和缓冲区中剩余的活动""" """Drain the queue and flush every platform buffer, even partial ones."""
# 首先处理队列中剩余的活动,添加到缓冲区 # Move anything still in the queue into the per-platform buffers.
while not self._activity_queue.empty(): while not self._activity_queue.empty():
try: try:
activity = self._activity_queue.get_nowait() activity = self._activity_queue.get_nowait()
@ -435,61 +430,55 @@ class ZepGraphMemoryUpdater:
self._platform_buffers[platform].append(activity) self._platform_buffers[platform].append(activity)
except Empty: except Empty:
break break
# 然后发送各平台缓冲区中剩余的活动即使不足BATCH_SIZE条 # Flush each platform buffer regardless of whether it reached BATCH_SIZE.
with self._buffer_lock: with self._buffer_lock:
for platform, buffer in self._platform_buffers.items(): for platform, buffer in self._platform_buffers.items():
if buffer: if buffer:
display_name = self._get_platform_display_name(platform) display_name = self._get_platform_display_name(platform)
logger.info(t("log.zep_graph_memory_updater.m010", display_name=display_name, len=len(buffer))) logger.info(t("log.zep_graph_memory_updater.m010", display_name=display_name, len=len(buffer)))
self._send_batch_activities(buffer, platform) self._send_batch_activities(buffer, platform)
# 清空所有缓冲区
for platform in self._platform_buffers: for platform in self._platform_buffers:
self._platform_buffers[platform] = [] self._platform_buffers[platform] = []
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""获取统计信息""" """Return a snapshot of updater statistics."""
with self._buffer_lock: with self._buffer_lock:
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()} buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
return { return {
"graph_id": self.graph_id, "graph_id": self.graph_id,
"batch_size": self.BATCH_SIZE, "batch_size": self.BATCH_SIZE,
"total_activities": self._total_activities, # 添加到队列的活动总数 "total_activities": self._total_activities, # activities accepted into the queue
"batches_sent": self._total_sent, # 成功发送的批次数 "batches_sent": self._total_sent, # batches successfully sent
"items_sent": self._total_items_sent, # 成功发送的活动条数 "items_sent": self._total_items_sent, # activities successfully sent
"failed_count": self._failed_count, # 发送失败的批次数 "failed_count": self._failed_count, # batches that failed to send
"skipped_count": self._skipped_count, # 被过滤跳过的活动数DO_NOTHING "skipped_count": self._skipped_count, # activities filtered out (e.g. DO_NOTHING)
"queue_size": self._activity_queue.qsize(), "queue_size": self._activity_queue.qsize(),
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小 "buffer_sizes": buffer_sizes, # per-platform buffer depth
"running": self._running, "running": self._running,
} }
class ZepGraphMemoryManager: class ZepGraphMemoryManager:
""" """Registry that owns one ``ZepGraphMemoryUpdater`` per active simulation."""
管理多个模拟的Zep图谱记忆更新器
每个模拟可以有自己的更新器实例
"""
_updaters: Dict[str, ZepGraphMemoryUpdater] = {} _updaters: Dict[str, ZepGraphMemoryUpdater] = {}
_lock = threading.Lock() _lock = threading.Lock()
@classmethod @classmethod
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater: def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
""" """Create (and start) a graph-memory updater for a simulation.
为模拟创建图谱记忆更新器
Args: Args:
simulation_id: 模拟ID simulation_id: Simulation ID.
graph_id: Zep图谱ID graph_id: Zep graph ID.
Returns: Returns:
ZepGraphMemoryUpdater实例 The started ``ZepGraphMemoryUpdater`` instance.
""" """
with cls._lock: with cls._lock:
# 如果已存在,先停止旧的 # An updater already exists for this simulation — stop it first.
if simulation_id in cls._updaters: if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop() cls._updaters[simulation_id].stop()
@ -502,25 +491,24 @@ class ZepGraphMemoryManager:
@classmethod @classmethod
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]: def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
"""获取模拟的更新器""" """Return the updater for a simulation, or ``None`` if absent."""
return cls._updaters.get(simulation_id) return cls._updaters.get(simulation_id)
@classmethod @classmethod
def stop_updater(cls, simulation_id: str): def stop_updater(cls, simulation_id: str):
"""停止并移除模拟的更新器""" """Stop and deregister the updater belonging to a simulation."""
with cls._lock: with cls._lock:
if simulation_id in cls._updaters: if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop() cls._updaters[simulation_id].stop()
del cls._updaters[simulation_id] del cls._updaters[simulation_id]
logger.info(t("log.zep_graph_memory_updater.m012", simulation_id=simulation_id)) logger.info(t("log.zep_graph_memory_updater.m012", simulation_id=simulation_id))
# 防止 stop_all 重复调用的标志 # Idempotency guard so ``stop_all`` only runs once per process lifetime.
_stop_all_done = False _stop_all_done = False
@classmethod @classmethod
def stop_all(cls): def stop_all(cls):
"""停止所有更新器""" """Stop every registered updater (idempotent)."""
# 防止重复调用
if cls._stop_all_done: if cls._stop_all_done:
return return
cls._stop_all_done = True cls._stop_all_done = True
@ -537,7 +525,7 @@ class ZepGraphMemoryManager:
@classmethod @classmethod
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]: def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
"""获取所有更新器的统计信息""" """Return statistics for every registered updater."""
return { return {
sim_id: updater.get_stats() sim_id: updater.get_stats()
for sim_id, updater in cls._updaters.items() for sim_id, updater in cls._updaters.items()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,4 @@
""" """Backend utilities package."""
工具模块
"""
from .file_parser import FileParser from .file_parser import FileParser
from .llm_client import LLMClient from .llm_client import LLMClient

View File

@ -1,6 +1,6 @@
""" """File parsing utilities.
文件解析工具
支持PDFMarkdownTXT文件的文本提取 Supports text extraction from PDF, Markdown, and plain-text files.
""" """
import os import os
@ -9,30 +9,27 @@ from typing import List, Optional
def _read_text_with_fallback(file_path: str) -> str: def _read_text_with_fallback(file_path: str) -> str:
""" """Read a text file, falling back through encoding detectors when UTF-8 fails.
读取文本文件UTF-8失败时自动探测编码
Multi-stage fallback strategy:
采用多级回退策略 1. Try UTF-8 first.
1. 首先尝试 UTF-8 解码 2. Use ``charset_normalizer`` to detect the encoding.
2. 使用 charset_normalizer 检测编码 3. Fall back to ``chardet``.
3. 回退到 chardet 检测编码 4. Last resort: decode with UTF-8 + ``errors='replace'``.
4. 最终使用 UTF-8 + errors='replace' 兜底
Args: Args:
file_path: 文件路径 file_path: Path to the file to read.
Returns: Returns:
解码后的文本内容 The decoded text content.
""" """
data = Path(file_path).read_bytes() data = Path(file_path).read_bytes()
# 首先尝试 UTF-8
try: try:
return data.decode('utf-8') return data.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
# 尝试使用 charset_normalizer 检测编码
encoding = None encoding = None
try: try:
from charset_normalizer import from_bytes from charset_normalizer import from_bytes
@ -41,8 +38,7 @@ def _read_text_with_fallback(file_path: str) -> str:
encoding = best.encoding encoding = best.encoding
except Exception: except Exception:
pass pass
# 回退到 chardet
if not encoding: if not encoding:
try: try:
import chardet import chardet
@ -50,89 +46,86 @@ def _read_text_with_fallback(file_path: str) -> str:
encoding = result.get('encoding') if result else None encoding = result.get('encoding') if result else None
except Exception: except Exception:
pass pass
# 最终兜底:使用 UTF-8 + replace
if not encoding: if not encoding:
encoding = 'utf-8' encoding = 'utf-8'
return data.decode(encoding, errors='replace') return data.decode(encoding, errors='replace')
class FileParser: class FileParser:
"""文件解析器""" """Parser for the supported document formats."""
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'} SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
@classmethod @classmethod
def extract_text(cls, file_path: str) -> str: def extract_text(cls, file_path: str) -> str:
""" """Extract plain text from a single supported file.
从文件中提取文本
Args: Args:
file_path: 文件路径 file_path: Path to the file.
Returns: Returns:
提取的文本内容 The extracted text content.
""" """
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}") raise FileNotFoundError(f"文件不存在: {file_path}")
suffix = path.suffix.lower() suffix = path.suffix.lower()
if suffix not in cls.SUPPORTED_EXTENSIONS: if suffix not in cls.SUPPORTED_EXTENSIONS:
raise ValueError(f"不支持的文件格式: {suffix}") raise ValueError(f"不支持的文件格式: {suffix}")
if suffix == '.pdf': if suffix == '.pdf':
return cls._extract_from_pdf(file_path) return cls._extract_from_pdf(file_path)
elif suffix in {'.md', '.markdown'}: elif suffix in {'.md', '.markdown'}:
return cls._extract_from_md(file_path) return cls._extract_from_md(file_path)
elif suffix == '.txt': elif suffix == '.txt':
return cls._extract_from_txt(file_path) return cls._extract_from_txt(file_path)
raise ValueError(f"无法处理的文件格式: {suffix}") raise ValueError(f"无法处理的文件格式: {suffix}")
@staticmethod @staticmethod
def _extract_from_pdf(file_path: str) -> str: def _extract_from_pdf(file_path: str) -> str:
"""从PDF提取文本""" """Extract text from a PDF file using PyMuPDF."""
try: try:
import fitz # PyMuPDF import fitz # PyMuPDF
except ImportError: except ImportError:
raise ImportError("需要安装PyMuPDF: pip install PyMuPDF") raise ImportError("需要安装PyMuPDF: pip install PyMuPDF")
text_parts = [] text_parts = []
with fitz.open(file_path) as doc: with fitz.open(file_path) as doc:
for page in doc: for page in doc:
text = page.get_text() text = page.get_text()
if text.strip(): if text.strip():
text_parts.append(text) text_parts.append(text)
return "\n\n".join(text_parts) return "\n\n".join(text_parts)
@staticmethod @staticmethod
def _extract_from_md(file_path: str) -> str: def _extract_from_md(file_path: str) -> str:
"""从Markdown提取文本支持自动编码检测""" """Extract text from a Markdown file with automatic encoding detection."""
return _read_text_with_fallback(file_path) return _read_text_with_fallback(file_path)
@staticmethod @staticmethod
def _extract_from_txt(file_path: str) -> str: def _extract_from_txt(file_path: str) -> str:
"""从TXT提取文本支持自动编码检测""" """Extract text from a plain-text file with automatic encoding detection."""
return _read_text_with_fallback(file_path) return _read_text_with_fallback(file_path)
@classmethod @classmethod
def extract_from_multiple(cls, file_paths: List[str]) -> str: def extract_from_multiple(cls, file_paths: List[str]) -> str:
""" """Extract and concatenate text from multiple files.
从多个文件提取文本并合并
Args: Args:
file_paths: 文件路径列表 file_paths: Paths of files to read.
Returns: Returns:
合并后的文本 The merged text, with per-file headers separating each section.
""" """
all_texts = [] all_texts = []
for i, file_path in enumerate(file_paths, 1): for i, file_path in enumerate(file_paths, 1):
try: try:
text = cls.extract_text(file_path) text = cls.extract_text(file_path)
@ -140,50 +133,48 @@ class FileParser:
all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}") all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}")
except Exception as e: except Exception as e:
all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===") all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===")
return "\n\n".join(all_texts) return "\n\n".join(all_texts)
def split_text_into_chunks( def split_text_into_chunks(
text: str, text: str,
chunk_size: int = 500, chunk_size: int = 500,
overlap: int = 50 overlap: int = 50
) -> List[str]: ) -> List[str]:
""" """Split text into overlapping chunks.
将文本分割成小块
Args: Args:
text: 原始文本 text: The source text to split.
chunk_size: 每块的字符数 chunk_size: Target characters per chunk.
overlap: 重叠字符数 overlap: Number of characters overlapping between consecutive chunks.
Returns: Returns:
文本块列表 A list of chunk strings.
""" """
if len(text) <= chunk_size: if len(text) <= chunk_size:
return [text] if text.strip() else [] return [text] if text.strip() else []
chunks = [] chunks = []
start = 0 start = 0
while start < len(text): while start < len(text):
end = start + chunk_size end = start + chunk_size
# 尝试在句子边界处分割 # Prefer splitting on a sentence boundary near the chunk end
if end < len(text): if end < len(text):
# 查找最近的句子结束符
for sep in ['', '', '', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']: for sep in ['', '', '', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
last_sep = text[start:end].rfind(sep) last_sep = text[start:end].rfind(sep)
if last_sep != -1 and last_sep > chunk_size * 0.3: if last_sep != -1 and last_sep > chunk_size * 0.3:
end = start + last_sep + len(sep) end = start + last_sep + len(sep)
break break
chunk = text[start:end].strip() chunk = text[start:end].strip()
if chunk: if chunk:
chunks.append(chunk) chunks.append(chunk)
# 下一个块从重叠位置开始 # Next chunk starts at the overlap point
start = end - overlap if end < len(text) else len(text) start = end - overlap if end < len(text) else len(text)
return chunks return chunks

View File

@ -1,6 +1,6 @@
""" """LLM client wrapper.
LLM客户端封装
统一使用OpenAI格式调用 All providers are called through the OpenAI-compatible API surface.
""" """
import json import json
@ -13,7 +13,7 @@ from ..config import Config
class LLMClient: class LLMClient:
"""LLM客户端""" """Thin wrapper around the OpenAI-compatible chat completions API."""
def __init__( def __init__(
self, self,
@ -37,17 +37,16 @@ class LLMClient:
max_tokens: int = 4096, max_tokens: int = 4096,
response_format: Optional[Dict] = None, response_format: Optional[Dict] = None,
) -> str: ) -> str:
""" """Send a chat completion request.
发送聊天请求
Args: Args:
messages: 消息列表 messages: Chat messages in OpenAI format.
temperature: 温度参数 temperature: Sampling temperature.
max_tokens: 最大token数 max_tokens: Maximum number of tokens to generate.
response_format: 响应格式如JSON模式 response_format: Optional response format hint (e.g. JSON mode).
Returns: Returns:
模型响应文本 The assistant's response text.
""" """
kwargs = { kwargs = {
"model": self.model, "model": self.model,
@ -61,7 +60,7 @@ class LLMClient:
response = self.client.chat.completions.create(**kwargs) response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除 # Some reasoning models (e.g. MiniMax M2.5) embed <think>...</think> blocks; strip them.
content = re.sub(r"<think>[\s\S]*?</think>", "", content).strip() content = re.sub(r"<think>[\s\S]*?</think>", "", content).strip()
return content return content
@ -79,7 +78,7 @@ class LLMClient:
messages=messages, temperature=temperature, max_tokens=max_tokens messages=messages, temperature=temperature, max_tokens=max_tokens
) )
# 清理markdown代码块标记 # Strip surrounding markdown code-fence markers if present.
cleaned_response = response.strip() cleaned_response = response.strip()
cleaned_response = re.sub( cleaned_response = re.sub(
r"^```(?:json)?\s*\n?", "", cleaned_response, flags=re.IGNORECASE r"^```(?:json)?\s*\n?", "", cleaned_response, flags=re.IGNORECASE

View File

@ -1,6 +1,7 @@
""" """Logger configuration module.
日志配置模块
提供统一的日志管理同时输出到控制台和文件 Provides unified logging that writes simultaneously to the console and a
rotating log file.
""" """
import os import os
@ -11,59 +12,55 @@ from logging.handlers import RotatingFileHandler
def _ensure_utf8_stdout(): def _ensure_utf8_stdout():
""" """Force stdout/stderr to UTF-8.
确保 stdout/stderr 使用 UTF-8 编码
解决 Windows 控制台中文乱码问题 Fixes garbled non-ASCII output on the Windows console.
""" """
if sys.platform == 'win32': if sys.platform == 'win32':
# Windows 下重新配置标准输出为 UTF-8 # On Windows, reconfigure the standard streams to UTF-8.
if hasattr(sys.stdout, 'reconfigure'): if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stdout.reconfigure(encoding='utf-8', errors='replace')
if hasattr(sys.stderr, 'reconfigure'): if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace')
# 日志目录 # Directory that holds rotated log files.
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs') LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger: def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
""" """Configure and return a logger.
设置日志器
Args: Args:
name: 日志器名称 name: Logger name.
level: 日志级别 level: Minimum log level for the logger.
Returns: Returns:
配置好的日志器 The configured logger.
""" """
# 确保日志目录存在
os.makedirs(LOG_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True)
# 创建日志器
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
# 阻止日志向上传播到根 logger避免重复输出 # Prevent propagation to the root logger to avoid duplicate output.
logger.propagate = False logger.propagate = False
# 如果已经有处理器,不重复添加 # If handlers are already attached, do not re-add them.
if logger.handlers: if logger.handlers:
return logger return logger
# 日志格式
detailed_formatter = logging.Formatter( detailed_formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s', '[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S' datefmt='%Y-%m-%d %H:%M:%S'
) )
simple_formatter = logging.Formatter( simple_formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s: %(message)s', '[%(asctime)s] %(levelname)s: %(message)s',
datefmt='%H:%M:%S' datefmt='%H:%M:%S'
) )
# 1. 文件处理器 - 详细日志(按日期命名,带轮转) # 1. File handler — detailed log, named by date and rotated by size.
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log' log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
os.path.join(LOG_DIR, log_filename), os.path.join(LOG_DIR, log_filename),
@ -73,30 +70,28 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
) )
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(detailed_formatter) file_handler.setFormatter(detailed_formatter)
# 2. 控制台处理器 - 简洁日志INFO及以上 # 2. Console handler — concise log, INFO and above.
# 确保 Windows 下使用 UTF-8 编码,避免中文乱码 # Ensure UTF-8 on Windows so non-ASCII characters render correctly.
_ensure_utf8_stdout() _ensure_utf8_stdout()
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
console_handler.setFormatter(simple_formatter) console_handler.setFormatter(simple_formatter)
# 添加处理器
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
return logger return logger
def get_logger(name: str = 'mirofish') -> logging.Logger: def get_logger(name: str = 'mirofish') -> logging.Logger:
""" """Return an existing logger by name, creating it lazily if needed.
获取日志器如果不存在则创建
Args: Args:
name: 日志器名称 name: Logger name.
Returns: Returns:
日志器实例 The logger instance.
""" """
logger = logging.getLogger(name) logger = logging.getLogger(name)
if not logger.handlers: if not logger.handlers:
@ -104,11 +99,11 @@ def get_logger(name: str = 'mirofish') -> logging.Logger:
return logger return logger
# 创建默认日志器 # Default module-level logger.
logger = setup_logger() logger = setup_logger()
# 便捷方法 # Convenience module-level helpers.
def debug(msg, *args, **kwargs): def debug(msg, *args, **kwargs):
logger.debug(msg, *args, **kwargs) logger.debug(msg, *args, **kwargs)

View File

@ -1,6 +1,7 @@
""" """API call retry primitives.
API调用重试机制
用于处理LLM等外部API调用的重试逻辑 Helpers for retrying calls to external APIs (LLMs, etc.) with exponential
backoff and jitter.
""" """
import time import time
@ -22,18 +23,17 @@ def retry_with_backoff(
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None on_retry: Optional[Callable[[Exception, int], None]] = None
): ):
""" """Decorator that retries a callable with exponential backoff.
带指数退避的重试装饰器
Args: Args:
max_retries: 最大重试次数 max_retries: Maximum number of retries before giving up.
initial_delay: 初始延迟 initial_delay: Initial delay in seconds before the first retry.
max_delay: 最大延迟 max_delay: Cap on the delay between retries (seconds).
backoff_factor: 退避因子 backoff_factor: Multiplicative factor applied to the delay each retry.
jitter: 是否添加随机抖动 jitter: When ``True``, randomize the delay to avoid thundering herd.
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry.
on_retry: 重试时的回调函数 (exception, retry_count) on_retry: Optional callback invoked on each retry as ``(exception, retry_count)``.
Usage: Usage:
@retry_with_backoff(max_retries=3) @retry_with_backoff(max_retries=3)
def call_llm_api(): def call_llm_api():
@ -61,7 +61,7 @@ def retry_with_backoff(
)) ))
raise raise
# 计算延迟 # Compute the next delay, capped at ``max_delay``.
current_delay = min(delay, max_delay) current_delay = min(delay, max_delay)
if jitter: if jitter:
current_delay = current_delay * (0.5 + random.random()) current_delay = current_delay * (0.5 + random.random())
@ -92,9 +92,7 @@ def retry_with_backoff_async(
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
on_retry: Optional[Callable[[Exception, int], None]] = None on_retry: Optional[Callable[[Exception, int], None]] = None
): ):
""" """Async variant of :func:`retry_with_backoff`."""
异步版本的重试装饰器
"""
import asyncio import asyncio
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@ -141,9 +139,7 @@ def retry_with_backoff_async(
class RetryableAPIClient: class RetryableAPIClient:
""" """Class-based wrapper around the retry helpers."""
可重试的API客户端封装
"""
def __init__( def __init__(
self, self,
@ -164,17 +160,16 @@ class RetryableAPIClient:
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
**kwargs **kwargs
) -> Any: ) -> Any:
""" """Invoke ``func`` with retry on failure.
执行函数调用并在失败时重试
Args: Args:
func: 要调用的函数 func: Callable to invoke.
*args: 函数参数 *args: Positional arguments forwarded to ``func``.
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry.
**kwargs: 函数关键字参数 **kwargs: Keyword arguments forwarded to ``func``.
Returns: Returns:
函数返回值 The value returned by ``func``.
""" """
last_exception = None last_exception = None
delay = self.initial_delay delay = self.initial_delay
@ -214,17 +209,17 @@ class RetryableAPIClient:
exceptions: Tuple[Type[Exception], ...] = (Exception,), exceptions: Tuple[Type[Exception], ...] = (Exception,),
continue_on_failure: bool = True continue_on_failure: bool = True
) -> Tuple[list, list]: ) -> Tuple[list, list]:
""" """Process ``items`` in sequence, retrying each independently on failure.
批量调用并对每个失败项单独重试
Args: Args:
items: 要处理的项目列表 items: Items to process.
process_func: 处理函数接收单个item作为参数 process_func: Callable invoked once per item.
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry.
continue_on_failure: 单项失败后是否继续处理其他项 continue_on_failure: When ``True``, keep processing remaining items after a failure.
Returns: Returns:
(成功结果列表, 失败项列表) ``(successes, failures)`` a list of successful results and a list
of failure descriptors ``{"index", "item", "error"}``.
""" """
results = [] results = []
failures = [] failures = []

View File

@ -1,7 +1,8 @@
"""Zep Graph 分页读取工具。 """Zep Graph paging helpers.
Zep node/edge 列表接口使用 UUID cursor 分页 Zep's node/edge list APIs paginate with a UUID cursor. This module wraps the
本模块封装自动翻页逻辑含单页重试对调用方透明地返回完整列表 auto-paging loop (including per-page retry) so callers see the full list
transparently.
""" """
from __future__ import annotations from __future__ import annotations
@ -30,7 +31,7 @@ def _fetch_page_with_retry(
page_description: str = "page", page_description: str = "page",
**kwargs: Any, **kwargs: Any,
) -> list[Any]: ) -> list[Any]:
"""单页请求失败时指数退避重试。自动处理429限速。""" """Fetch one page, retrying with exponential backoff. Handles 429 rate limits."""
if max_retries < 1: if max_retries < 1:
raise ValueError("max_retries must be >= 1") raise ValueError("max_retries must be >= 1")
@ -43,7 +44,7 @@ def _fetch_page_with_retry(
except Exception as e: except Exception as e:
last_exception = e last_exception = e
if attempt < max_retries - 1: if attempt < max_retries - 1:
# 检测429限速使用retry-after头部指定的等待时间 # If a 429 rate limit is detected, prefer the retry-after header for the wait.
wait = delay wait = delay
logger.warning( logger.warning(
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {wait:.1f}s..." f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {wait:.1f}s..."
@ -65,7 +66,7 @@ def fetch_all_nodes(
max_retries: int = _DEFAULT_MAX_RETRIES, max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY, retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]: ) -> list[Any]:
"""分页获取图谱节点,最多返回 max_items 条(默认 2000。每页请求自带重试。""" """Page through graph nodes; return at most ``max_items`` (default 2000). Each page is retried internally."""
all_nodes: list[Any] = [] all_nodes: list[Any] = []
cursor: str | None = None cursor: str | None = None
page_num = 0 page_num = 0
@ -110,7 +111,7 @@ def fetch_all_edges(
max_retries: int = _DEFAULT_MAX_RETRIES, max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY, retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]: ) -> list[Any]:
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。""" """Page through every graph edge and return the full list. Each page is retried internally."""
all_edges: list[Any] = [] all_edges: list[Any] = []
cursor: str | None = None cursor: str | None = None
page_num = 0 page_num = 0

View File

@ -1,21 +1,20 @@
""" """MiroFish backend entry point."""
MiroFish Backend 启动入口
"""
import os import os
import sys import sys
# 解决 Windows 控制台中文乱码问题:在所有导入之前设置 UTF-8 编码 # Force UTF-8 on Windows console before importing anything that might write to
# stdout/stderr; otherwise non-ASCII characters render as mojibake.
if sys.platform == 'win32': if sys.platform == 'win32':
# 设置环境变量确保 Python 使用 UTF-8 # Make sure Python itself uses UTF-8.
os.environ.setdefault('PYTHONIOENCODING', 'utf-8') os.environ.setdefault('PYTHONIOENCODING', 'utf-8')
# 重新配置标准输出流为 UTF-8 # Reconfigure the standard streams to UTF-8.
if hasattr(sys.stdout, 'reconfigure'): if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stdout.reconfigure(encoding='utf-8', errors='replace')
if hasattr(sys.stderr, 'reconfigure'): if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace')
# 添加项目根目录到路径 # Add the project root to sys.path so the ``app`` package resolves.
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from app import create_app from app import create_app
@ -23,8 +22,7 @@ from app.config import Config
def main(): def main():
"""主函数""" """Validate configuration and start the Flask development server."""
# 验证配置
errors = Config.validate() errors = Config.validate()
if errors: if errors:
print("配置错误:") print("配置错误:")
@ -32,19 +30,16 @@ def main():
print(f" - {err}") print(f" - {err}")
print("\n请检查 .env 文件中的配置") print("\n请检查 .env 文件中的配置")
sys.exit(1) sys.exit(1)
# 创建应用
app = create_app() app = create_app()
# 获取运行配置 # Resolve runtime host/port from the environment.
host = os.environ.get('FLASK_HOST', '0.0.0.0') host = os.environ.get('FLASK_HOST', '0.0.0.0')
port = int(os.environ.get('FLASK_PORT', 5001)) port = int(os.environ.get('FLASK_PORT', 5001))
debug = Config.DEBUG debug = Config.DEBUG
# 启动服务
app.run(host=host, port=port, debug=debug, threaded=True) app.run(host=host, port=port, debug=debug, threaded=True)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,15 +1,17 @@
""" """Action logger.
动作日志记录器
用于记录OASIS模拟中每个Agent的动作供后端监控使用 Records each agent action during an OASIS simulation so the backend can
monitor progress.
Log layout::
日志结构:
sim_xxx/ sim_xxx/
twitter/ twitter/
actions.jsonl # Twitter 平台动作日志 actions.jsonl # Twitter action log
reddit/ reddit/
actions.jsonl # Reddit 平台动作日志 actions.jsonl # Reddit action log
simulation.log # 主模拟进程日志 simulation.log # main simulation process log
run_state.json # 运行状态API 查询用) run_state.json # run state (queried by the API)
""" """
import json import json
@ -20,26 +22,25 @@ from typing import Dict, Any, Optional
class PlatformActionLogger: class PlatformActionLogger:
"""单平台动作日志记录器""" """Per-platform action logger."""
def __init__(self, platform: str, base_dir: str): def __init__(self, platform: str, base_dir: str):
""" """Initialize the logger.
初始化日志记录器
Args: Args:
platform: 平台名称 (twitter/reddit) platform: Platform name (``twitter`` or ``reddit``).
base_dir: 模拟目录的基础路径 base_dir: Base path of the simulation directory.
""" """
self.platform = platform self.platform = platform
self.base_dir = base_dir self.base_dir = base_dir
self.log_dir = os.path.join(base_dir, platform) self.log_dir = os.path.join(base_dir, platform)
self.log_path = os.path.join(self.log_dir, "actions.jsonl") self.log_path = os.path.join(self.log_dir, "actions.jsonl")
self._ensure_dir() self._ensure_dir()
def _ensure_dir(self): def _ensure_dir(self):
"""确保目录存在""" """Ensure the log directory exists."""
os.makedirs(self.log_dir, exist_ok=True) os.makedirs(self.log_dir, exist_ok=True)
def log_action( def log_action(
self, self,
round_num: int, round_num: int,
@ -50,7 +51,7 @@ class PlatformActionLogger:
result: Optional[str] = None, result: Optional[str] = None,
success: bool = True success: bool = True
): ):
"""记录一个动作""" """Append a single action record."""
entry = { entry = {
"round": round_num, "round": round_num,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
@ -61,36 +62,36 @@ class PlatformActionLogger:
"result": result, "result": result,
"success": success, "success": success,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_start(self, round_num: int, simulated_hour: int): def log_round_start(self, round_num: int, simulated_hour: int):
"""记录轮次开始""" """Append a round-start marker."""
entry = { entry = {
"round": round_num, "round": round_num,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"event_type": "round_start", "event_type": "round_start",
"simulated_hour": simulated_hour, "simulated_hour": simulated_hour,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_end(self, round_num: int, actions_count: int): def log_round_end(self, round_num: int, actions_count: int):
"""记录轮次结束""" """Append a round-end marker."""
entry = { entry = {
"round": round_num, "round": round_num,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"event_type": "round_end", "event_type": "round_end",
"actions_count": actions_count, "actions_count": actions_count,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_start(self, config: Dict[str, Any]): def log_simulation_start(self, config: Dict[str, Any]):
"""记录模拟开始""" """Append a simulation-start marker."""
entry = { entry = {
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"event_type": "simulation_start", "event_type": "simulation_start",
@ -98,12 +99,12 @@ class PlatformActionLogger:
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2, "total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
"agents_count": len(config.get("agent_configs", [])), "agents_count": len(config.get("agent_configs", [])),
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_end(self, total_rounds: int, total_actions: int): def log_simulation_end(self, total_rounds: int, total_actions: int):
"""记录模拟结束""" """Append a simulation-end marker."""
entry = { entry = {
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"event_type": "simulation_end", "event_type": "simulation_end",
@ -111,42 +112,42 @@ class PlatformActionLogger:
"total_rounds": total_rounds, "total_rounds": total_rounds,
"total_actions": total_actions, "total_actions": total_actions,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
class SimulationLogManager: class SimulationLogManager:
"""Top-level log manager.
Owns and dispatches to the per-platform action loggers, and exposes a
main process logger for non-action messages.
""" """
模拟日志管理器
统一管理所有日志文件按平台分离
"""
def __init__(self, simulation_dir: str): def __init__(self, simulation_dir: str):
""" """Initialize the log manager.
初始化日志管理器
Args: Args:
simulation_dir: 模拟目录路径 simulation_dir: Path to the simulation directory.
""" """
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
self.twitter_logger: Optional[PlatformActionLogger] = None self.twitter_logger: Optional[PlatformActionLogger] = None
self.reddit_logger: Optional[PlatformActionLogger] = None self.reddit_logger: Optional[PlatformActionLogger] = None
self._main_logger: Optional[logging.Logger] = None self._main_logger: Optional[logging.Logger] = None
# 设置主日志 # Configure the main process logger.
self._setup_main_logger() self._setup_main_logger()
def _setup_main_logger(self): def _setup_main_logger(self):
"""设置主模拟日志""" """Configure the main simulation log."""
log_path = os.path.join(self.simulation_dir, "simulation.log") log_path = os.path.join(self.simulation_dir, "simulation.log")
# 创建 logger # Build the logger.
self._main_logger = logging.getLogger(f"simulation.{os.path.basename(self.simulation_dir)}") self._main_logger = logging.getLogger(f"simulation.{os.path.basename(self.simulation_dir)}")
self._main_logger.setLevel(logging.INFO) self._main_logger.setLevel(logging.INFO)
self._main_logger.handlers.clear() self._main_logger.handlers.clear()
# 文件处理器 # File handler.
file_handler = logging.FileHandler(log_path, encoding='utf-8', mode='w') file_handler = logging.FileHandler(log_path, encoding='utf-8', mode='w')
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter( file_handler.setFormatter(logging.Formatter(
@ -154,8 +155,8 @@ class SimulationLogManager:
datefmt='%Y-%m-%d %H:%M:%S' datefmt='%Y-%m-%d %H:%M:%S'
)) ))
self._main_logger.addHandler(file_handler) self._main_logger.addHandler(file_handler)
# 控制台处理器 # Console handler.
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter( console_handler.setFormatter(logging.Formatter(
@ -163,56 +164,56 @@ class SimulationLogManager:
datefmt='%H:%M:%S' datefmt='%H:%M:%S'
)) ))
self._main_logger.addHandler(console_handler) self._main_logger.addHandler(console_handler)
self._main_logger.propagate = False self._main_logger.propagate = False
def get_twitter_logger(self) -> PlatformActionLogger: def get_twitter_logger(self) -> PlatformActionLogger:
"""获取 Twitter 平台日志记录器""" """Lazily construct and return the Twitter platform logger."""
if self.twitter_logger is None: if self.twitter_logger is None:
self.twitter_logger = PlatformActionLogger("twitter", self.simulation_dir) self.twitter_logger = PlatformActionLogger("twitter", self.simulation_dir)
return self.twitter_logger return self.twitter_logger
def get_reddit_logger(self) -> PlatformActionLogger: def get_reddit_logger(self) -> PlatformActionLogger:
"""获取 Reddit 平台日志记录器""" """Lazily construct and return the Reddit platform logger."""
if self.reddit_logger is None: if self.reddit_logger is None:
self.reddit_logger = PlatformActionLogger("reddit", self.simulation_dir) self.reddit_logger = PlatformActionLogger("reddit", self.simulation_dir)
return self.reddit_logger return self.reddit_logger
def log(self, message: str, level: str = "info"): def log(self, message: str, level: str = "info"):
"""记录主日志""" """Forward a message to the main logger at the given level."""
if self._main_logger: if self._main_logger:
getattr(self._main_logger, level.lower(), self._main_logger.info)(message) getattr(self._main_logger, level.lower(), self._main_logger.info)(message)
def info(self, message: str): def info(self, message: str):
self.log(message, "info") self.log(message, "info")
def warning(self, message: str): def warning(self, message: str):
self.log(message, "warning") self.log(message, "warning")
def error(self, message: str): def error(self, message: str):
self.log(message, "error") self.log(message, "error")
def debug(self, message: str): def debug(self, message: str):
self.log(message, "debug") self.log(message, "debug")
# ============ 兼容旧接口 ============ # ============ Legacy interface ============
class ActionLogger: class ActionLogger:
"""Legacy single-platform action logger.
Prefer :class:`SimulationLogManager` for new code.
""" """
动作日志记录器兼容旧接口
建议使用 SimulationLogManager 代替
"""
def __init__(self, log_path: str): def __init__(self, log_path: str):
self.log_path = log_path self.log_path = log_path
self._ensure_dir() self._ensure_dir()
def _ensure_dir(self): def _ensure_dir(self):
log_dir = os.path.dirname(self.log_path) log_dir = os.path.dirname(self.log_path)
if log_dir: if log_dir:
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
def log_action( def log_action(
self, self,
round_num: int, round_num: int,
@ -235,10 +236,10 @@ class ActionLogger:
"result": result, "result": result,
"success": success, "success": success,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_start(self, round_num: int, simulated_hour: int, platform: str): def log_round_start(self, round_num: int, simulated_hour: int, platform: str):
entry = { entry = {
"round": round_num, "round": round_num,
@ -247,10 +248,10 @@ class ActionLogger:
"event_type": "round_start", "event_type": "round_start",
"simulated_hour": simulated_hour, "simulated_hour": simulated_hour,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_round_end(self, round_num: int, actions_count: int, platform: str): def log_round_end(self, round_num: int, actions_count: int, platform: str):
entry = { entry = {
"round": round_num, "round": round_num,
@ -259,10 +260,10 @@ class ActionLogger:
"event_type": "round_end", "event_type": "round_end",
"actions_count": actions_count, "actions_count": actions_count,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_start(self, platform: str, config: Dict[str, Any]): def log_simulation_start(self, platform: str, config: Dict[str, Any]):
entry = { entry = {
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
@ -271,10 +272,10 @@ class ActionLogger:
"total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2, "total_rounds": config.get("time_config", {}).get("total_simulation_hours", 72) * 2,
"agents_count": len(config.get("agent_configs", [])), "agents_count": len(config.get("agent_configs", [])),
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int): def log_simulation_end(self, platform: str, total_rounds: int, total_actions: int):
entry = { entry = {
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
@ -283,23 +284,23 @@ class ActionLogger:
"total_rounds": total_rounds, "total_rounds": total_rounds,
"total_actions": total_actions, "total_actions": total_actions,
} }
with open(self.log_path, 'a', encoding='utf-8') as f: with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(entry, ensure_ascii=False) + '\n') f.write(json.dumps(entry, ensure_ascii=False) + '\n')
# 全局日志实例(兼容旧接口) # Process-wide logger instance, used by the legacy interface.
_global_logger: Optional[ActionLogger] = None _global_logger: Optional[ActionLogger] = None
def get_logger(log_path: Optional[str] = None) -> ActionLogger: def get_logger(log_path: Optional[str] = None) -> ActionLogger:
"""获取全局日志实例(兼容旧接口)""" """Return the process-wide :class:`ActionLogger` (legacy interface)."""
global _global_logger global _global_logger
if log_path: if log_path:
_global_logger = ActionLogger(log_path) _global_logger = ActionLogger(log_path)
if _global_logger is None: if _global_logger is None:
_global_logger = ActionLogger("actions.jsonl") _global_logger = ActionLogger("actions.jsonl")
return _global_logger return _global_logger

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,16 @@
""" """OASIS Reddit simulation preset script.
OASIS Reddit模拟预设脚本
此脚本读取配置文件中的参数来执行模拟实现全程自动化
功能特性: This script reads parameters from a config file and runs the simulation end-to-end automatically.
- 完成模拟后不立即关闭环境进入等待命令模式
- 支持通过IPC接收Interview命令
- 支持单个Agent采访和批量采访
- 支持远程关闭环境命令
使用方式: Features:
- After the simulation finishes, the environment stays alive and enters a command-wait mode.
- Accepts Interview commands over IPC.
- Supports single-agent and batch interviews.
- Supports a remote close-environment command.
Usage:
python run_reddit_simulation.py --config /path/to/simulation_config.json python run_reddit_simulation.py --config /path/to/simulation_config.json
python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done
""" """
import argparse import argparse
@ -25,18 +25,18 @@ import sqlite3
from datetime import datetime from datetime import datetime
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
# 全局变量:用于信号处理 # Globals used by the signal handler.
_shutdown_event = None _shutdown_event = None
_cleanup_done = False _cleanup_done = False
# 添加项目路径 # Add project paths to sys.path so sibling modules import correctly.
_scripts_dir = os.path.dirname(os.path.abspath(__file__)) _scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir) sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir) sys.path.insert(0, _backend_dir)
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) # Load the .env file from the project root (contains LLM_API_KEY and related settings).
from dotenv import load_dotenv from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env') _env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file): if os.path.exists(_env_file):
@ -51,7 +51,7 @@ import re
class UnicodeFormatter(logging.Formatter): class UnicodeFormatter(logging.Formatter):
"""自定义格式化器,将 Unicode 转义序列转换为可读字符""" """Custom log formatter that converts Unicode escape sequences into readable characters."""
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
@ -68,24 +68,23 @@ class UnicodeFormatter(logging.Formatter):
class MaxTokensWarningFilter(logging.Filter): class MaxTokensWarningFilter(logging.Filter):
"""过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens让模型自行决定""" """Suppress camel-ai's max_tokens warning (we intentionally leave max_tokens unset and let the model decide)."""
def filter(self, record): def filter(self, record):
# 过滤掉包含 max_tokens 警告的日志
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
return False return False
return True return True
# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 # Install the filter at module import time so it takes effect before any camel code runs.
logging.getLogger().addFilter(MaxTokensWarningFilter()) logging.getLogger().addFilter(MaxTokensWarningFilter())
def setup_oasis_logging(log_dir: str): def setup_oasis_logging(log_dir: str):
"""配置 OASIS 的日志,使用固定名称的日志文件""" """Configure OASIS logging with fixed log file names."""
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件 # Remove stale log files from previous runs so the new run starts clean.
for f in os.listdir(log_dir): for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f) old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'): if os.path.isfile(old_log) and f.endswith('.log'):
@ -131,20 +130,20 @@ except ImportError as e:
sys.exit(1) sys.exit(1)
# IPC相关常量 # IPC-related constants.
IPC_COMMANDS_DIR = "ipc_commands" IPC_COMMANDS_DIR = "ipc_commands"
IPC_RESPONSES_DIR = "ipc_responses" IPC_RESPONSES_DIR = "ipc_responses"
ENV_STATUS_FILE = "env_status.json" ENV_STATUS_FILE = "env_status.json"
class CommandType: class CommandType:
"""命令类型常量""" """Command type constants."""
INTERVIEW = "interview" INTERVIEW = "interview"
BATCH_INTERVIEW = "batch_interview" BATCH_INTERVIEW = "batch_interview"
CLOSE_ENV = "close_env" CLOSE_ENV = "close_env"
class IPCHandler: class IPCHandler:
"""IPC命令处理器""" """IPC command handler."""
def __init__(self, simulation_dir: str, env, agent_graph): def __init__(self, simulation_dir: str, env, agent_graph):
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
@ -154,13 +153,12 @@ class IPCHandler:
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
self._running = True self._running = True
# 确保目录存在
os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True)
def update_status(self, status: str): def update_status(self, status: str):
"""更新环境状态""" """Update the environment status file."""
with open(self.status_file, 'w', encoding='utf-8') as f: with open(self.status_file, 'w', encoding='utf-8') as f:
json.dump({ json.dump({
"status": status, "status": status,
@ -168,11 +166,11 @@ class IPCHandler:
}, f, ensure_ascii=False, indent=2) }, f, ensure_ascii=False, indent=2)
def poll_command(self) -> Optional[Dict[str, Any]]: def poll_command(self) -> Optional[Dict[str, Any]]:
"""轮询获取待处理命令""" """Poll for pending IPC commands."""
if not os.path.exists(self.commands_dir): if not os.path.exists(self.commands_dir):
return None return None
# 获取命令文件(按时间排序) # Collect command files sorted by modification time so older commands are handled first.
command_files = [] command_files = []
for filename in os.listdir(self.commands_dir): for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'): if filename.endswith('.json'):
@ -191,7 +189,7 @@ class IPCHandler:
return None return None
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
"""发送响应""" """Send an IPC response for a command."""
response = { response = {
"command_id": command_id, "command_id": command_id,
"status": status, "status": status,
@ -203,8 +201,8 @@ class IPCHandler:
response_file = os.path.join(self.responses_dir, f"{command_id}.json") response_file = os.path.join(self.responses_dir, f"{command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f: with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response, f, ensure_ascii=False, indent=2) json.dump(response, f, ensure_ascii=False, indent=2)
# 删除命令文件 # Remove the command file once a response has been written so it isn't re-processed.
command_file = os.path.join(self.commands_dir, f"{command_id}.json") command_file = os.path.join(self.commands_dir, f"{command_id}.json")
try: try:
os.remove(command_file) os.remove(command_file)
@ -212,29 +210,25 @@ class IPCHandler:
pass pass
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
""" """Handle a single-agent interview command.
处理单个Agent采访命令
Returns: Returns:
True 表示成功False 表示失败 True on success, False on failure.
""" """
try: try:
# 获取Agent
agent = self.agent_graph.get_agent(agent_id) agent = self.agent_graph.get_agent(agent_id)
# 创建Interview动作
interview_action = ManualAction( interview_action = ManualAction(
action_type=ActionType.INTERVIEW, action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt} action_args={"prompt": prompt}
) )
# 执行Interview
actions = {agent: interview_action} actions = {agent: interview_action}
await self.env.step(actions) await self.env.step(actions)
# 从数据库获取结果 # Read the interview answer back from the simulation database.
result = self._get_interview_result(agent_id) result = self._get_interview_result(agent_id)
self.send_response(command_id, "completed", result=result) self.send_response(command_id, "completed", result=result)
print(f" Interview完成: agent_id={agent_id}") print(f" Interview完成: agent_id={agent_id}")
return True return True
@ -246,17 +240,15 @@ class IPCHandler:
return False return False
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
""" """Handle a batch interview command.
处理批量采访命令
Args: Args:
interviews: [{"agent_id": int, "prompt": str}, ...] interviews: [{"agent_id": int, "prompt": str}, ...]
""" """
try: try:
# 构建动作字典
actions = {} actions = {}
agent_prompts = {} # 记录每个agent的prompt agent_prompts = {} # Track which prompt was sent to each agent so results can be paired back.
for interview in interviews: for interview in interviews:
agent_id = interview.get("agent_id") agent_id = interview.get("agent_id")
prompt = interview.get("prompt", "") prompt = interview.get("prompt", "")
@ -274,11 +266,9 @@ class IPCHandler:
if not actions: if not actions:
self.send_response(command_id, "failed", error="没有有效的Agent") self.send_response(command_id, "failed", error="没有有效的Agent")
return False return False
# 执行批量Interview
await self.env.step(actions) await self.env.step(actions)
# 获取所有结果
results = {} results = {}
for agent_id in agent_prompts.keys(): for agent_id in agent_prompts.keys():
result = self._get_interview_result(agent_id) result = self._get_interview_result(agent_id)
@ -298,7 +288,7 @@ class IPCHandler:
return False return False
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
"""从数据库获取最新的Interview结果""" """Fetch the most recent interview result for an agent from the database."""
db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") db_path = os.path.join(self.simulation_dir, "reddit_simulation.db")
result = { result = {
@ -313,8 +303,8 @@ class IPCHandler:
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
# 查询最新的Interview记录 # Query the most recent interview row for this agent.
cursor.execute(""" cursor.execute("""
SELECT user_id, info, created_at SELECT user_id, info, created_at
FROM trace FROM trace
@ -341,11 +331,10 @@ class IPCHandler:
return result return result
async def process_commands(self) -> bool: async def process_commands(self) -> bool:
""" """Process all pending IPC commands.
处理所有待处理命令
Returns: Returns:
True 表示继续运行False 表示应该退出 True to keep running, False if the loop should exit.
""" """
command = self.poll_command() command = self.poll_command()
if not command: if not command:
@ -383,9 +372,9 @@ class IPCHandler:
class RedditSimulationRunner: class RedditSimulationRunner:
"""Reddit模拟运行器""" """Reddit simulation runner."""
# Reddit可用动作不包含INTERVIEWINTERVIEW只能通过ManualAction手动触发 # Available Reddit actions (INTERVIEW is excluded because it can only be triggered via ManualAction).
AVAILABLE_ACTIONS = [ AVAILABLE_ACTIONS = [
ActionType.LIKE_POST, ActionType.LIKE_POST,
ActionType.DISLIKE_POST, ActionType.DISLIKE_POST,
@ -403,12 +392,11 @@ class RedditSimulationRunner:
] ]
def __init__(self, config_path: str, wait_for_commands: bool = True): def __init__(self, config_path: str, wait_for_commands: bool = True):
""" """Initialize the simulation runner.
初始化模拟运行器
Args: Args:
config_path: 配置文件路径 (simulation_config.json) config_path: Path to the configuration file (simulation_config.json).
wait_for_commands: 模拟完成后是否等待命令默认True wait_for_commands: Whether to wait for commands after the simulation finishes (default True).
""" """
self.config_path = config_path self.config_path = config_path
self.config = self._load_config() self.config = self._load_config()
@ -419,37 +407,36 @@ class RedditSimulationRunner:
self.ipc_handler = None self.ipc_handler = None
def _load_config(self) -> Dict[str, Any]: def _load_config(self) -> Dict[str, Any]:
"""加载配置文件""" """Load the configuration file."""
with open(self.config_path, 'r', encoding='utf-8') as f: with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
def _get_profile_path(self) -> str: def _get_profile_path(self) -> str:
"""获取Profile文件路径""" """Return the path to the agent profiles file."""
return os.path.join(self.simulation_dir, "reddit_profiles.json") return os.path.join(self.simulation_dir, "reddit_profiles.json")
def _get_db_path(self) -> str: def _get_db_path(self) -> str:
"""获取数据库路径""" """Return the path to the simulation database."""
return os.path.join(self.simulation_dir, "reddit_simulation.db") return os.path.join(self.simulation_dir, "reddit_simulation.db")
def _create_model(self): def _create_model(self):
"""Create the LLM model.
Configuration is sourced from the project-root ``.env`` file (highest priority):
- LLM_API_KEY: API key.
- LLM_BASE_URL: API base URL.
- LLM_MODEL_NAME: Model name.
""" """
创建LLM模型 # Prefer values from .env over the per-simulation config.
统一使用项目根目录 .env 文件中的配置优先级最高
- LLM_API_KEY: API密钥
- LLM_BASE_URL: API基础URL
- LLM_MODEL_NAME: 模型名称
"""
# 优先从 .env 读取配置
llm_api_key = os.environ.get("LLM_API_KEY", "") llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "") llm_model = os.environ.get("LLM_MODEL_NAME", "")
# 如果 .env 中没有,则使用 config 作为备用 # Fall back to the simulation config file if .env did not specify a model.
if not llm_model: if not llm_model:
llm_model = self.config.get("llm_model", "gpt-4o-mini") llm_model = self.config.get("llm_model", "gpt-4o-mini")
# 设置 camel-ai 所需的环境变量 # Export the env vars camel-ai expects.
if llm_api_key: if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key os.environ["OPENAI_API_KEY"] = llm_api_key
@ -472,9 +459,7 @@ class RedditSimulationRunner:
current_hour: int, current_hour: int,
round_num: int round_num: int
) -> List: ) -> List:
""" """Decide which agents are active for the current round, based on time of day and config."""
根据时间和配置决定本轮激活哪些Agent
"""
time_config = self.config.get("time_config", {}) time_config = self.config.get("time_config", {})
agent_configs = self.config.get("agent_configs", []) agent_configs = self.config.get("agent_configs", [])
@ -521,10 +506,10 @@ class RedditSimulationRunner:
return active_agents return active_agents
async def run(self, max_rounds: int = None): async def run(self, max_rounds: int = None):
"""运行Reddit模拟 """Run the Reddit simulation.
Args: Args:
max_rounds: 最大模拟轮数可选用于截断过长的模拟 max_rounds: Optional cap on the number of simulation rounds (used to truncate overly long runs).
""" """
print("=" * 60) print("=" * 60)
print("OASIS Reddit模拟") print("OASIS Reddit模拟")
@ -538,7 +523,7 @@ class RedditSimulationRunner:
minutes_per_round = time_config.get("minutes_per_round", 30) minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断 # Truncate if a max_rounds cap was supplied.
if max_rounds is not None and max_rounds > 0: if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds) total_rounds = min(total_rounds, max_rounds)
@ -578,17 +563,16 @@ class RedditSimulationRunner:
agent_graph=self.agent_graph, agent_graph=self.agent_graph,
platform=oasis.DefaultPlatformType.REDDIT, platform=oasis.DefaultPlatformType.REDDIT,
database_path=db_path, database_path=db_path,
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 semaphore=30, # Cap concurrent LLM requests to avoid overloading the API.
) )
await self.env.reset() await self.env.reset()
print("环境初始化完成\n") print("环境初始化完成\n")
# 初始化IPC处理器
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
self.ipc_handler.update_status("running") self.ipc_handler.update_status("running")
# 执行初始事件 # Apply the configured initial events (seed posts) before starting the main loop.
event_config = self.config.get("event_config", {}) event_config = self.config.get("event_config", {})
initial_posts = event_config.get("initial_posts", []) initial_posts = event_config.get("initial_posts", [])
@ -619,7 +603,7 @@ class RedditSimulationRunner:
await self.env.step(initial_actions) await self.env.step(initial_actions)
print(f" 已发布 {len(initial_actions)} 条初始帖子") print(f" 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环 # Main simulation loop.
print("\n开始模拟循环...") print("\n开始模拟循环...")
start_time = datetime.now() start_time = datetime.now()
@ -655,7 +639,7 @@ class RedditSimulationRunner:
print(f" - 总耗时: {total_elapsed:.1f}") print(f" - 总耗时: {total_elapsed:.1f}")
print(f" - 数据库: {db_path}") print(f" - 数据库: {db_path}")
# 是否进入等待命令模式 # Optionally enter command-wait mode.
if self.wait_for_commands: if self.wait_for_commands:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("进入等待命令模式 - 环境保持运行") print("进入等待命令模式 - 环境保持运行")
@ -664,7 +648,7 @@ class RedditSimulationRunner:
self.ipc_handler.update_status("alive") self.ipc_handler.update_status("alive")
# 等待命令循环(使用全局 _shutdown_event # Command-wait loop driven by the global _shutdown_event.
try: try:
while not _shutdown_event.is_set(): while not _shutdown_event.is_set():
should_continue = await self.ipc_handler.process_commands() should_continue = await self.ipc_handler.process_commands()
@ -672,7 +656,7 @@ class RedditSimulationRunner:
break break
try: try:
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
break # 收到退出信号 break # Shutdown signal received.
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
@ -683,8 +667,7 @@ class RedditSimulationRunner:
print(f"\n命令处理出错: {e}") print(f"\n命令处理出错: {e}")
print("\n关闭环境...") print("\n关闭环境...")
# 关闭环境
self.ipc_handler.update_status("stopped") self.ipc_handler.update_status("stopped")
await self.env.close() await self.env.close()
@ -715,7 +698,7 @@ async def main():
args = parser.parse_args() args = parser.parse_args()
# 在 main 函数开始时创建 shutdown 事件 # Create the shutdown event lazily here so it is bound to the running asyncio loop.
global _shutdown_event global _shutdown_event
_shutdown_event = asyncio.Event() _shutdown_event = asyncio.Event()
@ -723,7 +706,7 @@ async def main():
print(f"错误: 配置文件不存在: {args.config}") print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1) sys.exit(1)
# 初始化日志配置(使用固定文件名,清理旧日志) # Initialize log config with fixed filenames; old logs are cleared inside setup_oasis_logging.
simulation_dir = os.path.dirname(args.config) or "." simulation_dir = os.path.dirname(args.config) or "."
setup_oasis_logging(os.path.join(simulation_dir, "log")) setup_oasis_logging(os.path.join(simulation_dir, "log"))
@ -735,9 +718,9 @@ async def main():
def setup_signal_handlers(): def setup_signal_handlers():
""" """Install signal handlers so SIGTERM/SIGINT trigger a graceful exit.
设置信号处理器确保收到 SIGTERM/SIGINT 时能够正确退出
让程序有机会正常清理资源关闭数据库环境等 This gives the program a chance to clean up resources (close the database, the OASIS environment, etc.).
""" """
def signal_handler(signum, frame): def signal_handler(signum, frame):
global _cleanup_done global _cleanup_done
@ -748,7 +731,7 @@ def setup_signal_handlers():
if _shutdown_event: if _shutdown_event:
_shutdown_event.set() _shutdown_event.set()
else: else:
# 重复收到信号才强制退出 # Force exit only on a repeat signal so the user can still hard-kill if cleanup hangs.
print("强制退出...") print("强制退出...")
sys.exit(1) sys.exit(1)

View File

@ -1,16 +1,18 @@
""" """
OASIS Twitter模拟预设脚本 OASIS Twitter simulation preset script.
此脚本读取配置文件中的参数来执行模拟实现全程自动化
功能特性: This script reads parameters from a config file to run a fully automated simulation.
- 完成模拟后不立即关闭环境进入等待命令模式
- 支持通过IPC接收Interview命令
- 支持单个Agent采访和批量采访
- 支持远程关闭环境命令
使用方式: Features:
- Does not close the environment immediately when the simulation finishes; enters
command-wait mode instead.
- Receives Interview commands over IPC.
- Supports both single-agent and batch interviews.
- Supports a remote close-environment command.
Usage:
python run_twitter_simulation.py --config /path/to/simulation_config.json python run_twitter_simulation.py --config /path/to/simulation_config.json
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done
""" """
import argparse import argparse
@ -25,18 +27,17 @@ import sqlite3
from datetime import datetime from datetime import datetime
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
# 全局变量:用于信号处理 # Globals used by the signal handler.
_shutdown_event = None _shutdown_event = None
_cleanup_done = False _cleanup_done = False
# 添加项目路径
_scripts_dir = os.path.dirname(os.path.abspath(__file__)) _scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir) sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir) sys.path.insert(0, _backend_dir)
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) # Load the project-root .env (it carries LLM_API_KEY and friends).
from dotenv import load_dotenv from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env') _env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file): if os.path.exists(_env_file):
@ -51,7 +52,7 @@ import re
class UnicodeFormatter(logging.Formatter): class UnicodeFormatter(logging.Formatter):
"""自定义格式化器,将 Unicode 转义序列转换为可读字符""" """Custom formatter that turns Unicode escape sequences into readable characters."""
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
@ -68,24 +69,23 @@ class UnicodeFormatter(logging.Formatter):
class MaxTokensWarningFilter(logging.Filter): class MaxTokensWarningFilter(logging.Filter):
"""过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens让模型自行决定""" """Suppress camel-ai's max_tokens warning — we intentionally leave it unset and let the model decide."""
def filter(self, record): def filter(self, record):
# 过滤掉包含 max_tokens 警告的日志
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
return False return False
return True return True
# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 # Install the filter at import time so it is active before any camel code runs.
logging.getLogger().addFilter(MaxTokensWarningFilter()) logging.getLogger().addFilter(MaxTokensWarningFilter())
def setup_oasis_logging(log_dir: str): def setup_oasis_logging(log_dir: str):
"""配置 OASIS 的日志,使用固定名称的日志文件""" """Configure OASIS logging with fixed log filenames."""
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
# 清理旧的日志文件 # Wipe stale log files from previous runs.
for f in os.listdir(log_dir): for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f) old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'): if os.path.isfile(old_log) and f.endswith('.log'):
@ -131,21 +131,21 @@ except ImportError as e:
sys.exit(1) sys.exit(1)
# IPC相关常量 # IPC-related constants.
IPC_COMMANDS_DIR = "ipc_commands" IPC_COMMANDS_DIR = "ipc_commands"
IPC_RESPONSES_DIR = "ipc_responses" IPC_RESPONSES_DIR = "ipc_responses"
ENV_STATUS_FILE = "env_status.json" ENV_STATUS_FILE = "env_status.json"
class CommandType: class CommandType:
"""命令类型常量""" """Command type constants."""
INTERVIEW = "interview" INTERVIEW = "interview"
BATCH_INTERVIEW = "batch_interview" BATCH_INTERVIEW = "batch_interview"
CLOSE_ENV = "close_env" CLOSE_ENV = "close_env"
class IPCHandler: class IPCHandler:
"""IPC命令处理器""" """Handles IPC commands directed at the running simulation."""
def __init__(self, simulation_dir: str, env, agent_graph): def __init__(self, simulation_dir: str, env, agent_graph):
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
self.env = env self.env = env
@ -154,13 +154,12 @@ class IPCHandler:
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
self._running = True self._running = True
# 确保目录存在
os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True)
def update_status(self, status: str): def update_status(self, status: str):
"""更新环境状态""" """Write the current environment status to the status file."""
with open(self.status_file, 'w', encoding='utf-8') as f: with open(self.status_file, 'w', encoding='utf-8') as f:
json.dump({ json.dump({
"status": status, "status": status,
@ -168,11 +167,11 @@ class IPCHandler:
}, f, ensure_ascii=False, indent=2) }, f, ensure_ascii=False, indent=2)
def poll_command(self) -> Optional[Dict[str, Any]]: def poll_command(self) -> Optional[Dict[str, Any]]:
"""轮询获取待处理命令""" """Poll for the next pending command."""
if not os.path.exists(self.commands_dir): if not os.path.exists(self.commands_dir):
return None return None
# 获取命令文件(按时间排序) # Collect command files ordered by mtime.
command_files = [] command_files = []
for filename in os.listdir(self.commands_dir): for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'): if filename.endswith('.json'):
@ -191,7 +190,7 @@ class IPCHandler:
return None return None
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
"""发送响应""" """Send a response for a processed command."""
response = { response = {
"command_id": command_id, "command_id": command_id,
"status": status, "status": status,
@ -203,8 +202,8 @@ class IPCHandler:
response_file = os.path.join(self.responses_dir, f"{command_id}.json") response_file = os.path.join(self.responses_dir, f"{command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f: with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response, f, ensure_ascii=False, indent=2) json.dump(response, f, ensure_ascii=False, indent=2)
# 删除命令文件 # Remove the command file once a response has been written.
command_file = os.path.join(self.commands_dir, f"{command_id}.json") command_file = os.path.join(self.commands_dir, f"{command_id}.json")
try: try:
os.remove(command_file) os.remove(command_file)
@ -212,27 +211,23 @@ class IPCHandler:
pass pass
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
""" """Handle a single-agent interview command.
处理单个Agent采访命令
Returns: Returns:
True 表示成功False 表示失败 True on success, False on failure.
""" """
try: try:
# 获取Agent
agent = self.agent_graph.get_agent(agent_id) agent = self.agent_graph.get_agent(agent_id)
# 创建Interview动作
interview_action = ManualAction( interview_action = ManualAction(
action_type=ActionType.INTERVIEW, action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt} action_args={"prompt": prompt}
) )
# 执行Interview
actions = {agent: interview_action} actions = {agent: interview_action}
await self.env.step(actions) await self.env.step(actions)
# 从数据库获取结果 # Pull the resulting transcript from the simulation database.
result = self._get_interview_result(agent_id) result = self._get_interview_result(agent_id)
self.send_response(command_id, "completed", result=result) self.send_response(command_id, "completed", result=result)
@ -246,17 +241,15 @@ class IPCHandler:
return False return False
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
""" """Handle a batch interview command.
处理批量采访命令
Args: Args:
interviews: [{"agent_id": int, "prompt": str}, ...] interviews: [{"agent_id": int, "prompt": str}, ...]
""" """
try: try:
# 构建动作字典
actions = {} actions = {}
agent_prompts = {} # 记录每个agent的prompt agent_prompts = {} # Track the prompt issued to each agent for later result lookup.
for interview in interviews: for interview in interviews:
agent_id = interview.get("agent_id") agent_id = interview.get("agent_id")
prompt = interview.get("prompt", "") prompt = interview.get("prompt", "")
@ -274,11 +267,10 @@ class IPCHandler:
if not actions: if not actions:
self.send_response(command_id, "failed", error="没有有效的Agent") self.send_response(command_id, "failed", error="没有有效的Agent")
return False return False
# 执行批量Interview
await self.env.step(actions) await self.env.step(actions)
# 获取所有结果 # Collect the per-agent interview results.
results = {} results = {}
for agent_id in agent_prompts.keys(): for agent_id in agent_prompts.keys():
result = self._get_interview_result(agent_id) result = self._get_interview_result(agent_id)
@ -298,7 +290,7 @@ class IPCHandler:
return False return False
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
"""从数据库获取最新的Interview结果""" """Fetch the most recent interview result for an agent from the database."""
db_path = os.path.join(self.simulation_dir, "twitter_simulation.db") db_path = os.path.join(self.simulation_dir, "twitter_simulation.db")
result = { result = {
@ -313,8 +305,8 @@ class IPCHandler:
try: try:
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
cursor = conn.cursor() cursor = conn.cursor()
# 查询最新的Interview记录 # Pull the most recent INTERVIEW trace row for this agent.
cursor.execute(""" cursor.execute("""
SELECT user_id, info, created_at SELECT user_id, info, created_at
FROM trace FROM trace
@ -341,11 +333,10 @@ class IPCHandler:
return result return result
async def process_commands(self) -> bool: async def process_commands(self) -> bool:
""" """Process pending commands.
处理所有待处理命令
Returns: Returns:
True 表示继续运行False 表示应该退出 True if the run loop should continue, False if it should exit.
""" """
command = self.poll_command() command = self.poll_command()
if not command: if not command:
@ -383,9 +374,9 @@ class IPCHandler:
class TwitterSimulationRunner: class TwitterSimulationRunner:
"""Twitter模拟运行器""" """Drives a single Twitter simulation run."""
# Twitter可用动作不包含INTERVIEWINTERVIEW只能通过ManualAction手动触发 # Available Twitter actions. INTERVIEW is intentionally excluded — it can only be triggered via ManualAction.
AVAILABLE_ACTIONS = [ AVAILABLE_ACTIONS = [
ActionType.CREATE_POST, ActionType.CREATE_POST,
ActionType.LIKE_POST, ActionType.LIKE_POST,
@ -396,12 +387,11 @@ class TwitterSimulationRunner:
] ]
def __init__(self, config_path: str, wait_for_commands: bool = True): def __init__(self, config_path: str, wait_for_commands: bool = True):
""" """Initialize the simulation runner.
初始化模拟运行器
Args: Args:
config_path: 配置文件路径 (simulation_config.json) config_path: Path to the config file (simulation_config.json).
wait_for_commands: 模拟完成后是否等待命令默认True wait_for_commands: Whether to wait for IPC commands after the simulation completes (default True).
""" """
self.config_path = config_path self.config_path = config_path
self.config = self._load_config() self.config = self._load_config()
@ -412,37 +402,36 @@ class TwitterSimulationRunner:
self.ipc_handler = None self.ipc_handler = None
def _load_config(self) -> Dict[str, Any]: def _load_config(self) -> Dict[str, Any]:
"""加载配置文件""" """Load the simulation config file."""
with open(self.config_path, 'r', encoding='utf-8') as f: with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
def _get_profile_path(self) -> str: def _get_profile_path(self) -> str:
"""获取Profile文件路径OASIS Twitter使用CSV格式""" """Return the agent profile path (OASIS Twitter expects CSV)."""
return os.path.join(self.simulation_dir, "twitter_profiles.csv") return os.path.join(self.simulation_dir, "twitter_profiles.csv")
def _get_db_path(self) -> str: def _get_db_path(self) -> str:
"""获取数据库路径""" """Return the simulation SQLite database path."""
return os.path.join(self.simulation_dir, "twitter_simulation.db") return os.path.join(self.simulation_dir, "twitter_simulation.db")
def _create_model(self): def _create_model(self):
"""Create the LLM model.
Uses the project-root .env file (highest precedence):
- LLM_API_KEY: API key
- LLM_BASE_URL: API base URL
- LLM_MODEL_NAME: model name
""" """
创建LLM模型 # Prefer values from .env.
统一使用项目根目录 .env 文件中的配置优先级最高
- LLM_API_KEY: API密钥
- LLM_BASE_URL: API基础URL
- LLM_MODEL_NAME: 模型名称
"""
# 优先从 .env 读取配置
llm_api_key = os.environ.get("LLM_API_KEY", "") llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "") llm_model = os.environ.get("LLM_MODEL_NAME", "")
# 如果 .env 中没有,则使用 config 作为备用 # Fall back to the simulation config if .env did not provide a model name.
if not llm_model: if not llm_model:
llm_model = self.config.get("llm_model", "gpt-4o-mini") llm_model = self.config.get("llm_model", "gpt-4o-mini")
# 设置 camel-ai 所需的环境变量 # camel-ai reads OPENAI_API_KEY from the environment.
if llm_api_key: if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key os.environ["OPENAI_API_KEY"] = llm_api_key
@ -465,25 +454,24 @@ class TwitterSimulationRunner:
current_hour: int, current_hour: int,
round_num: int round_num: int
) -> List: ) -> List:
""" """Decide which agents activate this round, based on time and config.
根据时间和配置决定本轮激活哪些Agent
Args: Args:
env: OASIS环境 env: The OASIS environment.
current_hour: 当前模拟小时0-23 current_hour: Current simulated hour (0-23).
round_num: 当前轮数 round_num: Current round number.
Returns: Returns:
激活的Agent列表 The list of agents activated this round.
""" """
time_config = self.config.get("time_config", {}) time_config = self.config.get("time_config", {})
agent_configs = self.config.get("agent_configs", []) agent_configs = self.config.get("agent_configs", [])
# 基础激活数量 # Base activation count per round.
base_min = time_config.get("agents_per_hour_min", 5) base_min = time_config.get("agents_per_hour_min", 5)
base_max = time_config.get("agents_per_hour_max", 20) base_max = time_config.get("agents_per_hour_max", 20)
# 根据时段调整 # Adjust by time-of-day (peak vs. off-peak hours).
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
@ -495,29 +483,27 @@ class TwitterSimulationRunner:
multiplier = 1.0 multiplier = 1.0
target_count = int(random.uniform(base_min, base_max) * multiplier) target_count = int(random.uniform(base_min, base_max) * multiplier)
# 根据每个Agent的配置计算激活概率 # Compute activation probability for each configured agent.
candidates = [] candidates = []
for cfg in agent_configs: for cfg in agent_configs:
agent_id = cfg.get("agent_id", 0) agent_id = cfg.get("agent_id", 0)
active_hours = cfg.get("active_hours", list(range(8, 23))) active_hours = cfg.get("active_hours", list(range(8, 23)))
activity_level = cfg.get("activity_level", 0.5) activity_level = cfg.get("activity_level", 0.5)
# 检查是否在活跃时间
if current_hour not in active_hours: if current_hour not in active_hours:
continue continue
# 根据活跃度计算概率
if random.random() < activity_level: if random.random() < activity_level:
candidates.append(agent_id) candidates.append(agent_id)
# 随机选择 # Pick a random subset of the eligible candidates.
selected_ids = random.sample( selected_ids = random.sample(
candidates, candidates,
min(target_count, len(candidates)) min(target_count, len(candidates))
) if candidates else [] ) if candidates else []
# 转换为Agent对象 # Resolve IDs to Agent objects.
active_agents = [] active_agents = []
for agent_id in selected_ids: for agent_id in selected_ids:
try: try:
@ -529,10 +515,10 @@ class TwitterSimulationRunner:
return active_agents return active_agents
async def run(self, max_rounds: int = None): async def run(self, max_rounds: int = None):
"""运行Twitter模拟 """Run the Twitter simulation.
Args: Args:
max_rounds: 最大模拟轮数可选用于截断过长的模拟 max_rounds: Optional cap on the number of rounds, used to truncate overly long simulations.
""" """
print("=" * 60) print("=" * 60)
print("OASIS Twitter模拟") print("OASIS Twitter模拟")
@ -540,16 +526,14 @@ class TwitterSimulationRunner:
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
print("=" * 60) print("=" * 60)
# 加载时间配置
time_config = self.config.get("time_config", {}) time_config = self.config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72) total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30) minutes_per_round = time_config.get("minutes_per_round", 30)
# 计算总轮数
total_rounds = (total_hours * 60) // minutes_per_round total_rounds = (total_hours * 60) // minutes_per_round
# 如果指定了最大轮数,则截断 # Truncate to max_rounds when one was supplied.
if max_rounds is not None and max_rounds > 0: if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds) total_rounds = min(total_rounds, max_rounds)
@ -563,12 +547,11 @@ class TwitterSimulationRunner:
if max_rounds: if max_rounds:
print(f" - 最大轮数限制: {max_rounds}") print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
# 创建模型
print("\n初始化LLM模型...") print("\n初始化LLM模型...")
model = self._create_model() model = self._create_model()
# 加载Agent图 # Load the agent graph from the profile CSV.
print("加载Agent Profile...") print("加载Agent Profile...")
profile_path = self._get_profile_path() profile_path = self._get_profile_path()
if not os.path.exists(profile_path): if not os.path.exists(profile_path):
@ -581,29 +564,27 @@ class TwitterSimulationRunner:
available_actions=self.AVAILABLE_ACTIONS, available_actions=self.AVAILABLE_ACTIONS,
) )
# 数据库路径 # Reset the simulation database for a clean run.
db_path = self._get_db_path() db_path = self._get_db_path()
if os.path.exists(db_path): if os.path.exists(db_path):
os.remove(db_path) os.remove(db_path)
print(f"已删除旧数据库: {db_path}") print(f"已删除旧数据库: {db_path}")
# 创建环境
print("创建OASIS环境...") print("创建OASIS环境...")
self.env = oasis.make( self.env = oasis.make(
agent_graph=self.agent_graph, agent_graph=self.agent_graph,
platform=oasis.DefaultPlatformType.TWITTER, platform=oasis.DefaultPlatformType.TWITTER,
database_path=db_path, database_path=db_path,
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 semaphore=30, # Cap concurrent LLM requests to avoid API overload.
) )
await self.env.reset() await self.env.reset()
print("环境初始化完成\n") print("环境初始化完成\n")
# 初始化IPC处理器
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
self.ipc_handler.update_status("running") self.ipc_handler.update_status("running")
# 执行初始事件 # Run the initial seeded events (kickoff posts).
event_config = self.config.get("event_config", {}) event_config = self.config.get("event_config", {})
initial_posts = event_config.get("initial_posts", []) initial_posts = event_config.get("initial_posts", [])
@ -625,35 +606,32 @@ class TwitterSimulationRunner:
if initial_actions: if initial_actions:
await self.env.step(initial_actions) await self.env.step(initial_actions)
print(f" 已发布 {len(initial_actions)} 条初始帖子") print(f" 已发布 {len(initial_actions)} 条初始帖子")
# 主模拟循环 # Main simulation loop.
print("\n开始模拟循环...") print("\n开始模拟循环...")
start_time = datetime.now() start_time = datetime.now()
for round_num in range(total_rounds): for round_num in range(total_rounds):
# 计算当前模拟时间 # Map round number to simulated wall-clock time.
simulated_minutes = round_num * minutes_per_round simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24 simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1 simulated_day = simulated_minutes // (60 * 24) + 1
# 获取本轮激活的Agent
active_agents = self._get_active_agents_for_round( active_agents = self._get_active_agents_for_round(
self.env, simulated_hour, round_num self.env, simulated_hour, round_num
) )
if not active_agents: if not active_agents:
continue continue
# 构建动作
actions = { actions = {
agent: LLMAction() agent: LLMAction()
for _, agent in active_agents for _, agent in active_agents
} }
# 执行动作
await self.env.step(actions) await self.env.step(actions)
# 打印进度 # Periodic progress log.
if (round_num + 1) % 10 == 0 or round_num == 0: if (round_num + 1) % 10 == 0 or round_num == 0:
elapsed = (datetime.now() - start_time).total_seconds() elapsed = (datetime.now() - start_time).total_seconds()
progress = (round_num + 1) / total_rounds * 100 progress = (round_num + 1) / total_rounds * 100
@ -667,7 +645,7 @@ class TwitterSimulationRunner:
print(f" - 总耗时: {total_elapsed:.1f}") print(f" - 总耗时: {total_elapsed:.1f}")
print(f" - 数据库: {db_path}") print(f" - 数据库: {db_path}")
# 是否进入等待命令模式 # Optionally enter command-wait mode.
if self.wait_for_commands: if self.wait_for_commands:
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("进入等待命令模式 - 环境保持运行") print("进入等待命令模式 - 环境保持运行")
@ -675,8 +653,8 @@ class TwitterSimulationRunner:
print("=" * 60) print("=" * 60)
self.ipc_handler.update_status("alive") self.ipc_handler.update_status("alive")
# 等待命令循环(使用全局 _shutdown_event # Command-wait loop, driven by the global _shutdown_event.
try: try:
while not _shutdown_event.is_set(): while not _shutdown_event.is_set():
should_continue = await self.ipc_handler.process_commands() should_continue = await self.ipc_handler.process_commands()
@ -684,7 +662,7 @@ class TwitterSimulationRunner:
break break
try: try:
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
break # 收到退出信号 break # Shutdown signal received.
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
@ -695,8 +673,7 @@ class TwitterSimulationRunner:
print(f"\n命令处理出错: {e}") print(f"\n命令处理出错: {e}")
print("\n关闭环境...") print("\n关闭环境...")
# 关闭环境
self.ipc_handler.update_status("stopped") self.ipc_handler.update_status("stopped")
await self.env.close() await self.env.close()
@ -726,16 +703,16 @@ async def main():
) )
args = parser.parse_args() args = parser.parse_args()
# 在 main 函数开始时创建 shutdown 事件 # Create the shutdown event inside the running event loop.
global _shutdown_event global _shutdown_event
_shutdown_event = asyncio.Event() _shutdown_event = asyncio.Event()
if not os.path.exists(args.config): if not os.path.exists(args.config):
print(f"错误: 配置文件不存在: {args.config}") print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1) sys.exit(1)
# 初始化日志配置(使用固定文件名,清理旧日志) # Initialize logging with fixed filenames; old logs are wiped.
simulation_dir = os.path.dirname(args.config) or "." simulation_dir = os.path.dirname(args.config) or "."
setup_oasis_logging(os.path.join(simulation_dir, "log")) setup_oasis_logging(os.path.join(simulation_dir, "log"))
@ -747,9 +724,11 @@ async def main():
def setup_signal_handlers(): def setup_signal_handlers():
""" """Install signal handlers so SIGTERM/SIGINT trigger an orderly shutdown.
设置信号处理器确保收到 SIGTERM/SIGINT 时能够正确退出
让程序有机会正常清理资源关闭数据库环境等 The handler gives the program a chance to clean up resources properly
(closing the database, the OASIS environment, etc.) on the first signal,
and only force-exits on a repeated signal.
""" """
def signal_handler(signum, frame): def signal_handler(signum, frame):
global _cleanup_done global _cleanup_done
@ -760,7 +739,7 @@ def setup_signal_handlers():
if _shutdown_event: if _shutdown_event:
_shutdown_event.set() _shutdown_event.set()
else: else:
# 重复收到信号才强制退出 # Force exit only on a repeat signal.
print("强制退出...") print("强制退出...")
sys.exit(1) sys.exit(1)

View File

@ -1,8 +1,8 @@
""" """Profile-format generation tests for OASIS compatibility.
测试Profile格式生成是否符合OASIS要求
验证 Verifies that:
1. Twitter Profile生成CSV格式 1. Twitter profiles serialize to CSV format.
2. Reddit Profile生成JSON详细格式 2. Reddit profiles serialize to detailed JSON format.
""" """
import os import os
@ -11,19 +11,19 @@ import json
import csv import csv
import tempfile import tempfile
# 添加项目路径 # Add the project root to sys.path so the ``app`` package resolves.
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile from app.services.oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
def test_profile_formats(): def test_profile_formats():
"""测试Profile格式""" """Exercise both profile-format outputs end-to-end."""
print("=" * 60) print("=" * 60)
print("OASIS Profile格式测试") print("OASIS Profile格式测试")
print("=" * 60) print("=" * 60)
# 创建测试Profile数据 # Build a small set of test profiles.
test_profiles = [ test_profiles = [
OasisAgentProfile( OasisAgentProfile(
user_id=0, user_id=0,
@ -62,18 +62,18 @@ def test_profile_formats():
] ]
generator = OasisProfileGenerator.__new__(OasisProfileGenerator) generator = OasisProfileGenerator.__new__(OasisProfileGenerator)
# 使用临时目录 # Use a temp directory for the test fixtures.
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
twitter_path = os.path.join(temp_dir, "twitter_profiles.csv") twitter_path = os.path.join(temp_dir, "twitter_profiles.csv")
reddit_path = os.path.join(temp_dir, "reddit_profiles.json") reddit_path = os.path.join(temp_dir, "reddit_profiles.json")
# 测试Twitter CSV格式 # Twitter CSV format.
print("\n1. 测试Twitter Profile (CSV格式)") print("\n1. 测试Twitter Profile (CSV格式)")
print("-" * 40) print("-" * 40)
generator._save_twitter_csv(test_profiles, twitter_path) generator._save_twitter_csv(test_profiles, twitter_path)
# 读取并验证CSV # Read back and verify the CSV.
with open(twitter_path, 'r', encoding='utf-8') as f: with open(twitter_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
rows = list(reader) rows = list(reader)
@ -85,8 +85,8 @@ def test_profile_formats():
for key, value in rows[0].items(): for key, value in rows[0].items():
print(f" {key}: {value}") print(f" {key}: {value}")
# 验证必需字段 # Verify the required fields are present.
required_twitter_fields = ['user_id', 'user_name', 'name', 'bio', required_twitter_fields = ['user_id', 'user_name', 'name', 'bio',
'friend_count', 'follower_count', 'statuses_count', 'created_at'] 'friend_count', 'follower_count', 'statuses_count', 'created_at']
missing = set(required_twitter_fields) - set(rows[0].keys()) missing = set(required_twitter_fields) - set(rows[0].keys())
if missing: if missing:
@ -94,12 +94,12 @@ def test_profile_formats():
else: else:
print(f"\n [通过] 所有必需字段都存在") print(f"\n [通过] 所有必需字段都存在")
# 测试Reddit JSON格式 # Reddit JSON format.
print("\n2. 测试Reddit Profile (JSON详细格式)") print("\n2. 测试Reddit Profile (JSON详细格式)")
print("-" * 40) print("-" * 40)
generator._save_reddit_json(test_profiles, reddit_path) generator._save_reddit_json(test_profiles, reddit_path)
# 读取并验证JSON # Read back and verify the JSON.
with open(reddit_path, 'r', encoding='utf-8') as f: with open(reddit_path, 'r', encoding='utf-8') as f:
reddit_data = json.load(f) reddit_data = json.load(f)
@ -109,7 +109,7 @@ def test_profile_formats():
print(f"\n 示例数据 (第1条):") print(f"\n 示例数据 (第1条):")
print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4)) print(json.dumps(reddit_data[0], ensure_ascii=False, indent=4))
# 验证详细格式字段 # Verify the detailed Reddit format fields.
required_reddit_fields = ['realname', 'username', 'bio', 'persona'] required_reddit_fields = ['realname', 'username', 'bio', 'persona']
optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics'] optional_reddit_fields = ['age', 'gender', 'mbti', 'country', 'profession', 'interested_topics']
@ -128,7 +128,7 @@ def test_profile_formats():
def show_expected_formats(): def show_expected_formats():
"""显示OASIS期望的格式""" """Print the canonical OASIS-expected profile formats for reference."""
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("OASIS 期望的Profile格式参考") print("OASIS 期望的Profile格式参考")
print("=" * 60) print("=" * 60)