#!/usr/bin/env python3
"""rclaude triage helper — Haiku-powered per-session summary + prioritization.

Iterates ~/.claude/projects/, builds a compressed transcript per session, and
runs them through claude-code-batch-sdk (Haiku) with content-addressable
disk caching. Sessions whose mtime + transcript fingerprint haven't changed
are served from cache.

Output (TSV, one row per session, sorted by priority desc / mtime desc):
    <mtime>\\t<uuid>\\t<cwd>\\t<priority>\\t<status>\\t<summary>\\t<next_action>

Env tuning:
    RCLAUDE_TRIAGE_MODEL       Claude model (default: haiku)
    RCLAUDE_TRIAGE_CONCURRENT  Concurrent claude subprocesses (default: 4)
    RCLAUDE_TRIAGE_BATCH       Sessions per CLI call (default: 8)
    RCLAUDE_TRIAGE_LIMIT       Max sessions to consider (default: 100)
    RCLAUDE_TRIAGE_CTX_BYTES   Bytes of transcript per session (default: 3000)

CLI flags:
    --limit N      override session cap
    --uuids U...   restrict to specific session UUIDs (prefix match ok)
    --refresh      bypass cache for selected sessions
"""
from __future__ import annotations

import argparse
import asyncio
import json
import os
import re
import signal
import sys
from pathlib import Path

# Allow piped output to truncate without traceback.
signal.signal(signal.SIGPIPE, signal.SIG_DFL)

try:
    from claude_code_batch_sdk import (
        ClaudeClient,
        GenerationItem,
        ResponseCache,
        run_batched,
    )
except ImportError as e:
    print(
        f"_claude-triage: claude-code-batch-sdk not installed for this python ({sys.executable}): {e}",
        file=sys.stderr,
    )
    sys.exit(2)

ROOT = Path.home() / ".claude" / "projects"
CACHE_DIR = Path.home() / ".claude" / ".cache" / "rclaude-triage"

MODEL = os.environ.get("RCLAUDE_TRIAGE_MODEL", "haiku")
MAX_CONCURRENT = int(os.environ.get("RCLAUDE_TRIAGE_CONCURRENT", "4"))
BATCH_SIZE = int(os.environ.get("RCLAUDE_TRIAGE_BATCH", "8"))
CTX_PER_SESSION = int(os.environ.get("RCLAUDE_TRIAGE_CTX_BYTES", "3000"))

SYSTEM_PREFIXES = (
    "<command-name>", "<command-message>", "<system-reminder>", "<local-command-",
    "Caveat:", "<bash-input>", "<bash-stdout>", "[task-persistence]", "[tts-state]",
    "This session is being continued", "Please continue", "<task-notification>",
    "[Request interrupted",
)


def is_system_user(text: str) -> bool:
    s = (text or "").lstrip()
    return not s or s.startswith(SYSTEM_PREFIXES)


def get_text(entry: dict) -> str:
    msg = entry.get("message") or {}
    content = msg.get("content")
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts: list[str] = []
        for block in content:
            if not isinstance(block, dict):
                continue
            t = block.get("type")
            if t == "text":
                parts.append(block.get("text", ""))
            elif t == "tool_use":
                parts.append(f"[tool:{block.get('name','?')}]")
        return " ".join(parts)
    return ""


def compress_session(jsonl: Path) -> tuple[str, str, str]:
    """Return (cwd, first_user_text, transcript_excerpt)."""
    first_user = ""
    cwd = ""
    turns: list[str] = []
    try:
        with jsonl.open(encoding="utf-8", errors="replace") as f:
            for line in f:
                try:
                    entry = json.loads(line)
                except json.JSONDecodeError:
                    continue
                if not cwd and entry.get("cwd"):
                    cwd = entry["cwd"]
                role = (entry.get("message") or {}).get("role") or entry.get("type")
                if role not in ("user", "assistant"):
                    continue
                text = get_text(entry)
                if not text.strip():
                    continue
                if role == "user" and is_system_user(text):
                    continue
                flat = re.sub(r"\s+", " ", text).strip()
                if role == "user" and not first_user:
                    first_user = flat[:400]
                turns.append(f"[{role}] {flat[:300]}")
    except OSError:
        return ("", "", "")
    # Keep recent turns; head-truncate to budget.
    transcript = "\n".join(turns[-20:])[-CTX_PER_SESSION:]
    return (cwd, first_user, transcript)


SYSTEM_PROMPT = """You triage Claude Code coding sessions. For each session in the input batch, emit one JSON object with these fields:
- ref_index: integer matching the input's ref_index
- summary: ONE short sentence describing what's happening
- status: one of done, in_progress, blocked, waiting_on_user, abandoned
- priority: integer 1-5 (5 = critical to resume now, 1 = abandonable)
- next_action: ONE short imperative phrase, or empty string if status is done/abandoned

Output ONLY a JSON array. No markdown, no prose."""


def build_prompt(batch: list[GenerationItem]) -> str:
    parts = ["Triage these sessions. Respond with JSON array.\n"]
    for i, item in enumerate(batch):
        m = item.metadata
        parts.append(f"\n=== ref_index={i} ===")
        parts.append(f"INITIAL REQUEST: {m['first_user']}")
        parts.append("RECENT TRANSCRIPT:")
        parts.append(m["transcript"])
    parts.append(
        '\n\nReply with: [{"ref_index": 0, "summary": "...", '
        '"status": "...", "priority": N, "next_action": "..."}, ...]'
    )
    return "\n".join(parts)


def validate(result: dict) -> bool:
    if not isinstance(result, dict):
        return False
    if not all(k in result for k in ("summary", "status", "priority", "next_action")):
        return False
    try:
        int(result["priority"])
    except (TypeError, ValueError):
        return False
    return True


def enrich(result: dict, item: GenerationItem) -> dict:
    return {
        **result,
        "uuid": item.metadata["uuid"],
        "cwd": item.metadata["cwd"],
        "mtime": item.metadata["mtime"],
    }


def collect_candidates(limit: int) -> list[tuple[int, Path]]:
    if not ROOT.is_dir():
        return []
    candidates: list[tuple[int, Path]] = []
    for project_dir in ROOT.iterdir():
        if not project_dir.is_dir():
            continue
        for jsonl in project_dir.glob("*.jsonl"):
            try:
                mtime = int(jsonl.stat().st_mtime)
            except OSError:
                continue
            candidates.append((mtime, jsonl))
    candidates.sort(key=lambda x: x[0], reverse=True)
    return candidates[:limit]


async def main_async(args: argparse.Namespace) -> None:
    candidates = collect_candidates(args.limit)
    if args.uuids:
        wanted = list(args.uuids)
        candidates = [
            (m, j) for (m, j) in candidates
            if any(j.stem == u or j.stem.startswith(u) for u in wanted)
        ]
    if not candidates:
        return

    # Cache key is (uuid, mtime) only — mtime is sufficient since a session's
    # content can't change without its mtime advancing, and this lets us check
    # the cache BEFORE reading the JSONL (the expensive part for a triage run
    # over hundreds of sessions where most are unchanged).
    cache = ResponseCache(CACHE_DIR)
    template_id = "rclaude-triage-v1"
    cached_results: list[dict] = []
    uncached: list[tuple[int, Path, str]] = []
    for mtime, jsonl in candidates:
        cache_key = f"{jsonl.stem}:{mtime}"
        hit = None if args.refresh else cache.get(template_id, cache_key)
        if hit is not None:
            # Carry through identifying metadata in case the cached payload
            # was written by an older script revision that didn't enrich.
            cached_results.append({
                **hit,
                "uuid": hit.get("uuid", jsonl.stem),
                "mtime": hit.get("mtime", mtime),
            })
            continue
        uncached.append((mtime, jsonl, cache_key))
    if args.refresh:
        for _, _, ck in uncached:
            key_hash = cache._hash_key(template_id, ck)
            path = cache._cache_path(template_id, key_hash)
            if path.exists():
                try: path.unlink()
                except OSError: pass

    items: list[GenerationItem] = []
    for mtime, jsonl, cache_key in uncached:
        cwd, first_user, transcript = compress_session(jsonl)
        if not cwd:
            continue
        items.append(
            GenerationItem(
                template_id=template_id,
                cache_key=cache_key,
                metadata={
                    "uuid": jsonl.stem,
                    "cwd": cwd,
                    "mtime": mtime,
                    "first_user": first_user,
                    "transcript": transcript,
                },
            )
        )

    results = list(cached_results)
    if items:
        client = ClaudeClient(model=MODEL, max_concurrent=MAX_CONCURRENT)
        try:
            new_results = await run_batched(
                client=client,
                cache=cache,
                items=items,
                system_prompt=SYSTEM_PROMPT,
                build_batch_prompt=build_prompt,
                validate_result=validate,
                enrich_result=enrich,
                index_key="ref_index",
                batch_size=BATCH_SIZE,
                description="rclaude-triage",
            )
        finally:
            await client.close()
        results.extend(new_results)

    def sort_key(r: dict) -> tuple[int, int]:
        try:
            prio = int(r.get("priority", 0))
        except (TypeError, ValueError):
            prio = 0
        return (prio, int(r.get("mtime", 0)))

    for r in sorted(results, key=sort_key, reverse=True):
        line = "\t".join(
            [
                str(r.get("mtime", 0)),
                str(r.get("uuid", "")),
                str(r.get("cwd", "")),
                str(r.get("priority", 0)),
                str(r.get("status", "")),
                re.sub(r"\s+", " ", str(r.get("summary", "")))[:200],
                re.sub(r"\s+", " ", str(r.get("next_action", "")))[:200],
            ]
        )
        print(line)
    print(f"# cache {cache.stats_summary()}", file=sys.stderr)


def main() -> None:
    ap = argparse.ArgumentParser(description="Triage Claude Code sessions with Haiku.")
    ap.add_argument(
        "--limit",
        type=int,
        default=int(os.environ.get("RCLAUDE_TRIAGE_LIMIT", "100")),
        help="Max sessions to consider (default: 100, env: RCLAUDE_TRIAGE_LIMIT)",
    )
    ap.add_argument("--uuids", nargs="*", help="Restrict to specific UUIDs (prefix ok)")
    ap.add_argument("--refresh", action="store_true", help="Bypass cache")
    args = ap.parse_args()
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()
