#!/usr/bin/env python3
"""
Render diagnostic plots for the Weibull Survival Analysis example.

Usage:
    Place dataset_weibull_surv_v0.csv in the same directory, then run:
        python render_weibullsurv_plots.py

Requires: pyinla, pandas, numpy, matplotlib
"""

from __future__ import annotations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyinla import pyinla, inla_surv

# ── Fit the model ────────────────────────────────────────────────────────────

df = pd.read_csv("dataset_weibull_surv_v0.csv")

y_surv = inla_surv(
    time=df["y"].to_numpy(),
    event=df["event"].to_numpy(),
)

model = {
    "response": y_surv,
    "fixed": ["1", "x"],
}

control = {
    "family": {
        "variant": 0,
    },
    "compute": {
        "dic": True,
        "cpo": True,
        "mlik": True,
        "return_marginals": True,
    },
}

result = pyinla(model=model, family="weibullsurv", data=df, control=control)

# ── Extract quantities ───────────────────────────────────────────────────────

intercept = float(result.summary_fixed.loc["(Intercept)", "mean"])
slope = float(result.summary_fixed.loc["x", "mean"])
alpha = float(result.summary_hyperpar.iloc[0]["mean"])

eta_hat = intercept + slope * df["x"].to_numpy()
lam_hat = np.exp(eta_hat)
median_hat = (np.log(2) / lam_hat) ** (1.0 / alpha)  # Weibull v0 median
residuals = df["y"].to_numpy() - median_hat

density = result.marginals_hyperpar["alpha parameter for weibullsurv"]
density = density.rename(columns={density.columns[0]: "x", density.columns[1]: "y"})

# ── Plot 1: Observed time vs fitted median ─────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(median_hat, df["y"], color="#38bdf8", s=46, edgecolor="white",
           linewidth=0.4, alpha=0.7, label="observations")
max_val = max(median_hat.max(), df["y"].max())
ax.plot([0, max_val], [0, max_val], color="#f97316", linewidth=2.0,
        linestyle="--", label="y = x (perfect fit)")
ax.set_xlabel("Fitted median survival time")
ax.set_ylabel("Observed time (y)")
ax.set_title("Weibull survival: observed vs fitted")
ax.legend(loc="best")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("weibull-survival-fit.png", dpi=160)
plt.close(fig)

# ── Plot 2: Residuals vs fitted median ─────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(median_hat, residuals, color="#a78bfa", s=42, edgecolor="white",
           linewidth=0.35)
ax.axhline(0.0, color="#f97316", linewidth=1.6, linestyle="--",
           label="zero residual")
ax.set_xlabel("Fitted median survival time")
ax.set_ylabel("Residual (y \u2212 median)")
ax.set_title("Weibull survival: residuals")
ax.legend(loc="upper right")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("weibull-survival-residuals.png", dpi=160)
plt.close(fig)

# ── Plot 3: Posterior density for alpha ────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.plot(density["x"], density["y"], color="#f97316", linewidth=2.0)
ax.fill_between(density["x"], density["y"], color="#f97316", alpha=0.25)
ax.set_xlabel("Shape parameter (\u03b1)")
ax.set_ylabel("Density")
ax.set_title("Weibull survival: posterior density for \u03b1")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("weibull-survival-alpha.png", dpi=160)
plt.close(fig)

print("Plots saved: weibull-survival-fit.png, weibull-survival-residuals.png, weibull-survival-alpha.png")
print(f"Recovered: intercept = {intercept:.4f}, slope = {slope:.4f}, alpha = {alpha:.4f}")
print(result.summary_fixed)
