From 75762ccc186d67f2f37e4e31756ba97fd42d8535 Mon Sep 17 00:00:00 2001 From: Christian Moellmann Date: Sat, 23 May 2026 12:16:21 +0200 Subject: [PATCH] feat(interviews): diversity subagent with Q-sort + 6 Likert axes + PCA/k-means typology Co-Authored-By: Claude Opus 4.7 (1M context) --- backend/app/services/interviews/diversity.py | 133 ++++++++++++++++++ backend/scripts/instruments/diversity_v1.yaml | 36 +++++ backend/tests/interviews/test_diversity.py | 48 +++++++ 3 files changed, 217 insertions(+) create mode 100644 backend/app/services/interviews/diversity.py create mode 100644 backend/scripts/instruments/diversity_v1.yaml create mode 100644 backend/tests/interviews/test_diversity.py diff --git a/backend/app/services/interviews/diversity.py b/backend/app/services/interviews/diversity.py new file mode 100644 index 00000000..96febcf5 --- /dev/null +++ b/backend/app/services/interviews/diversity.py @@ -0,0 +1,133 @@ +from __future__ import annotations +import json +from pathlib import Path +from typing import Optional +import numpy as np +from sklearn.decomposition import PCA +from sklearn.cluster import KMeans +import yaml +from app.models.interview import QSortResponse +from app.services.interviews.base import StakeholderInterviewer, PersonaRecord +from app.services.interviews.instrument_loader import InstrumentValidationError + + +class DiversitySubagent: + def __init__(self, llm, memory, instrument_path: Path, language: str = "de"): + self.instrument = self._load(Path(instrument_path)) + self.interviewer = StakeholderInterviewer(llm=llm, memory=memory, language=language) + self.language = language + + def _load(self, path: Path) -> dict: + with path.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) + if not isinstance(data, dict) or "statements" not in data or "distribution" not in data: + raise InstrumentValidationError(f"invalid diversity instrument: {path}") + if sum(data["distribution"]) != len(data["statements"]): + raise InstrumentValidationError("distribution sum must equal number of statements") + return data + + def _schema_hint(self) -> str: + return json.dumps({ + "placements": {s["statement_id"]: "" for s in self.instrument["statements"]}, + "likert_axes": {a["axis_id"]: "" for a in self.instrument["likert_axes"]}, + }, ensure_ascii=False) + + def _user_prompt(self) -> str: + dist = self.instrument["distribution"] + buckets = list(range(-3, 4)) + bucket_desc = ", ".join(f"{b}:{n}" for b, n in zip(buckets, dist)) + lines = [ + ("Ordnen Sie jede Aussage genau einer Box von -3 (lehne stark ab) bis +3 (stimme stark zu) zu. " + f"Die Verteilung ist erzwungen: {bucket_desc}.") if self.language == "de" else + ("Place every statement into exactly one box from -3 (strongly disagree) to +3 (strongly agree). " + f"The distribution is forced: {bucket_desc}."), + "", + "Statements:", + ] + for s in self.instrument["statements"]: + txt = s["de"] if self.language == "de" else s["en"] + lines.append(f"- [{s['statement_id']}] {txt}") + lines += ["", "Then rate each axis from 1 to 7:"] + for a in self.instrument["likert_axes"]: + txt = a["de"] if self.language == "de" else a["en"] + lines.append(f"- [{a['axis_id']}] {txt}") + return "\n".join(lines) + + def _validator(self, raw: dict) -> Optional[dict]: + if not isinstance(raw, dict): + return None + placements = raw.get("placements", {}) + axes = raw.get("likert_axes", {}) + statements = {s["statement_id"] for s in self.instrument["statements"]} + if set(placements.keys()) != statements: + return None + dist = self.instrument["distribution"] + target = {b: n for b, n in zip(range(-3, 4), dist)} + got: dict[int, int] = {} + for v in placements.values(): + if not isinstance(v, int) or not -3 <= v <= 3: + return None + got[v] = got.get(v, 0) + 1 + if got != target: + return None + for a in self.instrument["likert_axes"]: + v = axes.get(a["axis_id"]) + if not isinstance(v, int) or not 1 <= v <= 7: + return None + return raw + + def administer(self, persona: PersonaRecord) -> QSortResponse: + raw = self.interviewer.ask_in_character( + persona, + user_prompt=self._user_prompt(), + schema_hint=self._schema_hint(), + validate=self._validator, + ) + return QSortResponse( + agent_id=persona.agent_id, + placements={k: int(v) for k, v in raw["placements"].items()}, + likert_axes={k: int(v) for k, v in raw["likert_axes"].items()}, + ) + + +def _vectorize(r: QSortResponse, statements: list[str], axes: list[str]) -> np.ndarray: + return np.array( + [r.placements.get(s, 0) for s in statements] + + [r.likert_axes.get(a, 4) for a in axes], + dtype=float, + ) + + +def run_typology(responses: list[QSortResponse], n_clusters: int = 4) -> dict: + if not responses: + return {"n": 0, "clusters": [], "pca": {"components": [], "explained_variance": []}} + statements = sorted({k for r in responses for k in r.placements}) + axes = sorted({k for r in responses for k in r.likert_axes}) + X = np.vstack([_vectorize(r, statements, axes) for r in responses]) + n_clusters = min(n_clusters, len(responses)) + pca = PCA(n_components=min(5, X.shape[1], X.shape[0])) + pcs = pca.fit_transform(X) + km = KMeans(n_clusters=n_clusters, n_init=10, random_state=0) + labels = km.fit_predict(X) + clusters = [] + for c in range(n_clusters): + members = [responses[i].agent_id for i in range(len(responses)) if labels[i] == c] + centroid = km.cluster_centers_[c] + clusters.append({ + "cluster_id": int(c), + "n": len(members), + "agent_ids": members, + "top_loadings": { + statements[i] if i < len(statements) else axes[i - len(statements)]: float(centroid[i]) + for i in np.argsort(np.abs(centroid))[::-1][:8].tolist() + }, + }) + return { + "n": len(responses), + "clusters": clusters, + "pca": { + "components": pcs.tolist(), + "explained_variance": pca.explained_variance_ratio_.tolist(), + "agent_ids": [r.agent_id for r in responses], + }, + } diff --git a/backend/scripts/instruments/diversity_v1.yaml b/backend/scripts/instruments/diversity_v1.yaml new file mode 100644 index 00000000..7c47cd96 --- /dev/null +++ b/backend/scripts/instruments/diversity_v1.yaml @@ -0,0 +1,36 @@ +name: diversity_v1 +version: "1.0" +language_default: de +distribution: [2, 3, 4, 6, 4, 3, 2] # buckets from -3 to +3, total 24 +statements: + - {statement_id: st_01, de: "Die Ostsee gehört den Fischern, die hier seit Generationen leben.", en: "The Baltic belongs to fishers who have lived here for generations."} + - {statement_id: st_02, de: "MSC-Zertifizierung schützt vor allem große Konzerne.", en: "MSC certification mainly protects large corporations."} + - {statement_id: st_03, de: "Wissenschaftliche Quoten sind die einzige Grundlage für Politik.", en: "Scientific quotas are the only legitimate basis for policy."} + - {statement_id: st_04, de: "Aquakultur kann Ostseefischerei ersetzen.", en: "Aquaculture can replace Baltic fisheries."} + - {statement_id: st_05, de: "Sportfischer schaden den Beständen mehr als die Berufsfischer.", en: "Recreational anglers harm stocks more than commercial fishers."} + - {statement_id: st_06, de: "Die EU-Fischereipolitik kennt die Ostsee nicht.", en: "EU fisheries policy doesn't understand the Baltic."} + - {statement_id: st_07, de: "Großtechnische Fischerei ist effizienter und damit nachhaltiger.", en: "Industrial fisheries are more efficient and therefore more sustainable."} + - {statement_id: st_08, de: "Wer Fisch isst, sollte mehr dafür bezahlen.", en: "Those who eat fish should pay more for it."} + - {statement_id: st_09, de: "Die Kleinfischerei muss subventioniert werden.", en: "Small-scale fisheries must be subsidised."} + - {statement_id: st_10, de: "Marine Schutzgebiete sind reine Symbolpolitik.", en: "Marine protected areas are mere symbolism."} + - {statement_id: st_11, de: "Russlands Krieg ändert alles in der Ostsee.", en: "Russia's war changes everything in the Baltic."} + - {statement_id: st_12, de: "Nur drastische Reduktion der Fangmengen rettet die Bestände.", en: "Only drastic catch reductions will save the stocks."} + - {statement_id: st_13, de: "NGOs übertreiben die Krise systematisch.", en: "NGOs systematically exaggerate the crisis."} + - {statement_id: st_14, de: "Klimawandel ist das eigentliche Problem, nicht die Fischerei.", en: "Climate change is the real problem, not fisheries."} + - {statement_id: st_15, de: "Tradition zählt mehr als kurzfristige Bestandszahlen.", en: "Tradition matters more than short-term stock numbers."} + - {statement_id: st_16, de: "Verbraucher entscheiden über die Zukunft des Fisches.", en: "Consumers decide the future of fish."} + - {statement_id: st_17, de: "Ohne Generalstreik der Fischer ändert sich nichts.", en: "Without a fishers' general strike, nothing will change."} + - {statement_id: st_18, de: "Die Bundesregierung sollte Kutter aufkaufen und stilllegen.", en: "The federal government should buy out and decommission boats."} + - {statement_id: st_19, de: "Die Dorschkrise ist Folge gescheiterter Politik.", en: "The cod crisis is the result of policy failure."} + - {statement_id: st_20, de: "Ostsee-Aquakultur ist ökologisch problematisch.", en: "Baltic aquaculture is ecologically problematic."} + - {statement_id: st_21, de: "Junge Menschen werden keinen Fischereibetrieb mehr übernehmen.", en: "Young people will no longer take over fishing businesses."} + - {statement_id: st_22, de: "Markt regelt sich selbst, auch beim Fisch.", en: "The market regulates itself, also for fish."} + - {statement_id: st_23, de: "Lokale Genossenschaften sind die Lösung.", en: "Local cooperatives are the solution."} + - {statement_id: st_24, de: "In 20 Jahren gibt es keine deutsche Ostseefischerei mehr.", en: "In 20 years there will be no German Baltic fisheries left."} +likert_axes: + - {axis_id: ax_pres_extr, scale: 7, de: "Bewahrung (1) vs. Nutzung (7)", en: "Preservation (1) vs. Extraction (7)"} + - {axis_id: ax_loc_eu, scale: 7, de: "Lokal (1) vs. EU-zentral (7)", en: "Local (1) vs. EU-central (7)"} + - {axis_id: ax_sci_trad, scale: 7, de: "Wissenschaft (1) vs. Tradition (7)", en: "Science-led (1) vs. Tradition-led (7)"} + - {axis_id: ax_ind_col, scale: 7, de: "Individuum (1) vs. Kollektiv (7)", en: "Individual (1) vs. Collective (7)"} + - {axis_id: ax_short_long,scale: 7, de: "Kurzfristig (1) vs. Langfristig (7)", en: "Short-term (1) vs. Long-term (7)"} + - {axis_id: ax_mkt_reg, scale: 7, de: "Markt (1) vs. Regulierung (7)", en: "Market (1) vs. Regulation (7)"} diff --git a/backend/tests/interviews/test_diversity.py b/backend/tests/interviews/test_diversity.py new file mode 100644 index 00000000..7650fac2 --- /dev/null +++ b/backend/tests/interviews/test_diversity.py @@ -0,0 +1,48 @@ +from pathlib import Path +import numpy as np +from app.services.interviews.base import PersonaRecord, MemoryDigest +from app.services.interviews.diversity import ( + DiversitySubagent, run_typology, +) + +class _Mem: + def get_digest(self, agent_id, max_chars=2000): + return MemoryDigest(text="x", available=True) + +class _CannedLLM: + def chat_json(self, messages, temperature=0.0, max_tokens=None, **kw): + # Place all 24 statements into legal buckets per the forced distribution + placements = {} + buckets = [-3]*2 + [-2]*3 + [-1]*4 + [0]*6 + [1]*4 + [2]*3 + [3]*2 + for i in range(24): + placements[f"st_{i+1:02d}"] = buckets[i] + return { + "placements": placements, + "likert_axes": {"ax_pres_extr": 5, "ax_loc_eu": 3, "ax_sci_trad": 4, + "ax_ind_col": 4, "ax_short_long": 5, "ax_mkt_reg": 3}, + } + +INSTRUMENT = Path(__file__).resolve().parents[2] / "scripts" / "instruments" / "diversity_v1.yaml" + +def test_diversity_administer(): + sub = DiversitySubagent(llm=_CannedLLM(), memory=_Mem(), instrument_path=INSTRUMENT) + persona = PersonaRecord(agent_id=1, name="A", persona="p") + resp = sub.administer(persona) + assert len(resp.placements) == 24 + assert set(resp.likert_axes.keys()) == { + "ax_pres_extr","ax_loc_eu","ax_sci_trad","ax_ind_col","ax_short_long","ax_mkt_reg" + } + +def test_typology_runs_pca_kmeans(): + from app.models.interview import QSortResponse + rng = np.random.default_rng(42) + responses = [] + for aid in range(20): + placements = {f"st_{i+1:02d}": int(rng.integers(-3, 4)) for i in range(24)} + axes = {f"ax_{j}": int(rng.integers(1, 8)) for j in range(6)} + responses.append(QSortResponse(agent_id=aid, placements=placements, likert_axes=axes)) + result = run_typology(responses, n_clusters=3) + assert "clusters" in result + assert len(result["clusters"]) == 3 + assert "pca" in result + assert len(result["pca"]["components"]) >= 2