feat(interviews): StakeholderInterviewer base with in-character prompting and schema retry
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
eb3c3629c1
commit
289a0cff56
|
|
@ -0,0 +1,72 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable, Optional, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PersonaRecord:
|
||||||
|
agent_id: int
|
||||||
|
name: str
|
||||||
|
persona: str
|
||||||
|
profession: Optional[str] = None
|
||||||
|
bio: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryDigest:
|
||||||
|
text: str
|
||||||
|
available: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProvider(Protocol):
|
||||||
|
def get_digest(self, agent_id: int, max_chars: int = 2000) -> MemoryDigest: ...
|
||||||
|
|
||||||
|
|
||||||
|
class StakeholderInterviewer:
|
||||||
|
def __init__(self, llm, memory: MemoryProvider, language: str = "de"):
|
||||||
|
self.llm = llm
|
||||||
|
self.memory = memory
|
||||||
|
self.language = language
|
||||||
|
|
||||||
|
def _system_prompt(self, persona: PersonaRecord, digest: MemoryDigest, schema_hint: str) -> str:
|
||||||
|
memory_block = digest.text if digest.available else "[no simulation memory available]"
|
||||||
|
lang_note = "Antworte ausschließlich auf Deutsch." if self.language == "de" else "Answer in English."
|
||||||
|
return (
|
||||||
|
f"You are {persona.name}. {persona.persona}\n\n"
|
||||||
|
"You are answering a survey about the future of German fisheries. "
|
||||||
|
"Answer strictly in character based on your background, values, and what you experienced "
|
||||||
|
"during the simulated social media discourse summarised below.\n\n"
|
||||||
|
f"--- simulation memory digest ---\n{memory_block}\n--- end ---\n\n"
|
||||||
|
f"{lang_note} Return JSON ONLY matching this schema:\n{schema_hint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def ask_in_character(
|
||||||
|
self,
|
||||||
|
persona: PersonaRecord,
|
||||||
|
user_prompt: str,
|
||||||
|
schema_hint: str,
|
||||||
|
*,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
validate: Optional[Callable[[dict], Optional[dict]]] = None,
|
||||||
|
) -> dict:
|
||||||
|
digest = self.memory.get_digest(persona.agent_id)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": self._system_prompt(persona, digest, schema_hint)},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
out = self.llm.chat_json(messages=messages, temperature=temperature, max_tokens=max_tokens)
|
||||||
|
if validate is not None:
|
||||||
|
validated = validate(out)
|
||||||
|
if validated is not None:
|
||||||
|
return validated
|
||||||
|
messages.append({"role": "assistant", "content": str(out)})
|
||||||
|
messages.append({"role": "user", "content":
|
||||||
|
"Your previous response did not match the required schema. "
|
||||||
|
f"Return ONLY valid JSON matching: {schema_hint}"})
|
||||||
|
out = self.llm.chat_json(messages=messages, temperature=0.0, max_tokens=max_tokens)
|
||||||
|
validated = validate(out)
|
||||||
|
if validated is None:
|
||||||
|
raise ValueError(f"agent {persona.agent_id}: schema violation after retry")
|
||||||
|
return validated
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from app.services.interviews.base import StakeholderInterviewer, MemoryDigest, PersonaRecord
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self, responses):
|
||||||
|
self.responses = list(responses)
|
||||||
|
self.calls = []
|
||||||
|
def chat_json(self, messages, temperature=0.0, max_tokens=None, **kw):
|
||||||
|
self.calls.append(messages)
|
||||||
|
return self.responses.pop(0)
|
||||||
|
|
||||||
|
class _FakeMemory:
|
||||||
|
def get_digest(self, agent_id, max_chars=2000):
|
||||||
|
return MemoryDigest(text=f"digest-for-{agent_id}", available=True)
|
||||||
|
|
||||||
|
def test_in_character_prompt_includes_persona_and_memory():
|
||||||
|
llm = _FakeLLM([{"x": 1}])
|
||||||
|
mem = _FakeMemory()
|
||||||
|
interviewer = StakeholderInterviewer(llm=llm, memory=mem)
|
||||||
|
persona = PersonaRecord(agent_id=7, name="A", persona="I am a small-scale Baltic fisher.")
|
||||||
|
out = interviewer.ask_in_character(persona, user_prompt="Q?", schema_hint="{...}")
|
||||||
|
assert out == {"x": 1}
|
||||||
|
sys_msg = llm.calls[0][0]["content"]
|
||||||
|
assert "small-scale Baltic fisher" in sys_msg
|
||||||
|
assert "digest-for-7" in sys_msg
|
||||||
|
|
||||||
|
def test_schema_retry_on_first_failure():
|
||||||
|
bad_then_good = [{}, {"responses": {"a": 3}}]
|
||||||
|
llm = _FakeLLM(bad_then_good)
|
||||||
|
mem = _FakeMemory()
|
||||||
|
interviewer = StakeholderInterviewer(llm=llm, memory=mem)
|
||||||
|
def validator(d):
|
||||||
|
return d if "responses" in d else None
|
||||||
|
persona = PersonaRecord(agent_id=1, name="A", persona="p")
|
||||||
|
out = interviewer.ask_in_character(persona, user_prompt="Q?", schema_hint="x", validate=validator)
|
||||||
|
assert out == {"responses": {"a": 3}}
|
||||||
|
assert len(llm.calls) == 2
|
||||||
|
|
||||||
|
def test_two_failures_raise():
|
||||||
|
llm = _FakeLLM([{}, {}])
|
||||||
|
mem = _FakeMemory()
|
||||||
|
interviewer = StakeholderInterviewer(llm=llm, memory=mem)
|
||||||
|
persona = PersonaRecord(agent_id=1, name="A", persona="p")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
interviewer.ask_in_character(persona, user_prompt="Q?", schema_hint="x",
|
||||||
|
validate=lambda d: d if "responses" in d else None)
|
||||||
Loading…
Reference in New Issue