#!/usr/bin/env python3
"""Contextual heart-rate analysis for Chicho's MVP VPT.

Crosses Health Bridge heart-rate samples with sleep sessions and activity/exercise
windows to produce operational readiness/coaching metrics.
"""
import argparse
import json
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from statistics import mean
from zoneinfo import ZoneInfo

DB_PATH = Path.home() / "health-bridge" / "health.db"
TZ = ZoneInfo("America/Argentina/Buenos_Aires")

@dataclass
class Window:
    label: str
    start: datetime
    end: datetime
    meta: dict


def parse_ts(value: str) -> datetime:
    if value.endswith("Z"):
        value = value[:-1] + "+00:00"
    dt = datetime.fromisoformat(value)
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    return dt.astimezone(timezone.utc)


def load_json_rows(con, subject, data_type, start_date, end_date):
    rows = con.execute(
        """
        select local_date, source, data_json
        from raw_data
        where subject=? and data_type=? and local_date between ? and ?
        order by local_date, received_at
        """,
        (subject, data_type, start_date, end_date),
    ).fetchall()
    out = []
    for local_date, source, payload in rows:
        try:
            obj = json.loads(payload)
        except Exception:
            continue
        obj.setdefault("local_date", local_date)
        obj.setdefault("source", source)
        out.append(obj)
    return out


def stats(samples):
    vals = [s[1] for s in samples]
    if not vals:
        return None
    vals_sorted = sorted(vals)
    def pct(p):
        i = min(len(vals_sorted) - 1, max(0, round((len(vals_sorted)-1) * p)))
        return vals_sorted[i]
    return {
        "samples": len(vals),
        "avg_bpm": round(mean(vals), 1),
        "min_bpm": min(vals),
        "max_bpm": max(vals),
        "p50_bpm": pct(0.50),
        "p90_bpm": pct(0.90),
        "minutes_est": round(len(vals) / 30, 1),  # Fitbit often gives ~2s samples during dense windows
    }


def window_samples(samples, start, end):
    return [(t, bpm) for t, bpm in samples if start <= t < end]


def overlap(a_start, a_end, b_start, b_end):
    return max(a_start, b_start) < min(a_end, b_end)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--date", help="Local date YYYY-MM-DD. Defaults to latest HR date.")
    ap.add_argument("--days", type=int, default=1)
    ap.add_argument("--subject", default="chicho")
    ap.add_argument("--json", action="store_true")
    args = ap.parse_args()

    con = sqlite3.connect(DB_PATH)
    if not args.date:
        row = con.execute("select max(local_date) from raw_data where subject=? and data_type='heart_rate'", (args.subject,)).fetchone()
        args.date = row[0]
    if not args.date:
        raise SystemExit("No heart_rate data found")

    start_local = datetime.fromisoformat(args.date).replace(tzinfo=TZ)
    end_local = start_local + timedelta(days=args.days)
    start_date = start_local.date().isoformat()
    end_date = (end_local - timedelta(seconds=1)).date().isoformat()

    hr_records = load_json_rows(con, args.subject, "heart_rate", start_date, end_date)
    samples = []
    for rec in hr_records:
        for s in rec.get("samples") or []:
            if "time" in s and "bpm" in s:
                samples.append((parse_ts(s["time"]), float(s["bpm"])))
    samples.sort()

    sleep_records = load_json_rows(con, args.subject, "sleep", start_date, end_date)
    exercise_records = load_json_rows(con, args.subject, "exercise", start_date, end_date)

    windows = []
    for rec in sleep_records:
        if rec.get("start") and rec.get("end"):
            windows.append(Window("sleep_window", parse_ts(rec["start"]), parse_ts(rec["end"]), {
                "source": rec.get("source"),
                "effective_sleep_min": rec.get("minutes"),
                "awake_min": rec.get("awake_minutes"),
                "efficiency": round((rec.get("minutes", 0) / rec.get("duration_minutes", 1)) * 100, 1) if rec.get("duration_minutes") else None,
            }))
            for st in rec.get("stages") or []:
                if st.get("start") and st.get("end"):
                    windows.append(Window("sleep_stage_" + st.get("stage", "unknown"), parse_ts(st["start"]), parse_ts(st["end"]), {"stage": st.get("stage")}))

    for rec in exercise_records:
        if rec.get("start") and rec.get("end"):
            cls = rec.get("source_class") or rec.get("source")
            label = "exercise_auto" if cls != "explicit_workout" else "exercise_workout"
            windows.append(Window(label, parse_ts(rec["start"]), parse_ts(rec["end"]), {
                "source": rec.get("source"), "minutes": rec.get("minutes"), "exercise_type": rec.get("exercise_type")
            }))

    day_start_utc = start_local.astimezone(timezone.utc)
    day_end_utc = end_local.astimezone(timezone.utc)
    sleep_intervals = [(w.start, w.end) for w in windows if w.label == "sleep_window"]
    exercise_intervals = [(w.start, w.end) for w in windows if w.label.startswith("exercise_")]

    all_stats = stats(window_samples(samples, day_start_utc, day_end_utc))
    result = {"subject": args.subject, "date": args.date, "all_day": all_stats, "windows": []}

    for w in windows:
        st = stats(window_samples(samples, w.start, w.end))
        if st:
            result["windows"].append({
                "label": w.label,
                "start_local": w.start.astimezone(TZ).isoformat(timespec="minutes"),
                "end_local": w.end.astimezone(TZ).isoformat(timespec="minutes"),
                **w.meta,
                **st,
            })

    awake_rest = []
    active = []
    for t, bpm in samples:
        if not (day_start_utc <= t < day_end_utc):
            continue
        in_sleep = any(s <= t < e for s, e in sleep_intervals)
        in_ex = any(s <= t < e for s, e in exercise_intervals)
        if in_ex or bpm >= 105:
            active.append((t, bpm))
        elif not in_sleep and bpm < 90:
            awake_rest.append((t, bpm))
    result["derived"] = {
        "awake_resting_like": stats(awake_rest),
        "active_or_high_hr": stats(active),
    }

    sleep_stats = [w for w in result["windows"] if w["label"] == "sleep_window"]
    awake_rest_stats = result["derived"].get("awake_resting_like")
    if sleep_stats and awake_rest_stats:
        result["derived"]["sleep_vs_awake_rest_delta_bpm"] = round(
            sleep_stats[0]["avg_bpm"] - awake_rest_stats["avg_bpm"], 1
        )

    if args.json:
        print(json.dumps(result, ensure_ascii=False, indent=2))
        return

    print(f"HR context — {args.subject} — {args.date}")
    print(f"All day: {all_stats or 'no HR samples'}")
    for key, val in result["derived"].items():
        print(f"{key}: {val}")
    print("\nWindows:")
    for w in result["windows"]:
        meta = []
        for k in ("source", "effective_sleep_min", "awake_min", "efficiency", "minutes", "stage"):
            if w.get(k) is not None:
                meta.append(f"{k}={w[k]}")
        print(f"- {w['label']} {w['start_local']}→{w['end_local']}: avg {w['avg_bpm']} min {w['min_bpm']} max {w['max_bpm']} p90 {w['p90_bpm']} samples {w['samples']} {'; '.join(meta)}")

if __name__ == "__main__":
    main()
