#!/usr/bin/env python3
"""
Render diagnostic plots for the Negative Binomial Variant example.

Usage:
    Place dataset_nbinomial.csv in the same directory, then run:
        python render_nbinomial_plots.py

Requires: pyinla, pandas, numpy, matplotlib
"""

from __future__ import annotations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyinla import pyinla

# ── Fit the model ────────────────────────────────────────────────────────────

df = pd.read_csv("dataset_nbinomial.csv")
Ntrials = df["Ntrials"].to_numpy()

model = {
    "response": "y",
    "fixed": ["1", "z"],
}

result = pyinla(
    model=model,
    family="binomial",
    data=df,
    Ntrials=Ntrials,
    control_family={"variant": 1},
)

# ── Extract quantities ───────────────────────────────────────────────────────

intercept = float(result.summary_fixed.loc["(Intercept)", "mean"])
slope = float(result.summary_fixed.loc["z", "mean"])

eta_hat = intercept + slope * df["z"].to_numpy()
p_hat = np.exp(eta_hat) / (1 + np.exp(eta_hat))
expected_n = df["y"].to_numpy() / p_hat  # E[N] = y / p
residuals = Ntrials - expected_n

# ── Plot 1: Observed Ntrials vs expected ─────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(expected_n, Ntrials, color="#a78bfa", s=46, edgecolor="white", linewidth=0.4, alpha=0.7)
max_val = max(expected_n.max(), Ntrials.max())
ax.plot([0, max_val], [0, max_val], color="#f97316", linewidth=2, linestyle="--", label="y = x")
ax.set_xlabel("Expected total count (y / p)")
ax.set_ylabel("Observed total count (Ntrials)")
ax.set_title("Negative binomial variant: observed vs expected")
ax.legend(loc="best")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("binomial-nbinomial-fit.png", dpi=160)
plt.close(fig)

# ── Plot 2: Residuals vs expected ────────────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(expected_n, residuals, color="#38bdf8", s=42, edgecolor="white", linewidth=0.35)
ax.axhline(0.0, color="#f97316", linewidth=1.6, linestyle="--", label="zero residual")
ax.set_xlabel("Expected total count (y / p)")
ax.set_ylabel("Residual (Ntrials − expected)")
ax.set_title("Negative binomial variant: residuals vs expected")
ax.legend(loc="upper right")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("binomial-nbinomial-residuals.png", dpi=160)
plt.close(fig)

print("Plots saved: binomial-nbinomial-fit.png, binomial-nbinomial-residuals.png")
print(f"Recovered: intercept = {intercept:.4f}, slope = {slope:.4f}")
print(result.summary_fixed)
