package ctxwatch_test

import (
	"context"
	"sync/atomic"
	"testing"
	"time"

	"github.com/jackc/pgconn/internal/ctxwatch"
	"github.com/stretchr/testify/require"
)

func TestContextWatcherContextCancelled(t *testing.T) {
	canceledChan := make(chan struct{})
	cleanupCalled := false
	cw := ctxwatch.NewContextWatcher(func() {
		canceledChan <- struct{}{}
	}, func() {
		cleanupCalled = true
	})

	ctx, cancel := context.WithCancel(context.Background())
	cw.Watch(ctx)
	cancel()

	select {
	case <-canceledChan:
	case <-time.NewTimer(time.Second).C:
		t.Fatal("Timed out waiting for cancel func to be called")
	}

	cw.Unwatch()

	require.True(t, cleanupCalled, "Cleanup func was not called")
}

func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
	cw := ctxwatch.NewContextWatcher(func() {
		t.Error("cancel func should not have been called")
	}, func() {
		t.Error("cleanup func should not have been called")
	})

	ctx, cancel := context.WithCancel(context.Background())
	cw.Watch(ctx)
	cw.Unwatch()
	cancel()
}

func TestContextWatcherMultipleWatchPanics(t *testing.T) {
	cw := ctxwatch.NewContextWatcher(func() {}, func() {})

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	cw.Watch(ctx)

	ctx2, cancel2 := context.WithCancel(context.Background())
	defer cancel2()
	require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
}

func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
	cw := ctxwatch.NewContextWatcher(func() {}, func() {})
	cw.Unwatch() // unwatch when not / never watching

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	cw.Watch(ctx)
	cw.Unwatch()
	cw.Unwatch() // double unwatch
}

func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
	cw := ctxwatch.NewContextWatcher(func() {}, func() {})

	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
	defer cancel()
	cw.Watch(ctx)

	go cw.Unwatch()
	go cw.Unwatch()

	<-ctx.Done()
}

func TestContextWatcherStress(t *testing.T) {
	var cancelFuncCalls int64
	var cleanupFuncCalls int64

	cw := ctxwatch.NewContextWatcher(func() {
		atomic.AddInt64(&cancelFuncCalls, 1)
	}, func() {
		atomic.AddInt64(&cleanupFuncCalls, 1)
	})

	cycleCount := 100000

	for i := 0; i < cycleCount; i++ {
		ctx, cancel := context.WithCancel(context.Background())
		cw.Watch(ctx)
		if i%2 == 0 {
			cancel()
		}

		// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
		if i%3 == 0 {
			time.Sleep(time.Nanosecond)
		}

		cw.Unwatch()
		if i%2 == 1 {
			cancel()
		}
	}

	actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
	actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)

	if actualCancelFuncCalls == 0 {
		t.Fatal("actualCancelFuncCalls == 0")
	}

	maxCancelFuncCalls := int64(cycleCount) / 2
	if actualCancelFuncCalls > maxCancelFuncCalls {
		t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
	}

	if actualCancelFuncCalls != actualCleanupFuncCalls {
		t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
	}
}

func BenchmarkContextWatcherUncancellable(b *testing.B) {
	cw := ctxwatch.NewContextWatcher(func() {}, func() {})

	for i := 0; i < b.N; i++ {
		cw.Watch(context.Background())
		cw.Unwatch()
	}
}

func BenchmarkContextWatcherCancelled(b *testing.B) {
	cw := ctxwatch.NewContextWatcher(func() {}, func() {})

	for i := 0; i < b.N; i++ {
		ctx, cancel := context.WithCancel(context.Background())
		cw.Watch(ctx)
		cancel()
		cw.Unwatch()
	}
}

func BenchmarkContextWatcherCancellable(b *testing.B) {
	cw := ctxwatch.NewContextWatcher(func() {}, func() {})

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	for i := 0; i < b.N; i++ {
		cw.Watch(ctx)
		cw.Unwatch()
	}
}
