This commit is contained in:
Karesansui 2026-05-28 17:27:44 -04:00 committed by GitHub
commit fb0b10429f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 65 additions and 26 deletions

View File

@ -20,6 +20,29 @@ from ..models.project import ProjectManager
logger = get_logger('mirofish.api.simulation')
def _get_default_platform(simulation_id: str) -> str:
"""
根据模拟配置返回默认平台
读取 SimulationState 中的 enable_twitter / enable_reddit 设置
返回该模拟实际使用的平台而非硬编码 'reddit'
Args:
simulation_id: 模拟ID
Returns:
'twitter' 'reddit'
"""
try:
manager = SimulationManager()
state = manager._load_simulation_state(simulation_id)
if state:
return state.get_default_platform()
except Exception:
pass
return "reddit"
# Interview prompt 优化前缀
# 添加此前缀可以避免Agent调用工具直接用文本回复
INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:"
@ -996,7 +1019,7 @@ def get_simulation_profiles(simulation_id: str):
platform: 平台类型reddit/twitter默认reddit
"""
try:
platform = request.args.get('platform', 'reddit')
platform = request.args.get('platform') or _get_default_platform(simulation_id)
manager = SimulationManager()
profiles = manager.get_profiles(simulation_id, platform=platform)
@ -1058,7 +1081,7 @@ def get_simulation_profiles_realtime(simulation_id: str):
from datetime import datetime
try:
platform = request.args.get('platform', 'reddit')
platform = request.args.get('platform') or _get_default_platform(simulation_id)
# 获取模拟目录
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
@ -1997,7 +2020,7 @@ def get_simulation_posts(simulation_id: str):
返回帖子列表从SQLite数据库读取
"""
try:
platform = request.args.get('platform', 'reddit')
platform = request.args.get('platform') or _get_default_platform(simulation_id)
limit = request.args.get('limit', 50, type=int)
offset = request.args.get('offset', 0, type=int)
@ -2065,14 +2088,16 @@ def get_simulation_posts(simulation_id: str):
@simulation_bp.route('/<simulation_id>/comments', methods=['GET'])
def get_simulation_comments(simulation_id: str):
"""
获取模拟中的评论仅Reddit
获取模拟中的评论
Query参数
platform: 平台类型twitter/reddit根据模拟配置自动选择
post_id: 过滤帖子ID可选
limit: 返回数量
offset: 偏移量
"""
try:
platform = request.args.get('platform') or _get_default_platform(simulation_id)
post_id = request.args.get('post_id')
limit = request.args.get('limit', 50, type=int)
offset = request.args.get('offset', 0, type=int)
@ -2082,7 +2107,7 @@ def get_simulation_comments(simulation_id: str):
f'../../uploads/simulations/{simulation_id}'
)
db_path = os.path.join(sim_dir, "reddit_simulation.db")
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
if not os.path.exists(db_path):
return jsonify({

View File

@ -97,6 +97,15 @@ class SimulationState:
"error": self.error,
}
def get_default_platform(self) -> str:
"""根据启用状态返回默认平台"""
if self.enable_twitter and self.enable_reddit:
return "reddit" # 两者都启用时保持原默认
elif self.enable_twitter:
return "twitter"
else:
return "reddit"
def to_simple_dict(self) -> Dict[str, Any]:
"""简化状态字典API返回使用"""
return {
@ -478,12 +487,15 @@ class SimulationManager:
return simulations
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
def get_profiles(self, simulation_id: str, platform: str = None) -> List[Dict[str, Any]]:
"""获取模拟的Agent Profile"""
state = self._load_simulation_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
if platform is None:
platform = state.get_default_platform()
sim_dir = self._get_simulation_dir(simulation_id)
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")

View File

@ -35,19 +35,21 @@ export const getSimulation = (simulationId) => {
/**
* 获取模拟的 Agent Profiles
* @param {string} simulationId
* @param {string} platform - 'reddit' | 'twitter'
* @param {string} [platform] - 'reddit' | 'twitter'省略时由后端根据模拟配置自动选择
*/
export const getSimulationProfiles = (simulationId, platform = 'reddit') => {
return service.get(`/api/simulation/${simulationId}/profiles`, { params: { platform } })
export const getSimulationProfiles = (simulationId, platform) => {
const params = platform ? { platform } : {}
return service.get(`/api/simulation/${simulationId}/profiles`, { params })
}
/**
* 实时获取生成中的 Agent Profiles
* @param {string} simulationId
* @param {string} platform - 'reddit' | 'twitter'
* @param {string} [platform] - 'reddit' | 'twitter'省略时由后端根据模拟配置自动选择
*/
export const getSimulationProfilesRealtime = (simulationId, platform = 'reddit') => {
return service.get(`/api/simulation/${simulationId}/profiles/realtime`, { params: { platform } })
export const getSimulationProfilesRealtime = (simulationId, platform) => {
const params = platform ? { platform } : {}
return service.get(`/api/simulation/${simulationId}/profiles/realtime`, { params })
}
/**
@ -111,14 +113,14 @@ export const getRunStatusDetail = (simulationId) => {
/**
* 获取模拟中的帖子
* @param {string} simulationId
* @param {string} platform - 'reddit' | 'twitter'
* @param {string} [platform] - 'reddit' | 'twitter'省略时由后端根据模拟配置自动选择
* @param {number} limit - 返回数量
* @param {number} offset - 偏移量
*/
export const getSimulationPosts = (simulationId, platform = 'reddit', limit = 50, offset = 0) => {
return service.get(`/api/simulation/${simulationId}/posts`, {
params: { platform, limit, offset }
})
export const getSimulationPosts = (simulationId, platform, limit = 50, offset = 0) => {
const params = { limit, offset }
if (platform) params.platform = platform
return service.get(`/api/simulation/${simulationId}/posts`, { params })
}
/**

View File

@ -912,7 +912,7 @@ const fetchProfilesRealtime = async () => {
if (!props.simulationId) return
try {
const res = await getSimulationProfilesRealtime(props.simulationId, 'reddit')
const res = await getSimulationProfilesRealtime(props.simulationId)
if (res.success && res.data) {
const prevCount = profiles.value.length

View File

@ -918,7 +918,7 @@ const loadProfiles = async () => {
if (!props.simulationId) return
try {
const res = await getSimulationProfilesRealtime(props.simulationId, 'reddit')
const res = await getSimulationProfilesRealtime(props.simulationId)
if (res.success && res.data) {
profiles.value = res.data.profiles || []
addLog(t('log.loadedProfiles', { count: profiles.value.length }))