"""Read-only helpers that pull already-aggregated data from MySQL."""
from __future__ import annotations
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from cecl_db import connect


def fetch_bank_metric(shortqtr: str, metric_name: str) -> dict | None:
    """Return dict with keys n_banks, p10, p25, p50, p75, p90, mean — or None."""
    conn = connect()
    cur = conn.cursor(dictionary=True)
    cur.execute(
        """SELECT n_banks, p10, p25, p50, p75, p90, mean, count_true, count_total
           FROM cecl_benchmark_aggregates
           WHERE shortqtr=%s AND scope='bank' AND metric_name=%s""",
        (shortqtr, metric_name),
    )
    row = cur.fetchone()
    cur.close()
    conn.close()
    return row


def fetch_pool_metric(shortqtr: str, metric_name: str) -> list[dict]:
    conn = connect()
    cur = conn.cursor(dictionary=True)
    cur.execute(
        """SELECT callreportcode, n_banks, p10, p25, p50, p75, p90, mean
           FROM cecl_benchmark_aggregates
           WHERE shortqtr=%s AND scope='bank_pool' AND metric_name=%s
           ORDER BY callreportcode""",
        (shortqtr, metric_name),
    )
    rows = cur.fetchall()
    cur.close()
    conn.close()
    return rows


def fetch_top_locked(shortqtr: str, top_n: int = 10) -> list[dict]:
    conn = connect()
    cur = conn.cursor(dictionary=True)
    cur.execute(
        """SELECT a.callreportcode, a.count_true, a.count_total
           FROM cecl_benchmark_aggregates a
           WHERE a.shortqtr=%s AND a.scope='pool' AND a.metric_name='locked_bank_count'
           ORDER BY a.count_true DESC LIMIT %s""",
        (shortqtr, top_n),
    )
    rows = cur.fetchall()
    cur.close()
    conn.close()
    return rows


def fetch_theme_prevalence(shortqtr: str) -> list[dict]:
    conn = connect()
    cur = conn.cursor(dictionary=True)
    cur.execute(
        """SELECT c.name, a.count_true, a.count_total, a.p50 AS median_amount
           FROM cecl_benchmark_aggregates a
           JOIN cecl_qfactor_categories c ON c.id = a.category_id
           WHERE a.shortqtr=%s AND a.scope='category' AND a.metric_name='theme_prevalence'
           ORDER BY a.count_true DESC""",
        (shortqtr,),
    )
    rows = cur.fetchall()
    cur.close()
    conn.close()
    return rows


def fetch_raw_pool_values(shortqtr: str, metric_column: str, callreportcode: str) -> list[float]:
    """Pull raw per-bank values for a pool metric — used for box plots.

    Matches by parent-code prefix so sub-pools (e.g. '01E1-Sub Not 1-4 Family')
    roll up into their parent ('01E1').
    """
    allowed = {"qualadj_bps", "customfactors_bps", "lossrate", "confloss"}
    if metric_column not in allowed:
        raise ValueError(f"metric_column must be one of {allowed}")
    conn = connect()
    cur = conn.cursor()
    cur.execute(
        f"""SELECT {metric_column} FROM cecl_benchmark_bank_pool
            WHERE shortqtr=%s
              AND SUBSTRING_INDEX(callreportcode,'-',1)=%s
              AND {metric_column} IS NOT NULL""",
        (shortqtr, callreportcode),
    )
    vals = [float(r[0]) for r in cur.fetchall()]
    cur.close()
    conn.close()
    return vals
