feat(interviews): diversity subagent with Q-sort + 6 Likert axes + PCA/k-means typology
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
0fcb815cde
commit
75762ccc18
|
|
@ -0,0 +1,133 @@
|
|||
from __future__ import annotations
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.cluster import KMeans
|
||||
import yaml
|
||||
from app.models.interview import QSortResponse
|
||||
from app.services.interviews.base import StakeholderInterviewer, PersonaRecord
|
||||
from app.services.interviews.instrument_loader import InstrumentValidationError
|
||||
|
||||
|
||||
class DiversitySubagent:
|
||||
def __init__(self, llm, memory, instrument_path: Path, language: str = "de"):
|
||||
self.instrument = self._load(Path(instrument_path))
|
||||
self.interviewer = StakeholderInterviewer(llm=llm, memory=memory, language=language)
|
||||
self.language = language
|
||||
|
||||
def _load(self, path: Path) -> dict:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict) or "statements" not in data or "distribution" not in data:
|
||||
raise InstrumentValidationError(f"invalid diversity instrument: {path}")
|
||||
if sum(data["distribution"]) != len(data["statements"]):
|
||||
raise InstrumentValidationError("distribution sum must equal number of statements")
|
||||
return data
|
||||
|
||||
def _schema_hint(self) -> str:
|
||||
return json.dumps({
|
||||
"placements": {s["statement_id"]: "<int in -3..+3>" for s in self.instrument["statements"]},
|
||||
"likert_axes": {a["axis_id"]: "<int 1-7>" for a in self.instrument["likert_axes"]},
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def _user_prompt(self) -> str:
|
||||
dist = self.instrument["distribution"]
|
||||
buckets = list(range(-3, 4))
|
||||
bucket_desc = ", ".join(f"{b}:{n}" for b, n in zip(buckets, dist))
|
||||
lines = [
|
||||
("Ordnen Sie jede Aussage genau einer Box von -3 (lehne stark ab) bis +3 (stimme stark zu) zu. "
|
||||
f"Die Verteilung ist erzwungen: {bucket_desc}.") if self.language == "de" else
|
||||
("Place every statement into exactly one box from -3 (strongly disagree) to +3 (strongly agree). "
|
||||
f"The distribution is forced: {bucket_desc}."),
|
||||
"",
|
||||
"Statements:",
|
||||
]
|
||||
for s in self.instrument["statements"]:
|
||||
txt = s["de"] if self.language == "de" else s["en"]
|
||||
lines.append(f"- [{s['statement_id']}] {txt}")
|
||||
lines += ["", "Then rate each axis from 1 to 7:"]
|
||||
for a in self.instrument["likert_axes"]:
|
||||
txt = a["de"] if self.language == "de" else a["en"]
|
||||
lines.append(f"- [{a['axis_id']}] {txt}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _validator(self, raw: dict) -> Optional[dict]:
|
||||
if not isinstance(raw, dict):
|
||||
return None
|
||||
placements = raw.get("placements", {})
|
||||
axes = raw.get("likert_axes", {})
|
||||
statements = {s["statement_id"] for s in self.instrument["statements"]}
|
||||
if set(placements.keys()) != statements:
|
||||
return None
|
||||
dist = self.instrument["distribution"]
|
||||
target = {b: n for b, n in zip(range(-3, 4), dist)}
|
||||
got: dict[int, int] = {}
|
||||
for v in placements.values():
|
||||
if not isinstance(v, int) or not -3 <= v <= 3:
|
||||
return None
|
||||
got[v] = got.get(v, 0) + 1
|
||||
if got != target:
|
||||
return None
|
||||
for a in self.instrument["likert_axes"]:
|
||||
v = axes.get(a["axis_id"])
|
||||
if not isinstance(v, int) or not 1 <= v <= 7:
|
||||
return None
|
||||
return raw
|
||||
|
||||
def administer(self, persona: PersonaRecord) -> QSortResponse:
|
||||
raw = self.interviewer.ask_in_character(
|
||||
persona,
|
||||
user_prompt=self._user_prompt(),
|
||||
schema_hint=self._schema_hint(),
|
||||
validate=self._validator,
|
||||
)
|
||||
return QSortResponse(
|
||||
agent_id=persona.agent_id,
|
||||
placements={k: int(v) for k, v in raw["placements"].items()},
|
||||
likert_axes={k: int(v) for k, v in raw["likert_axes"].items()},
|
||||
)
|
||||
|
||||
|
||||
def _vectorize(r: QSortResponse, statements: list[str], axes: list[str]) -> np.ndarray:
|
||||
return np.array(
|
||||
[r.placements.get(s, 0) for s in statements] +
|
||||
[r.likert_axes.get(a, 4) for a in axes],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
|
||||
def run_typology(responses: list[QSortResponse], n_clusters: int = 4) -> dict:
|
||||
if not responses:
|
||||
return {"n": 0, "clusters": [], "pca": {"components": [], "explained_variance": []}}
|
||||
statements = sorted({k for r in responses for k in r.placements})
|
||||
axes = sorted({k for r in responses for k in r.likert_axes})
|
||||
X = np.vstack([_vectorize(r, statements, axes) for r in responses])
|
||||
n_clusters = min(n_clusters, len(responses))
|
||||
pca = PCA(n_components=min(5, X.shape[1], X.shape[0]))
|
||||
pcs = pca.fit_transform(X)
|
||||
km = KMeans(n_clusters=n_clusters, n_init=10, random_state=0)
|
||||
labels = km.fit_predict(X)
|
||||
clusters = []
|
||||
for c in range(n_clusters):
|
||||
members = [responses[i].agent_id for i in range(len(responses)) if labels[i] == c]
|
||||
centroid = km.cluster_centers_[c]
|
||||
clusters.append({
|
||||
"cluster_id": int(c),
|
||||
"n": len(members),
|
||||
"agent_ids": members,
|
||||
"top_loadings": {
|
||||
statements[i] if i < len(statements) else axes[i - len(statements)]: float(centroid[i])
|
||||
for i in np.argsort(np.abs(centroid))[::-1][:8].tolist()
|
||||
},
|
||||
})
|
||||
return {
|
||||
"n": len(responses),
|
||||
"clusters": clusters,
|
||||
"pca": {
|
||||
"components": pcs.tolist(),
|
||||
"explained_variance": pca.explained_variance_ratio_.tolist(),
|
||||
"agent_ids": [r.agent_id for r in responses],
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
name: diversity_v1
|
||||
version: "1.0"
|
||||
language_default: de
|
||||
distribution: [2, 3, 4, 6, 4, 3, 2] # buckets from -3 to +3, total 24
|
||||
statements:
|
||||
- {statement_id: st_01, de: "Die Ostsee gehört den Fischern, die hier seit Generationen leben.", en: "The Baltic belongs to fishers who have lived here for generations."}
|
||||
- {statement_id: st_02, de: "MSC-Zertifizierung schützt vor allem große Konzerne.", en: "MSC certification mainly protects large corporations."}
|
||||
- {statement_id: st_03, de: "Wissenschaftliche Quoten sind die einzige Grundlage für Politik.", en: "Scientific quotas are the only legitimate basis for policy."}
|
||||
- {statement_id: st_04, de: "Aquakultur kann Ostseefischerei ersetzen.", en: "Aquaculture can replace Baltic fisheries."}
|
||||
- {statement_id: st_05, de: "Sportfischer schaden den Beständen mehr als die Berufsfischer.", en: "Recreational anglers harm stocks more than commercial fishers."}
|
||||
- {statement_id: st_06, de: "Die EU-Fischereipolitik kennt die Ostsee nicht.", en: "EU fisheries policy doesn't understand the Baltic."}
|
||||
- {statement_id: st_07, de: "Großtechnische Fischerei ist effizienter und damit nachhaltiger.", en: "Industrial fisheries are more efficient and therefore more sustainable."}
|
||||
- {statement_id: st_08, de: "Wer Fisch isst, sollte mehr dafür bezahlen.", en: "Those who eat fish should pay more for it."}
|
||||
- {statement_id: st_09, de: "Die Kleinfischerei muss subventioniert werden.", en: "Small-scale fisheries must be subsidised."}
|
||||
- {statement_id: st_10, de: "Marine Schutzgebiete sind reine Symbolpolitik.", en: "Marine protected areas are mere symbolism."}
|
||||
- {statement_id: st_11, de: "Russlands Krieg ändert alles in der Ostsee.", en: "Russia's war changes everything in the Baltic."}
|
||||
- {statement_id: st_12, de: "Nur drastische Reduktion der Fangmengen rettet die Bestände.", en: "Only drastic catch reductions will save the stocks."}
|
||||
- {statement_id: st_13, de: "NGOs übertreiben die Krise systematisch.", en: "NGOs systematically exaggerate the crisis."}
|
||||
- {statement_id: st_14, de: "Klimawandel ist das eigentliche Problem, nicht die Fischerei.", en: "Climate change is the real problem, not fisheries."}
|
||||
- {statement_id: st_15, de: "Tradition zählt mehr als kurzfristige Bestandszahlen.", en: "Tradition matters more than short-term stock numbers."}
|
||||
- {statement_id: st_16, de: "Verbraucher entscheiden über die Zukunft des Fisches.", en: "Consumers decide the future of fish."}
|
||||
- {statement_id: st_17, de: "Ohne Generalstreik der Fischer ändert sich nichts.", en: "Without a fishers' general strike, nothing will change."}
|
||||
- {statement_id: st_18, de: "Die Bundesregierung sollte Kutter aufkaufen und stilllegen.", en: "The federal government should buy out and decommission boats."}
|
||||
- {statement_id: st_19, de: "Die Dorschkrise ist Folge gescheiterter Politik.", en: "The cod crisis is the result of policy failure."}
|
||||
- {statement_id: st_20, de: "Ostsee-Aquakultur ist ökologisch problematisch.", en: "Baltic aquaculture is ecologically problematic."}
|
||||
- {statement_id: st_21, de: "Junge Menschen werden keinen Fischereibetrieb mehr übernehmen.", en: "Young people will no longer take over fishing businesses."}
|
||||
- {statement_id: st_22, de: "Markt regelt sich selbst, auch beim Fisch.", en: "The market regulates itself, also for fish."}
|
||||
- {statement_id: st_23, de: "Lokale Genossenschaften sind die Lösung.", en: "Local cooperatives are the solution."}
|
||||
- {statement_id: st_24, de: "In 20 Jahren gibt es keine deutsche Ostseefischerei mehr.", en: "In 20 years there will be no German Baltic fisheries left."}
|
||||
likert_axes:
|
||||
- {axis_id: ax_pres_extr, scale: 7, de: "Bewahrung (1) vs. Nutzung (7)", en: "Preservation (1) vs. Extraction (7)"}
|
||||
- {axis_id: ax_loc_eu, scale: 7, de: "Lokal (1) vs. EU-zentral (7)", en: "Local (1) vs. EU-central (7)"}
|
||||
- {axis_id: ax_sci_trad, scale: 7, de: "Wissenschaft (1) vs. Tradition (7)", en: "Science-led (1) vs. Tradition-led (7)"}
|
||||
- {axis_id: ax_ind_col, scale: 7, de: "Individuum (1) vs. Kollektiv (7)", en: "Individual (1) vs. Collective (7)"}
|
||||
- {axis_id: ax_short_long,scale: 7, de: "Kurzfristig (1) vs. Langfristig (7)", en: "Short-term (1) vs. Long-term (7)"}
|
||||
- {axis_id: ax_mkt_reg, scale: 7, de: "Markt (1) vs. Regulierung (7)", en: "Market (1) vs. Regulation (7)"}
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
from pathlib import Path
|
||||
import numpy as np
|
||||
from app.services.interviews.base import PersonaRecord, MemoryDigest
|
||||
from app.services.interviews.diversity import (
|
||||
DiversitySubagent, run_typology,
|
||||
)
|
||||
|
||||
class _Mem:
|
||||
def get_digest(self, agent_id, max_chars=2000):
|
||||
return MemoryDigest(text="x", available=True)
|
||||
|
||||
class _CannedLLM:
|
||||
def chat_json(self, messages, temperature=0.0, max_tokens=None, **kw):
|
||||
# Place all 24 statements into legal buckets per the forced distribution
|
||||
placements = {}
|
||||
buckets = [-3]*2 + [-2]*3 + [-1]*4 + [0]*6 + [1]*4 + [2]*3 + [3]*2
|
||||
for i in range(24):
|
||||
placements[f"st_{i+1:02d}"] = buckets[i]
|
||||
return {
|
||||
"placements": placements,
|
||||
"likert_axes": {"ax_pres_extr": 5, "ax_loc_eu": 3, "ax_sci_trad": 4,
|
||||
"ax_ind_col": 4, "ax_short_long": 5, "ax_mkt_reg": 3},
|
||||
}
|
||||
|
||||
INSTRUMENT = Path(__file__).resolve().parents[2] / "scripts" / "instruments" / "diversity_v1.yaml"
|
||||
|
||||
def test_diversity_administer():
|
||||
sub = DiversitySubagent(llm=_CannedLLM(), memory=_Mem(), instrument_path=INSTRUMENT)
|
||||
persona = PersonaRecord(agent_id=1, name="A", persona="p")
|
||||
resp = sub.administer(persona)
|
||||
assert len(resp.placements) == 24
|
||||
assert set(resp.likert_axes.keys()) == {
|
||||
"ax_pres_extr","ax_loc_eu","ax_sci_trad","ax_ind_col","ax_short_long","ax_mkt_reg"
|
||||
}
|
||||
|
||||
def test_typology_runs_pca_kmeans():
|
||||
from app.models.interview import QSortResponse
|
||||
rng = np.random.default_rng(42)
|
||||
responses = []
|
||||
for aid in range(20):
|
||||
placements = {f"st_{i+1:02d}": int(rng.integers(-3, 4)) for i in range(24)}
|
||||
axes = {f"ax_{j}": int(rng.integers(1, 8)) for j in range(6)}
|
||||
responses.append(QSortResponse(agent_id=aid, placements=placements, likert_axes=axes))
|
||||
result = run_typology(responses, n_clusters=3)
|
||||
assert "clusters" in result
|
||||
assert len(result["clusters"]) == 3
|
||||
assert "pca" in result
|
||||
assert len(result["pca"]["components"]) >= 2
|
||||
Loading…
Reference in New Issue