import argparse
import json
import sqlite3
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
from PIL import Image, ImageOps


DEFAULT_DB = Path(__file__).resolve().parents[1] / "books.db"
DEFAULT_IMAGES = Path(r"C:\Users\ignac\OneDrive\Pictures\selected_libros")
DEFAULT_OUTPUT = Path(__file__).resolve().parents[1] / "optimized_images"
ANALYSIS_MAX = 640
PADDING_RATIO = 0.03
CONFIDENCE_THRESHOLD = 0.45
UNCROPPED_IMAGE_IDS = {
    81,
    318,
    492,
    522,
    555,
    563,
    592,
    598,
    602,
    606,
    607,
    608,
    610,
    611,
    634,
    652,
    653,
    658,
    659,
}


@dataclass
class CropResult:
    box: tuple[int, int, int, int]
    method: str
    confidence: float
    crop_ratio: float
    density: float
    border_cleanliness: float
    size_score: float


def pil_to_bgr(img: Image.Image) -> np.ndarray:
    rgb = ImageOps.exif_transpose(img).convert("RGB")
    return cv2.cvtColor(np.array(rgb), cv2.COLOR_RGB2BGR)


def source_path(images_dir: Path, folder: str, filename: str) -> Path:
    return images_dir / filename if folder == "root" else images_dir / folder / filename


def resize_for_analysis(bgr: np.ndarray, max_side: int = ANALYSIS_MAX) -> tuple[np.ndarray, float]:
    height, width = bgr.shape[:2]
    scale = min(1.0, max_side / max(height, width))
    if scale == 1.0:
        return bgr.copy(), 1.0
    resized = cv2.resize(bgr, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
    return resized, scale


def border_mask(height: int, width: int, ratio: float = 0.06) -> np.ndarray:
    border = max(3, int(min(height, width) * ratio))
    mask = np.zeros((height, width), dtype=np.uint8)
    mask[:border, :] = 1
    mask[-border:, :] = 1
    mask[:, :border] = 1
    mask[:, -border:] = 1
    return mask


def dominant_border_ranges(hsv: np.ndarray) -> list[tuple[np.ndarray, np.ndarray]]:
    h, w = hsv.shape[:2]
    edge = border_mask(h, w).astype(bool)
    border_pixels = hsv[edge]
    if border_pixels.size == 0:
        return []

    presets = [
        ("pink_a", np.array([130, 35, 40]), np.array([179, 255, 255])),
        ("pink_b", np.array([0, 35, 40]), np.array([12, 255, 255])),
        ("green", np.array([28, 18, 25]), np.array([95, 255, 255])),
        ("dark", np.array([0, 0, 0]), np.array([179, 90, 105])),
    ]

    ranges = []
    for _, low, high in presets:
        hits = np.all((border_pixels >= low) & (border_pixels <= high), axis=1).mean()
        if hits >= 0.08:
            ranges.append((low, high))

    if ranges:
        return ranges

    sample = border_pixels
    if len(sample) > 2500:
        idx = np.linspace(0, len(sample) - 1, 2500).astype(np.int32)
        sample = sample[idx]
    sample = np.float32(sample)
    _, labels, centers = cv2.kmeans(
        sample,
        2,
        None,
        (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 18, 1.0),
        3,
        cv2.KMEANS_PP_CENTERS,
    )
    counts = np.bincount(labels.flatten())
    center = centers[int(np.argmax(counts))]
    hue = int(center[0])
    sat = int(center[1])
    val = int(center[2])
    if sat < 35 and val < 140:
        return [(np.array([0, 0, 0]), np.array([179, min(120, sat + 55), min(150, val + 60)]))]

    low = np.array([max(0, hue - 16), max(0, sat - 70), max(0, val - 65)])
    high = np.array([min(179, hue + 16), min(255, sat + 70), min(255, val + 65)])
    return [(low, high)]


def mask_from_ranges(hsv: np.ndarray, ranges: list[tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
    bg = np.zeros(hsv.shape[:2], dtype=np.uint8)
    for low, high in ranges:
        bg = cv2.bitwise_or(bg, cv2.inRange(hsv, low, high))
    return bg


def connected_to_edge(mask: np.ndarray) -> np.ndarray:
    num_labels, labels = cv2.connectedComponents((mask > 0).astype(np.uint8))
    if num_labels <= 1:
        return np.zeros_like(mask)
    edge_labels = set(labels[0, :]) | set(labels[-1, :]) | set(labels[:, 0]) | set(labels[:, -1])
    edge_labels.discard(0)
    result = np.zeros_like(mask)
    for label in edge_labels:
        result[labels == label] = 255
    return result


def detect_edge_skin(hsv: np.ndarray, foreground: np.ndarray) -> np.ndarray:
    lower_a = np.array([0, 18, 45])
    upper_a = np.array([25, 210, 255])
    lower_b = np.array([160, 18, 45])
    upper_b = np.array([179, 210, 255])
    skin = cv2.bitwise_or(cv2.inRange(hsv, lower_a, upper_a), cv2.inRange(hsv, lower_b, upper_b))
    kernel = np.ones((5, 5), np.uint8)
    skin = cv2.morphologyEx(skin, cv2.MORPH_OPEN, kernel)
    skin = connected_to_edge(skin)
    if skin.sum() == 0:
        return skin

    total_pixels = skin.shape[0] * skin.shape[1]
    skin_ratio = float((skin > 0).sum()) / max(1, total_pixels)
    fg_ratio = float((foreground > 0).sum()) / max(1, total_pixels)
    if skin_ratio > 0.12:
        return np.zeros_like(skin)
    if fg_ratio > 0 and skin_ratio > fg_ratio * 0.25:
        return np.zeros_like(skin)
    return skin


def contour_score(contour: np.ndarray, shape: tuple[int, int]) -> tuple[float, dict]:
    height, width = shape
    area = cv2.contourArea(contour)
    if area <= 0:
        return 0.0, {}
    x, y, w, h = cv2.boundingRect(contour)
    hull = cv2.convexHull(contour)
    hull_area = max(1.0, cv2.contourArea(hull))
    solidity = float(area) / hull_area
    cx = x + w / 2
    cy = y + h / 2
    dx = abs(cx - width / 2) / max(1.0, width / 2)
    dy = abs(cy - height / 2) / max(1.0, height / 2)
    centrality = max(0.2, 1.55 - (dx * 0.85 + dy * 1.1))
    touches_edge = x <= 1 or y <= 1 or x + w >= width - 1 or y + h >= height - 1
    edge_penalty = 0.62 if touches_edge else 1.0
    if touches_edge and (w >= width * 0.72 or h >= height * 0.72):
        edge_penalty *= 0.08
    score = area * max(0.25, solidity) * centrality * edge_penalty
    return score, {
        "x": x,
        "y": y,
        "w": w,
        "h": h,
        "area": area,
        "solidity": solidity,
        "score": score,
    }


def merge_boxes(primary: dict, boxes: list[dict]) -> dict:
    x1 = primary["x"]
    y1 = primary["y"]
    x2 = primary["x"] + primary["w"]
    y2 = primary["y"] + primary["h"]
    pad_x = max(10, int(primary["w"] * 0.18))
    pad_y = max(10, int(primary["h"] * 0.18))
    primary_area = primary["area"]
    for box in boxes:
        if box is primary:
            continue
        if box["area"] < primary_area * 0.08:
            continue
        bx1 = box["x"]
        by1 = box["y"]
        bx2 = box["x"] + box["w"]
        by2 = box["y"] + box["h"]
        overlap_x = bx1 <= x2 + pad_x and bx2 >= x1 - pad_x
        overlap_y = by1 <= y2 + pad_y and by2 >= y1 - pad_y
        if overlap_x and overlap_y:
            x1 = min(x1, bx1)
            y1 = min(y1, by1)
            x2 = max(x2, bx2)
            y2 = max(y2, by2)
    return {"x": x1, "y": y1, "w": x2 - x1, "h": y2 - y1}


def confidence_for_box(box: tuple[int, int, int, int], fg_mask: np.ndarray) -> tuple[float, float, float, float]:
    height, width = fg_mask.shape
    x1, y1, x2, y2 = box
    x1 = max(0, min(width - 1, x1))
    y1 = max(0, min(height - 1, y1))
    x2 = max(x1 + 1, min(width, x2))
    y2 = max(y1 + 1, min(height, y2))
    crop = fg_mask[y1:y2, x1:x2]
    crop_ratio = float(crop.size) / max(1, width * height)
    density = float((crop > 0).mean()) if crop.size else 0.0
    if 0.06 <= crop_ratio <= 0.93:
        size_score = 1.0
    elif crop_ratio < 0.06:
        size_score = max(0.0, crop_ratio / 0.06)
    else:
        size_score = max(0.0, 1.0 - ((crop_ratio - 0.93) / 0.07))

    ring = np.zeros_like(fg_mask)
    outer_pad_x = max(4, int((x2 - x1) * 0.06))
    outer_pad_y = max(4, int((y2 - y1) * 0.06))
    ox1 = max(0, x1 - outer_pad_x)
    oy1 = max(0, y1 - outer_pad_y)
    ox2 = min(width, x2 + outer_pad_x)
    oy2 = min(height, y2 + outer_pad_y)
    ring[oy1:oy2, ox1:ox2] = 1
    ring[y1:y2, x1:x2] = 0
    ring_pixels = fg_mask[ring > 0]
    ring_fg = float((ring_pixels > 0).mean()) if ring_pixels.size else 0.0
    border_cleanliness = max(0.0, 1.0 - min(1.0, ring_fg / 0.28))
    confidence = 0.35 * density + 0.30 * size_score + 0.35 * border_cleanliness
    return confidence, crop_ratio, density, border_cleanliness


def scale_box(box: tuple[int, int, int, int], scale: float, original_shape: tuple[int, int]) -> tuple[int, int, int, int]:
    orig_h, orig_w = original_shape
    x1, y1, x2, y2 = box
    x1 = int(round(x1 / scale))
    y1 = int(round(y1 / scale))
    x2 = int(round(x2 / scale))
    y2 = int(round(y2 / scale))
    return (
        max(0, min(orig_w - 1, x1)),
        max(0, min(orig_h - 1, y1)),
        max(1, min(orig_w, x2)),
        max(1, min(orig_h, y2)),
    )


def pad_box(box: tuple[int, int, int, int], shape: tuple[int, int], ratio: float = PADDING_RATIO) -> tuple[int, int, int, int]:
    height, width = shape
    x1, y1, x2, y2 = box
    pad_x = max(4, int((x2 - x1) * ratio))
    pad_y = max(4, int((y2 - y1) * ratio))
    return (
        max(0, x1 - pad_x),
        max(0, y1 - pad_y),
        min(width, x2 + pad_x),
        min(height, y2 + pad_y),
    )


def contours_to_result(mask: np.ndarray, method: str, original_shape: tuple[int, int], scale: float) -> Optional[CropResult]:
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return None
    min_area = max(30.0, mask.shape[0] * mask.shape[1] * 0.01)
    scored = []
    for contour in contours:
        score, meta = contour_score(contour, mask.shape)
        if not meta or meta["area"] < min_area:
            continue
        scored.append(meta)
    if not scored:
        return None
    primary = max(scored, key=lambda item: item["score"])
    merged = merge_boxes(primary, scored)
    x1 = merged["x"]
    y1 = merged["y"]
    x2 = merged["x"] + merged["w"]
    y2 = merged["y"] + merged["h"]
    confidence, crop_ratio, density, border_cleanliness = confidence_for_box((x1, y1, x2, y2), mask)
    scaled = scale_box((x1, y1, x2, y2), scale, original_shape)
    padded = pad_box(scaled, original_shape, PADDING_RATIO)
    return CropResult(
        box=padded,
        method=method,
        confidence=round(confidence, 3),
        crop_ratio=round(crop_ratio, 3),
        density=round(density, 3),
        border_cleanliness=round(border_cleanliness, 3),
        size_score=round(max(0.0, min(1.0, 1.0 if 0.06 <= crop_ratio <= 0.93 else (crop_ratio / 0.06 if crop_ratio < 0.06 else 1.0 - ((crop_ratio - 0.93) / 0.07)))), 3),
    )


def method_hsv(bgr: np.ndarray) -> Optional[CropResult]:
    reduced, scale = resize_for_analysis(bgr)
    hsv = cv2.cvtColor(reduced, cv2.COLOR_BGR2HSV)
    ranges = dominant_border_ranges(hsv)
    if not ranges:
        return None
    bg = mask_from_ranges(hsv, ranges)
    bg = cv2.morphologyEx(bg, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8))
    bg = connected_to_edge(bg)
    bg_ratio = float((bg > 0).mean())
    if bg_ratio >= 0.88:
        return None
    fg = cv2.bitwise_not(bg)
    fg = cv2.morphologyEx(fg, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8))
    fg = cv2.morphologyEx(fg, cv2.MORPH_CLOSE, np.ones((9, 9), np.uint8))
    skin = detect_edge_skin(hsv, fg)
    if skin.any():
        fg[skin > 0] = 0
        fg = cv2.morphologyEx(fg, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8))
    return contours_to_result(fg, "hsv", bgr.shape[:2], scale)


def method_edges(bgr: np.ndarray) -> Optional[CropResult]:
    reduced, scale = resize_for_analysis(bgr)
    gray = cv2.cvtColor(reduced, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    median = float(np.median(blur))
    low = int(max(25, 0.66 * median))
    high = int(min(255, max(90, 1.33 * median)))
    edges = cv2.Canny(blur, low, high)
    edges = cv2.dilate(edges, np.ones((5, 5), np.uint8), iterations=1)
    edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, np.ones((11, 11), np.uint8))
    return contours_to_result(edges, "edges", bgr.shape[:2], scale)


def method_grabcut(bgr: np.ndarray, hint: Optional[tuple[int, int, int, int]]) -> Optional[CropResult]:
    reduced, scale = resize_for_analysis(bgr)
    h, w = reduced.shape[:2]
    if hint:
        hx1, hy1, hx2, hy2 = hint
        x1 = int(hx1 * scale)
        y1 = int(hy1 * scale)
        x2 = int(hx2 * scale)
        y2 = int(hy2 * scale)
        pad_x = max(8, int((x2 - x1) * 0.12))
        pad_y = max(8, int((y2 - y1) * 0.12))
        rect = (
            max(1, x1 - pad_x),
            max(1, y1 - pad_y),
            max(10, min(w - 2, x2 + pad_x) - max(1, x1 - pad_x)),
            max(10, min(h - 2, y2 + pad_y) - max(1, y1 - pad_y)),
        )
    else:
        margin_x = int(w * 0.09)
        margin_y = int(h * 0.09)
        rect = (margin_x, margin_y, w - 2 * margin_x, h - 2 * margin_y)
    mask = np.zeros(reduced.shape[:2], np.uint8)
    bg_model = np.zeros((1, 65), np.float64)
    fg_model = np.zeros((1, 65), np.float64)
    try:
        cv2.grabCut(reduced, mask, rect, bg_model, fg_model, 3, cv2.GC_INIT_WITH_RECT)
    except cv2.error:
        return None
    fg = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
    fg = cv2.morphologyEx(fg, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8))
    fg = cv2.morphologyEx(fg, cv2.MORPH_CLOSE, np.ones((9, 9), np.uint8))
    return contours_to_result(fg, "grabcut", bgr.shape[:2], scale)


def method_center(bgr: np.ndarray) -> CropResult:
    h, w = bgr.shape[:2]
    margin_x = int(w * 0.09)
    margin_y = int(h * 0.09)
    box = (margin_x, margin_y, w - margin_x, h - margin_y)
    return CropResult(box=box, method="center", confidence=0.30, crop_ratio=0.82, density=0.0, border_cleanliness=0.0, size_score=0.82)


def method_uncropped(bgr: np.ndarray) -> CropResult:
    h, w = bgr.shape[:2]
    return CropResult(box=(0, 0, w, h), method="override_full", confidence=1.0, crop_ratio=1.0, density=1.0, border_cleanliness=1.0, size_score=0.0)


def smart_crop_cv(bgr: np.ndarray, threshold: float = CONFIDENCE_THRESHOLD) -> CropResult:
    hsv_result = method_hsv(bgr)
    if hsv_result and hsv_result.confidence >= threshold:
        return hsv_result
    edge_result = method_edges(bgr)
    if edge_result and edge_result.confidence >= threshold:
        return edge_result
    hint = hsv_result.box if hsv_result else (edge_result.box if edge_result else None)
    grabcut_result = method_grabcut(bgr, hint)
    if grabcut_result and grabcut_result.confidence >= threshold:
        return grabcut_result

    candidates = [result for result in (hsv_result, edge_result, grabcut_result) if result]
    if candidates:
        best = max(candidates, key=lambda result: result.confidence)
        if best.confidence >= 0.30:
            return best
    return method_center(bgr)


def smart_crop_pil(img: Image.Image, threshold: float = CONFIDENCE_THRESHOLD) -> CropResult:
    return smart_crop_cv(pil_to_bgr(img), threshold=threshold)


def find_crop_box_cv(img: Image.Image | np.ndarray, threshold: float = CONFIDENCE_THRESHOLD) -> tuple[int, int, int, int]:
    if isinstance(img, Image.Image):
        bgr = pil_to_bgr(img)
    else:
        bgr = img
    return smart_crop_cv(bgr, threshold=threshold).box


def build_variant(
    src: Path,
    dest: Path,
    max_side: int,
    quality: int,
    threshold: float,
    image_id: Optional[int] = None,
) -> CropResult:
    dest.parent.mkdir(parents=True, exist_ok=True)
    with Image.open(src) as img:
        oriented = ImageOps.exif_transpose(img).convert("RGB")
        if image_id in UNCROPPED_IMAGE_IDS:
            result = method_uncropped(pil_to_bgr(oriented))
        else:
            result = smart_crop_pil(oriented, threshold=threshold)
        cropped = oriented.crop(result.box)
        cropped.thumbnail((max_side, max_side), Image.Resampling.LANCZOS)
        cropped.save(dest, format="JPEG", quality=quality, optimize=True, progressive=True)
    return result


def ensure_columns(conn: sqlite3.Connection):
    columns = {row[1] for row in conn.execute("PRAGMA table_info(book_images)").fetchall()}
    if "storage_path" not in columns:
        conn.execute('ALTER TABLE book_images ADD COLUMN storage_path TEXT DEFAULT ""')
    if "thumb_storage_path" not in columns:
        conn.execute('ALTER TABLE book_images ADD COLUMN thumb_storage_path TEXT DEFAULT ""')
    conn.commit()


def run_batch(args):
    args.out.mkdir(parents=True, exist_ok=True)
    conn = sqlite3.connect(str(args.db))
    conn.row_factory = sqlite3.Row
    ensure_columns(conn)
    rows = conn.execute("SELECT id, source_folder, source_filename FROM book_images ORDER BY id").fetchall()

    report = []
    flagged = []
    for row in rows:
        image_id = int(row["id"])
        src = source_path(args.images, row["source_folder"], row["source_filename"])
        if not src.exists():
            flagged.append({"id": image_id, "filename": row["source_filename"], "reason": "missing source file"})
            continue

        view_path = f"view/{image_id:04d}.jpg"
        thumb_path = f"thumb/{image_id:04d}.jpg"
        view_dest = args.out / view_path
        thumb_dest = args.out / thumb_path

        with Image.open(src) as img:
            oriented = ImageOps.exif_transpose(img).convert("RGB")
            if image_id in UNCROPPED_IMAGE_IDS:
                result = method_uncropped(pil_to_bgr(oriented))
            else:
                result = smart_crop_pil(oriented, threshold=args.threshold)
            if not args.dry_run:
                crop = oriented.crop(result.box)
                view = crop.copy()
                thumb = crop.copy()
                view.thumbnail((args.view_max, args.view_max), Image.Resampling.LANCZOS)
                thumb.thumbnail((args.thumb_max, args.thumb_max), Image.Resampling.LANCZOS)
                view_dest.parent.mkdir(parents=True, exist_ok=True)
                thumb_dest.parent.mkdir(parents=True, exist_ok=True)
                view.save(view_dest, format="JPEG", quality=args.quality, optimize=True, progressive=True)
                thumb.save(thumb_dest, format="JPEG", quality=args.quality, optimize=True, progressive=True)
                conn.execute(
                    "UPDATE book_images SET storage_path = ?, thumb_storage_path = ? WHERE id = ?",
                    (view_path, thumb_path, image_id),
                )

        item = {
            "id": image_id,
            "source_folder": row["source_folder"],
            "source_filename": row["source_filename"],
            **asdict(result),
        }
        report.append(item)
        if result.confidence < args.threshold:
            flagged.append({"id": image_id, "filename": row["source_filename"], "method": result.method, "confidence": result.confidence})

    if not args.dry_run:
        conn.commit()
    conn.close()

    output = {
        "generated": len(report),
        "threshold": args.threshold,
        "flagged_for_review": flagged,
    }
    report_path = args.out / "crop_report.json"
    report_path.write_text(json.dumps({"summary": output, "items": report}, indent=2), encoding="utf-8")
    print(json.dumps({**output, "report": str(report_path)}, indent=2))


def run_test(args):
    out_dir = args.out_dir
    out_dir.mkdir(parents=True, exist_ok=True)
    results = []
    for image_path in args.images:
        src = Path(image_path)
        with Image.open(src) as img:
            oriented = ImageOps.exif_transpose(img).convert("RGB")
            result = smart_crop_pil(oriented, threshold=args.threshold)
            crop = oriented.crop(result.box)
            crop.thumbnail((args.thumb_max, args.thumb_max), Image.Resampling.LANCZOS)
            output_path = out_dir / f"{src.stem}_cropped.jpg"
            crop.save(output_path, format="JPEG", quality=args.quality, optimize=True, progressive=True)
            results.append({"image": str(src), "output": str(output_path), **asdict(result)})
    print(json.dumps(results, indent=2))


def main():
    parser = argparse.ArgumentParser(description="Smart crop and optimize Bibliothek book photos")
    subparsers = parser.add_subparsers(dest="command", required=True)

    test_parser = subparsers.add_parser("test", help="Crop specific test images")
    test_parser.add_argument("images", nargs="+", type=Path)
    test_parser.add_argument("--out-dir", type=Path, default=DEFAULT_OUTPUT / "debug")
    test_parser.add_argument("--quality", type=int, default=84)
    test_parser.add_argument("--thumb-max", type=int, default=420)
    test_parser.add_argument("--threshold", type=float, default=CONFIDENCE_THRESHOLD)
    test_parser.set_defaults(func=run_test)

    batch_parser = subparsers.add_parser("batch", help="Process the full catalog")
    batch_parser.add_argument("--db", type=Path, default=DEFAULT_DB)
    batch_parser.add_argument("--images", type=Path, default=DEFAULT_IMAGES)
    batch_parser.add_argument("--out", type=Path, default=DEFAULT_OUTPUT)
    batch_parser.add_argument("--view-max", type=int, default=1600)
    batch_parser.add_argument("--thumb-max", type=int, default=420)
    batch_parser.add_argument("--quality", type=int, default=84)
    batch_parser.add_argument("--threshold", type=float, default=CONFIDENCE_THRESHOLD)
    batch_parser.add_argument("--dry-run", action="store_true")
    batch_parser.set_defaults(func=run_batch)

    args = parser.parse_args()
    args.func(args)


if __name__ == "__main__":
    main()
