"""
Parse Claude Code JSONL session files from ~/.claude/projects/
Detects project from cwd/file paths, computes active time.
"""

import json
import os
import re
from datetime import datetime, timezone
from pathlib import Path

CLAUDE_DIR = Path.home() / ".claude"
PROJECTS_DIR = CLAUDE_DIR / "projects"
IDLE_THRESHOLD_SECONDS = 10 * 60  # 10 minutes

# ── Pricing per token (USD) — updated April 2026 from docs.anthropic.com ──
PRICING = {
    "opus": {
        "input": 5 / 1_000_000, "output": 25 / 1_000_000,
        "cache_read": 0.50 / 1_000_000, "cache_write": 6.25 / 1_000_000,
    },
    "sonnet": {
        "input": 3 / 1_000_000, "output": 15 / 1_000_000,
        "cache_read": 0.30 / 1_000_000, "cache_write": 3.75 / 1_000_000,
    },
    "haiku": {
        "input": 1 / 1_000_000, "output": 5 / 1_000_000,
        "cache_read": 0.10 / 1_000_000, "cache_write": 1.25 / 1_000_000,
    },
}


def _model_family(model_str: str) -> str:
    """Determine pricing family from model ID string."""
    m = (model_str or "").lower()
    if "haiku" in m:
        return "haiku"
    if "sonnet" in m:
        return "sonnet"
    return "opus"  # default — covers opus 4, 4.1, 4.5, 4.6


def estimate_cost_for_tokens(inp, out, cache_r, cache_w, model="opus"):
    """Calculate cost for a single API call using per-model pricing."""
    p = PRICING.get(_model_family(model), PRICING["opus"])
    return (inp * p["input"] + out * p["output"] +
            cache_r * p["cache_read"] + cache_w * p["cache_write"])

# Known project directories -> display names
# Auto-detected from ~/*/  directories
PROJECT_DIRS = {}


EXCLUDE_DIRS = {
    "applications", "desktop", "documents", "downloads", "library",
    "movies", "music", "pictures", "public", "projects", "sites",
    "go", "node_modules", "venv", "opt", "tmp", "bin",
}


def detect_projects():
    """Scan home directory and ~/Projects/ for known project folders."""
    home = Path.home()
    PROJECT_DIRS.clear()
    for d in home.iterdir():
        if d.is_dir() and not d.name.startswith(".") and d.name.lower() not in EXCLUDE_DIRS:
            PROJECT_DIRS[str(d).lower()] = d.name

    # Also scan ~/Projects/ subdirectories
    projects_dir = home / "Projects"
    if projects_dir.is_dir():
        for d in projects_dir.iterdir():
            if d.is_dir() and not d.name.startswith("."):
                PROJECT_DIRS[str(d).lower()] = d.name


def detect_project_from_msg(msg):
    """Detect which project a message belongs to based on cwd or file paths."""
    cwd = (msg.get("cwd") or "").lower()

    # Check cwd first
    for proj_path, proj_name in PROJECT_DIRS.items():
        if proj_path in cwd:
            return proj_path, proj_name

    # Check content for file paths
    inner = msg.get("message") or {}
    content = inner.get("content")
    if isinstance(content, list):
        for block in content:
            if isinstance(block, dict):
                # Tool use blocks have input with file paths
                tool_input = block.get("input") or {}
                for val in (
                    tool_input.get("file_path", ""),
                    tool_input.get("path", ""),
                    tool_input.get("command", ""),
                ):
                    val_lower = val.lower() if isinstance(val, str) else ""
                    for proj_path, proj_name in PROJECT_DIRS.items():
                        if proj_path in val_lower:
                            return proj_path, proj_name

    return None, None


def parse_timestamp(ts_str):
    """Parse ISO timestamp string to datetime."""
    if not ts_str or not isinstance(ts_str, str):
        return None
    try:
        return datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
    except (ValueError, TypeError):
        return None


def compute_active_time(timestamps, threshold=IDLE_THRESHOLD_SECONDS):
    """Compute active time from sorted timestamps, skipping gaps > threshold."""
    if len(timestamps) < 2:
        return 0
    active = 0
    for i in range(1, len(timestamps)):
        gap = (timestamps[i] - timestamps[i - 1]).total_seconds()
        if gap <= threshold:
            active += gap
    return active


def count_lines_from_tools(content_blocks):
    """Count lines of code written/edited from tool_use blocks."""
    lines = 0
    if not isinstance(content_blocks, list):
        return 0
    for block in content_blocks:
        if not isinstance(block, dict) or block.get("type") != "tool_use":
            continue
        name = block.get("name", "")
        inp = block.get("input") or {}
        if name == "Write":
            text = inp.get("content", "")
            if text:
                lines += text.count("\n") + 1
        elif name == "Edit":
            new_text = inp.get("new_string", "")
            if new_text:
                lines += new_text.count("\n") + 1
    return lines


def parse_session_file(filepath):
    """Parse a single JSONL session file. Returns session stats and per-message data."""
    messages = []
    timestamps = []
    human_count = 0
    api_count = 0
    total_input = 0
    total_output = 0
    total_cache_read = 0
    total_cache_write = 0
    total_lines_written = 0
    total_cost = 0.0
    models = {}
    project_votes = {}  # Track which project appears most

    with open(filepath, "r", errors="replace") as f:
        for line in f:
            try:
                msg = json.loads(line.strip())
            except (json.JSONDecodeError, ValueError):
                continue

            msg_type = msg.get("type", "")
            if msg_type in ("file-history-snapshot",):
                continue

            ts = parse_timestamp(msg.get("timestamp"))
            if ts:
                timestamps.append(ts)

            inner = msg.get("message") or {}
            role = inner.get("role", "")

            # Detect project
            proj_path, proj_name = detect_project_from_msg(msg)
            if proj_name:
                project_votes[proj_name] = project_votes.get(proj_name, 0) + 1

            # Human prompts (exclude tool results)
            is_human = False
            if msg_type == "user" and not msg.get("toolUseResult"):
                content = inner.get("content", "")
                preview = ""
                if isinstance(content, str) and content.strip():
                    is_human = True
                    preview = content[:200]
                elif isinstance(content, list):
                    for c in content:
                        if isinstance(c, dict) and c.get("type") == "text" and c.get("text", "").strip():
                            is_human = True
                            preview = c["text"][:200]
                            break

                if is_human:
                    human_count += 1
                    messages.append({
                        "timestamp": ts.isoformat() if ts else None,
                        "type": "human",
                        "model": None,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "cache_read_tokens": 0,
                        "cache_write_tokens": 0,
                        "content_preview": preview,
                        "project_name": proj_name,
                    })

            # Assistant / API calls
            if msg_type == "assistant":
                api_count += 1
                model = inner.get("model", "unknown")
                models[model] = models.get(model, 0) + 1

                usage = inner.get("usage") or {}
                inp = usage.get("input_tokens", 0) or 0
                out = usage.get("output_tokens", 0) or 0
                cr = usage.get("cache_read_input_tokens", 0) or 0
                cw = usage.get("cache_creation_input_tokens", 0) or 0

                total_input += inp
                total_output += out
                total_cache_read += cr
                total_cache_write += cw
                total_cost += estimate_cost_for_tokens(inp, out, cr, cw, model)

                # Count lines of code written via Write/Edit tools
                content_blocks = inner.get("content", [])
                total_lines_written += count_lines_from_tools(content_blocks)

                messages.append({
                    "timestamp": ts.isoformat() if ts else None,
                    "type": "assistant",
                    "model": model,
                    "input_tokens": inp,
                    "output_tokens": out,
                    "cache_read_tokens": cr,
                    "cache_write_tokens": cw,
                    "content_preview": None,
                    "project_name": proj_name,
                })

    timestamps.sort()
    active_seconds = compute_active_time(timestamps)
    wall_seconds = (timestamps[-1] - timestamps[0]).total_seconds() if len(timestamps) >= 2 else 0

    # Dominant project
    dominant_project = max(project_votes, key=project_votes.get) if project_votes else None

    return {
        "first_msg_at": timestamps[0].isoformat() if timestamps else None,
        "last_msg_at": timestamps[-1].isoformat() if timestamps else None,
        "wall_seconds": wall_seconds,
        "active_seconds": active_seconds,
        "human_prompts": human_count,
        "api_calls": api_count,
        "input_tokens": total_input,
        "output_tokens": total_output,
        "cache_read_tokens": total_cache_read,
        "cache_write_tokens": total_cache_write,
        "lines_written": total_lines_written,
        "cost": total_cost,
        "model_usage": models,
        "dominant_project": dominant_project,
        "project_votes": project_votes,
        "messages": messages,
    }


def parse_subagents(session_dir):
    """Parse subagent JSONL files for a session."""
    sub_dir = session_dir / "subagents"
    if not sub_dir.exists():
        return {"api_calls": 0, "input_tokens": 0, "output_tokens": 0,
                "cache_read_tokens": 0, "cache_write_tokens": 0, "lines_written": 0, "cost": 0.0}

    totals = {"api_calls": 0, "input_tokens": 0, "output_tokens": 0,
              "cache_read_tokens": 0, "cache_write_tokens": 0, "lines_written": 0, "cost": 0.0}

    for sf in sub_dir.iterdir():
        if sf.suffix == ".jsonl":
            with open(sf, "r", errors="replace") as f:
                for line in f:
                    try:
                        msg = json.loads(line.strip())
                    except (json.JSONDecodeError, ValueError):
                        continue
                    if msg.get("type") != "assistant":
                        continue
                    inner = msg.get("message") or {}
                    usage = inner.get("usage") or {}
                    model = inner.get("model", "unknown")
                    inp = usage.get("input_tokens", 0) or 0
                    out = usage.get("output_tokens", 0) or 0
                    cr = usage.get("cache_read_input_tokens", 0) or 0
                    cw = usage.get("cache_creation_input_tokens", 0) or 0
                    totals["api_calls"] += 1
                    totals["input_tokens"] += inp
                    totals["output_tokens"] += out
                    totals["cache_read_tokens"] += cr
                    totals["cache_write_tokens"] += cw
                    totals["lines_written"] += count_lines_from_tools(inner.get("content", []))
                    totals["cost"] += estimate_cost_for_tokens(inp, out, cr, cw, model)

    return totals


def discover_sessions():
    """Find all session JSONL files across all project directories."""
    sessions = []
    for proj_dir in PROJECTS_DIR.iterdir():
        if not proj_dir.is_dir():
            continue
        for f in proj_dir.iterdir():
            if f.suffix == ".jsonl" and f.stat().st_size > 0:
                session_id = f.stem
                # Check for subagent directory
                session_subdir = proj_dir / session_id
                sessions.append({
                    "id": session_id,
                    "file_path": str(f),
                    "session_dir": session_subdir if session_subdir.exists() else None,
                    "mtime": f.stat().st_mtime,
                    "size_bytes": f.stat().st_size,
                })
    return sessions


def full_parse(db_conn):
    """Parse all sessions and store in database."""
    from database import get_db

    detect_projects()
    sessions = discover_sessions()

    conn = db_conn
    cursor = conn.cursor()

    parsed = 0
    skipped = 0

    for sess_info in sessions:
        sid = sess_info["id"]
        mtime = sess_info["mtime"]

        # Check if already parsed with same mtime
        existing = cursor.execute(
            "SELECT file_mtime FROM sessions WHERE id = ?", (sid,)
        ).fetchone()
        if existing and abs(existing["file_mtime"] - mtime) < 1:
            skipped += 1
            continue

        # Parse the session
        stats = parse_session_file(sess_info["file_path"])
        sub_stats = parse_subagents(sess_info["session_dir"]) if sess_info["session_dir"] else {
            "api_calls": 0, "input_tokens": 0, "output_tokens": 0,
            "cache_read_tokens": 0, "cache_write_tokens": 0, "lines_written": 0, "cost": 0.0
        }

        # Determine project
        proj_name = stats["dominant_project"] or "Other"
        proj_path = ""
        for pp, pn in PROJECT_DIRS.items():
            if pn == proj_name:
                proj_path = pp
                break

        # Use a single "Other" bucket for unassigned sessions
        if not proj_path:
            proj_path = "other"
            proj_name = "Other"

        # Upsert project
        cursor.execute(
            "INSERT INTO projects (path, name) VALUES (?, ?) ON CONFLICT(path) DO UPDATE SET name=excluded.name",
            (proj_path, proj_name),
        )
        proj_id = cursor.execute(
            "SELECT id FROM projects WHERE path = ?", (proj_path,)
        ).fetchone()["id"]

        # Upsert session
        cursor.execute("""
            INSERT INTO sessions (id, project_id, file_path, file_mtime,
                first_msg_at, last_msg_at, wall_seconds, active_seconds,
                human_prompts, api_calls, subagent_api_calls,
                input_tokens, output_tokens, cache_read_tokens, cache_write_tokens,
                lines_written, cost, model_usage, parsed_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
            ON CONFLICT(id) DO UPDATE SET
                project_id=excluded.project_id,
                file_mtime=excluded.file_mtime,
                first_msg_at=excluded.first_msg_at,
                last_msg_at=excluded.last_msg_at,
                wall_seconds=excluded.wall_seconds,
                active_seconds=excluded.active_seconds,
                human_prompts=excluded.human_prompts,
                api_calls=excluded.api_calls,
                subagent_api_calls=excluded.subagent_api_calls,
                input_tokens=excluded.input_tokens,
                output_tokens=excluded.output_tokens,
                cache_read_tokens=excluded.cache_read_tokens,
                cache_write_tokens=excluded.cache_write_tokens,
                lines_written=excluded.lines_written,
                cost=excluded.cost,
                model_usage=excluded.model_usage,
                parsed_at=excluded.parsed_at
        """, (
            sid, proj_id, sess_info["file_path"], mtime,
            stats["first_msg_at"], stats["last_msg_at"],
            stats["wall_seconds"], stats["active_seconds"],
            stats["human_prompts"], stats["api_calls"], sub_stats["api_calls"],
            stats["input_tokens"] + sub_stats["input_tokens"],
            stats["output_tokens"] + sub_stats["output_tokens"],
            stats["cache_read_tokens"] + sub_stats["cache_read_tokens"],
            stats["cache_write_tokens"] + sub_stats["cache_write_tokens"],
            stats["lines_written"] + sub_stats["lines_written"],
            stats["cost"] + sub_stats["cost"],
            json.dumps(stats["model_usage"]),
        ))

        # Store individual prompts (delete old ones first)
        cursor.execute("DELETE FROM prompts WHERE session_id = ?", (sid,))
        for m in stats["messages"]:
            m_proj_name = m.get("project_name") or proj_name
            m_proj_id = proj_id
            if m_proj_name != proj_name:
                cursor.execute(
                    "SELECT id FROM projects WHERE name = ?", (m_proj_name,)
                )
                row = cursor.fetchone()
                if row:
                    m_proj_id = row["id"]

            cursor.execute("""
                INSERT INTO prompts (session_id, project_id, timestamp, type, model,
                    input_tokens, output_tokens, cache_read_tokens, cache_write_tokens,
                    content_preview)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                sid, m_proj_id, m["timestamp"], m["type"], m["model"],
                m["input_tokens"], m["output_tokens"],
                m["cache_read_tokens"], m["cache_write_tokens"],
                m["content_preview"],
            ))

        parsed += 1

    conn.commit()
    return {"parsed": parsed, "skipped": skipped, "total": len(sessions)}
