"""Slides 1–3: cohort summary + headline settings + categorical adoption."""
from __future__ import annotations
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import matplotlib.pyplot as plt

from .style import BC_BLUE, BC_GRAY, apply_style, n_label
from .data import fetch_bank_metric
from cecl_db import connect


def plot_cohort_summary(shortqtr: str, out_dir: Path) -> None:
    """Slide 1: who's in the data — N banks + state spread."""
    apply_style()
    conn = connect()
    cur = conn.cursor(dictionary=True)
    cur.execute(
        """SELECT COUNT(*) AS n_banks, COUNT(DISTINCT state) AS n_states
           FROM cecl_benchmark_bank WHERE shortqtr = %s""",
        (shortqtr,),
    )
    summary = cur.fetchone()
    cur.execute(
        """SELECT state, COUNT(*) AS n FROM cecl_benchmark_bank
           WHERE shortqtr = %s AND state IS NOT NULL AND state <> ''
           GROUP BY state ORDER BY n DESC""",
        (shortqtr,),
    )
    states = cur.fetchall()
    cur.close()
    conn.close()

    fig, ax = plt.subplots(figsize=(12, 6))
    names = [s["state"] for s in states][::-1]
    counts = [s["n"] for s in states][::-1]
    ax.barh(names, counts, color=BC_BLUE)
    for i, c in enumerate(counts):
        ax.text(c + 0.5, i, str(c), va="center", fontsize=9, color=BC_GRAY)
    ax.set_xlabel("Banks in cohort")
    ax.set_title(
        f"Cohort — {summary['n_banks']} banks across {summary['n_states']} states · {shortqtr.upper()}",
        fontsize=16,
    )
    plt.tight_layout()
    out = out_dir / "01-cohort-summary.png"
    fig.savefig(out)
    plt.close(fig)
    print(f"  wrote {out}")


SETTINGS_TO_PLOT = [
    ("lookback",        "Lookback window (quarters)"),
    ("lookforward",     "Lookforward window (quarters)"),
    ("prepay",          "Prepay assumption"),
    ("floor_bps",       "Loss-rate floor (bps)"),
    ("peer_percentage", "Peer percentage"),
]


def plot_settings_distributions(shortqtr: str, out_dir: Path) -> None:
    apply_style()
    fig, axes = plt.subplots(len(SETTINGS_TO_PLOT), 1, figsize=(10, 3 * len(SETTINGS_TO_PLOT)))
    fig.suptitle(f"Headline CECL settings — {shortqtr.upper()}", fontsize=18, y=1.0)

    for ax, (metric, label) in zip(axes, SETTINGS_TO_PLOT):
        row = fetch_bank_metric(shortqtr, metric)
        if row is None or row["n_banks"] < 5:
            ax.text(0.5, 0.5, f"{label}: insufficient data", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        # Consolidate percentiles that collapse to the same value — when most
        # banks use an identical setting (e.g. lookforward=3 for everyone),
        # p50..p90 all equal, so we merge their labels to avoid overlap.
        raw = [("p10", row["p10"]), ("p25", row["p25"]), ("p50 (median)", row["p50"]),
               ("p75", row["p75"]), ("p90", row["p90"])]
        groups: dict[float, list[str]] = {}
        for label_text, val in raw:
            v = round(float(val), 4)
            groups.setdefault(v, []).append(label_text)

        positions = list(groups.keys())
        ax.plot(positions, [1] * len(positions), "o", color=BC_BLUE, markersize=10)
        for val, labs in groups.items():
            combined = ", ".join(labs)
            ax.annotate(f"{combined}\n{val:.2f}", xy=(val, 1), xytext=(val, 1.25),
                        ha="center", fontsize=10, color=BC_GRAY)
        ax.set_title(label)
        ax.set_ylim(0.5, 2.0)
        ax.set_yticks([])
        ax.text(0.99, 0.05, n_label(row["n_banks"]), transform=ax.transAxes,
                ha="right", fontsize=10, color=BC_GRAY)

    plt.tight_layout()
    out = out_dir / "02-headline-settings.png"
    fig.savefig(out)
    plt.close(fig)
    print(f"  wrote {out}")


def plot_categorical_usage(shortqtr: str, out_dir: Path) -> None:
    apply_style()
    metrics = [
        ("neg_forward_off",         "Banks with no negative forward-looking adjustments"),
        ("has_floor",               "Banks using a loss-rate floor"),
        ("uses_locked_indicators",  "Banks with locked indicators set"),
    ]
    fig, ax = plt.subplots(figsize=(10, 5))
    names = []
    pcts = []
    ns = []
    for m, label in metrics:
        row = fetch_bank_metric(shortqtr, m)
        if row is None or row["n_banks"] < 5 or not row["count_total"]:
            continue
        pct = 100.0 * row["count_true"] / row["count_total"]
        names.append(label)
        pcts.append(pct)
        ns.append(row["count_total"])

    bars = ax.barh(names, pcts, color=BC_BLUE)
    for bar, pct, n in zip(bars, pcts, ns):
        ax.text(pct + 1, bar.get_y() + bar.get_height() / 2,
                f"{pct:.0f}%  (N={n})", va="center", fontsize=11, color=BC_GRAY)
    ax.set_xlim(0, 100)
    ax.set_xlabel("% of banks")
    ax.set_title(f"Settings adoption — {shortqtr.upper()}")
    plt.tight_layout()
    out = out_dir / "03-categorical-usage.png"
    fig.savefig(out)
    plt.close(fig)
    print(f"  wrote {out}")
