"""
DM Pairing System

Code-based approval flow for authorizing new users on messaging platforms.
Instead of static allowlists with user IDs, unknown users receive a one-time
pairing code that the bot owner approves via the CLI.

Security features (based on OWASP + NIST SP 800-63-4 guidance):
  - 8-char codes from 32-char unambiguous alphabet (no 0/O/1/I)
  - Cryptographic randomness via secrets.choice()
  - 1-hour code expiry
  - Max 3 pending codes per platform
  - Rate limiting: 1 request per user per 10 minutes
  - Lockout after 5 failed approval attempts (1 hour)
  - File permissions: chmod 0600 on all data files
  - Codes are never logged to stdout

Storage: ~/.hermes/pairing/
"""

import hashlib
import json
import os
import secrets
import tempfile
import threading
import time
from pathlib import Path
from typing import Optional

from gateway.whatsapp_identity import (
    expand_whatsapp_aliases,
    normalize_whatsapp_identifier,
)
from hermes_constants import get_hermes_dir
from utils import atomic_replace


# Unambiguous alphabet -- excludes 0/O, 1/I to prevent confusion
ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
CODE_LENGTH = 8

# Timing constants
CODE_TTL_SECONDS = 3600             # Codes expire after 1 hour
RATE_LIMIT_SECONDS = 600            # 1 request per user per 10 minutes
LOCKOUT_SECONDS = 3600              # Lockout duration after too many failures

# Limits
MAX_PENDING_PER_PLATFORM = 3        # Max pending codes per platform
MAX_FAILED_ATTEMPTS = 5             # Failed approvals before lockout

PAIRING_DIR = get_hermes_dir("platforms/pairing", "pairing")


def _secure_write(path: Path, data: str) -> None:
    """Write data to file with restrictive permissions (owner read/write only).

    Uses a temp-file + atomic rename so readers always see either the old
    complete file or the new one — never a partial write.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp")
    try:
        with os.fdopen(fd, "w", encoding="utf-8") as f:
            f.write(data)
            f.flush()
            os.fsync(f.fileno())
        atomic_replace(tmp_path, path)
        try:
            os.chmod(path, 0o600)
        except OSError:
            pass  # Windows doesn't support chmod the same way
    except BaseException:
        try:
            os.unlink(tmp_path)
        except OSError:
            pass
        raise


class PairingStore:
    """
    Manages pairing codes and approved user lists.

    Data files per platform:
      - {platform}-pending.json   : pending pairing requests
      - {platform}-approved.json  : approved (paired) users
      - _rate_limits.json         : rate limit tracking
    """

    def __init__(self):
        PAIRING_DIR.mkdir(parents=True, exist_ok=True)
        # Protects all read-modify-write cycles. The gateway runs multiple
        # platform adapters concurrently in threads sharing one PairingStore.
        self._lock = threading.RLock()

    def _pending_path(self, platform: str) -> Path:
        return PAIRING_DIR / f"{platform}-pending.json"

    def _approved_path(self, platform: str) -> Path:
        return PAIRING_DIR / f"{platform}-approved.json"

    def _rate_limit_path(self) -> Path:
        return PAIRING_DIR / "_rate_limits.json"

    def _load_json(self, path: Path) -> dict:
        if path.exists():
            try:
                return json.loads(path.read_text(encoding="utf-8"))
            except (json.JSONDecodeError, OSError):
                return {}
        return {}

    def _save_json(self, path: Path, data: dict) -> None:
        _secure_write(path, json.dumps(data, indent=2, ensure_ascii=False))

    def _normalize_user_id(self, platform: str, user_id: str) -> str:
        """Normalize platform-specific user IDs before persisting them."""
        raw_user_id = str(user_id or "").strip()
        if platform == "whatsapp":
            return normalize_whatsapp_identifier(raw_user_id) or raw_user_id
        return raw_user_id

    def _user_id_aliases(self, platform: str, user_id: str) -> set[str]:
        """Return all known equivalent user IDs for auth/rate-limit checks."""
        raw_user_id = str(user_id or "").strip()
        if not raw_user_id:
            return set()

        aliases = {raw_user_id, self._normalize_user_id(platform, raw_user_id)}
        if platform == "whatsapp":
            aliases.update(expand_whatsapp_aliases(raw_user_id))
        aliases.discard("")
        return aliases

    def _user_ids_match(self, platform: str, left: str, right: str) -> bool:
        """Return True when two user IDs represent the same principal."""
        left_aliases = self._user_id_aliases(platform, left)
        right_aliases = self._user_id_aliases(platform, right)
        return bool(left_aliases and right_aliases and (left_aliases & right_aliases))

    # ----- Approved users -----

    def is_approved(self, platform: str, user_id: str) -> bool:
        """Check if a user is approved (paired) on a platform."""
        approved = self._load_json(self._approved_path(platform))
        for approved_user_id in approved:
            if self._user_ids_match(platform, approved_user_id, user_id):
                return True
        return False

    def list_approved(self, platform: str = None) -> list:
        """List approved users, optionally filtered by platform."""
        results = []
        platforms = [platform] if platform else self._all_platforms("approved")
        for p in platforms:
            approved = self._load_json(self._approved_path(p))
            for uid, info in approved.items():
                results.append({"platform": p, "user_id": uid, **info})
        return results

    def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None:
        """Add a user to the approved list. Must be called under self._lock."""
        approved = self._load_json(self._approved_path(platform))
        normalized_user_id = self._normalize_user_id(platform, user_id)
        duplicate_ids = [
            approved_user_id
            for approved_user_id in approved
            if self._user_ids_match(platform, approved_user_id, normalized_user_id)
        ]
        for approved_user_id in duplicate_ids:
            del approved[approved_user_id]

        approved[normalized_user_id] = {
            "user_name": user_name,
            "approved_at": time.time(),
        }
        self._save_json(self._approved_path(platform), approved)

    def revoke(self, platform: str, user_id: str) -> bool:
        """Remove a user from the approved list. Returns True if found."""
        path = self._approved_path(platform)
        with self._lock:
            approved = self._load_json(path)
            matching_ids = [
                approved_user_id
                for approved_user_id in approved
                if self._user_ids_match(platform, approved_user_id, user_id)
            ]
            if matching_ids:
                for approved_user_id in matching_ids:
                    del approved[approved_user_id]
                self._save_json(path, approved)
                return True
        return False

    # ----- Pending codes -----

    @staticmethod
    def _hash_code(code: str, salt: bytes) -> str:
        """Hash a pairing code with the given salt using SHA-256."""
        return hashlib.sha256(salt + code.encode("utf-8")).hexdigest()

    def generate_code(
        self, platform: str, user_id: str, user_name: str = ""
    ) -> Optional[str]:
        """
        Generate a pairing code for a new user.

        Returns the code string, or None if:
          - User is rate-limited (too recent request)
          - Max pending codes reached for this platform
          - User/platform is in lockout due to failed attempts

        The code is NOT stored in plaintext.  Only a salted SHA-256 hash is
        persisted so that reading the pending file does not reveal codes.
        """
        with self._lock:
            self._cleanup_expired(platform)
            normalized_user_id = self._normalize_user_id(platform, user_id)

            # Check lockout
            if self._is_locked_out(platform):
                return None

            # Check rate limit for this specific user
            if self._is_rate_limited(platform, user_id):
                return None

            # Check max pending
            pending = self._load_json(self._pending_path(platform))
            if len(pending) >= MAX_PENDING_PER_PLATFORM:
                return None

            # Generate cryptographically random code
            code = "".join(secrets.choice(ALPHABET) for _ in range(CODE_LENGTH))

            # Hash the code with a random salt before storing
            salt = os.urandom(16)
            code_hash = self._hash_code(code, salt)

            # Use a unique entry id as the key (not the code itself)
            entry_id = secrets.token_hex(8)

            # Store pending request with hashed code
            pending[entry_id] = {
                "hash": code_hash,
                "salt": salt.hex(),
                "user_id": normalized_user_id,
                "user_name": user_name,
                "created_at": time.time(),
            }
            self._save_json(self._pending_path(platform), pending)

            # Record rate limit
            self._record_rate_limit(platform, user_id)

            return code

    def approve_code(self, platform: str, code: str) -> Optional[dict]:
        """
        Approve a pairing code. Adds the user to the approved list.

        Returns ``{user_id, user_name}`` on success, ``None`` if the code is
        invalid/expired OR the platform is currently locked out after
        ``MAX_FAILED_ATTEMPTS`` failed approvals (#10195). Callers can
        disambiguate with ``_is_locked_out(platform)``.

        Verification: the user-provided code is hashed with each stored
        entry's salt and compared to the stored hash using constant-time
        comparison. Pre-hash entries (legacy plaintext-key format from
        pre-upgrade pending.json files) are silently ignored — they get
        pruned at TTL by ``_cleanup_expired``.
        """
        with self._lock:
            self._cleanup_expired(platform)
            code = code.upper().strip()

            # Lockout check — must run before the pending lookup so a
            # valid code (e.g. one already sitting in pending) cannot be
            # accepted once the lockout fires. Without this, the lockout
            # only blocks `generate_code`, not `approve_code` — nullifying
            # the brute-force protection for any code already issued.
            if self._is_locked_out(platform):
                return None

            pending = self._load_json(self._pending_path(platform))

            # Find the entry whose hash matches the provided code.
            # Tolerate legacy plaintext-key entries (no salt/hash) and
            # malformed entries — skip them rather than KeyError, so an
            # in-place upgrade across an existing pending.json doesn't
            # crash on the first approve call. Legacy entries get pruned
            # at their TTL by _cleanup_expired.
            matched_key = None
            matched_entry = None
            for entry_id, entry in pending.items():
                if not isinstance(entry, dict):
                    continue
                if "salt" not in entry or "hash" not in entry:
                    continue
                try:
                    salt = bytes.fromhex(entry["salt"])
                except ValueError:
                    continue
                candidate_hash = self._hash_code(code, salt)
                if secrets.compare_digest(candidate_hash, entry["hash"]):
                    matched_key = entry_id
                    matched_entry = entry
                    break

            if matched_key is None:
                self._record_failed_attempt(platform)
                return None

            del pending[matched_key]
            self._save_json(self._pending_path(platform), pending)

            # Add to approved list
            self._approve_user(platform, matched_entry["user_id"],
                               matched_entry.get("user_name", ""))

            return {
                "user_id": matched_entry["user_id"],
                "user_name": matched_entry.get("user_name", ""),
            }

    def list_pending(self, platform: str = None) -> list:
        """List pending pairing requests, optionally filtered by platform.

        Codes are stored hashed — the ``code`` field is replaced with the
        first 8 hex characters of the hash so admins can distinguish entries
        without revealing the original code. Legacy plaintext-key entries
        (pre-hash format) are shown with a "legacy" placeholder so admins
        can see them age out without crashing on a missing ``hash`` field.
        """
        results = []
        with self._lock:
            platforms = [platform] if platform else self._all_platforms("pending")
            for p in platforms:
                self._cleanup_expired(p)
                pending = self._load_json(self._pending_path(p))
                for entry_id, info in pending.items():
                    if not isinstance(info, dict):
                        continue
                    created_at = info.get("created_at")
                    if not isinstance(created_at, (int, float)):
                        continue
                    age_min = int((time.time() - created_at) / 60)
                    hash_val = info.get("hash")
                    code_display = hash_val[:8] if isinstance(hash_val, str) else "legacy"
                    results.append({
                        "platform": p,
                        "code": code_display,
                        "user_id": info.get("user_id", ""),
                        "user_name": info.get("user_name", ""),
                        "age_minutes": age_min,
                    })
        return results

    def clear_pending(self, platform: str = None) -> int:
        """Clear all pending requests. Returns count removed."""
        with self._lock:
            count = 0
            platforms = [platform] if platform else self._all_platforms("pending")
            for p in platforms:
                pending = self._load_json(self._pending_path(p))
                count += len(pending)
                self._save_json(self._pending_path(p), {})
        return count

    # ----- Rate limiting and lockout -----

    def _is_rate_limited(self, platform: str, user_id: str) -> bool:
        """Check if a user has requested a code too recently."""
        limits = self._load_json(self._rate_limit_path())
        for alias in self._user_id_aliases(platform, user_id):
            key = f"{platform}:{alias}"
            last_request = limits.get(key, 0)
            if (time.time() - last_request) < RATE_LIMIT_SECONDS:
                return True
        return False

    def _record_rate_limit(self, platform: str, user_id: str) -> None:
        """Record the time of a pairing request for rate limiting."""
        limits = self._load_json(self._rate_limit_path())
        now = time.time()
        for alias in self._user_id_aliases(platform, user_id):
            key = f"{platform}:{alias}"
            limits[key] = now
        self._save_json(self._rate_limit_path(), limits)

    def _is_locked_out(self, platform: str) -> bool:
        """Check if a platform is in lockout due to failed approval attempts."""
        limits = self._load_json(self._rate_limit_path())
        lockout_key = f"_lockout:{platform}"
        lockout_until = limits.get(lockout_key, 0)
        return time.time() < lockout_until

    def _record_failed_attempt(self, platform: str) -> None:
        """Record a failed approval attempt. Triggers lockout after MAX_FAILED_ATTEMPTS."""
        limits = self._load_json(self._rate_limit_path())
        fail_key = f"_failures:{platform}"
        fails = limits.get(fail_key, 0) + 1
        limits[fail_key] = fails
        if fails >= MAX_FAILED_ATTEMPTS:
            lockout_key = f"_lockout:{platform}"
            limits[lockout_key] = time.time() + LOCKOUT_SECONDS
            limits[fail_key] = 0  # Reset counter
            print(f"[pairing] Platform {platform} locked out for {LOCKOUT_SECONDS}s "
                  f"after {MAX_FAILED_ATTEMPTS} failed attempts", flush=True)
        self._save_json(self._rate_limit_path(), limits)

    # ----- Cleanup -----

    def _cleanup_expired(self, platform: str) -> None:
        """Remove expired pending codes.

        Tolerant of malformed / legacy entries — anything without a numeric
        ``created_at`` is treated as expired (it's effectively unusable
        with the new hash-keyed schema anyway).
        """
        path = self._pending_path(platform)
        pending = self._load_json(path)
        now = time.time()
        expired = []
        for entry_id, info in pending.items():
            if not isinstance(info, dict):
                expired.append(entry_id)
                continue
            created_at = info.get("created_at")
            if not isinstance(created_at, (int, float)):
                expired.append(entry_id)
                continue
            if (now - created_at) > CODE_TTL_SECONDS:
                expired.append(entry_id)
        if expired:
            for entry_id in expired:
                del pending[entry_id]
            self._save_json(path, pending)

    def _all_platforms(self, suffix: str) -> list:
        """List all platforms that have data files of a given suffix."""
        platforms = []
        for f in PAIRING_DIR.iterdir():
            if f.name.endswith(f"-{suffix}.json"):
                platform = f.name.replace(f"-{suffix}.json", "")
                if not platform.startswith("_"):
                    platforms.append(platform)
        return platforms
