package debug

import (
	"context"
	"encoding/json"
	"errors"
	"io"
	"log"
	"net"
	"os"

	"github.com/jackc/pgproto3/v2"
	"github.com/jackc/pgx/v4"
	"google.golang.org/grpc/test/bufconn"
)

type Proxy struct {
	dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
	errChan     chan error
}

func NewProxy() Proxy {
	dialer := net.Dialer{}
	return Proxy{
		dialContext: dialer.DialContext,
		errChan:     make(chan error, 1),
	}
}

func SetupPGX(config *pgx.ConnConfig) {
	proxy := Proxy{
		dialContext: config.DialFunc,
		errChan:     make(chan error, 1),
	}
	config.DialFunc = proxy.DialFunc
	config.TLSConfig = nil
}

func (p *Proxy) DialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
	serverConn, err := p.dialContext(ctx, network, addr)
	if err != nil {
		return nil, err
	}

	const bufSize = 1024 * 1024
	ln := bufconn.Listen(bufSize)
	go func() {
		defer serverConn.Close()
		clientConn, err := ln.Accept()
		if err != nil {
			// Unreachable code as bufconn never throws, but just in case
			panic(err)
		}
		defer clientConn.Close()

		backend := NewBackend(clientConn)
		frontend := NewFrontend(serverConn)
		go backend.forward(frontend, p.errChan)
		go frontend.forward(backend, p.errChan)

		for {
			// Since pgx closes connection first, every EOF is seen as unexpected
			if err := <-p.errChan; err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
				panic(err)
			}
		}
	}()

	return ln.DialContext(ctx)
}

type Backend struct {
	*pgproto3.Backend
	logger *log.Logger
}

func NewBackend(clientConn net.Conn) Backend {
	return Backend{
		pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn),
		log.New(os.Stderr, "PG Recv: ", log.LstdFlags|log.Lmsgprefix),
	}
}

func (b *Backend) forward(frontend Frontend, errChan chan error) {
	startupMessage, err := b.ReceiveStartupMessage()
	if err != nil {
		errChan <- err
		return
	}

	buf, err := json.Marshal(startupMessage)
	if err != nil {
		errChan <- err
		return
	}
	frontend.logger.Println(string(buf))

	if err = frontend.Send(startupMessage); err != nil {
		errChan <- err
		return
	}

	for {
		msg, err := b.Receive()
		if err != nil {
			errChan <- err
			return
		}

		buf, err := json.Marshal(msg)
		if err != nil {
			errChan <- err
			return
		}
		frontend.logger.Println(string(buf))

		if err = frontend.Send(msg); err != nil {
			errChan <- err
			return
		}
	}
}

type Frontend struct {
	*pgproto3.Frontend
	logger *log.Logger
}

func NewFrontend(serverConn net.Conn) Frontend {
	return Frontend{
		pgproto3.NewFrontend(pgproto3.NewChunkReader(serverConn), serverConn),
		log.New(os.Stderr, "PG Send: ", log.LstdFlags|log.Lmsgprefix),
	}
}

func (f *Frontend) forward(backend Backend, errChan chan error) {
	for {
		msg, err := f.Receive()
		if err != nil {
			errChan <- err
			return
		}

		buf, err := json.Marshal(msg)
		if err != nil {
			errChan <- err
			return
		}
		backend.logger.Println(string(buf))

		if _, ok := msg.(pgproto3.AuthenticationResponseMessage); ok {
			// Set the authentication type so the next backend.Receive() will
			// properly decode the appropriate 'p' message.
			if err := backend.SetAuthType(f.GetAuthType()); err != nil {
				errChan <- err
				return
			}
		}

		if err := backend.Send(msg); err != nil {
			errChan <- err
			return
		}
	}
}
