#!/usr/bin/env python3
"""Host-header-rewriting reverse proxy for Hermes Dashboard behind Tailscale Funnel.

Listens on 127.0.0.1:9118 and forwards to 127.0.0.1:9119,
rewriting the Host header so the dashboard accepts funnel requests.
Supports WebSocket upgrade for the TUI chat terminal.
"""
import socket
import select
import threading
import urllib.parse


TARGET_HOST = "127.0.0.1"
TARGET_PORT = 9119
LISTEN_HOST = "127.0.0.1"
LISTEN_PORT = 9118


def handle_client(client_sock):
    """Forward one client connection to the dashboard, rewriting Host header."""
    try:
        # Read the initial HTTP request to rewrite Host
        data = b""
        while b"\r\n\r\n" not in data:
            chunk = client_sock.recv(4096)
            if not chunk:
                client_sock.close()
                return
            data += chunk
            if len(data) > 65536:  # 64KB limit for headers
                client_sock.close()
                return

        # Rewrite Host header
        lines = data.split(b"\r\n")
        new_lines = []
        for line in lines:
            if line.lower().startswith(b"host:"):
                new_lines.append(f"Host: {TARGET_HOST}:{TARGET_PORT}".encode())
            else:
                new_lines.append(line)
        rewritten = b"\r\n".join(new_lines)

        # Connect to target
        target_sock = socket.create_connection((TARGET_HOST, TARGET_PORT), timeout=10)
        target_sock.sendall(rewritten)

        # Bidirectional pipe for the rest
        sockets = [client_sock, target_sock]
        while True:
            readable, _, _ = select.select(sockets, [], [], 60)
            if not readable:
                break
            for sock in readable:
                try:
                    data = sock.recv(65536)
                    if not data:
                        raise OSError("EOF")
                    if sock is client_sock:
                        target_sock.sendall(data)
                    else:
                        client_sock.sendall(data)
                except OSError:
                    for s in sockets:
                        try:
                            s.close()
                        except OSError:
                            pass
                    return
    except Exception:
        pass
    finally:
        try:
            client_sock.close()
        except OSError:
            pass


def main():
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((LISTEN_HOST, LISTEN_PORT))
    server.listen(50)
    print(f"Dashboard WebSocket-capable proxy on :{LISTEN_PORT} -> :{TARGET_PORT}", flush=True)

    while True:
        client_sock, addr = server.accept()
        t = threading.Thread(target=handle_client, args=(client_sock,), daemon=True)
        t.start()


if __name__ == "__main__":
    main()
