#!/usr/bin/env python3
"""
Book deduplication pipeline - content-based approach.
Compares OCR-extracted title/author/publisher to find duplicates.
Uses Mistral AI to validate ambiguous groups.
"""

import sqlite3
import json
import re
import time
import os
from collections import defaultdict, Counter
from difflib import SequenceMatcher
from dotenv import load_dotenv
from mistralai.client import Mistral
from tqdm import tqdm

load_dotenv()

DB_PATH = "books.db"

STOPWORDS = {'the', 'a', 'an', 'of', 'and', 'in', 'to', 'for', 'on', 'at', 'by', 'from', 'with',
             'el', 'la', 'los', 'las', 'de', 'del', 'en', 'y', 'un', 'una', 'por', 'con'}


# ─── Text similarity ─────────────────────────────────────────────────────────

def normalize(s):
    if not s:
        return ''
    s = s.lower().strip()
    s = re.sub(r'[^a-z0-9àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ\s]', '', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def content_words(s):
    words = normalize(s).split()
    return [w for w in words if w not in STOPWORDS and len(w) > 1]

def title_match_score(t1, t2):
    n1, n2 = normalize(t1), normalize(t2)
    if not n1 or not n2:
        return 0.0
    if n1 == n2:
        return 1.0

    char_sim = SequenceMatcher(None, n1, n2).ratio()
    w1, w2 = set(content_words(t1)), set(content_words(t2))
    if not w1 or not w2:
        return char_sim

    jaccard = len(w1 & w2) / len(w1 | w2)
    # Both must agree - prevents "The X of Y" matching "The Z of Y"
    return min(char_sim, jaccard) * 0.3 + max(char_sim, jaccard) * 0.7

def author_match_score(a1, a2):
    n1, n2 = normalize(a1), normalize(a2)
    if not n1 or not n2:
        return 0.0
    return SequenceMatcher(None, n1, n2).ratio()

def is_unknown(s):
    return normalize(s) in ('', 'unknown', 'none', 'na', 'n/a')


# ─── Clustering ──────────────────────────────────────────────────────────────

def match_to_cluster(img, cluster_rep):
    """Check if an image matches a cluster representative."""
    t1 = img.get('title', '')
    a1 = img.get('author', '')
    p1 = img.get('publisher', '')

    t2 = cluster_rep.get('title', '')
    a2 = cluster_rep.get('author', '')
    p2 = cluster_rep.get('publisher', '')

    # Both have unknown titles -> don't cluster
    if is_unknown(t1) and is_unknown(t2):
        return 0.0
    # One side has no useful info at all
    if is_unknown(t1) and is_unknown(a1):
        return 0.0
    if is_unknown(t2) and is_unknown(a2):
        return 0.0

    tscore = title_match_score(t1, t2)
    ascore = author_match_score(a1, a2)

    # Publisher mismatch = different edition = different book
    if not is_unknown(p1) and not is_unknown(p2):
        pscore = SequenceMatcher(None, normalize(p1), normalize(p2)).ratio()
        if pscore < 0.25:
            return 0.0

    # Strong title match alone
    if tscore >= 0.70:
        return tscore
    # Moderate title + author confirmation
    if tscore >= 0.55 and ascore >= 0.50:
        return tscore * 0.6 + ascore * 0.4

    return 0.0

def cluster_images(images):
    """Greedy clustering: assign each image to best matching cluster or create new."""
    clusters = []  # Each: (representative_img, [member_indices])
    THRESHOLD = 0.60

    for i, img in enumerate(tqdm(images, desc="Clustering")):
        best_score = 0
        best_cluster = -1

        for ci, (rep, members) in enumerate(clusters):
            score = match_to_cluster(img, rep)
            if score > best_score:
                best_score = score
                best_cluster = ci

        if best_score >= THRESHOLD:
            clusters[best_cluster][1].append(i)
            # Update representative if this image has more info
            rep = clusters[best_cluster][0]
            t_new = img.get('title', '') or ''
            t_old = rep.get('title', '') or ''
            if len(t_new) > len(t_old) and not is_unknown(t_new):
                clusters[best_cluster] = (img, clusters[best_cluster][1])
        else:
            clusters.append((img, [i]))

    groups = [members for _, members in clusters]
    multi = [g for g in groups if len(g) > 1]
    print(f"\n  {len(images)} images -> {len(groups)} groups ({len(multi)} multi-image)")

    sizes = Counter(len(g) for g in groups)
    for s in sorted(sizes):
        if sizes[s] > 0:
            print(f"    {s} images: {sizes[s]} groups")

    return groups


# ─── AI validation ───────────────────────────────────────────────────────────

def validate_with_mistral(client, images, groups):
    """Use Mistral to validate groups >= 4 images."""
    to_validate = [(gi, g) for gi, g in enumerate(groups) if len(g) >= 4]
    if not to_validate:
        print("\nNo large groups to validate with AI.")
        return groups

    print(f"\nAI validation of {len(to_validate)} large groups...")

    batch_size = 8
    for batch_start in range(0, len(to_validate), batch_size):
        batch = to_validate[batch_start:batch_start + batch_size]

        prompt = """Review these groups of book catalog entries. Each group should contain images of ONE physical book.

For each group, determine if all entries are the same book or if some should be split off.
- Minor OCR typos in title/author = SAME book
- Completely different titles = DIFFERENT books = split
- Different publisher/year for same title = DIFFERENT edition = split

Return JSON: {"results": [{"group_id": N, "action": "ok"} or {"group_id": N, "action": "split", "subgroups": [[idx1,idx2],[idx3,idx4]]}]}

GROUPS:
"""
        for gi, g in batch:
            entries = []
            for idx in g:
                img = images[idx]
                entries.append(f"  idx={idx}: title=\"{img['title']}\", author=\"{img['author']}\", publisher=\"{img.get('publisher','')}\"")
            prompt += f"\nGroup {gi} ({len(g)} images):\n" + "\n".join(entries)

        for attempt in range(5):
            try:
                resp = client.chat.complete(
                    model="mistral-small-latest",
                    messages=[{"role": "user", "content": prompt}],
                    response_format={"type": "json_object"}
                )
                result = json.loads(resp.choices[0].message.content)

                for r in result.get('results', []):
                    gi = r['group_id']
                    if gi < len(groups) and r.get('action') == 'split' and 'subgroups' in r:
                        subs = r['subgroups']
                        if subs:
                            groups[gi] = subs[0]
                            for sub in subs[1:]:
                                groups.append(sub)
                            print(f"  Split group {gi} into {len(subs)} subgroups")
                break
            except Exception as e:
                if '429' in str(e) or 'rate' in str(e).lower():
                    wait = 5 * (attempt + 1)
                    print(f"  Rate limited, waiting {wait}s...")
                    time.sleep(wait)
                else:
                    print(f"  AI error: {e}")
                    break
        time.sleep(2)

    groups = [g for g in groups if g]
    return groups


# ─── Merge and save ──────────────────────────────────────────────────────────

def merge_group(images, group_indices):
    """Merge multiple image records into one book record."""
    group = [images[i] for i in group_indices]

    fields = ['title', 'subtitle', 'author', 'translator', 'publisher',
              'year', 'edition', 'language', 'series', 'volume',
              'condition', 'special_features']

    result = {}
    for field in fields:
        values = []
        for img in group:
            val = (img.get(field) or '').strip()
            if val and not is_unknown(val):
                values.append(val)
        result[field] = max(values, key=len) if values else ''

    # Notes
    notes = set()
    for img in group:
        val = (img.get('other_text') or '').strip()
        if val and not is_unknown(val):
            notes.add(val)
    result['notes'] = '; '.join(notes) if notes else ''

    # Raw OCR
    raws = []
    for img in group:
        r = (img.get('raw_ocr_text') or '').strip()
        if r:
            raws.append(r)
    result['raw_ocr_combined'] = '\n---\n'.join(raws)

    # Source images
    sources = [{'id': img['id'], 'folder': img['source_folder'], 'filename': img['source_filename']}
               for img in group]
    result['source_images'] = json.dumps(sources)
    result['image_count'] = len(group)

    return result

def save_books(conn, images, groups):
    cur = conn.cursor()
    cur.execute("DROP TABLE IF EXISTS books")
    cur.execute("""
        CREATE TABLE books (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            title TEXT,
            subtitle TEXT,
            author TEXT,
            translator TEXT,
            publisher TEXT,
            year TEXT,
            edition TEXT,
            language TEXT,
            series TEXT,
            volume TEXT,
            condition TEXT,
            special_features TEXT,
            notes TEXT,
            raw_ocr_combined TEXT,
            source_images TEXT,
            image_count INTEGER
        )
    """)

    saved = 0
    for group in groups:
        merged = merge_group(images, group)
        if not merged['title'] and not merged['author'] and not merged['publisher']:
            continue
        cur.execute("""
            INSERT INTO books (title, subtitle, author, translator, publisher, year,
                             edition, language, series, volume, condition, special_features,
                             notes, raw_ocr_combined, source_images, image_count)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            merged['title'], merged['subtitle'], merged['author'], merged['translator'],
            merged['publisher'], merged['year'], merged['edition'], merged['language'],
            merged['series'], merged['volume'], merged['condition'], merged['special_features'],
            merged['notes'], merged['raw_ocr_combined'], merged['source_images'], merged['image_count']
        ))
        saved += 1

    conn.commit()
    return saved


# ─── Main ────────────────────────────────────────────────────────────────────

def main():
    print("=" * 60)
    print("BOOK DEDUPLICATION PIPELINE")
    print("=" * 60)

    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    cur.execute("""
        SELECT id, title, subtitle, author, translator, publisher, year,
               edition, language, series, volume, condition, special_features,
               other_text, source_folder, source_filename, raw_ocr_text
        FROM book_images
    """)
    cols = [d[0] for d in cur.description]
    images = [dict(zip(cols, r)) for r in cur.fetchall()]
    print(f"Loaded {len(images)} image records\n")

    # Step 1: Content-based clustering
    print("Step 1: Clustering by title/author/publisher similarity...")
    groups = cluster_images(images)

    # Step 2: AI validation for large groups
    print("\nStep 2: AI validation...")
    client = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
    groups = validate_with_mistral(client, images, groups)

    # Step 3: Merge and save
    print("\nStep 3: Merging and saving...")
    count = save_books(conn, images, groups)

    multi = [g for g in groups if len(g) > 1]
    print(f"\n{'=' * 60}")
    print(f"DONE: {len(images)} images -> {count} unique books")
    print(f"  {len(multi)} books had multiple photos")
    if groups:
        print(f"  Largest group: {max(len(g) for g in groups)} images")
    print(f"{'=' * 60}")

    conn.close()

if __name__ == "__main__":
    main()
