181 lines
5.3 KiB
Python
181 lines
5.3 KiB
Python
"""
|
||
OASIS 数据模型
|
||
|
||
包含 OasisAgentProfile 数据类和相关的辅助函数。
|
||
这些组件不依赖 Flask,可以独立使用。
|
||
"""
|
||
|
||
import random
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Dict, Any, List, Optional
|
||
|
||
|
||
# Entity type classification constants
|
||
INDIVIDUAL_ENTITY_TYPES = [
|
||
"student", "alumni", "professor", "person", "publicfigure",
|
||
"expert", "faculty", "official", "journalist", "activist"
|
||
]
|
||
|
||
GROUP_ENTITY_TYPES = [
|
||
"university", "governmentagency", "organization", "ngo",
|
||
"mediaoutlet", "company", "institution", "group", "community"
|
||
]
|
||
|
||
|
||
@dataclass
|
||
class OasisAgentProfile:
|
||
"""OASIS Agent Profile数据结构"""
|
||
# 通用字段
|
||
user_id: int
|
||
user_name: str
|
||
name: str
|
||
bio: str
|
||
persona: str
|
||
|
||
# 可选字段 - Reddit风格
|
||
karma: int = 1000
|
||
|
||
# 可选字段 - Twitter风格
|
||
friend_count: int = 100
|
||
follower_count: int = 150
|
||
statuses_count: int = 500
|
||
|
||
# 额外人设信息
|
||
age: Optional[int] = None
|
||
gender: Optional[str] = None
|
||
mbti: Optional[str] = None
|
||
country: Optional[str] = None
|
||
profession: Optional[str] = None
|
||
interested_topics: List[str] = field(default_factory=list)
|
||
|
||
# 来源实体信息
|
||
source_entity_uuid: Optional[str] = None
|
||
source_entity_type: Optional[str] = None
|
||
|
||
created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
|
||
|
||
def to_reddit_format(self) -> Dict[str, Any]:
|
||
"""转换为Reddit平台格式"""
|
||
profile = {
|
||
"user_id": self.user_id,
|
||
"username": self.user_name, # OASIS 库要求字段名为 username(无下划线)
|
||
"name": self.name,
|
||
"bio": self.bio,
|
||
"persona": self.persona,
|
||
"karma": self.karma,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
# 添加额外人设信息(如果有)
|
||
if self.age:
|
||
profile["age"] = self.age
|
||
if self.gender:
|
||
profile["gender"] = self.gender
|
||
if self.mbti:
|
||
profile["mbti"] = self.mbti
|
||
if self.country:
|
||
profile["country"] = self.country
|
||
if self.profession:
|
||
profile["profession"] = self.profession
|
||
if self.interested_topics:
|
||
profile["interested_topics"] = self.interested_topics
|
||
|
||
return profile
|
||
|
||
def to_twitter_format(self) -> Dict[str, Any]:
|
||
"""转换为Twitter平台格式"""
|
||
profile = {
|
||
"user_id": self.user_id,
|
||
"username": self.user_name, # OASIS 库要求字段名为 username(无下划线)
|
||
"name": self.name,
|
||
"bio": self.bio,
|
||
"persona": self.persona,
|
||
"friend_count": self.friend_count,
|
||
"follower_count": self.follower_count,
|
||
"statuses_count": self.statuses_count,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
# 添加额外人设信息
|
||
if self.age:
|
||
profile["age"] = self.age
|
||
if self.gender:
|
||
profile["gender"] = self.gender
|
||
if self.mbti:
|
||
profile["mbti"] = self.mbti
|
||
if self.country:
|
||
profile["country"] = self.country
|
||
if self.profession:
|
||
profile["profession"] = self.profession
|
||
if self.interested_topics:
|
||
profile["interested_topics"] = self.interested_topics
|
||
|
||
return profile
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为完整字典格式"""
|
||
return {
|
||
"user_id": self.user_id,
|
||
"user_name": self.user_name,
|
||
"name": self.name,
|
||
"bio": self.bio,
|
||
"persona": self.persona,
|
||
"karma": self.karma,
|
||
"friend_count": self.friend_count,
|
||
"follower_count": self.follower_count,
|
||
"statuses_count": self.statuses_count,
|
||
"age": self.age,
|
||
"gender": self.gender,
|
||
"mbti": self.mbti,
|
||
"country": self.country,
|
||
"profession": self.profession,
|
||
"interested_topics": self.interested_topics,
|
||
"source_entity_uuid": self.source_entity_uuid,
|
||
"source_entity_type": self.source_entity_type,
|
||
"created_at": self.created_at,
|
||
}
|
||
|
||
|
||
def _generate_username(name: str) -> str:
|
||
"""生成用户名"""
|
||
username = name.lower().replace(" ", "_")
|
||
username = ''.join(c for c in username if c.isalnum() or c == '_')
|
||
suffix = random.randint(100, 999)
|
||
return f"{username}_{suffix}"
|
||
|
||
|
||
def _normalize_gender(gender: Optional[str]) -> str:
|
||
"""
|
||
标准化gender字段为OASIS要求的英文格式
|
||
|
||
OASIS要求: male, female, other
|
||
"""
|
||
if not gender:
|
||
return "other"
|
||
|
||
gender_lower = gender.lower().strip()
|
||
|
||
# 中文映射
|
||
gender_map = {
|
||
"男": "male",
|
||
"女": "female",
|
||
"机构": "other",
|
||
"其他": "other",
|
||
# 英文已有
|
||
"male": "male",
|
||
"female": "female",
|
||
"other": "other",
|
||
}
|
||
|
||
return gender_map.get(gender_lower, "other")
|
||
|
||
|
||
def _is_individual_entity(entity_type: str) -> bool:
|
||
"""判断是否是个人类型实体"""
|
||
return entity_type.lower() in INDIVIDUAL_ENTITY_TYPES
|
||
|
||
|
||
def _is_group_entity(entity_type: str) -> bool:
|
||
"""判断是否是群体/机构类型实体"""
|
||
return entity_type.lower() in GROUP_ENTITY_TYPES |