#!/usr/bin/env python3
"""
Render diagnostic plots for the Gaussian precision-offset example.

Usage:
    Place dataset_gaussian_offset.csv in the same directory, then run:
        python render_gaussian_plots.py

Requires: pyinla, pandas, numpy, matplotlib
"""

from __future__ import annotations

import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pyinla import pyinla

# ── Fit the model ────────────────────────────────────────────────────────────

df = pd.read_csv("dataset_gaussian_offset.csv")

var0 = 1.0  # known measurement variance

model = {
    "response": "y",
    "fixed": ["1", "x"],
}

control = {
    "family": {
        "hyper": [
            {
                "id": "precoffset",
                "initial": math.log(1.0 / var0),  # log(tau_0) = log(1/var0)
                "fixed": True,
            }
        ]
    },
    "predictor": {"compute": True},
}

result = pyinla(model=model, family="gaussian", data=df, control=control)

# ── Extract quantities ───────────────────────────────────────────────────────

intercept = float(result.summary_fixed.loc["(Intercept)", "mean"])
slope = float(result.summary_fixed.loc["x", "mean"])
fitted_means = result.summary_fitted_values["mean"].to_numpy()
residuals = df["y"].to_numpy() - fitted_means

density = result.marginals_hyperpar["Precision for the Gaussian observations"]
density = density.rename(columns={density.columns[0]: "x", density.columns[1]: "y"})

x_grid = np.linspace(df["x"].min(), df["x"].max(), 200)
fitted_line = intercept + slope * x_grid

# ── Plot 1: Observations and posterior mean ──────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(df["x"], df["y"], color="#38bdf8", s=46, edgecolor="white", linewidth=0.4, label="observations")
ax.plot(x_grid, fitted_line, color="#f97316", linewidth=2.4, label="pyINLA posterior mean")
ax.set_xlabel("x (covariate)")
ax.set_ylabel("y")
ax.set_title("Gaussian offset: observed data and posterior mean trend")
ax.legend(loc="best")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("gaussian-offset-fit.png", dpi=160)
plt.close(fig)

# ── Plot 2: Residuals vs fitted values ───────────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.scatter(fitted_means, residuals, color="#a78bfa", s=42, edgecolor="white", linewidth=0.35)
ax.axhline(0.0, color="#f97316", linewidth=1.6, linestyle="--", label="zero residual")
ax.set_xlabel("Fitted mean")
ax.set_ylabel("Observed - fitted")
ax.set_title("Gaussian offset: residuals vs fitted values")
ax.legend(loc="upper right")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("gaussian-offset-residuals.png", dpi=160)
plt.close(fig)

# ── Plot 3: Posterior density for precision ──────────────────────────────────

fig, ax = plt.subplots(figsize=(6.2, 4.1))
ax.plot(density["x"], density["y"], color="#f97316", linewidth=2.0)
ax.fill_between(density["x"], density["y"], color="#f97316", alpha=0.25)
ax.set_xlabel("Precision (τ)")
ax.set_ylabel("Density")
ax.set_title("Gaussian offset: posterior density for precision τ")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig("gaussian-offset-precision.png", dpi=160)
plt.close(fig)

print("Plots saved: gaussian-offset-fit.png, gaussian-offset-residuals.png, gaussian-offset-precision.png")
print(f"Recovered: intercept = {intercept:.4f}, slope = {slope:.4f}")
print(result.summary_fixed)
print(result.summary_hyperpar)
