"""Pass 2: assign every custom-factor row to a cluster (or UNCATEGORIZED).

Usage:
  python cecl_qfactor_assign.py --shortqtr=1q26
"""
from __future__ import annotations
import argparse
import json
import os
import sys

from anthropic import Anthropic
from tqdm import tqdm

from cecl_db import connect, batched

MODEL = "claude-opus-4-7"
BATCH_SIZE = 50

SYSTEM_PROMPT = """You assign CECL Q-factor descriptions to pre-approved theme
clusters. Given a list of clusters (name + description) and a batch of input
descriptions, return JSON mapping each input index to the best cluster name
(or the string "UNCATEGORIZED" if none fit) with a confidence 0.0-1.0.

Return STRICT JSON only. No prose. Schema:
{
  "assignments": [
    {"index": 0, "category": "name-or-UNCATEGORIZED", "confidence": 0.0-1.0}
  ]
}"""


def build_user_prompt(clusters: list[dict], descriptions: list[str]) -> str:
    lines = ["Approved clusters:"]
    for c in clusters:
        lines.append(f"- {c['name']}: {c['description']}")
    lines.append("\nInput descriptions (assign each by index):")
    for i, d in enumerate(descriptions):
        lines.append(f"[{i}] {d}")
    return "\n".join(lines)


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--shortqtr", required=True)
    args = ap.parse_args()

    conn = connect()
    cur = conn.cursor(dictionary=True)

    cur.execute(
        "SELECT id, name, description FROM cecl_qfactor_categories WHERE shortqtr = %s",
        (args.shortqtr,),
    )
    clusters = cur.fetchall()
    if not clusters:
        print(f"ERROR: no clusters loaded for shortqtr={args.shortqtr}. Run cecl_qfactor_cluster.py --load first.", file=sys.stderr)
        sys.exit(1)
    name_to_id = {c["name"]: c["id"] for c in clusters}

    cur.execute(
        """SELECT t.id, t.description
           FROM cecl_benchmark_customfactor_text t
           LEFT JOIN cecl_qfactor_categorized qc ON qc.customfactor_text_id = t.id
           WHERE t.shortqtr = %s AND qc.id IS NULL""",
        (args.shortqtr,),
    )
    rows = cur.fetchall()
    print(f"To assign: {len(rows)} descriptions in batches of {BATCH_SIZE}")

    if not rows:
        print("Nothing to do.")
        return

    client = Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
    ins = conn.cursor()

    for batch in tqdm(list(batched(rows, BATCH_SIZE))):
        descs = [r["description"] for r in batch]
        user = build_user_prompt(clusters, descs)
        resp = client.messages.create(
            model=MODEL,
            max_tokens=4000,
            system=SYSTEM_PROMPT,
            messages=[{"role": "user", "content": [{"type": "text", "text": user, "cache_control": {"type": "ephemeral"}}]}],
        )
        try:
            parsed = json.loads(resp.content[0].text.strip())
        except json.JSONDecodeError:
            print(f"Skipping batch — invalid JSON response:\n{resp.content[0].text}", file=sys.stderr)
            continue

        for a in parsed["assignments"]:
            idx = a["index"]
            if idx < 0 or idx >= len(batch):
                continue
            text_id = batch[idx]["id"]
            cat = a["category"]
            if cat == "UNCATEGORIZED":
                ins.execute(
                    """INSERT INTO cecl_qfactor_categorized
                        (customfactor_text_id, category_id, confidence, is_uncategorized)
                        VALUES (%s, NULL, %s, 1)
                        ON DUPLICATE KEY UPDATE category_id=VALUES(category_id),
                          confidence=VALUES(confidence), is_uncategorized=1, assigned_at=CURRENT_TIMESTAMP""",
                    (text_id, a.get("confidence")),
                )
            else:
                cat_id = name_to_id.get(cat)
                if cat_id is None:
                    # LLM hallucinated a cluster name; treat as uncategorized.
                    ins.execute(
                        """INSERT INTO cecl_qfactor_categorized
                            (customfactor_text_id, category_id, confidence, is_uncategorized)
                            VALUES (%s, NULL, %s, 1)
                            ON DUPLICATE KEY UPDATE is_uncategorized=1""",
                        (text_id, a.get("confidence")),
                    )
                else:
                    ins.execute(
                        """INSERT INTO cecl_qfactor_categorized
                            (customfactor_text_id, category_id, confidence, is_uncategorized)
                            VALUES (%s, %s, %s, 0)
                            ON DUPLICATE KEY UPDATE category_id=VALUES(category_id),
                              confidence=VALUES(confidence), is_uncategorized=0, assigned_at=CURRENT_TIMESTAMP""",
                        (text_id, cat_id, a.get("confidence")),
                    )
        conn.commit()

    ins.close()
    cur.close()
    conn.close()
    print("Done.")


if __name__ == "__main__":
    main()
