"""
DEWA PR Tracker — Database layer (SQLite)
"""
import sqlite3
from contextlib import contextmanager
from datetime import datetime
from config import DB_PATH


# ── Schema ────────────────────────────────────────────────────────────────────

SCHEMA = """
CREATE TABLE IF NOT EXISTS visitors (
    id         INTEGER PRIMARY KEY AUTOINCREMENT,
    session_id TEXT    NOT NULL,
    visited_at TEXT    NOT NULL,
    user_agent TEXT
);

CREATE TABLE IF NOT EXISTS press_releases (
    id          INTEGER PRIMARY KEY AUTOINCREMENT,
    title       TEXT    NOT NULL,
    url         TEXT    UNIQUE NOT NULL,
    date        TEXT,                    -- ISO-8601 (YYYY-MM-DD)
    category    TEXT,
    description TEXT,
    scraped_at  TEXT    NOT NULL
);

CREATE TABLE IF NOT EXISTS podcasts (
    id          INTEGER PRIMARY KEY AUTOINCREMENT,
    pr_id       INTEGER REFERENCES press_releases(id),
    title       TEXT    NOT NULL,
    url         TEXT    UNIQUE NOT NULL,
    audio_url   TEXT,
    date        TEXT,
    duration    TEXT,
    language    TEXT    DEFAULT 'en',
    scraped_at  TEXT    NOT NULL
);

CREATE TABLE IF NOT EXISTS scrape_log (
    id          INTEGER PRIMARY KEY AUTOINCREMENT,
    ran_at      TEXT NOT NULL,
    prs_found   INTEGER DEFAULT 0,
    pods_found  INTEGER DEFAULT 0,
    status      TEXT,
    notes       TEXT
);
"""


@contextmanager
def get_conn():
    conn = sqlite3.connect(DB_PATH)
    conn.row_factory = sqlite3.Row
    try:
        yield conn
        conn.commit()
    finally:
        conn.close()


def init_db():
    with get_conn() as conn:
        conn.executescript(SCHEMA)


# ── Upsert helpers ────────────────────────────────────────────────────────────

def upsert_press_release(title: str, url: str, date: str = None,
                          category: str = None, description: str = None) -> int:
    """Insert or update a press release. Returns the row id."""
    now = datetime.utcnow().isoformat()
    with get_conn() as conn:
        cur = conn.execute(
            """
            INSERT INTO press_releases (title, url, date, category, description, scraped_at)
            VALUES (?, ?, ?, ?, ?, ?)
            ON CONFLICT(url) DO UPDATE SET
                title       = excluded.title,
                date        = excluded.date,
                category    = excluded.category,
                description = excluded.description,
                scraped_at  = excluded.scraped_at
            """,
            (title, url, date, category, description, now),
        )
        if cur.lastrowid:
            return cur.lastrowid
        row = conn.execute("SELECT id FROM press_releases WHERE url = ?", (url,)).fetchone()
        return row["id"]


def upsert_podcast(title: str, url: str, pr_id: int = None,
                   audio_url: str = None, date: str = None, duration: str = None,
                   language: str = "en") -> int:
    now = datetime.utcnow().isoformat()
    with get_conn() as conn:
        cur = conn.execute(
            """
            INSERT INTO podcasts (pr_id, title, url, audio_url, date, duration, language, scraped_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            ON CONFLICT(url) DO UPDATE SET
                title     = excluded.title,
                audio_url = excluded.audio_url,
                date      = excluded.date,
                duration  = excluded.duration,
                language  = excluded.language,
                scraped_at= excluded.scraped_at
            """,
            (pr_id, title, url, audio_url, date, duration, language, now),
        )
        if cur.lastrowid:
            return cur.lastrowid
        row = conn.execute("SELECT id FROM podcasts WHERE url = ?", (url,)).fetchone()
        return row["id"]


def log_visit(session_id: str, user_agent: str = ""):
    with get_conn() as conn:
        exists = conn.execute(
            "SELECT 1 FROM visitors WHERE session_id = ?", (session_id,)
        ).fetchone()
        if not exists:
            conn.execute(
                "INSERT INTO visitors (session_id, visited_at, user_agent) VALUES (?,?,?)",
                (session_id, datetime.utcnow().isoformat(), user_agent or ""),
            )
            return True   # new visit
    return False  # already logged


def get_visitor_stats():
    with get_conn() as conn:
        total   = conn.execute("SELECT COUNT(*) FROM visitors").fetchone()[0]
        today   = conn.execute(
            "SELECT COUNT(*) FROM visitors WHERE visited_at >= date('now')"
        ).fetchone()[0]
        by_day  = conn.execute(
            "SELECT substr(visited_at,1,10) AS day, COUNT(*) AS visits "
            "FROM visitors GROUP BY day ORDER BY day DESC LIMIT 30"
        ).fetchall()
        recent  = conn.execute(
            "SELECT visited_at, user_agent FROM visitors ORDER BY visited_at DESC LIMIT 20"
        ).fetchall()
    return {"total": total, "today": today, "by_day": by_day, "recent": recent}


def log_scrape(prs_found: int, pods_found: int, status: str, notes: str = ""):
    with get_conn() as conn:
        conn.execute(
            "INSERT INTO scrape_log (ran_at, prs_found, pods_found, status, notes) VALUES (?,?,?,?,?)",
            (datetime.utcnow().isoformat(), prs_found, pods_found, status, notes),
        )


# ── Query helpers ─────────────────────────────────────────────────────────────

def get_all_press_releases():
    with get_conn() as conn:
        return conn.execute(
            """
            SELECT
                pr.id,
                pr.title,
                pr.url,
                pr.date,
                pr.category,
                pr.description,
                pr.scraped_at,
                COUNT(p.id) AS podcast_count
            FROM press_releases pr
            LEFT JOIN podcasts p ON p.pr_id = pr.id
            GROUP BY pr.id
            ORDER BY pr.date DESC, pr.id DESC
            """
        ).fetchall()


def get_all_podcasts(language: str = None):
    with get_conn() as conn:
        if language:
            return conn.execute(
                """
                SELECT p.*, pr.title AS pr_title
                FROM podcasts p
                LEFT JOIN press_releases pr ON pr.id = p.pr_id
                WHERE p.language = ?
                ORDER BY p.date DESC, p.id DESC
                """,
                (language,),
            ).fetchall()
        return conn.execute(
            """
            SELECT p.*, pr.title AS pr_title
            FROM podcasts p
            LEFT JOIN press_releases pr ON pr.id = p.pr_id
            ORDER BY p.date DESC, p.id DESC
            """
        ).fetchall()


def get_counts_by_day():
    with get_conn() as conn:
        return conn.execute(
            "SELECT date, COUNT(*) as count FROM press_releases WHERE date IS NOT NULL GROUP BY date ORDER BY date"
        ).fetchall()


def get_counts_by_month():
    with get_conn() as conn:
        return conn.execute(
            """
            SELECT substr(date,1,7) AS month, COUNT(*) AS count
            FROM press_releases WHERE date IS NOT NULL
            GROUP BY month ORDER BY month
            """
        ).fetchall()


def get_roi_monthly_since(start_date: str):
    """
    Monthly ROI breakdown since start_date.
    Cost unit = per press release (each PR had one podcast @ AED 20,000).
    Podcast episodes (EN + AR) are counted separately for reference.
    """
    with get_conn() as conn:
        return conn.execute(
            """
            SELECT
                substr(pr.date, 1, 7)                              AS month,
                COUNT(DISTINCT pr.id)                              AS pr_count,
                SUM(CASE WHEN p.language = 'en' THEN 1 ELSE 0 END) AS en_episodes,
                SUM(CASE WHEN p.language = 'ar' THEN 1 ELSE 0 END) AS ar_episodes,
                COUNT(p.id)                                        AS total_episodes
            FROM press_releases pr
            LEFT JOIN podcasts p ON p.pr_id = pr.id
            WHERE pr.date >= ?
            GROUP BY month
            ORDER BY month
            """,
            (start_date,),
        ).fetchall()


def get_scrape_log():
    with get_conn() as conn:
        return conn.execute(
            "SELECT * FROM scrape_log ORDER BY ran_at DESC LIMIT 50"
        ).fetchall()


def get_stats():
    with get_conn() as conn:
        total_prs    = conn.execute("SELECT COUNT(*) FROM press_releases").fetchone()[0]
        total_pods   = conn.execute("SELECT COUNT(*) FROM podcasts").fetchone()[0]
        last_scraped = conn.execute("SELECT MAX(scraped_at) FROM press_releases").fetchone()[0]
        return {"total_prs": total_prs, "total_podcasts": total_pods, "last_scraped": last_scraped}
