diff --git a/.env.example b/.env.example index 78a3b72c..39c74d8c 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,14 @@ +# ===== Security ===== +# Set API_KEY to require X-Api-Key header on every request (recommended for production). +# Leave unset for local development (all endpoints will be publicly accessible). +# API_KEY=your_strong_random_key_here + +# ===== Flask ===== +# Set to False in production (default is False) +# FLASK_DEBUG=False +# Provide a strong SECRET_KEY for session signing (auto-generated if unset) +# SECRET_KEY=your_strong_random_key_here + # LLM API配置(支持 OpenAI SDK 格式的任意 LLM API) # 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/ # 注意消耗较大,可先进行小于40轮的模拟尝试 diff --git a/backend/app/__init__.py b/backend/app/__init__.py index aba624bb..e220ffe9 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -14,6 +14,7 @@ from flask_cors import CORS from .config import Config from .utils.logger import setup_logger, get_logger +from .utils.auth import check_api_key def create_app(config_class=Config): @@ -48,6 +49,15 @@ def create_app(config_class=Config): if should_log_startup: logger.info("已注册模拟进程清理函数") + # API key authentication + app.before_request(check_api_key) + + if not Config.API_KEY: + logger.warning( + "API_KEY is not set — all endpoints are publicly accessible. " + "Set API_KEY in your .env file to enable authentication." + ) + # 请求日志中间件 @app.before_request def log_request(): diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 759ff48b..7c8c198e 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -10,6 +10,7 @@ from flask import request, jsonify from . import graph_bp from ..config import Config +from ..utils.id_validator import validate_safe_id from ..services.ontology_generator import OntologyGenerator from ..services.graph_builder import GraphBuilderService from ..services.text_processor import TextProcessor @@ -38,6 +39,7 @@ def get_project(project_id: str): """ 获取项目详情 """ + validate_safe_id(project_id, "project_id") project = ProjectManager.get_project(project_id) if not project: @@ -72,6 +74,7 @@ def delete_project(project_id: str): """ 删除项目 """ + validate_safe_id(project_id, "project_id") success = ProjectManager.delete_project(project_id) if not success: @@ -91,6 +94,7 @@ def reset_project(project_id: str): """ 重置项目状态(用于重新构建图谱) """ + validate_safe_id(project_id, "project_id") project = ProjectManager.get_project(project_id) if not project: @@ -182,12 +186,13 @@ def generate_ontology(): all_text = "" for file in uploaded_files: - if file and file.filename and allowed_file(file.filename): + safe_filename = os.path.basename(file.filename) if file.filename else '' + if file and safe_filename and allowed_file(safe_filename): # 保存文件到项目目录 file_info = ProjectManager.save_file_to_project( - project.project_id, - file, - file.filename + project.project_id, + file, + safe_filename ) project.files.append({ "filename": file_info["original_filename"], @@ -250,8 +255,7 @@ def generate_ontology(): except Exception as e: return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -304,7 +308,9 @@ def build_graph(): "success": False, "error": t('api.requireProjectId') }), 400 - + + validate_safe_id(project_id, "project_id") + # 获取项目 project = ProjectManager.get_project(project_id) if not project: @@ -524,8 +530,7 @@ def build_graph(): except Exception as e: return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -572,6 +577,7 @@ def get_graph_data(graph_id: str): 获取图谱数据(节点和边) """ try: + validate_safe_id(graph_id, "graph_id") if not Config.ZEP_API_KEY: return jsonify({ "success": False, @@ -589,8 +595,7 @@ def get_graph_data(graph_id: str): except Exception as e: return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -600,6 +605,7 @@ def delete_graph(graph_id: str): 删除Zep图谱 """ try: + validate_safe_id(graph_id, "graph_id") if not Config.ZEP_API_KEY: return jsonify({ "success": False, @@ -617,6 +623,5 @@ def delete_graph(graph_id: str): except Exception as e: return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 diff --git a/backend/app/api/report.py b/backend/app/api/report.py index d7f2a4d0..ab2483b2 100644 --- a/backend/app/api/report.py +++ b/backend/app/api/report.py @@ -10,6 +10,7 @@ from flask import request, jsonify, send_file from . import report_bp from ..config import Config +from ..utils.id_validator import validate_safe_id from ..services.report_agent import ReportAgent, ReportManager, ReportStatus from ..services.simulation_manager import SimulationManager from ..models.project import ProjectManager @@ -195,8 +196,7 @@ def generate_report(): logger.error(f"启动报告生成任务失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -294,6 +294,7 @@ def get_report(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") report = ReportManager.get_report(report_id) if not report: @@ -311,8 +312,7 @@ def get_report(report_id: str): logger.error(f"获取报告失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -331,6 +331,7 @@ def get_report_by_simulation(simulation_id: str): } """ try: + validate_safe_id(simulation_id, "simulation_id") report = ReportManager.get_report_by_simulation(simulation_id) if not report: @@ -350,8 +351,7 @@ def get_report_by_simulation(simulation_id: str): logger.error(f"获取报告失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -390,8 +390,7 @@ def list_reports(): logger.error(f"列出报告失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -403,6 +402,7 @@ def download_report(report_id: str): 返回Markdown文件 """ try: + validate_safe_id(report_id, "report_id") report = ReportManager.get_report(report_id) if not report: @@ -436,8 +436,7 @@ def download_report(report_id: str): logger.error(f"下载报告失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -445,6 +444,7 @@ def download_report(report_id: str): def delete_report(report_id: str): """删除报告""" try: + validate_safe_id(report_id, "report_id") success = ReportManager.delete_report(report_id) if not success: @@ -462,8 +462,7 @@ def delete_report(report_id: str): logger.error(f"删除报告失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -559,8 +558,7 @@ def chat_with_report_agent(): logger.error(f"对话失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -585,6 +583,7 @@ def get_report_progress(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") progress = ReportManager.get_progress(report_id) if not progress: @@ -602,8 +601,7 @@ def get_report_progress(report_id: str): logger.error(f"获取报告进度失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -633,6 +631,7 @@ def get_report_sections(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") sections = ReportManager.get_generated_sections(report_id) # 获取报告状态 @@ -653,8 +652,7 @@ def get_report_sections(report_id: str): logger.error(f"获取章节列表失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -673,6 +671,7 @@ def get_single_section(report_id: str, section_index: int): } """ try: + validate_safe_id(report_id, "report_id") section_path = ReportManager._get_section_path(report_id, section_index) if not os.path.exists(section_path): @@ -697,8 +696,7 @@ def get_single_section(report_id: str, section_index: int): logger.error(f"获取章节内容失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -724,6 +722,7 @@ def check_report_status(simulation_id: str): } """ try: + validate_safe_id(simulation_id, "simulation_id") report = ReportManager.get_report_by_simulation(simulation_id) has_report = report is not None @@ -748,8 +747,7 @@ def check_report_status(simulation_id: str): logger.error(f"检查报告状态失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -796,6 +794,7 @@ def get_agent_log(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") from_line = request.args.get('from_line', 0, type=int) log_data = ReportManager.get_agent_log(report_id, from_line=from_line) @@ -809,8 +808,7 @@ def get_agent_log(report_id: str): logger.error(f"获取Agent日志失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -829,6 +827,7 @@ def stream_agent_log(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") logs = ReportManager.get_agent_log_stream(report_id) return jsonify({ @@ -843,8 +842,7 @@ def stream_agent_log(report_id: str): logger.error(f"获取Agent日志失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -878,6 +876,7 @@ def get_console_log(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") from_line = request.args.get('from_line', 0, type=int) log_data = ReportManager.get_console_log(report_id, from_line=from_line) @@ -891,8 +890,7 @@ def get_console_log(report_id: str): logger.error(f"获取控制台日志失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -911,6 +909,7 @@ def stream_console_log(report_id: str): } """ try: + validate_safe_id(report_id, "report_id") logs = ReportManager.get_console_log_stream(report_id) return jsonify({ @@ -925,8 +924,7 @@ def stream_console_log(report_id: str): logger.error(f"获取控制台日志失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -975,8 +973,7 @@ def search_graph_tool(): logger.error(f"图谱搜索失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1015,6 +1012,5 @@ def get_graph_statistics_tool(): logger.error(f"获取图谱统计失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a8e1e3f..643ab046 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -9,6 +9,7 @@ from flask import request, jsonify, send_file from . import simulation_bp from ..config import Config +from ..utils.id_validator import validate_safe_id, safe_join from ..services.zep_entity_reader import ZepEntityReader from ..services.oasis_profile_generator import OasisProfileGenerator from ..services.simulation_manager import SimulationManager, SimulationStatus @@ -85,8 +86,7 @@ def get_graph_entities(graph_id: str): logger.error(f"获取图谱实体失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -118,8 +118,7 @@ def get_entity_detail(graph_id: str, entity_uuid: str): logger.error(f"获取实体详情失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -155,8 +154,7 @@ def get_entities_by_type(graph_id: str, entity_type: str): logger.error(f"获取实体失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -232,8 +230,7 @@ def create_simulation(): logger.error(f"创建模拟失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -256,8 +253,9 @@ def _check_simulation_prepared(simulation_id: str) -> tuple: import os from ..config import Config - simulation_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) - + validate_safe_id(simulation_id, "simulation_id") + simulation_dir = safe_join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) + # 检查目录是否存在 if not os.path.exists(simulation_dir): return False, {"reason": "模拟目录不存在"} @@ -634,8 +632,7 @@ def prepare_simulation(): logger.error(f"启动准备任务失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -780,8 +777,7 @@ def get_simulation(simulation_id: str): logger.error(f"获取模拟状态失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -809,8 +805,7 @@ def list_simulations(): logger.error(f"列出模拟失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -982,8 +977,7 @@ def get_simulation_history(): logger.error(f"获取历史模拟失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1020,8 +1014,7 @@ def get_simulation_profiles(simulation_id: str): logger.error(f"获取Profile失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1059,16 +1052,17 @@ def get_simulation_profiles_realtime(simulation_id: str): try: platform = request.args.get('platform', 'reddit') - + + validate_safe_id(simulation_id, "simulation_id") # 获取模拟目录 - sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) - + sim_dir = safe_join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) + if not os.path.exists(sim_dir): return jsonify({ "success": False, "error": t('api.simulationNotFound', id=simulation_id) }), 404 - + # 确定文件路径 if platform == "reddit": profiles_file = os.path.join(sim_dir, "reddit_profiles.json") @@ -1130,8 +1124,7 @@ def get_simulation_profiles_realtime(simulation_id: str): logger.error(f"实时获取Profile失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1163,15 +1156,16 @@ def get_simulation_config_realtime(simulation_id: str): from datetime import datetime try: + validate_safe_id(simulation_id, "simulation_id") # 获取模拟目录 - sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) - + sim_dir = safe_join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id) + if not os.path.exists(sim_dir): return jsonify({ "success": False, "error": t('api.simulationNotFound', id=simulation_id) }), 404 - + # 配置文件路径 config_file = os.path.join(sim_dir, "simulation_config.json") @@ -1250,8 +1244,7 @@ def get_simulation_config_realtime(simulation_id: str): logger.error(f"实时获取Config失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1286,8 +1279,7 @@ def get_simulation_config(simulation_id: str): logger.error(f"获取配置失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1315,8 +1307,7 @@ def download_simulation_config(simulation_id: str): logger.error(f"下载配置失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1367,8 +1358,7 @@ def download_simulation_script(script_name: str): logger.error(f"下载脚本失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1441,8 +1431,7 @@ def generate_profiles(): logger.error(f"生成Profile失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1636,8 +1625,7 @@ def start_simulation(): logger.error(f"启动模拟失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1695,8 +1683,7 @@ def stop_simulation(): logger.error(f"停止模拟失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1755,8 +1742,7 @@ def get_run_status(simulation_id: str): logger.error(f"获取运行状态失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1856,8 +1842,7 @@ def get_run_status_detail(simulation_id: str): logger.error(f"获取详细状态失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1910,8 +1895,7 @@ def get_simulation_actions(simulation_id: str): logger.error(f"获取动作历史失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1950,8 +1934,7 @@ def get_simulation_timeline(simulation_id: str): logger.error(f"获取时间线失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -1977,8 +1960,7 @@ def get_agent_stats(simulation_id: str): logger.error(f"获取Agent统计失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2057,8 +2039,7 @@ def get_simulation_posts(simulation_id: str): logger.error(f"获取帖子失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2132,8 +2113,7 @@ def get_simulation_comments(simulation_id: str): logger.error(f"获取评论失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2263,8 +2243,7 @@ def interview_agent(): logger.error(f"Interview失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2401,8 +2380,7 @@ def interview_agents_batch(): logger.error(f"批量Interview失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2504,8 +2482,7 @@ def interview_all_agents(): logger.error(f"全局Interview失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2576,8 +2553,7 @@ def get_interview_history(): logger.error(f"获取Interview历史失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2641,8 +2617,7 @@ def get_env_status(): logger.error(f"获取环境状态失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 @@ -2711,6 +2686,5 @@ def close_simulation_env(): logger.error(f"关闭环境失败: {str(e)}") return jsonify({ "success": False, - "error": str(e), - "traceback": traceback.format_exc() + "error": str(e) }), 500 diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50..ca996ff5 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -21,8 +21,11 @@ class Config: """Flask配置类""" # Flask配置 - SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') - DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' + SECRET_KEY = os.environ.get('SECRET_KEY') or os.urandom(32).hex() + DEBUG = os.environ.get('FLASK_DEBUG', 'False').lower() == 'true' + + # API authentication — set API_KEY in .env to require X-Api-Key on every request + API_KEY = os.environ.get('API_KEY') # JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式) JSON_AS_ASCII = False diff --git a/backend/app/utils/auth.py b/backend/app/utils/auth.py new file mode 100644 index 00000000..97ff8f2b --- /dev/null +++ b/backend/app/utils/auth.py @@ -0,0 +1,21 @@ +from flask import request, jsonify +from ..config import Config +from .logger import get_logger + +logger = get_logger('mirofish.auth') + +_SKIP_PATHS = {'/health'} + + +def check_api_key(): + """Flask before_request handler — enforces X-Api-Key when API_KEY is configured.""" + if not Config.API_KEY: + return # API key auth is disabled; log a warning once at startup instead + + if request.path in _SKIP_PATHS: + return + + provided = request.headers.get('X-Api-Key', '') + if not provided or provided != Config.API_KEY: + logger.warning(f"Unauthorized request to {request.method} {request.path}") + return jsonify({"success": False, "error": "Unauthorized"}), 401 diff --git a/backend/app/utils/id_validator.py b/backend/app/utils/id_validator.py new file mode 100644 index 00000000..c378f444 --- /dev/null +++ b/backend/app/utils/id_validator.py @@ -0,0 +1,20 @@ +import os +import re + +_SAFE_ID_RE = re.compile(r'^[a-zA-Z0-9_-]{1,128}$') + + +def validate_safe_id(value: str, name: str = "id") -> str: + """Raise ValueError if value contains path-traversal characters.""" + if not value or not _SAFE_ID_RE.match(value): + raise ValueError(f"Invalid {name}: must contain only alphanumeric characters, underscores, or hyphens") + return value + + +def safe_join(base_dir: str, *parts: str) -> str: + """Join paths and verify the result stays inside base_dir.""" + base = os.path.realpath(base_dir) + joined = os.path.realpath(os.path.join(base_dir, *parts)) + if joined != base and not joined.startswith(base + os.sep): + raise ValueError(f"Path traversal detected: resolved path is outside {base_dir!r}") + return joined