#!/usr/bin/env python3
"""
Render diagnostic plots for the Skew Normal Regression example.

Usage:
    Place dataset_sn_regression.csv in the same directory, then run:
        python render_skewnormal_plots.py

Requires: pyinla, pandas, numpy, matplotlib
"""

from __future__ import annotations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyinla import pyinla

# ── Fit the model ────────────────────────────────────────────────────────────

df = pd.read_csv("dataset_sn_regression.csv")

model = {
    "response": "y",
    "fixed": ["1", "x"],
}

control = {
    "family": {
        "hyper": {
            "prec": {
                "prior": "pc.prec",
                "param": [3, 0.01],
            }
        }
    },
    "predictor": {"compute": True},
}

result = pyinla(
    model=model,
    family="sn",
    data=df,
    control=control,
)

# ── Extract quantities ───────────────────────────────────────────────────────

intercept = float(result.summary_fixed.loc["(Intercept)", "mean"])
slope = float(result.summary_fixed.loc["x", "mean"])

x_vals = df["x"].to_numpy()
eta_hat = intercept + slope * x_vals
residuals = df["y"].to_numpy() - eta_hat

# ── Plot 1: Observed y vs fitted line ────────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(x_vals, df["y"], color="#d946ef", s=26, edgecolor="white",
           linewidth=0.3, alpha=0.55, label="observations")
x_grid = np.linspace(x_vals.min(), x_vals.max(), 200)
ax.plot(x_grid, intercept + slope * x_grid, color="#f97316",
        linewidth=2.4, label="posterior mean")
ax.set_xlabel("x (covariate)")
ax.set_ylabel("y")
ax.set_title("Skew Normal regression: data vs posterior mean")
ax.legend(loc="best")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("skewnormal-regression-fit.png", dpi=160)
plt.close(fig)

# ── Plot 2: Residuals vs fitted ─────────────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(eta_hat, residuals, color="#38bdf8", s=26, edgecolor="white",
           linewidth=0.3, alpha=0.55)
ax.axhline(0.0, color="#f97316", linewidth=1.6, linestyle="--",
           label="zero residual")
ax.set_xlabel("Fitted value (η)")
ax.set_ylabel("Residual (y − η)")
ax.set_title("Skew Normal regression: residuals vs fitted")
ax.legend(loc="upper right")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("skewnormal-regression-residuals.png", dpi=160)
plt.close(fig)

print("Plots saved: skewnormal-regression-fit.png, skewnormal-regression-residuals.png")
print(f"Recovered: intercept = {intercept:.4f}, slope = {slope:.4f}")
print(result.summary_fixed)
print(result.summary_hyperpar)
