""" OASIS Agent Profile generator. Converts entities from the Zep graph into the Agent Profile format required by the OASIS simulation platform. Improvements: 1. Call Zep retrieval to further enrich node information. 2. Optimized prompts that produce highly detailed personas. 3. Distinguishes individual entities from abstract group entities. """ import json import random import time from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from datetime import datetime from openai import OpenAI from .graphiti_adapter import GraphitiAdapter from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_language_instruction, get_locale, set_locale, t from .zep_entity_reader import EntityNode, ZepEntityReader logger = get_logger('mirofish.oasis_profile') @dataclass class OasisAgentProfile: """OASIS Agent Profile data structure.""" # Common fields user_id: int user_name: str name: str bio: str persona: str # Optional fields - Reddit style karma: int = 1000 # Optional fields - Twitter style friend_count: int = 100 follower_count: int = 150 statuses_count: int = 500 # Additional persona information age: Optional[int] = None gender: Optional[str] = None mbti: Optional[str] = None country: Optional[str] = None profession: Optional[str] = None interested_topics: List[str] = field(default_factory=list) # Source entity information source_entity_uuid: Optional[str] = None source_entity_type: Optional[str] = None created_at: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) def to_reddit_format(self) -> Dict[str, Any]: """Convert to Reddit platform format.""" profile = { "user_id": self.user_id, "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) "name": self.name, "bio": self.bio, "persona": self.persona, "karma": self.karma, "created_at": self.created_at, } if self.age: profile["age"] = self.age if self.gender: profile["gender"] = self.gender if self.mbti: profile["mbti"] = self.mbti if self.country: profile["country"] = self.country if self.profession: profile["profession"] = self.profession if self.interested_topics: profile["interested_topics"] = self.interested_topics return profile def to_twitter_format(self) -> Dict[str, Any]: """Convert to Twitter platform format.""" profile = { "user_id": self.user_id, "username": self.user_name, # OASIS 库要求字段名为 username(无下划线) "name": self.name, "bio": self.bio, "persona": self.persona, "friend_count": self.friend_count, "follower_count": self.follower_count, "statuses_count": self.statuses_count, "created_at": self.created_at, } if self.age: profile["age"] = self.age if self.gender: profile["gender"] = self.gender if self.mbti: profile["mbti"] = self.mbti if self.country: profile["country"] = self.country if self.profession: profile["profession"] = self.profession if self.interested_topics: profile["interested_topics"] = self.interested_topics return profile def to_dict(self) -> Dict[str, Any]: """Convert to a full dictionary representation.""" return { "user_id": self.user_id, "user_name": self.user_name, "name": self.name, "bio": self.bio, "persona": self.persona, "karma": self.karma, "friend_count": self.friend_count, "follower_count": self.follower_count, "statuses_count": self.statuses_count, "age": self.age, "gender": self.gender, "mbti": self.mbti, "country": self.country, "profession": self.profession, "interested_topics": self.interested_topics, "source_entity_uuid": self.source_entity_uuid, "source_entity_type": self.source_entity_type, "created_at": self.created_at, } class OasisProfileGenerator: """OASIS Profile generator. Converts entities from the Zep graph into the Agent Profiles required by the OASIS simulation. Highlights: 1. Uses Zep graph retrieval to gather richer context. 2. Produces highly detailed personas (basic info, career history, traits, social-media behavior, etc.). 3. Distinguishes individual entities from group/institution entities. """ MBTI_TYPES = [ "INTJ", "INTP", "ENTJ", "ENTP", "INFJ", "INFP", "ENFJ", "ENFP", "ISTJ", "ISFJ", "ESTJ", "ESFJ", "ISTP", "ISFP", "ESTP", "ESFP" ] COUNTRIES = [ "China", "US", "UK", "Japan", "Germany", "France", "Canada", "Australia", "Brazil", "India", "South Korea" ] # Individual entity types — generate a concrete persona for each. INDIVIDUAL_ENTITY_TYPES = [ "student", "alumni", "professor", "person", "publicfigure", "expert", "faculty", "official", "journalist", "activist" ] # Group / institution entity types — generate a representative-account persona. GROUP_ENTITY_TYPES = [ "university", "governmentagency", "organization", "ngo", "mediaoutlet", "company", "institution", "group", "community" ] def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model_name: Optional[str] = None, zep_api_key: Optional[str] = None, graph_id: Optional[str] = None ): self.api_key = api_key or Config.LLM_API_KEY self.base_url = base_url or Config.LLM_BASE_URL self.model_name = model_name or Config.LLM_MODEL_NAME if not self.api_key: raise ValueError("LLM_API_KEY 未配置") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url ) self.zep_client = GraphitiAdapter() self.graph_id = graph_id def generate_profile_from_entity( self, entity: EntityNode, user_id: int, use_llm: bool = True ) -> OasisAgentProfile: """Generate an OASIS Agent Profile from a Zep entity. Args: entity: The Zep entity node. user_id: The OASIS user id to assign. use_llm: Whether to use the LLM to generate a detailed persona. Returns: OasisAgentProfile """ entity_type = entity.get_entity_type() or "Entity" name = entity.name user_name = self._generate_username(name) context = self._build_entity_context(entity) if use_llm: profile_data = self._generate_profile_with_llm( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes, context=context ) else: profile_data = self._generate_profile_rule_based( entity_name=name, entity_type=entity_type, entity_summary=entity.summary, entity_attributes=entity.attributes ) return OasisAgentProfile( user_id=user_id, user_name=user_name, name=name, bio=profile_data.get("bio", f"{entity_type}: {name}"), persona=profile_data.get("persona", entity.summary or f"A {entity_type} named {name}."), karma=profile_data.get("karma", random.randint(500, 5000)), friend_count=profile_data.get("friend_count", random.randint(50, 500)), follower_count=profile_data.get("follower_count", random.randint(100, 1000)), statuses_count=profile_data.get("statuses_count", random.randint(100, 2000)), age=profile_data.get("age"), gender=profile_data.get("gender"), mbti=profile_data.get("mbti"), country=profile_data.get("country"), profession=profile_data.get("profession"), interested_topics=profile_data.get("interested_topics", []), source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) def _generate_username(self, name: str) -> str: """Generate a username from an entity name.""" # Strip special characters and lowercase the name. username = name.lower().replace(" ", "_") username = ''.join(c for c in username if c.isalnum() or c == '_') # Append a random numeric suffix to avoid collisions. suffix = random.randint(100, 999) return f"{username}_{suffix}" def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]: """Use Zep hybrid graph search to gather rich context for an entity. Zep does not expose a built-in hybrid search endpoint, so we search edges and nodes separately and merge the results. The two searches run in parallel for throughput. Args: entity: The entity node to search around. Returns: A dict with keys ``facts``, ``node_summaries`` and ``context``. """ import concurrent.futures if not self.zep_client: return {"facts": [], "node_summaries": [], "context": ""} entity_name = entity.name results = { "facts": [], "node_summaries": [], "context": "" } # A graph_id is required for any retrieval. if not self.graph_id: logger.debug(t("log.profile_generator.m001")) return results comprehensive_query = t('progress.zepSearchQuery', name=entity_name) def search_edges(): """Search edges (facts / relationships) with retries.""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: return self.zep_client.graph.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, scope="edges", ) except Exception as e: last_exception = e if attempt < max_retries - 1: logger.debug(t("log.profile_generator.m002", attempt=attempt + 1, str=str(e)[:80])) time.sleep(delay) delay *= 2 else: logger.debug(t("log.profile_generator.m003", max_retries=max_retries, e=e)) return None def search_nodes(): """Search nodes (entity summaries) with retries.""" max_retries = 3 last_exception = None delay = 2.0 for attempt in range(max_retries): try: return self.zep_client.graph.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, scope="nodes", ) except Exception as e: last_exception = e if attempt < max_retries - 1: logger.debug(t("log.profile_generator.m004", attempt=attempt + 1, str=str(e)[:80])) time.sleep(delay) delay *= 2 else: logger.debug(t("log.profile_generator.m005", max_retries=max_retries, e=e)) return None try: # Run edge and node searches in parallel. with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: edge_future = executor.submit(search_edges) node_future = executor.submit(search_nodes) edge_result = edge_future.result(timeout=30) node_result = node_future.result(timeout=30) # Process edge-search results. all_facts = set() if edge_result and hasattr(edge_result, 'edges') and edge_result.edges: for edge in edge_result.edges: if hasattr(edge, 'fact') and edge.fact: all_facts.add(edge.fact) results["facts"] = list(all_facts) # Process node-search results. all_summaries = set() if node_result and hasattr(node_result, 'nodes') and node_result.nodes: for node in node_result.nodes: if hasattr(node, 'summary') and node.summary: all_summaries.add(node.summary) if hasattr(node, 'name') and node.name and node.name != entity_name: all_summaries.add(f"相关实体: {node.name}") results["node_summaries"] = list(all_summaries) # Assemble the combined context block. context_parts = [] if results["facts"]: context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20])) if results["node_summaries"]: context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10])) results["context"] = "\n\n".join(context_parts) logger.info(t("log.profile_generator.m006", entity_name=entity_name, len=len(results['facts']), len_2=len(results['node_summaries']))) except concurrent.futures.TimeoutError: logger.warning(t("log.profile_generator.m007", entity_name=entity_name)) except Exception as e: logger.warning(t("log.profile_generator.m008", entity_name=entity_name, e=e)) return results def _build_entity_context(self, entity: EntityNode) -> str: """Assemble the full context block for an entity. Includes: 1. The entity's own edge information (facts). 2. Detailed information about related nodes. 3. Additional context retrieved from Zep hybrid search. """ context_parts = [] # 1. Entity attributes. if entity.attributes: attrs = [] for key, value in entity.attributes.items(): if value and str(value).strip(): attrs.append(f"- {key}: {value}") if attrs: context_parts.append("### 实体属性\n" + "\n".join(attrs)) # 2. Related edges (facts / relationships). existing_facts = set() if entity.related_edges: relationships = [] for edge in entity.related_edges: # No cap on count. fact = edge.get("fact", "") edge_name = edge.get("edge_name", "") direction = edge.get("direction", "") if fact: relationships.append(f"- {fact}") existing_facts.add(fact) elif edge_name: if direction == "outgoing": relationships.append(f"- {entity.name} --[{edge_name}]--> (相关实体)") else: relationships.append(f"- (相关实体) --[{edge_name}]--> {entity.name}") if relationships: context_parts.append("### 相关事实和关系\n" + "\n".join(relationships)) # 3. Detailed information for related nodes. if entity.related_nodes: related_info = [] for node in entity.related_nodes: # No cap on count. node_name = node.get("name", "") node_labels = node.get("labels", []) node_summary = node.get("summary", "") # Drop the default labels added by the graph store. custom_labels = [l for l in node_labels if l not in ["Entity", "Node"]] label_str = f" ({', '.join(custom_labels)})" if custom_labels else "" if node_summary: related_info.append(f"- **{node_name}**{label_str}: {node_summary}") else: related_info.append(f"- **{node_name}**{label_str}") if related_info: context_parts.append("### 关联实体信息\n" + "\n".join(related_info)) # 4. Augment with Zep hybrid retrieval. zep_results = self._search_zep_for_entity(entity) if zep_results.get("facts"): # Deduplicate against already-known facts. new_facts = [f for f in zep_results["facts"] if f not in existing_facts] if new_facts: context_parts.append("### Zep检索到的事实信息\n" + "\n".join(f"- {f}" for f in new_facts[:15])) if zep_results.get("node_summaries"): context_parts.append("### Zep检索到的相关节点\n" + "\n".join(f"- {s}" for s in zep_results["node_summaries"][:10])) return "\n\n".join(context_parts) def _is_individual_entity(self, entity_type: str) -> bool: """Return True if the entity type represents an individual.""" return entity_type.lower() in self.INDIVIDUAL_ENTITY_TYPES def _is_group_entity(self, entity_type: str) -> bool: """Return True if the entity type represents a group or institution.""" return entity_type.lower() in self.GROUP_ENTITY_TYPES def _generate_profile_with_llm( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> Dict[str, Any]: """Generate a highly detailed persona using the LLM. Branches on entity type: - Individual entities: produces a concrete persona for a person. - Group / institution entities: produces a representative-account persona. """ is_individual = self._is_individual_entity(entity_type) if is_individual: prompt = self._build_individual_persona_prompt( entity_name, entity_type, entity_summary, entity_attributes, context ) else: prompt = self._build_group_persona_prompt( entity_name, entity_type, entity_summary, entity_attributes, context ) # Retry generation up to max_attempts times. max_attempts = 3 last_error = None for attempt in range(max_attempts): try: response = self.client.chat.completions.create( model=self.model_name, messages=[ {"role": "system", "content": self._get_system_prompt(is_individual)}, {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, temperature=0.7 - (attempt * 0.1) # Lower the temperature on each retry. # No max_tokens cap so the LLM can produce a full persona. ) content = response.choices[0].message.content # Detect truncation (finish_reason other than 'stop'). finish_reason = response.choices[0].finish_reason if finish_reason == 'length': logger.warning(t("log.profile_generator.m009", attempt=attempt + 1)) content = self._fix_truncated_json(content) # Parse the JSON payload. try: result = json.loads(content) # Backfill required fields when missing. if "bio" not in result or not result["bio"]: result["bio"] = entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}" if "persona" not in result or not result["persona"]: result["persona"] = entity_summary or f"{entity_name}是一个{entity_type}。" return result except json.JSONDecodeError as je: logger.warning(t("log.profile_generator.m010", attempt=attempt + 1, str=str(je)[:80])) # Attempt to repair the JSON. result = self._try_fix_json(content, entity_name, entity_type, entity_summary) if result.get("_fixed"): del result["_fixed"] return result last_error = je except Exception as e: logger.warning(t("log.profile_generator.m011", attempt=attempt + 1, str=str(e)[:80])) last_error = e import time time.sleep(1 * (attempt + 1)) # Exponential backoff. logger.warning(t("log.profile_generator.m012", max_attempts=max_attempts, last_error=last_error)) return self._generate_profile_rule_based( entity_name, entity_type, entity_summary, entity_attributes ) def _fix_truncated_json(self, content: str) -> str: """Repair JSON output truncated by a max_tokens limit.""" import re # Trim whitespace before closing the structure. content = content.strip() # Count unbalanced brackets and braces. open_braces = content.count('{') - content.count('}') open_brackets = content.count('[') - content.count(']') # Heuristic: if the last char is not a quote, comma, or closing bracket, # the trailing string value was likely truncated mid-token. if content and content[-1] not in '",}]': # Close the dangling string. content += '"' # Close any open brackets and braces. content += ']' * open_brackets content += '}' * open_braces return content def _try_fix_json(self, content: str, entity_name: str, entity_type: str, entity_summary: str = "") -> Dict[str, Any]: """Best-effort repair of damaged JSON output.""" import re # 1. Repair truncation first. content = self._fix_truncated_json(content) # 2. Extract the JSON object span. json_match = re.search(r'\{[\s\S]*\}', content) if json_match: json_str = json_match.group() # 3. Fix newlines inside string values. def fix_string_newlines(match): s = match.group(0) # Replace literal newlines inside string values with spaces. s = s.replace('\n', ' ').replace('\r', ' ') # Collapse runs of whitespace. s = re.sub(r'\s+', ' ', s) return s # Match JSON string values. json_str = re.sub(r'"[^"\\]*(?:\\.[^"\\]*)*"', fix_string_newlines, json_str) # 4. Try to parse. try: result = json.loads(json_str) result["_fixed"] = True return result except json.JSONDecodeError as e: # 5. Fall back to a more aggressive repair pass. try: # Strip control characters. json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str) # Collapse all consecutive whitespace. json_str = re.sub(r'\s+', ' ', json_str) result = json.loads(json_str) result["_fixed"] = True return result except: pass # 6. Last resort: scrape partial fields out of the content. bio_match = re.search(r'"bio"\s*:\s*"([^"]*)"', content) persona_match = re.search(r'"persona"\s*:\s*"([^"]*)', content) # May be truncated. bio = bio_match.group(1) if bio_match else (entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}") persona = persona_match.group(1) if persona_match else (entity_summary or f"{entity_name}是一个{entity_type}。") # If we recovered something meaningful, mark the result as fixed. if bio_match or persona_match: logger.info(t("log.profile_generator.m013")) return { "bio": bio, "persona": persona, "_fixed": True } # 7. Total failure: return a minimal fallback structure. logger.warning(t("log.profile_generator.m014")) return { "bio": entity_summary[:200] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name}是一个{entity_type}。" } def _get_system_prompt(self, is_individual: bool) -> str: """Return the system prompt for persona generation.""" base_prompt = "You are an expert in social-media user-persona generation. Produce detailed, realistic personas for opinion simulation that faithfully reflect existing real-world conditions. You MUST return valid JSON; no string value may contain unescaped newlines." return f"{base_prompt}\n\n{get_language_instruction()}" def _build_individual_persona_prompt( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> str: """Build the detailed persona prompt for an individual entity.""" attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" context_str = context[:3000] if context else "No additional context" return f"""Generate a detailed social-media user persona for the entity, faithfully reflecting existing real-world conditions. Entity name: {entity_name} Entity type: {entity_type} Entity summary: {entity_summary} Entity attributes: {attrs_str} Context information: {context_str} Generate JSON with the following fields: 1. bio: social-media biography, ~200 characters 2. persona: detailed persona description (~2000 characters of plain text), covering: - Basic information (age, profession, education, location) - Background (notable experience, association with the event, social ties) - Personality (MBTI type, core traits, emotional expression) - Social-media behavior (posting frequency, content preferences, interaction style, language traits) - Stance (attitudes toward the topic, content likely to anger or move them) - Unique features (catchphrases, special experiences, hobbies) - Personal memory (a key part of the persona: this individual's relation to the event and prior actions/reactions in it) 3. age: age number (MUST be an integer) 4. gender: gender, MUST be one of the English literals: "male" or "female" 5. mbti: MBTI type (e.g. INTJ, ENFP) 6. country: country name 7. profession: profession 8. interested_topics: array of interest topics Important: - All field values MUST be strings or numbers; do not use unescaped newlines. - persona MUST be a single coherent block of text. - {get_language_instruction()} (gender field MUST use the English values "male" or "female") - Content must remain consistent with the entity information. - age MUST be a valid integer; gender MUST be "male" or "female". """ def _build_group_persona_prompt( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any], context: str ) -> str: """Build the detailed persona prompt for a group or institution entity.""" attrs_str = json.dumps(entity_attributes, ensure_ascii=False) if entity_attributes else "None" context_str = context[:3000] if context else "No additional context" return f"""Generate a detailed social-media account profile for the institution/group entity, faithfully reflecting existing real-world conditions. Entity name: {entity_name} Entity type: {entity_type} Entity summary: {entity_summary} Entity attributes: {attrs_str} Context information: {context_str} Generate JSON with the following fields: 1. bio: official-account biography, ~200 characters, professional and appropriate 2. persona: detailed account-profile description (~2000 characters of plain text), covering: - Institutional basics (formal name, institution type, founding background, primary functions) - Account positioning (account type, target audience, core function) - Voice (language traits, common phrasing, taboo topics) - Publishing pattern (content types, publishing frequency, active hours) - Stance (official position on the core topic, controversy-handling style) - Special notes (the group portrait represented, operational habits) - Institutional memory (a key part of the account profile: this institution's relation to the event and prior actions/reactions in it) 3. age: fixed integer 30 (the institutional virtual age) 4. gender: fixed literal "other" (institutional accounts use "other" to indicate non-individual) 5. mbti: MBTI type used to characterize account voice (e.g. ISTJ for strict/conservative) 6. country: country name 7. profession: institutional function description 8. interested_topics: array of focus areas Important: - All field values MUST be strings or numbers; null values are not allowed. - persona MUST be a single coherent block of text without unescaped newlines. - {get_language_instruction()} (gender field MUST use the English value "other") - age MUST be the integer 30; gender MUST be the string "other". - Account voice MUST match the institution's identity positioning.""" def _generate_profile_rule_based( self, entity_name: str, entity_type: str, entity_summary: str, entity_attributes: Dict[str, Any] ) -> Dict[str, Any]: """Rule-based fallback that generates a basic persona.""" # Branch on entity type to pick a persona shape. entity_type_lower = entity_type.lower() if entity_type_lower in ["student", "alumni"]: return { "bio": f"{entity_type} with interests in academics and social issues.", "persona": f"{entity_name} is a {entity_type.lower()} who is actively engaged in academic and social discussions. They enjoy sharing perspectives and connecting with peers.", "age": random.randint(18, 30), "gender": random.choice(["male", "female"]), "mbti": random.choice(self.MBTI_TYPES), "country": random.choice(self.COUNTRIES), "profession": "Student", "interested_topics": ["Education", "Social Issues", "Technology"], } elif entity_type_lower in ["publicfigure", "expert", "faculty"]: return { "bio": f"Expert and thought leader in their field.", "persona": f"{entity_name} is a recognized {entity_type.lower()} who shares insights and opinions on important matters. They are known for their expertise and influence in public discourse.", "age": random.randint(35, 60), "gender": random.choice(["male", "female"]), "mbti": random.choice(["ENTJ", "INTJ", "ENTP", "INTP"]), "country": random.choice(self.COUNTRIES), "profession": entity_attributes.get("occupation", "Expert"), "interested_topics": ["Politics", "Economics", "Culture & Society"], } elif entity_type_lower in ["mediaoutlet", "socialmediaplatform"]: return { "bio": f"Official account for {entity_name}. News and updates.", "persona": f"{entity_name} is a media entity that reports news and facilitates public discourse. The account shares timely updates and engages with the audience on current events.", "age": 30, # 机构虚拟年龄 "gender": "other", # 机构使用other "mbti": "ISTJ", # 机构风格:严谨保守 "country": "中国", "profession": "Media", "interested_topics": ["General News", "Current Events", "Public Affairs"], } elif entity_type_lower in ["university", "governmentagency", "ngo", "organization"]: return { "bio": f"Official account of {entity_name}.", "persona": f"{entity_name} is an institutional entity that communicates official positions, announcements, and engages with stakeholders on relevant matters.", "age": 30, # 机构虚拟年龄 "gender": "other", # 机构使用other "mbti": "ISTJ", # 机构风格:严谨保守 "country": "中国", "profession": entity_type, "interested_topics": ["Public Policy", "Community", "Official Announcements"], } else: # Default persona for unrecognised entity types. return { "bio": entity_summary[:150] if entity_summary else f"{entity_type}: {entity_name}", "persona": entity_summary or f"{entity_name} is a {entity_type.lower()} participating in social discussions.", "age": random.randint(25, 50), "gender": random.choice(["male", "female"]), "mbti": random.choice(self.MBTI_TYPES), "country": random.choice(self.COUNTRIES), "profession": entity_type, "interested_topics": ["General", "Social Issues"], } def set_graph_id(self, graph_id: str): """Set the graph id used for Zep retrieval.""" self.graph_id = graph_id def generate_profiles_from_entities( self, entities: List[EntityNode], use_llm: bool = True, progress_callback: Optional[callable] = None, graph_id: Optional[str] = None, parallel_count: int = 5, realtime_output_path: Optional[str] = None, output_platform: str = "reddit" ) -> List[OasisAgentProfile]: """Batch-generate Agent Profiles from entities (in parallel). Args: entities: The entities to convert. use_llm: Whether to use the LLM to generate detailed personas. progress_callback: Progress callback ``(current, total, message)``. graph_id: Graph id used for Zep retrieval to gather richer context. parallel_count: Number of profiles to generate concurrently (default 5). realtime_output_path: If set, profiles are flushed to this path after each successful generation. output_platform: Output platform format, ``"reddit"`` or ``"twitter"``. Returns: The generated list of Agent Profiles. """ import concurrent.futures from threading import Lock # Set the graph id used for Zep retrieval. if graph_id: self.graph_id = graph_id total = len(entities) profiles = [None] * total # Preallocate to keep insertion order. completed_count = [0] # List wrapper so closures can mutate the count. lock = Lock() def save_profiles_realtime(): """Flush the profiles generated so far to ``realtime_output_path``.""" if not realtime_output_path: return with lock: existing_profiles = [p for p in profiles if p is not None] if not existing_profiles: return try: if output_platform == "reddit": # Reddit JSON format. profiles_data = [p.to_reddit_format() for p in existing_profiles] with open(realtime_output_path, 'w', encoding='utf-8') as f: json.dump(profiles_data, f, ensure_ascii=False, indent=2) else: # Twitter CSV format. import csv profiles_data = [p.to_twitter_format() for p in existing_profiles] if profiles_data: fieldnames = list(profiles_data[0].keys()) with open(realtime_output_path, 'w', encoding='utf-8', newline='') as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(profiles_data) except Exception as e: logger.warning(t("log.profile_generator.m015", e=e)) # Capture locale before spawning thread pool workers current_locale = get_locale() def generate_single_profile(idx: int, entity: EntityNode) -> tuple: """Worker function that generates a single profile.""" set_locale(current_locale) entity_type = entity.get_entity_type() or "Entity" try: profile = self.generate_profile_from_entity( entity=entity, user_id=idx, use_llm=use_llm ) # Stream the generated persona to the console and log. self._print_generated_profile(entity.name, entity_type, profile) return idx, profile, None except Exception as e: logger.error(t("log.profile_generator.m016", entity=entity.name, str=str(e))) # Build a minimal fallback profile. fallback_profile = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), name=entity.name, bio=f"{entity_type}: {entity.name}", persona=entity.summary or f"A participant in social discussions.", source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) return idx, fallback_profile, str(e) logger.info(t("log.profile_generator.m017", total=total, parallel_count=parallel_count)) print(f"\n{'='*60}") print(t("log.profile_generator.m024", total=total, parallel_count=parallel_count)) print(f"{'='*60}\n") # Run generation across a thread pool. with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_count) as executor: future_to_entity = { executor.submit(generate_single_profile, idx, entity): (idx, entity) for idx, entity in enumerate(entities) } for future in concurrent.futures.as_completed(future_to_entity): idx, entity = future_to_entity[future] entity_type = entity.get_entity_type() or "Entity" try: result_idx, profile, error = future.result() profiles[result_idx] = profile with lock: completed_count[0] += 1 current = completed_count[0] # Flush profiles to disk in real time. save_profiles_realtime() if progress_callback: progress_callback( current, total, f"已完成 {current}/{total}: {entity.name}({entity_type})" ) if error: logger.warning(t("log.profile_generator.m018", current=current, total=total, entity=entity.name, error=error)) else: logger.info(t("log.profile_generator.m019", current=current, total=total, entity=entity.name, entity_type=entity_type)) except Exception as e: logger.error(t("log.profile_generator.m020", entity=entity.name, str=str(e))) with lock: completed_count[0] += 1 profiles[idx] = OasisAgentProfile( user_id=idx, user_name=self._generate_username(entity.name), name=entity.name, bio=f"{entity_type}: {entity.name}", persona=entity.summary or "A participant in social discussions.", source_entity_uuid=entity.uuid, source_entity_type=entity_type, ) # Flush profiles to disk even when only the fallback was produced. save_profiles_realtime() print(f"\n{'='*60}") print(t("log.profile_generator.m025", count=len([p for p in profiles if p]))) print(f"{'='*60}\n") return profiles def _print_generated_profile(self, entity_name: str, entity_type: str, profile: OasisAgentProfile): """Stream the generated persona to the console (full content, untruncated).""" separator = "-" * 70 # Assemble the full output (no truncation). topics_str = ', '.join(profile.interested_topics) if profile.interested_topics else '无' output_lines = [ f"\n{separator}", t('progress.profileGenerated', name=entity_name, type=entity_type), f"{separator}", f"用户名: {profile.user_name}", f"", f"【简介】", f"{profile.bio}", f"", f"【详细人设】", f"{profile.persona}", f"", f"【基本属性】", f"年龄: {profile.age} | 性别: {profile.gender} | MBTI: {profile.mbti}", f"职业: {profile.profession} | 国家: {profile.country}", f"兴趣话题: {topics_str}", separator ] output = "\n".join(output_lines) # Print to the console only — the logger no longer emits the full content # to avoid duplicate output. print(output) def save_profiles( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): """Save profiles to a file using the platform-specific format. OASIS format requirements: - Twitter: CSV format. - Reddit: JSON format. Args: profiles: The profiles to save. file_path: Destination file path. platform: Platform type, ``"reddit"`` or ``"twitter"``. """ if platform == "twitter": self._save_twitter_csv(profiles, file_path) else: self._save_reddit_json(profiles, file_path) def _save_twitter_csv(self, profiles: List[OasisAgentProfile], file_path: str): """Save Twitter profiles as CSV (matches OASIS's official format). Required CSV fields for OASIS Twitter: - user_id: User id (zero-indexed by CSV row order). - name: User's real-world display name. - username: System username. - user_char: Detailed persona text injected into the LLM system prompt to drive agent behavior. - description: Short public bio shown on the profile page. ``user_char`` vs ``description``: - user_char: Internal — LLM system prompt that controls how the agent thinks and acts. - description: External — short bio visible to other users. """ import csv # Ensure the file extension is .csv. if not file_path.endswith('.csv'): file_path = file_path.replace('.json', '.csv') with open(file_path, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) # Write the OASIS-required header row. headers = ['user_id', 'name', 'username', 'user_char', 'description'] writer.writerow(headers) for idx, profile in enumerate(profiles): # user_char: full persona (bio + persona), used in the LLM system prompt. user_char = profile.bio if profile.persona and profile.persona != profile.bio: user_char = f"{profile.bio} {profile.persona}" # Replace newlines with spaces for CSV compatibility. user_char = user_char.replace('\n', ' ').replace('\r', ' ') # description: short bio used for external display. description = profile.bio.replace('\n', ' ').replace('\r', ' ') row = [ idx, # user_id: zero-based sequential id profile.name, # name: real-world display name profile.user_name, # username: system username user_char, # user_char: full persona (internal LLM use) description # description: short bio (external display) ] writer.writerow(row) logger.info(t("log.profile_generator.m021", len=len(profiles), file_path=file_path)) def _normalize_gender(self, gender: Optional[str]) -> str: """Normalize the gender field into the English form required by OASIS. OASIS requires one of: ``male``, ``female``, ``other``. """ if not gender: return "other" gender_lower = gender.lower().strip() # Mapping from Chinese values to the English literals. gender_map = { "男": "male", "女": "female", "机构": "other", "其他": "other", # Already in English — pass through. "male": "male", "female": "female", "other": "other", } return gender_map.get(gender_lower, "other") def _save_reddit_json(self, profiles: List[OasisAgentProfile], file_path: str): """Save Reddit profiles as JSON. Uses the same shape as ``to_reddit_format()`` to ensure OASIS can read the file. The ``user_id`` field is mandatory — it is what ``agent_graph.get_agent()`` matches against. Required fields: - user_id: User id (integer; matches ``poster_agent_id`` in ``initial_posts``). - username: System username. - name: Display name. - bio: Short bio. - persona: Detailed persona. - age: Age (integer). - gender: One of ``"male"``, ``"female"``, ``"other"``. - mbti: MBTI type. - country: Country. """ data = [] for idx, profile in enumerate(profiles): # Match the shape of to_reddit_format(). item = { "user_id": profile.user_id if profile.user_id is not None else idx, # Critical: must include user_id. "username": profile.user_name, "name": profile.name, "bio": profile.bio[:150] if profile.bio else f"{profile.name}", "persona": profile.persona or f"{profile.name} is a participant in social discussions.", "karma": profile.karma if profile.karma else 1000, "created_at": profile.created_at, # OASIS-required fields — make sure each has a default. "age": profile.age if profile.age else 30, "gender": self._normalize_gender(profile.gender), "mbti": profile.mbti if profile.mbti else "ISTJ", "country": profile.country if profile.country else "中国", } # Optional fields. if profile.profession: item["profession"] = profile.profession if profile.interested_topics: item["interested_topics"] = profile.interested_topics data.append(item) with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) logger.info(t("log.profile_generator.m022", len=len(profiles), file_path=file_path)) # Retained as an alias for the old method name (backwards compatibility). def save_profiles_to_json( self, profiles: List[OasisAgentProfile], file_path: str, platform: str = "reddit" ): """[Deprecated] Use ``save_profiles()`` instead.""" logger.warning(t("log.profile_generator.m023")) self.save_profiles(profiles, file_path, platform)