Merge 18ba979c8d into 96096ea0ff
This commit is contained in:
commit
fb0b10429f
|
|
@ -20,6 +20,29 @@ from ..models.project import ProjectManager
|
|||
logger = get_logger('mirofish.api.simulation')
|
||||
|
||||
|
||||
def _get_default_platform(simulation_id: str) -> str:
|
||||
"""
|
||||
根据模拟配置返回默认平台
|
||||
|
||||
读取 SimulationState 中的 enable_twitter / enable_reddit 设置,
|
||||
返回该模拟实际使用的平台,而非硬编码 'reddit'。
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
'twitter' 或 'reddit'
|
||||
"""
|
||||
try:
|
||||
manager = SimulationManager()
|
||||
state = manager._load_simulation_state(simulation_id)
|
||||
if state:
|
||||
return state.get_default_platform()
|
||||
except Exception:
|
||||
pass
|
||||
return "reddit"
|
||||
|
||||
|
||||
# Interview prompt 优化前缀
|
||||
# 添加此前缀可以避免Agent调用工具,直接用文本回复
|
||||
INTERVIEW_PROMPT_PREFIX = "结合你的人设、所有的过往记忆与行动,不调用任何工具直接用文本回复我:"
|
||||
|
|
@ -996,8 +1019,8 @@ def get_simulation_profiles(simulation_id: str):
|
|||
platform: 平台类型(reddit/twitter,默认reddit)
|
||||
"""
|
||||
try:
|
||||
platform = request.args.get('platform', 'reddit')
|
||||
|
||||
platform = request.args.get('platform') or _get_default_platform(simulation_id)
|
||||
|
||||
manager = SimulationManager()
|
||||
profiles = manager.get_profiles(simulation_id, platform=platform)
|
||||
|
||||
|
|
@ -1058,8 +1081,8 @@ def get_simulation_profiles_realtime(simulation_id: str):
|
|||
from datetime import datetime
|
||||
|
||||
try:
|
||||
platform = request.args.get('platform', 'reddit')
|
||||
|
||||
platform = request.args.get('platform') or _get_default_platform(simulation_id)
|
||||
|
||||
# 获取模拟目录
|
||||
sim_dir = os.path.join(Config.OASIS_SIMULATION_DATA_DIR, simulation_id)
|
||||
|
||||
|
|
@ -1997,15 +2020,15 @@ def get_simulation_posts(simulation_id: str):
|
|||
返回帖子列表(从SQLite数据库读取)
|
||||
"""
|
||||
try:
|
||||
platform = request.args.get('platform', 'reddit')
|
||||
platform = request.args.get('platform') or _get_default_platform(simulation_id)
|
||||
limit = request.args.get('limit', 50, type=int)
|
||||
offset = request.args.get('offset', 0, type=int)
|
||||
|
||||
|
||||
sim_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
f'../../uploads/simulations/{simulation_id}'
|
||||
)
|
||||
|
||||
|
||||
db_file = f"{platform}_simulation.db"
|
||||
db_path = os.path.join(sim_dir, db_file)
|
||||
|
||||
|
|
@ -2065,24 +2088,26 @@ def get_simulation_posts(simulation_id: str):
|
|||
@simulation_bp.route('/<simulation_id>/comments', methods=['GET'])
|
||||
def get_simulation_comments(simulation_id: str):
|
||||
"""
|
||||
获取模拟中的评论(仅Reddit)
|
||||
|
||||
获取模拟中的评论
|
||||
|
||||
Query参数:
|
||||
platform: 平台类型(twitter/reddit,根据模拟配置自动选择)
|
||||
post_id: 过滤帖子ID(可选)
|
||||
limit: 返回数量
|
||||
offset: 偏移量
|
||||
"""
|
||||
try:
|
||||
platform = request.args.get('platform') or _get_default_platform(simulation_id)
|
||||
post_id = request.args.get('post_id')
|
||||
limit = request.args.get('limit', 50, type=int)
|
||||
offset = request.args.get('offset', 0, type=int)
|
||||
|
||||
|
||||
sim_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
f'../../uploads/simulations/{simulation_id}'
|
||||
)
|
||||
|
||||
db_path = os.path.join(sim_dir, "reddit_simulation.db")
|
||||
db_path = os.path.join(sim_dir, f"{platform}_simulation.db")
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
return jsonify({
|
||||
|
|
|
|||
|
|
@ -97,6 +97,15 @@ class SimulationState:
|
|||
"error": self.error,
|
||||
}
|
||||
|
||||
def get_default_platform(self) -> str:
|
||||
"""根据启用状态返回默认平台"""
|
||||
if self.enable_twitter and self.enable_reddit:
|
||||
return "reddit" # 两者都启用时保持原默认
|
||||
elif self.enable_twitter:
|
||||
return "twitter"
|
||||
else:
|
||||
return "reddit"
|
||||
|
||||
def to_simple_dict(self) -> Dict[str, Any]:
|
||||
"""简化状态字典(API返回使用)"""
|
||||
return {
|
||||
|
|
@ -478,12 +487,15 @@ class SimulationManager:
|
|||
|
||||
return simulations
|
||||
|
||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||
def get_profiles(self, simulation_id: str, platform: str = None) -> List[Dict[str, Any]]:
|
||||
"""获取模拟的Agent Profile"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
|
||||
if platform is None:
|
||||
platform = state.get_default_platform()
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||
|
||||
|
|
|
|||
|
|
@ -35,19 +35,21 @@ export const getSimulation = (simulationId) => {
|
|||
/**
|
||||
* 获取模拟的 Agent Profiles
|
||||
* @param {string} simulationId
|
||||
* @param {string} platform - 'reddit' | 'twitter'
|
||||
* @param {string} [platform] - 'reddit' | 'twitter'(省略时由后端根据模拟配置自动选择)
|
||||
*/
|
||||
export const getSimulationProfiles = (simulationId, platform = 'reddit') => {
|
||||
return service.get(`/api/simulation/${simulationId}/profiles`, { params: { platform } })
|
||||
export const getSimulationProfiles = (simulationId, platform) => {
|
||||
const params = platform ? { platform } : {}
|
||||
return service.get(`/api/simulation/${simulationId}/profiles`, { params })
|
||||
}
|
||||
|
||||
/**
|
||||
* 实时获取生成中的 Agent Profiles
|
||||
* @param {string} simulationId
|
||||
* @param {string} platform - 'reddit' | 'twitter'
|
||||
* @param {string} [platform] - 'reddit' | 'twitter'(省略时由后端根据模拟配置自动选择)
|
||||
*/
|
||||
export const getSimulationProfilesRealtime = (simulationId, platform = 'reddit') => {
|
||||
return service.get(`/api/simulation/${simulationId}/profiles/realtime`, { params: { platform } })
|
||||
export const getSimulationProfilesRealtime = (simulationId, platform) => {
|
||||
const params = platform ? { platform } : {}
|
||||
return service.get(`/api/simulation/${simulationId}/profiles/realtime`, { params })
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -111,14 +113,14 @@ export const getRunStatusDetail = (simulationId) => {
|
|||
/**
|
||||
* 获取模拟中的帖子
|
||||
* @param {string} simulationId
|
||||
* @param {string} platform - 'reddit' | 'twitter'
|
||||
* @param {string} [platform] - 'reddit' | 'twitter'(省略时由后端根据模拟配置自动选择)
|
||||
* @param {number} limit - 返回数量
|
||||
* @param {number} offset - 偏移量
|
||||
*/
|
||||
export const getSimulationPosts = (simulationId, platform = 'reddit', limit = 50, offset = 0) => {
|
||||
return service.get(`/api/simulation/${simulationId}/posts`, {
|
||||
params: { platform, limit, offset }
|
||||
})
|
||||
export const getSimulationPosts = (simulationId, platform, limit = 50, offset = 0) => {
|
||||
const params = { limit, offset }
|
||||
if (platform) params.platform = platform
|
||||
return service.get(`/api/simulation/${simulationId}/posts`, { params })
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -912,7 +912,7 @@ const fetchProfilesRealtime = async () => {
|
|||
if (!props.simulationId) return
|
||||
|
||||
try {
|
||||
const res = await getSimulationProfilesRealtime(props.simulationId, 'reddit')
|
||||
const res = await getSimulationProfilesRealtime(props.simulationId)
|
||||
|
||||
if (res.success && res.data) {
|
||||
const prevCount = profiles.value.length
|
||||
|
|
|
|||
|
|
@ -918,7 +918,7 @@ const loadProfiles = async () => {
|
|||
if (!props.simulationId) return
|
||||
|
||||
try {
|
||||
const res = await getSimulationProfilesRealtime(props.simulationId, 'reddit')
|
||||
const res = await getSimulationProfilesRealtime(props.simulationId)
|
||||
if (res.success && res.data) {
|
||||
profiles.value = res.data.profiles || []
|
||||
addLog(t('log.loadedProfiles', { count: profiles.value.length }))
|
||||
|
|
|
|||
Loading…
Reference in New Issue