"""
verify_results.py - public verification script for the Yardstick grounded-pipeline audit.

Recomputes every headline number in WHITEPAPER.pdf from LEDGER.json in this folder
and prints each computed value next to the value the paper states, with PASS or FAIL.

Requirements: Python 3.9+, standard library only. No network access, no model calls.
Usage:       python verify_results.py
Exit code:   0 if every check passes, 1 otherwise.

The ledger has one row per (vendor x model-combination). Column definitions are in
DATA-DICTIONARY.md. Vendors are anonymized as cohort-g<gold>; model names are real.
"""
import json, math, pathlib, random, statistics, sys

if sys.platform == "win32":
    sys.stdout.reconfigure(encoding="utf-8", errors="replace")

HERE = pathlib.Path(__file__).parent
L = json.load(open(HERE / "LEDGER.json", encoding="utf-8"))

FAILURES = []

def short(m):
    return m.split("/")[-1]

def check(label, got, want, tol):
    ok = abs(got - want) <= tol
    if not ok:
        FAILURES.append(label)
    print(f"  [{'PASS' if ok else 'FAIL'}] {label}: computed {got:.4f}  paper {want}  (tol {tol})")

# ---- group rows by (01 / 02 / 03) configuration -----------------------------
configs = {}
for r in L:
    key = f"{short(r['model_01'])} / {short(r['model_02'])} / {short(r['model_03'])}"
    configs.setdefault(key, []).append(r)

def find(m1, m2, m3):
    """Exact match on the short model names (substring matching is unsafe:
    'gpt-5' is a prefix of 'gpt-5-nano' and 'gpt-5-mini')."""
    for k, rows in configs.items():
        a, b, c = [s.strip() for s in k.split(" / ")]
        if (a, b, c) == (m1, m2, m3):
            return rows
    return None

MS = "mistral-small-3.2-24b-instruct"

def stat(rows):
    v = len(rows)
    c = sum(x["claims"] for x in rows)
    h = sum(x["hallucinations"] for x in rows)
    m = sum(x["mistags"] for x in rows)
    redo = sum(1 for x in rows if x["hallucinations"] > 0)
    base = statistics.mean(x["cost_usd"] for x in rows)
    deltas = [x["score_delta"] for x in rows if x["score_delta"] is not None]
    med = statistics.median(deltas) if deltas else float("nan")
    return dict(v=v, c=c, h=h, m=m, rate=100 * h / c, mrate=100 * m / c,
                rerun=redo / v, base=base, sdelta=med)

OVERWRITE = 0.49  # modeled cost of redoing a hallucinated dossier on the flagship (paper 3.6)

def eff(s):
    return s["base"] + s["rerun"] * OVERWRITE

# ---- 0. ledger shape ---------------------------------------------------------
print("=== 0. LEDGER SHAPE ===")
vendors = {r["vendor_anon"] for r in L}
print(f"  rows={len(L)} (paper: 454)   configurations={len(configs)} (paper: 31)   vendors={len(vendors)} (paper: 30)")
check("ledger rows", len(L), 454, 0)
check("distinct configurations", len(configs), 31, 0)
check("distinct vendors", len(vendors), 30, 0)

# ---- 1. headline configurations (paper 5.1 / 5.3) ----------------------------
print("\n=== 1. HEADLINE CONFIGURATIONS (paper 5.1 / 5.3) ===")
ctrl = find("claude-sonnet-4-6", "claude-opus-4-8", "claude-opus-4-8")
win  = find("gpt-5-mini", "grok-4.3", MS)
qwen = find("gpt-5-mini", "qwen3-235b-a22b-2507", MS)
mist = find("gpt-5-mini", MS, MS)

for name, rows, claims_w, h_w, rate_w in [
    ("CONTROL  sonnet/opus/opus",         ctrl, 1376, 6, 0.44),
    ("WINNER   gpt5mini/grok/mistral",    win,   804, 3, 0.37),
    ("ALT-mist gpt5mini/mistral/mistral", mist,  846, 6, 0.71),
    ("ALT-qwen gpt5mini/qwen/mistral",    qwen,  759, 9, 1.19),
]:
    if rows is None:
        FAILURES.append(name); print(f"  [FAIL] {name}: configuration not found"); continue
    s = stat(rows)
    print(f"\n {name}: vendors={s['v']} claims={s['c']} halluc={s['h']} mistags={s['m']}"
          f" base=${s['base']:.4f} eff=${eff(s):.4f} median-score-delta={s['sdelta']}")
    check("claims", s["c"], claims_w, 0)
    check("hallucinations", s["h"], h_w, 0)
    check("hallucination rate %", s["rate"], rate_w, 0.02)

sc, sw = stat(ctrl), stat(win)
check("control mistags", sc["m"], 131, 0)
check("winner mistags", sw["m"], 66, 0)
check("control mistag rate %", sc["mrate"], 9.5, 0.1)
check("winner mistag rate %", sw["mrate"], 8.2, 0.1)
check("control median score delta", sc["sdelta"], -4.5, 0.01)
check("winner median score delta", sw["sdelta"], -3.625, 0.01)

# ---- 2. costs and ratios (paper 5.5) -----------------------------------------
print("\n=== 2. COSTS AND RATIOS (paper 5.5) ===")
check("control base $/vendor", sc["base"], 0.4263, 0.002)
check("winner base $/vendor", sw["base"], 0.0373, 0.002)
check("control effective $/vendor", eff(sc), 0.5080, 0.003)
check("winner effective $/vendor", eff(sw), 0.0863, 0.003)
check("base ratio (x cheaper)", sc["base"] / sw["base"], 11.4, 0.4)
check("effective ratio (x cheaper)", eff(sc) / eff(sw), 5.9, 0.3)
new_spend = sum(r["cost_usd"] for r in L if r["arm"] == "new")
print(f"  [INFO] summed new-arm cost_usd = ${new_spend:.2f} "
      f"(paper: $14.73 full run + $4.6 extension; control + retrieval are subscription-marginal)")

# ---- 3. containment (paper 5.2 / 6.1) -----------------------------------------
print("\n=== 3. CONTAINMENT BEHIND THE GATE (paper 5.2 / 6.1) ===")
win_blocked = sum(1 for r in win if r["gate_blocked"])
win_escapes = sum(r["postgate_escapes"] for r in win)
win_h_pub   = sum(r["hallucinations"] for r in win if not r["gate_blocked"])
ctl_blocked = sum(1 for r in ctrl if r["gate_blocked"])
ctl_escapes = sum(r["postgate_escapes"] for r in ctrl)
ctl_h_pub   = sum(r["hallucinations"] for r in ctrl if not r["gate_blocked"])
print(f"  winner : {win_blocked}/30 dossiers gate-blocked, {win_escapes} escapes would publish,"
      f" {win_h_pub} hallucinations would publish")
print(f"  control: {ctl_blocked}/30 dossiers gate-blocked, {ctl_escapes} escapes would publish,"
      f" {ctl_h_pub} hallucination(s) would publish")
check("winner dossiers blocked", win_blocked, 30, 0)
check("winner published escapes", win_escapes, 0, 0)
check("winner published hallucinations", win_h_pub, 0, 0)
check("control dossiers blocked", ctl_blocked, 22, 0)
check("control published escapes", ctl_escapes, 26, 0)
check("control published hallucinations", ctl_h_pub, 1, 0)

# ---- 4. step-02 aggregates + reliability, full-pool side (paper 5.2 / 5.4) ----
print("\n=== 4. STEP-02 AGGREGATES + RELIABILITY, FULL-POOL SIDE (paper 5.2 / 5.4) ===")
by02 = {}
for r in L:
    by02.setdefault(short(r["model_02"]), []).append(r)
grok_all = by02.get("grok-4.3", [])
gc = sum(x["claims"] for x in grok_all); gh = sum(x["hallucinations"] for x in grok_all)
print(f"  grok-4.3 all rolls: {gh}/{gc} claims")
check("grok-4.3 all-rolls rate %", 100 * gh / gc, 0.64, 0.02)

# full-pool single-pass rates per confirmed configuration (panel-15 values are
# historical, recorded before the confirmation extension; see paper 9.4)
for label, m1, m2, m3, want in [
    ("grok-4.3 (winner)",  "gpt-5-mini", "grok-4.3", MS, 0.37),
    ("opus (control)",     "claude-sonnet-4-6", "claude-opus-4-8", "claude-opus-4-8", 0.44),
    ("deepseek-v3.1",      "gpt-5-mini", "deepseek-chat-v3.1", "o4-mini", 0.56),
    ("gemini-2.5-pro",     "gpt-5-mini", "gemini-2.5-pro", MS, 0.63),
    ("gpt-5",              "gpt-5-mini", "gpt-5", MS, 0.66),
    ("mistral-small",      "gpt-5-mini", MS, MS, 0.71),
    ("qwen3-235b",         "gpt-5-mini", "qwen3-235b-a22b-2507", MS, 1.19),
    ("qwen3-32b",          "gpt-5-mini", "qwen3-32b", MS, 1.27),
]:
    rows = find(m1, m2, m3)
    if rows is None:
        FAILURES.append(label); print(f"  [FAIL] {label}: configuration not found"); continue
    s = stat(rows)
    check(f"{label} full-pool rate %", s["rate"], want, 0.02)

# ---- 4b. every published combination row vs the ledger ------------------------
print("\n=== 4b. ALL 31 COMBINATION ROWS (results-per-combination.csv vs ledger) ===")
import csv as _csv
csv_path = HERE / "results-per-combination.csv"
mismatches = 0
with open(csv_path, encoding="utf-8") as f:
    for row in _csv.DictReader(f):
        rows = find(row["step_01"].strip(), row["step_02"].strip(), row["step_03"].strip())
        if rows is None:
            FAILURES.append(f"combination missing: {row['step_01']}/{row['step_02']}/{row['step_03']}")
            mismatches += 1; continue
        s = stat(rows)
        ok = (s["v"] == int(row["vendors"]) and s["c"] == int(row["claims"])
              and s["h"] == int(row["hallucinations"])
              and abs(s["rate"] - float(row["hallucination_rate_pct"])) <= 0.02)
        if not ok:
            FAILURES.append(f"combination mismatch: {row['step_02']}")
            mismatches += 1
            print(f"  [FAIL] {row['step_01']} / {row['step_02']} / {row['step_03']}:"
                  f" ledger {s['v']}v {s['c']}cl {s['h']}h {s['rate']:.2f}% vs csv row")
print(f"  {'[PASS]' if mismatches == 0 else '[FAIL]'} all 31 published combination rows"
      f" match the ledger ({mismatches} mismatches)")

# ---- 5. 95% vendor-level cluster bootstrap (paper 5.3, 20k resamples) ---------
print("\n=== 5. 95% VENDOR-LEVEL CLUSTER-BOOTSTRAP INTERVALS (paper 5.3) ===")
def boot_ci(rows, n=20000, seed=20260608):
    random.seed(seed)
    per = [(r["claims"], r["hallucinations"]) for r in rows]
    k = len(per); out = []
    for _ in range(n):
        s = [per[random.randrange(k)] for _ in range(k)]
        c = sum(a for a, _ in s); h = sum(b for _, b in s)
        out.append(100 * h / c if c else 0.0)
    out.sort()
    return out[int(0.025 * n)], out[int(0.975 * n)]

for label, rows, lo_w, hi_w in [
    ("winner",  win,  0.00, 0.83),
    ("control", ctrl, 0.14, 0.85),
    ("alt-mistral", mist, 0.18, 1.38),
    ("alt-qwen",    qwen, 0.27, 2.20),
]:
    lo, hi = boot_ci(rows)
    okl = abs(lo - lo_w) <= 0.12; okh = abs(hi - hi_w) <= 0.20
    if not (okl and okh):
        FAILURES.append(f"bootstrap CI {label}")
    print(f"  [{'PASS' if (okl and okh) else 'FAIL'}] {label}: computed [{lo:.2f}, {hi:.2f}]"
          f"   paper [{lo_w:.2f}, {hi_w:.2f}]  (resampling noise tolerance 0.12/0.20)")

# ---- 6. Fisher exact, winner vs control (paper 5.3) ----------------------------
print("\n=== 6. FISHER EXACT TEST, WINNER VS CONTROL (paper 5.3) ===")
def fisher_two_sided(a, b, c, d):
    """2x2 table [[a,b],[c,d]]: sum of all hypergeometric outcomes no more likely
    than the observed one (standard two-sided Fisher)."""
    row1, col1, n = a + b, a + c, a + b + c + d
    def pmf(k):
        return (math.comb(col1, k) * math.comb(n - col1, row1 - k)) / math.comb(n, row1)
    p_obs = pmf(a)
    return sum(pmf(k) for k in range(max(0, row1 + col1 - n), min(row1, col1) + 1)
               if pmf(k) <= p_obs * (1 + 1e-9))

p = fisher_two_sided(sw["h"], sw["c"] - sw["h"], sc["h"], sc["c"] - sc["h"])
check("Fisher p (winner vs control)", p, 1.000, 0.005)
print("  parity: the difference is not detectable at this sample size, the desired"
      " outcome for an equivalence claim")

# ---- 7. 99% Clopper-Pearson upper bound, winner (abstract / paper 7) -----------
print("\n=== 7. WINNER 99% CLOPPER-PEARSON UPPER BOUND (abstract / paper 7) ===")
def binom_cdf(k, n, p):
    # exact, in log space to survive n=804
    total = 0.0
    for i in range(k + 1):
        total += math.exp(math.lgamma(n + 1) - math.lgamma(i + 1) - math.lgamma(n - i + 1)
                          + i * math.log(p) + (n - i) * math.log1p(-p))
    return total

def cp_upper(x, n, conf=0.99):
    lo, hi = x / n, 1.0
    for _ in range(100):
        mid = (lo + hi) / 2
        if binom_cdf(x, n, mid) > 1 - conf:
            lo = mid
        else:
            hi = mid
    return hi

cp = 100 * cp_upper(sw["h"], sw["c"])
check("winner 99% CP upper bound %", cp, 1.24, 0.02)
print("  the paper reports this limit instead of claiming the 0.5% bound; the stated"
      " 37-or-more-vendor requirement is a best-case floor (paper 7)")

# ---- 8. power statement (paper 3.9) --------------------------------------------
print("\n=== 8. POWER STATEMENT (paper 3.9) ===")
p1, p2, za, zb = 0.0044, 0.0022, 1.959964, 0.841621
pbar = (p1 + p2) / 2
n_arm = ((za * math.sqrt(2 * pbar * (1 - pbar))
          + zb * math.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) ** 2) / ((p1 - p2) ** 2)
check("claims/arm to resolve 0.44% vs 0.22%", n_arm, 10667, 60)

# ---- verdict -------------------------------------------------------------------
print("\n" + "=" * 60)
if FAILURES:
    print(f"RESULT: {len(FAILURES)} CHECK(S) FAILED:")
    for f in FAILURES:
        print(f"  - {f}")
    sys.exit(1)
print("RESULT: ALL CHECKS PASS - every recomputed value matches the paper.")
