MicroFish/backend/app/services/zep_tools.py

1730 lines
66 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Zep retrieval tool service.
Wraps graph search, node reads, and edge queries as tools consumed by the Report Agent.
Core retrieval tools (optimized):
1. InsightForge (deep insight search) - the most powerful hybrid retrieval; auto-decomposes the
query into sub-questions and searches across multiple dimensions.
2. PanoramaSearch (breadth search) - returns the full picture including expired content.
3. QuickSearch (simple search) - lightweight, fast retrieval.
"""
import time
import json
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from .graphiti_adapter import GraphitiAdapter
from ..config import Config
from ..utils.logger import get_logger
from ..utils.llm_client import LLMClient
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from ..utils.locale import t
logger = get_logger('mirofish.zep_tools')
@dataclass
class SearchResult:
"""Search result."""
facts: List[str]
edges: List[Dict[str, Any]]
nodes: List[Dict[str, Any]]
query: str
total_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"facts": self.facts,
"edges": self.edges,
"nodes": self.nodes,
"query": self.query,
"total_count": self.total_count
}
def to_text(self) -> str:
"""Render to text format for LLM consumption."""
text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"]
if self.facts:
text_parts.append("\n### 相关事实:")
for i, fact in enumerate(self.facts, 1):
text_parts.append(f"{i}. {fact}")
return "\n".join(text_parts)
@dataclass
class NodeInfo:
"""Node information."""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes
}
def to_text(self) -> str:
"""Render to text format."""
entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型")
return f"实体: {self.name} (类型: {entity_type})\n摘要: {self.summary}"
@dataclass
class EdgeInfo:
"""Edge information."""
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
source_node_name: Optional[str] = None
target_node_name: Optional[str] = None
created_at: Optional[str] = None
valid_at: Optional[str] = None
invalid_at: Optional[str] = None
expired_at: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"fact": self.fact,
"source_node_uuid": self.source_node_uuid,
"target_node_uuid": self.target_node_uuid,
"source_node_name": self.source_node_name,
"target_node_name": self.target_node_name,
"created_at": self.created_at,
"valid_at": self.valid_at,
"invalid_at": self.invalid_at,
"expired_at": self.expired_at
}
def to_text(self, include_temporal: bool = False) -> str:
"""Render to text format."""
source = self.source_node_name or self.source_node_uuid[:8]
target = self.target_node_name or self.target_node_uuid[:8]
base_text = f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}"
if include_temporal:
valid_at = self.valid_at or "未知"
invalid_at = self.invalid_at or "至今"
base_text += f"\n时效: {valid_at} - {invalid_at}"
if self.expired_at:
base_text += f" (已过期: {self.expired_at})"
return base_text
@property
def is_expired(self) -> bool:
"""Whether this edge has expired."""
return self.expired_at is not None
@property
def is_invalid(self) -> bool:
"""Whether this edge has been invalidated."""
return self.invalid_at is not None
@dataclass
class InsightForgeResult:
"""Deep-insight retrieval result (InsightForge).
Holds the retrieval results from multiple sub-queries plus the synthesized analysis.
"""
query: str
simulation_requirement: str
sub_queries: List[str]
# Retrieval results across multiple dimensions.
semantic_facts: List[str] = field(default_factory=list)
entity_insights: List[Dict[str, Any]] = field(default_factory=list)
relationship_chains: List[str] = field(default_factory=list)
total_facts: int = 0
total_entities: int = 0
total_relationships: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"query": self.query,
"simulation_requirement": self.simulation_requirement,
"sub_queries": self.sub_queries,
"semantic_facts": self.semantic_facts,
"entity_insights": self.entity_insights,
"relationship_chains": self.relationship_chains,
"total_facts": self.total_facts,
"total_entities": self.total_entities,
"total_relationships": self.total_relationships
}
def to_text(self) -> str:
"""Render a detailed text representation for the LLM."""
text_parts = [
f"## 未来预测深度分析",
f"分析问题: {self.query}",
f"预测场景: {self.simulation_requirement}",
f"\n### 预测数据统计",
f"- 相关预测事实: {self.total_facts}",
f"- 涉及实体: {self.total_entities}",
f"- 关系链: {self.total_relationships}"
]
if self.sub_queries:
text_parts.append(f"\n### 分析的子问题")
for i, sq in enumerate(self.sub_queries, 1):
text_parts.append(f"{i}. {sq}")
if self.semantic_facts:
text_parts.append(f"\n### 【关键事实】(请在报告中引用这些原文)")
for i, fact in enumerate(self.semantic_facts, 1):
text_parts.append(f"{i}. \"{fact}\"")
if self.entity_insights:
text_parts.append(f"\n### 【核心实体】")
for entity in self.entity_insights:
text_parts.append(f"- **{entity.get('name', '未知')}** ({entity.get('type', '实体')})")
if entity.get('summary'):
text_parts.append(f" 摘要: \"{entity.get('summary')}\"")
if entity.get('related_facts'):
text_parts.append(f" 相关事实: {len(entity.get('related_facts', []))}")
if self.relationship_chains:
text_parts.append(f"\n### 【关系链】")
for chain in self.relationship_chains:
text_parts.append(f"- {chain}")
return "\n".join(text_parts)
@dataclass
class PanoramaResult:
"""Breadth-search result (Panorama).
Contains every piece of related information, including expired content.
"""
query: str
all_nodes: List[NodeInfo] = field(default_factory=list)
# All edges, including expired ones.
all_edges: List[EdgeInfo] = field(default_factory=list)
# Currently active facts.
active_facts: List[str] = field(default_factory=list)
# Expired or invalidated facts (historical record).
historical_facts: List[str] = field(default_factory=list)
total_nodes: int = 0
total_edges: int = 0
active_count: int = 0
historical_count: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"query": self.query,
"all_nodes": [n.to_dict() for n in self.all_nodes],
"all_edges": [e.to_dict() for e in self.all_edges],
"active_facts": self.active_facts,
"historical_facts": self.historical_facts,
"total_nodes": self.total_nodes,
"total_edges": self.total_edges,
"active_count": self.active_count,
"historical_count": self.historical_count
}
def to_text(self) -> str:
"""Render the full text format (no truncation)."""
text_parts = [
f"## 广度搜索结果(未来全景视图)",
f"查询: {self.query}",
f"\n### 统计信息",
f"- 总节点数: {self.total_nodes}",
f"- 总边数: {self.total_edges}",
f"- 当前有效事实: {self.active_count}",
f"- 历史/过期事实: {self.historical_count}"
]
# Currently active facts (emit in full, no truncation).
if self.active_facts:
text_parts.append(f"\n### 【当前有效事实】(模拟结果原文)")
for i, fact in enumerate(self.active_facts, 1):
text_parts.append(f"{i}. \"{fact}\"")
# Historical / expired facts (emit in full, no truncation).
if self.historical_facts:
text_parts.append(f"\n### 【历史/过期事实】(演变过程记录)")
for i, fact in enumerate(self.historical_facts, 1):
text_parts.append(f"{i}. \"{fact}\"")
# Key entities (emit in full, no truncation).
if self.all_nodes:
text_parts.append(f"\n### 【涉及实体】")
for node in self.all_nodes:
entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体")
text_parts.append(f"- **{node.name}** ({entity_type})")
return "\n".join(text_parts)
@dataclass
class AgentInterview:
"""Interview result for a single agent."""
agent_name: str
agent_role: str
agent_bio: str
question: str
response: str
key_quotes: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"agent_name": self.agent_name,
"agent_role": self.agent_role,
"agent_bio": self.agent_bio,
"question": self.question,
"response": self.response,
"key_quotes": self.key_quotes
}
def to_text(self) -> str:
text = f"**{self.agent_name}** ({self.agent_role})\n"
# Render the full agent_bio without truncation.
text += f"_简介: {self.agent_bio}_\n\n"
text += f"**Q:** {self.question}\n\n"
text += f"**A:** {self.response}\n"
if self.key_quotes:
text += "\n**关键引言:**\n"
for quote in self.key_quotes:
# Strip the various quote characters (curly quotes and CJK corner brackets).
clean_quote = quote.replace('\u201c', '').replace('\u201d', '').replace('"', '')
clean_quote = clean_quote.replace('\u300c', '').replace('\u300d', '')
clean_quote = clean_quote.strip()
# Drop any leading punctuation.
while clean_quote and clean_quote[0] in ',;:、。!?\n\r\t ':
clean_quote = clean_quote[1:]
# Skip junk content that contains a question-number label (e.g. labels 1-9).
skip = False
for d in '123456789':
if f'\u95ee\u9898{d}' in clean_quote:
skip = True
break
if skip:
continue
# Truncate over-long quotes at the next period rather than a hard cut.
if len(clean_quote) > 150:
dot_pos = clean_quote.find('\u3002', 80)
if dot_pos > 0:
clean_quote = clean_quote[:dot_pos + 1]
else:
clean_quote = clean_quote[:147] + "..."
if clean_quote and len(clean_quote) >= 10:
text += f'> "{clean_quote}"\n'
return text
@dataclass
class InterviewResult:
"""Interview result.
Aggregates the responses from multiple simulated agents.
"""
interview_topic: str
interview_questions: List[str]
# Agents chosen for the interview.
selected_agents: List[Dict[str, Any]] = field(default_factory=list)
# Per-agent interview responses.
interviews: List[AgentInterview] = field(default_factory=list)
# Reasoning for the agent selection.
selection_reasoning: str = ""
# Synthesized interview summary.
summary: str = ""
total_agents: int = 0
interviewed_count: int = 0
def to_dict(self) -> Dict[str, Any]:
return {
"interview_topic": self.interview_topic,
"interview_questions": self.interview_questions,
"selected_agents": self.selected_agents,
"interviews": [i.to_dict() for i in self.interviews],
"selection_reasoning": self.selection_reasoning,
"summary": self.summary,
"total_agents": self.total_agents,
"interviewed_count": self.interviewed_count
}
def to_text(self) -> str:
"""Render a detailed text representation for the LLM and report citations."""
text_parts = [
"## 深度采访报告",
f"**采访主题:** {self.interview_topic}",
f"**采访人数:** {self.interviewed_count} / {self.total_agents} 位模拟Agent",
"\n### 采访对象选择理由",
self.selection_reasoning or "(自动选择)",
"\n---",
"\n### 采访实录",
]
if self.interviews:
for i, interview in enumerate(self.interviews, 1):
text_parts.append(f"\n#### 采访 #{i}: {interview.agent_name}")
text_parts.append(interview.to_text())
text_parts.append("\n---")
else:
text_parts.append("(无采访记录)\n\n---")
text_parts.append("\n### 采访摘要与核心观点")
text_parts.append(self.summary or "(无摘要)")
return "\n".join(text_parts)
class ZepToolsService:
"""Zep retrieval tool service.
Core retrieval tools (optimized):
1. insight_forge - deep-insight search (most powerful; auto-generates sub-questions
and searches across multiple dimensions).
2. panorama_search - breadth search (full picture including expired content).
3. quick_search - simple, fast retrieval.
4. interview_agents - deep interview (interviews simulated agents and gathers
perspectives from multiple roles).
Basic tools:
- search_graph - semantic graph search.
- get_all_nodes - fetch every node in the graph.
- get_all_edges - fetch every edge in the graph (with temporal info).
- get_node_detail - fetch a single node's details.
- get_node_edges - fetch edges incident to a node.
- get_entities_by_type - fetch entities filtered by type.
- get_entity_summary - fetch a relationship summary for an entity.
"""
# Retry configuration.
MAX_RETRIES = 3
RETRY_DELAY = 2.0
def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None):
self.client = GraphitiAdapter()
# LLM client used by InsightForge to generate sub-questions.
self._llm_client = llm_client
logger.info(t("log.zep_tools.m001"))
@property
def llm(self) -> LLMClient:
"""Lazily initialize the LLM client."""
if self._llm_client is None:
self._llm_client = LLMClient()
return self._llm_client
def _call_with_retry(self, func, operation_name: str, max_retries: int = None):
"""API call with retry (auto-handles HTTP 429 rate limiting)."""
max_retries = max_retries or self.MAX_RETRIES
last_exception = None
delay = self.RETRY_DELAY
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
# On HTTP 429 rate-limit errors, honour the retry-after header.
wait = delay
if hasattr(e, 'status_code') and e.status_code == 429:
retry_after = None
if hasattr(e, 'headers') and e.headers:
retry_after = e.headers.get('retry-after')
wait = float(retry_after) + 1 if retry_after else 65.0
logger.warning(
t("log.zep_tools.m002", operation_name=operation_name, wait=wait, attempt=attempt + 1, max_retries=max_retries - 1)
)
else:
logger.warning(
t("log.zep_tools.m003", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], wait=wait)
)
time.sleep(wait)
delay *= 2
else:
logger.error(t("log.zep_tools.m004", operation_name=operation_name, max_retries=max_retries, str=str(e)))
raise last_exception
def search_graph(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""Semantic graph search.
Performs a hybrid search (semantic + BM25) over the graph. If the Zep Cloud search
API is unavailable, falls back to local keyword matching.
Args:
graph_id: Graph identifier (Standalone Graph).
query: Search query.
limit: Maximum number of results to return.
scope: Search scope, either "edges" or "nodes".
Returns:
SearchResult: The search result.
"""
logger.info(t("log.zep_tools.m005", graph_id=graph_id, query=query[:50]))
# Try the Zep Cloud Search API first.
try:
search_results = self._call_with_retry(
func=lambda: self.client.graph.search(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
),
operation_name=f"图谱搜索(graph={graph_id})"
)
facts = []
edges = []
nodes = []
# Parse edge search results.
if hasattr(search_results, 'edges') and search_results.edges:
for edge in search_results.edges:
if hasattr(edge, 'fact') and edge.fact:
facts.append(edge.fact)
edges.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": getattr(edge, 'name', ''),
"fact": getattr(edge, 'fact', ''),
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
})
# Parse node search results.
if hasattr(search_results, 'nodes') and search_results.nodes:
for node in search_results.nodes:
nodes.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": getattr(node, 'name', ''),
"labels": getattr(node, 'labels', []),
"summary": getattr(node, 'summary', ''),
})
# Treat node summaries as facts too.
if hasattr(node, 'summary') and node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(t("log.zep_tools.m006", len=len(facts)))
return SearchResult(
facts=facts,
edges=edges,
nodes=nodes,
query=query,
total_count=len(facts)
)
except Exception as e:
logger.warning(t("log.zep_tools.m007", str=str(e)))
# Fallback: local keyword-matching search.
return self._local_search(graph_id, query, limit, scope)
def _local_search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""Local keyword-matching search (fallback for the Zep Search API).
Loads all edges/nodes and matches them locally on the query keywords.
Args:
graph_id: Graph identifier.
query: Search query.
limit: Maximum number of results to return.
scope: Search scope.
Returns:
SearchResult: The search result.
"""
logger.info(t("log.zep_tools.m008", query=query[:30]))
facts = []
edges_result = []
nodes_result = []
# Extract query keywords with naive whitespace tokenization.
query_lower = query.lower()
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace('', ' ').split() if len(w.strip()) > 1]
def match_score(text: str) -> int:
"""Compute the match score between the text and the query."""
if not text:
return 0
text_lower = text.lower()
# Exact match against the full query.
if query_lower in text_lower:
return 100
# Per-keyword match.
score = 0
for keyword in keywords:
if keyword in text_lower:
score += 10
return score
try:
if scope in ["edges", "both"]:
# Fetch every edge and score each one.
all_edges = self.get_all_edges(graph_id)
scored_edges = []
for edge in all_edges:
score = match_score(edge.fact) + match_score(edge.name)
if score > 0:
scored_edges.append((score, edge))
# Sort by score descending.
scored_edges.sort(key=lambda x: x[0], reverse=True)
for score, edge in scored_edges[:limit]:
if edge.fact:
facts.append(edge.fact)
edges_result.append({
"uuid": edge.uuid,
"name": edge.name,
"fact": edge.fact,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
})
if scope in ["nodes", "both"]:
# Fetch every node and score each one.
all_nodes = self.get_all_nodes(graph_id)
scored_nodes = []
for node in all_nodes:
score = match_score(node.name) + match_score(node.summary)
if score > 0:
scored_nodes.append((score, node))
scored_nodes.sort(key=lambda x: x[0], reverse=True)
for score, node in scored_nodes[:limit]:
nodes_result.append({
"uuid": node.uuid,
"name": node.name,
"labels": node.labels,
"summary": node.summary,
})
if node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(t("log.zep_tools.m009", len=len(facts)))
except Exception as e:
logger.error(t("log.zep_tools.m010", str=str(e)))
return SearchResult(
facts=facts,
edges=edges_result,
nodes=nodes_result,
query=query,
total_count=len(facts)
)
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
"""Fetch every node in the graph (with pagination).
Args:
graph_id: Graph identifier.
Returns:
List of nodes.
"""
logger.info(t("log.zep_tools.m011", graph_id=graph_id))
nodes = fetch_all_nodes(self.client, graph_id)
result = []
for node in nodes:
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
result.append(NodeInfo(
uuid=str(node_uuid) if node_uuid else "",
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
))
logger.info(t("log.zep_tools.m012", len=len(result)))
return result
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
"""Fetch every edge in the graph (with pagination), including temporal info.
Args:
graph_id: Graph identifier.
include_temporal: Whether to include temporal fields (default True).
Returns:
List of edges, including created_at, valid_at, invalid_at, and expired_at.
"""
logger.info(t("log.zep_tools.m013", graph_id=graph_id))
edges = fetch_all_edges(self.client, graph_id)
result = []
for edge in edges:
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
edge_info = EdgeInfo(
uuid=str(edge_uuid) if edge_uuid else "",
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid or "",
target_node_uuid=edge.target_node_uuid or ""
)
# Attach temporal info.
if include_temporal:
edge_info.created_at = getattr(edge, 'created_at', None)
edge_info.valid_at = getattr(edge, 'valid_at', None)
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
edge_info.expired_at = getattr(edge, 'expired_at', None)
result.append(edge_info)
logger.info(t("log.zep_tools.m014", len=len(result)))
return result
def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]:
"""Fetch the details of a single node.
Args:
node_uuid: Node UUID.
Returns:
Node info, or None if not found.
"""
logger.info(t("log.zep_tools.m015", node_uuid=node_uuid[:8]))
try:
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)"
)
if not node:
return None
return NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
)
except Exception as e:
logger.error(t("log.zep_tools.m016", str=str(e)))
return None
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]:
"""Fetch all edges incident to a node.
Loads every edge in the graph and filters to those connected to the given node.
Args:
graph_id: Graph identifier.
node_uuid: Node UUID.
Returns:
List of edges incident to the node.
"""
logger.info(t("log.zep_tools.m017", node_uuid=node_uuid[:8]))
try:
# Load every edge in the graph, then filter.
all_edges = self.get_all_edges(graph_id)
result = []
for edge in all_edges:
# Keep the edge if it is incident to this node (as source or target).
if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid:
result.append(edge)
logger.info(t("log.zep_tools.m018", len=len(result)))
return result
except Exception as e:
logger.warning(t("log.zep_tools.m019", str=str(e)))
return []
def get_entities_by_type(
self,
graph_id: str,
entity_type: str
) -> List[NodeInfo]:
"""Fetch entities filtered by type.
Args:
graph_id: Graph identifier.
entity_type: Entity type (e.g. Student, PublicFigure).
Returns:
Entities matching the requested type.
"""
logger.info(t("log.zep_tools.m020", entity_type=entity_type))
all_nodes = self.get_all_nodes(graph_id)
filtered = []
for node in all_nodes:
# Keep the node if its labels include the requested type.
if entity_type in node.labels:
filtered.append(node)
logger.info(t("log.zep_tools.m021", len=len(filtered), entity_type=entity_type))
return filtered
def get_entity_summary(
self,
graph_id: str,
entity_name: str
) -> Dict[str, Any]:
"""Fetch the relationship summary for an entity.
Searches for everything related to the entity and assembles a summary.
Args:
graph_id: Graph identifier.
entity_name: Entity name.
Returns:
Entity summary information.
"""
logger.info(t("log.zep_tools.m022", entity_name=entity_name))
# First, search for information about this entity.
search_result = self.search_graph(
graph_id=graph_id,
query=entity_name,
limit=20
)
# Try to locate the entity in the full node list.
all_nodes = self.get_all_nodes(graph_id)
entity_node = None
for node in all_nodes:
if node.name.lower() == entity_name.lower():
entity_node = node
break
related_edges = []
if entity_node:
# Pass through the graph_id parameter.
related_edges = self.get_node_edges(graph_id, entity_node.uuid)
return {
"entity_name": entity_name,
"entity_info": entity_node.to_dict() if entity_node else None,
"related_facts": search_result.facts,
"related_edges": [e.to_dict() for e in related_edges],
"total_relations": len(related_edges)
}
def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]:
"""Fetch statistics about the graph.
Args:
graph_id: Graph identifier.
Returns:
Statistics dictionary.
"""
logger.info(t("log.zep_tools.m023", graph_id=graph_id))
nodes = self.get_all_nodes(graph_id)
edges = self.get_all_edges(graph_id)
# Tally entity type distribution.
entity_types = {}
for node in nodes:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types[label] = entity_types.get(label, 0) + 1
# Tally relation type distribution.
relation_types = {}
for edge in edges:
relation_types[edge.name] = relation_types.get(edge.name, 0) + 1
return {
"graph_id": graph_id,
"total_nodes": len(nodes),
"total_edges": len(edges),
"entity_types": entity_types,
"relation_types": relation_types
}
def get_simulation_context(
self,
graph_id: str,
simulation_requirement: str,
limit: int = 30
) -> Dict[str, Any]:
"""Fetch simulation-related context.
Combines a search over the simulation requirement with graph statistics and entities.
Args:
graph_id: Graph identifier.
simulation_requirement: Description of the simulation requirement.
limit: Per-category result limit.
Returns:
Simulation context information.
"""
logger.info(t("log.zep_tools.m024", simulation_requirement=simulation_requirement[:50]))
# Search for information related to the simulation requirement.
search_result = self.search_graph(
graph_id=graph_id,
query=simulation_requirement,
limit=limit
)
# Pull graph statistics.
stats = self.get_graph_statistics(graph_id)
# Load every entity node.
all_nodes = self.get_all_nodes(graph_id)
# Keep entities that have a concrete type (skip plain Entity nodes).
entities = []
for node in all_nodes:
custom_labels = [l for l in node.labels if l not in ["Entity", "Node"]]
if custom_labels:
entities.append({
"name": node.name,
"type": custom_labels[0],
"summary": node.summary
})
return {
"simulation_requirement": simulation_requirement,
"related_facts": search_result.facts,
"graph_statistics": stats,
"entities": entities[:limit], # Cap entity count.
"total_entities": len(entities)
}
# ========== Core retrieval tools (optimized) ==========
def insight_forge(
self,
graph_id: str,
query: str,
simulation_requirement: str,
report_context: str = "",
max_sub_queries: int = 5
) -> InsightForgeResult:
"""InsightForge - deep-insight retrieval.
Most powerful hybrid retrieval. Auto-decomposes the user question and searches across
multiple dimensions:
1. Uses an LLM to decompose the question into sub-questions.
2. Runs a semantic search for each sub-question.
3. Extracts related entities and fetches their details.
4. Traces relationship chains.
5. Synthesises everything into a deep-insight payload.
Args:
graph_id: Graph identifier.
query: The user's question.
simulation_requirement: Description of the simulation requirement.
report_context: Report context (optional; used to ground sub-question generation).
max_sub_queries: Maximum number of sub-questions to generate.
Returns:
InsightForgeResult: The deep-insight retrieval result.
"""
logger.info(t("log.zep_tools.m025", query=query[:50]))
result = InsightForgeResult(
query=query,
simulation_requirement=simulation_requirement,
sub_queries=[]
)
# Step 1: Use the LLM to generate sub-questions.
sub_queries = self._generate_sub_queries(
query=query,
simulation_requirement=simulation_requirement,
report_context=report_context,
max_queries=max_sub_queries
)
result.sub_queries = sub_queries
logger.info(t("log.zep_tools.m026", len=len(sub_queries)))
# Step 2: Run a semantic search for each sub-question.
all_facts = []
all_edges = []
seen_facts = set()
for sub_query in sub_queries:
search_result = self.search_graph(
graph_id=graph_id,
query=sub_query,
limit=15,
scope="edges"
)
for fact in search_result.facts:
if fact not in seen_facts:
all_facts.append(fact)
seen_facts.add(fact)
all_edges.extend(search_result.edges)
# Also search using the original question.
main_search = self.search_graph(
graph_id=graph_id,
query=query,
limit=20,
scope="edges"
)
for fact in main_search.facts:
if fact not in seen_facts:
all_facts.append(fact)
seen_facts.add(fact)
result.semantic_facts = all_facts
result.total_facts = len(all_facts)
# Step 3: Pull related entity UUIDs from the edges and only fetch those nodes
# (rather than every node in the graph).
entity_uuids = set()
for edge_data in all_edges:
if isinstance(edge_data, dict):
source_uuid = edge_data.get('source_node_uuid', '')
target_uuid = edge_data.get('target_node_uuid', '')
if source_uuid:
entity_uuids.add(source_uuid)
if target_uuid:
entity_uuids.add(target_uuid)
# Fetch details for every related entity (no cap, emit in full).
entity_insights = []
node_map = {} # Cached for relationship-chain assembly below.
for uuid in list(entity_uuids): # Walk every related entity, no truncation.
if not uuid:
continue
try:
# Fetch each related node individually.
node = self.get_node_detail(uuid)
if node:
node_map[uuid] = node
entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体")
# Collect every fact related to this entity (no truncation).
related_facts = [
f for f in all_facts
if node.name.lower() in f.lower()
]
entity_insights.append({
"uuid": node.uuid,
"name": node.name,
"type": entity_type,
"summary": node.summary,
"related_facts": related_facts
})
except Exception as e:
logger.debug(t("log.zep_tools.m027", uuid=uuid, e=e))
continue
result.entity_insights = entity_insights
result.total_entities = len(entity_insights)
# Step 4: Assemble every relationship chain (no cap).
relationship_chains = []
for edge_data in all_edges: # Walk every edge, no truncation.
if isinstance(edge_data, dict):
source_uuid = edge_data.get('source_node_uuid', '')
target_uuid = edge_data.get('target_node_uuid', '')
relation_name = edge_data.get('name', '')
source_name = node_map.get(source_uuid, NodeInfo('', '', [], '', {})).name or source_uuid[:8]
target_name = node_map.get(target_uuid, NodeInfo('', '', [], '', {})).name or target_uuid[:8]
chain = f"{source_name} --[{relation_name}]--> {target_name}"
if chain not in relationship_chains:
relationship_chains.append(chain)
result.relationship_chains = relationship_chains
result.total_relationships = len(relationship_chains)
logger.info(t("log.zep_tools.m028", result=result.total_facts, result_2=result.total_entities, result_3=result.total_relationships))
return result
def _generate_sub_queries(
self,
query: str,
simulation_requirement: str,
report_context: str = "",
max_queries: int = 5
) -> List[str]:
"""Use the LLM to generate sub-questions.
Decomposes a complex question into multiple sub-questions that can be retrieved
independently.
"""
system_prompt = """你是一个专业的问题分析专家。你的任务是将一个复杂问题分解为多个可以在模拟世界中独立观察的子问题。
要求:
1. 每个子问题应该足够具体可以在模拟世界中找到相关的Agent行为或事件
2. 子问题应该覆盖原问题的不同维度(如:谁、什么、为什么、怎么样、何时、何地)
3. 子问题应该与模拟场景相关
4. 返回JSON格式{"sub_queries": ["子问题1", "子问题2", ...]}"""
user_prompt = f"""模拟需求背景:
{simulation_requirement}
{f"报告上下文:{report_context[:500]}" if report_context else ""}
请将以下问题分解为{max_queries}个子问题:
{query}
返回JSON格式的子问题列表。"""
try:
response = self.llm.chat_json(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3
)
sub_queries = response.get("sub_queries", [])
# Coerce to a list of strings.
return [str(sq) for sq in sub_queries[:max_queries]]
except Exception as e:
logger.warning(t("log.zep_tools.m029", str=str(e)))
# Fallback: return variants of the original question.
return [
query,
f"{query} 的主要参与者",
f"{query} 的原因和影响",
f"{query} 的发展过程"
][:max_queries]
def panorama_search(
self,
graph_id: str,
query: str,
include_expired: bool = True,
limit: int = 50
) -> PanoramaResult:
"""PanoramaSearch - breadth search.
Returns the full picture, including all related content and historical/expired info:
1. Fetches every related node.
2. Fetches every edge (including expired/invalidated ones).
3. Sorts the facts into currently-active and historical buckets.
Use this tool when callers need to understand the full event landscape or trace how
something evolved over time.
Args:
graph_id: Graph identifier.
query: Search query (used for relevance ranking).
include_expired: Whether to include expired content (default True).
limit: Maximum number of results to return.
Returns:
PanoramaResult: The breadth-search result.
"""
logger.info(t("log.zep_tools.m030", query=query[:50]))
result = PanoramaResult(query=query)
# Fetch every node.
all_nodes = self.get_all_nodes(graph_id)
node_map = {n.uuid: n for n in all_nodes}
result.all_nodes = all_nodes
result.total_nodes = len(all_nodes)
# Fetch every edge (with temporal info).
all_edges = self.get_all_edges(graph_id, include_temporal=True)
result.all_edges = all_edges
result.total_edges = len(all_edges)
# Bucket facts into active vs. historical.
active_facts = []
historical_facts = []
for edge in all_edges:
if not edge.fact:
continue
# Attach entity names to the fact.
source_name = node_map.get(edge.source_node_uuid, NodeInfo('', '', [], '', {})).name or edge.source_node_uuid[:8]
target_name = node_map.get(edge.target_node_uuid, NodeInfo('', '', [], '', {})).name or edge.target_node_uuid[:8]
# Decide whether the edge is historical (expired or invalidated).
is_historical = edge.is_expired or edge.is_invalid
if is_historical:
# Historical/expired fact, prepend a time marker.
valid_at = edge.valid_at or "未知"
invalid_at = edge.invalid_at or edge.expired_at or "未知"
fact_with_time = f"[{valid_at} - {invalid_at}] {edge.fact}"
historical_facts.append(fact_with_time)
else:
# Currently active fact.
active_facts.append(edge.fact)
# Relevance-rank against the query.
query_lower = query.lower()
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace('', ' ').split() if len(w.strip()) > 1]
def relevance_score(fact: str) -> int:
fact_lower = fact.lower()
score = 0
if query_lower in fact_lower:
score += 100
for kw in keywords:
if kw in fact_lower:
score += 10
return score
# Sort and apply the result limit.
active_facts.sort(key=relevance_score, reverse=True)
historical_facts.sort(key=relevance_score, reverse=True)
result.active_facts = active_facts[:limit]
result.historical_facts = historical_facts[:limit] if include_expired else []
result.active_count = len(active_facts)
result.historical_count = len(historical_facts)
logger.info(t("log.zep_tools.m031", result=result.active_count, result_2=result.historical_count))
return result
def quick_search(
self,
graph_id: str,
query: str,
limit: int = 10
) -> SearchResult:
"""QuickSearch - simple, lightweight retrieval.
Calls Zep's semantic search directly and returns the most relevant results. Use this
for simple, straightforward retrieval needs.
Args:
graph_id: Graph identifier.
query: Search query.
limit: Maximum number of results to return.
Returns:
SearchResult: The search result.
"""
logger.info(t("log.zep_tools.m032", query=query[:50]))
# Delegate to the existing search_graph implementation.
result = self.search_graph(
graph_id=graph_id,
query=query,
limit=limit,
scope="edges"
)
logger.info(t("log.zep_tools.m033", result=result.total_count))
return result
def interview_agents(
self,
simulation_id: str,
interview_requirement: str,
simulation_requirement: str = "",
max_agents: int = 5,
custom_questions: List[str] = None
) -> InterviewResult:
"""InterviewAgents - deep interview.
Calls the real OASIS interview API and interviews agents that are currently running
in the simulation:
1. Reads the agent persona file to learn the available simulated agents.
2. Uses an LLM to analyse the interview requirement and pick the most relevant
agents.
3. Uses an LLM to generate interview questions.
4. Calls /api/simulation/interview/batch to run the real interview (across both
Twitter and Reddit platforms simultaneously).
5. Aggregates the interview responses into a report.
Important: this requires the simulation environment to be running (the OASIS
environment must not be torn down).
Use cases:
- Understanding how different roles view an event.
- Collecting opinions from multiple sides.
- Getting genuine responses from simulated agents (rather than LLM-only
simulation).
Args:
simulation_id: Simulation identifier (used to locate persona files and call the
interview API).
interview_requirement: Free-form interview brief (e.g. "understand how students
view the event").
simulation_requirement: Background context for the simulation (optional).
max_agents: Maximum number of agents to interview.
custom_questions: Custom interview questions (optional; auto-generated if absent).
Returns:
InterviewResult: The interview result.
"""
from .simulation_runner import SimulationRunner
logger.info(t("log.zep_tools.m034", interview_requirement=interview_requirement[:50]))
result = InterviewResult(
interview_topic=interview_requirement,
interview_questions=custom_questions or []
)
# Step 1: Load the persona file.
profiles = self._load_agent_profiles(simulation_id)
if not profiles:
logger.warning(t("log.zep_tools.m035", simulation_id=simulation_id))
result.summary = "未找到可采访的Agent人设文件"
return result
result.total_agents = len(profiles)
logger.info(t("log.zep_tools.m036", len=len(profiles)))
# Step 2: Use the LLM to pick interview targets (returns a list of agent IDs).
selected_agents, selected_indices, selection_reasoning = self._select_agents_for_interview(
profiles=profiles,
interview_requirement=interview_requirement,
simulation_requirement=simulation_requirement,
max_agents=max_agents
)
result.selected_agents = selected_agents
result.selection_reasoning = selection_reasoning
logger.info(t("log.zep_tools.m037", len=len(selected_agents), selected_indices=selected_indices))
# Step 3: Generate interview questions (if none were supplied).
if not result.interview_questions:
result.interview_questions = self._generate_interview_questions(
interview_requirement=interview_requirement,
simulation_requirement=simulation_requirement,
selected_agents=selected_agents
)
logger.info(t("log.zep_tools.m038", len=len(result.interview_questions)))
# Merge the questions into a single interview prompt.
combined_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(result.interview_questions)])
# Prepend an optimised prefix that constrains the agent's reply format.
INTERVIEW_PROMPT_PREFIX = (
"你正在接受一次采访。请结合你的人设、所有的过往记忆与行动,"
"以纯文本方式直接回答以下问题。\n"
"回复要求:\n"
"1. 直接用自然语言回答,不要调用任何工具\n"
"2. 不要返回JSON格式或工具调用格式\n"
"3. 不要使用Markdown标题如#、##、###\n"
"4. 按问题编号逐一回答每个回答以「问题X」开头X为问题编号\n"
"5. 每个问题的回答之间用空行分隔\n"
"6. 回答要有实质内容每个问题至少回答2-3句话\n\n"
)
optimized_prompt = f"{INTERVIEW_PROMPT_PREFIX}{combined_prompt}"
# Step 4: Call the real interview API. We omit the platform field so the API
# interviews on both Twitter and Reddit by default.
try:
# Build the batch-interview list (no platform => both platforms).
interviews_request = []
for agent_idx in selected_indices:
interviews_request.append({
"agent_id": agent_idx,
"prompt": optimized_prompt
# Omitting platform asks the API to interview on both Twitter and Reddit.
})
logger.info(t("log.zep_tools.m039", len=len(interviews_request)))
# Call SimulationRunner's batch interview helper (no platform => both platforms).
api_result = SimulationRunner.interview_agents_batch(
simulation_id=simulation_id,
interviews=interviews_request,
platform=None, # Omitting platform interviews both Twitter and Reddit.
timeout=180.0 # Dual-platform mode needs a longer timeout.
)
logger.info(t("log.zep_tools.m040", api_result=api_result.get('interviews_count', 0), api_result_2=api_result.get('success')))
# Check whether the API call succeeded.
if not api_result.get("success", False):
error_msg = api_result.get("error", "未知错误")
logger.warning(t("log.zep_tools.m041", error_msg=error_msg))
result.summary = f"采访API调用失败{error_msg}。请检查OASIS模拟环境状态。"
return result
# Step 5: Parse the API response and build AgentInterview objects.
# Dual-platform shape: {"twitter_0": {...}, "reddit_0": {...}, "twitter_1": {...}, ...}
api_data = api_result.get("result", {})
results_dict = api_data.get("results", {}) if isinstance(api_data, dict) else {}
for i, agent_idx in enumerate(selected_indices):
agent = selected_agents[i]
agent_name = agent.get("realname", agent.get("username", f"Agent_{agent_idx}"))
agent_role = agent.get("profession", "未知")
agent_bio = agent.get("bio", "")
# Fetch this agent's responses from both platforms.
twitter_result = results_dict.get(f"twitter_{agent_idx}", {})
reddit_result = results_dict.get(f"reddit_{agent_idx}", {})
twitter_response = twitter_result.get("response", "")
reddit_response = reddit_result.get("response", "")
# Strip any tool-call JSON wrapper from the agent's reply.
twitter_response = self._clean_tool_call_response(twitter_response)
reddit_response = self._clean_tool_call_response(reddit_response)
# Always emit both platform headers, even when one platform is empty.
twitter_text = twitter_response if twitter_response else "(该平台未获得回复)"
reddit_text = reddit_response if reddit_response else "(该平台未获得回复)"
response_text = f"【Twitter平台回答】\n{twitter_text}\n\n【Reddit平台回答】\n{reddit_text}"
# Extract key quotes from the responses on both platforms.
import re
combined_responses = f"{twitter_response} {reddit_response}"
# Clean up the response text: drop markers, numbering, Markdown noise.
clean_text = re.sub(r'#{1,6}\s+', '', combined_responses)
clean_text = re.sub(r'\{[^}]*tool_name[^}]*\}', '', clean_text)
clean_text = re.sub(r'[*_`|>~\-]{2,}', '', clean_text)
clean_text = re.sub(r'问题\d+[:]\s*', '', clean_text)
clean_text = re.sub(r'【[^】]+】', '', clean_text)
# Primary strategy: extract complete sentences with substantive content.
sentences = re.split(r'[。!?]', clean_text)
meaningful = [
s.strip() for s in sentences
if 20 <= len(s.strip()) <= 150
and not re.match(r'^[\s\W,;:、]+', s.strip())
and not s.strip().startswith(('{', '问题'))
]
meaningful.sort(key=len, reverse=True)
key_quotes = [s + "" for s in meaningful[:3]]
# Fallback strategy: long text inside properly paired CJK quotation marks「」.
if not key_quotes:
paired = re.findall(r'\u201c([^\u201c\u201d]{15,100})\u201d', clean_text)
paired += re.findall(r'\u300c([^\u300c\u300d]{15,100})\u300d', clean_text)
key_quotes = [q for q in paired if not re.match(r'^[,;:、]', q)][:3]
interview = AgentInterview(
agent_name=agent_name,
agent_role=agent_role,
agent_bio=agent_bio[:1000], # Allow a longer bio than the default limit.
question=combined_prompt,
response=response_text,
key_quotes=key_quotes[:5]
)
result.interviews.append(interview)
result.interviewed_count = len(result.interviews)
except ValueError as e:
# Simulation environment is not running.
logger.warning(t("log.zep_tools.m042", e=e))
result.summary = f"采访失败:{str(e)}。模拟环境可能已关闭请确保OASIS环境正在运行。"
return result
except Exception as e:
logger.error(t("log.zep_tools.m043", e=e))
import traceback
logger.error(traceback.format_exc())
result.summary = f"采访过程发生错误:{str(e)}"
return result
# Step 6: Generate the interview summary.
if result.interviews:
result.summary = self._generate_interview_summary(
interviews=result.interviews,
interview_requirement=interview_requirement
)
logger.info(t("log.zep_tools.m044", result=result.interviewed_count))
return result
@staticmethod
def _clean_tool_call_response(response: str) -> str:
"""Strip the JSON tool-call wrapper from an agent reply and return the inner content."""
if not response or not response.strip().startswith('{'):
return response
text = response.strip()
if 'tool_name' not in text[:80]:
return response
import re as _re
try:
data = json.loads(text)
if isinstance(data, dict) and 'arguments' in data:
for key in ('content', 'text', 'body', 'message', 'reply'):
if key in data['arguments']:
return str(data['arguments'][key])
except (json.JSONDecodeError, KeyError, TypeError):
match = _re.search(r'"content"\s*:\s*"((?:[^"\\]|\\.)*)"', text)
if match:
return match.group(1).replace('\\n', '\n').replace('\\"', '"')
return response
def _load_agent_profiles(self, simulation_id: str) -> List[Dict[str, Any]]:
"""Load the agent persona file for a simulation."""
import os
import csv
# Build the persona file path.
sim_dir = os.path.join(
os.path.dirname(__file__),
f'../../uploads/simulations/{simulation_id}'
)
profiles = []
# Prefer the Reddit JSON profile if it exists.
reddit_profile_path = os.path.join(sim_dir, "reddit_profiles.json")
if os.path.exists(reddit_profile_path):
try:
with open(reddit_profile_path, 'r', encoding='utf-8') as f:
profiles = json.load(f)
logger.info(t("log.zep_tools.m045", len=len(profiles)))
return profiles
except Exception as e:
logger.warning(t("log.zep_tools.m046", e=e))
# Otherwise fall back to the Twitter CSV profile.
twitter_profile_path = os.path.join(sim_dir, "twitter_profiles.csv")
if os.path.exists(twitter_profile_path):
try:
with open(twitter_profile_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
# Convert each CSV row into the unified profile shape.
profiles.append({
"realname": row.get("name", ""),
"username": row.get("username", ""),
"bio": row.get("description", ""),
"persona": row.get("user_char", ""),
"profession": "未知"
})
logger.info(t("log.zep_tools.m047", len=len(profiles)))
return profiles
except Exception as e:
logger.warning(t("log.zep_tools.m048", e=e))
return profiles
def _select_agents_for_interview(
self,
profiles: List[Dict[str, Any]],
interview_requirement: str,
simulation_requirement: str,
max_agents: int
) -> tuple:
"""Use the LLM to choose which agents to interview.
Returns:
tuple: ``(selected_agents, selected_indices, reasoning)`` where
- ``selected_agents`` is the full profile list for the chosen agents,
- ``selected_indices`` is the list of indices to pass to the API,
- ``reasoning`` explains why those agents were chosen.
"""
# Build a compact summary list of every candidate agent.
agent_summaries = []
for i, profile in enumerate(profiles):
summary = {
"index": i,
"name": profile.get("realname", profile.get("username", f"Agent_{i}")),
"profession": profile.get("profession", "未知"),
"bio": profile.get("bio", "")[:200],
"interested_topics": profile.get("interested_topics", [])
}
agent_summaries.append(summary)
system_prompt = """你是一个专业的采访策划专家。你的任务是根据采访需求从模拟Agent列表中选择最适合采访的对象。
选择标准:
1. Agent的身份/职业与采访主题相关
2. Agent可能持有独特或有价值的观点
3. 选择多样化的视角(如:支持方、反对方、中立方、专业人士等)
4. 优先选择与事件直接相关的角色
返回JSON格式
{
"selected_indices": [选中Agent的索引列表],
"reasoning": "选择理由说明"
}"""
user_prompt = f"""采访需求:
{interview_requirement}
模拟背景:
{simulation_requirement if simulation_requirement else "未提供"}
可选择的Agent列表{len(agent_summaries)}个):
{json.dumps(agent_summaries, ensure_ascii=False, indent=2)}
请选择最多{max_agents}个最适合采访的Agent并说明选择理由。"""
try:
response = self.llm.chat_json(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3
)
selected_indices = response.get("selected_indices", [])[:max_agents]
reasoning = response.get("reasoning", "基于相关性自动选择")
# Pull the full profile for each chosen agent.
selected_agents = []
valid_indices = []
for idx in selected_indices:
if 0 <= idx < len(profiles):
selected_agents.append(profiles[idx])
valid_indices.append(idx)
return selected_agents, valid_indices, reasoning
except Exception as e:
logger.warning(t("log.zep_tools.m049", e=e))
# Fallback: pick the first N profiles.
selected = profiles[:max_agents]
indices = list(range(min(max_agents, len(profiles))))
return selected, indices, "使用默认选择策略"
def _generate_interview_questions(
self,
interview_requirement: str,
simulation_requirement: str,
selected_agents: List[Dict[str, Any]]
) -> List[str]:
"""Use the LLM to generate interview questions."""
agent_roles = [a.get("profession", "未知") for a in selected_agents]
system_prompt = """你是一个专业的记者/采访者。根据采访需求生成3-5个深度采访问题。
问题要求:
1. 开放性问题,鼓励详细回答
2. 针对不同角色可能有不同答案
3. 涵盖事实、观点、感受等多个维度
4. 语言自然,像真实采访一样
5. 每个问题控制在50字以内简洁明了
6. 直接提问,不要包含背景说明或前缀
返回JSON格式{"questions": ["问题1", "问题2", ...]}"""
user_prompt = f"""采访需求:{interview_requirement}
模拟背景:{simulation_requirement if simulation_requirement else "未提供"}
采访对象角色:{', '.join(agent_roles)}
请生成3-5个采访问题。"""
try:
response = self.llm.chat_json(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.5
)
return response.get("questions", [f"关于{interview_requirement},您有什么看法?"])
except Exception as e:
logger.warning(t("log.zep_tools.m050", e=e))
return [
f"关于{interview_requirement},您的观点是什么?",
"这件事对您或您所代表的群体有什么影响?",
"您认为应该如何解决或改进这个问题?"
]
def _generate_interview_summary(
self,
interviews: List[AgentInterview],
interview_requirement: str
) -> str:
"""Generate the interview summary."""
if not interviews:
return "未完成任何采访"
# Gather every interview excerpt.
interview_texts = []
for interview in interviews:
interview_texts.append(f"{interview.agent_name}{interview.agent_role})】\n{interview.response[:500]}")
system_prompt = """你是一个专业的新闻编辑。请根据多位受访者的回答,生成一份采访摘要。
摘要要求:
1. 提炼各方主要观点
2. 指出观点的共识和分歧
3. 突出有价值的引言
4. 客观中立,不偏袒任何一方
5. 控制在1000字内
格式约束(必须遵守):
- 使用纯文本段落,用空行分隔不同部分
- 不要使用Markdown标题如#、##、###
- 不要使用分割线(如---、***
- 引用受访者原话时使用中文引号「」
- 可以使用**加粗**标记关键词但不要使用其他Markdown语法"""
user_prompt = f"""采访主题:{interview_requirement}
采访内容:
{"".join(interview_texts)}
请生成采访摘要。"""
try:
summary = self.llm.chat(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3,
max_tokens=800
)
return summary
except Exception as e:
logger.warning(t("log.zep_tools.m051", e=e))
# Fallback: simple concatenation of agent names.
return f"共采访了{len(interviews)}位受访者,包括:" + "".join([i.agent_name for i in interviews])