package pool_test

import (
	"context"
	"errors"
	"fmt"
	"strconv"
	"sync/atomic"
	"testing"
	"time"

	"github.com/sourcegraph/conc/pool"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestResultContextPool(t *testing.T) {
	t.Parallel()

	err1 := errors.New("err1")
	err2 := errors.New("err2")

	t.Run("panics on configuration after init", func(t *testing.T) {
		t.Run("before wait", func(t *testing.T) {
			t.Parallel()
			g := pool.NewWithResults[int]().WithContext(context.Background())
			g.Go(func(context.Context) (int, error) { return 0, nil })
			require.Panics(t, func() { g.WithMaxGoroutines(10) })
		})

		t.Run("after wait", func(t *testing.T) {
			t.Parallel()
			g := pool.NewWithResults[int]().WithContext(context.Background())
			g.Go(func(context.Context) (int, error) { return 0, nil })
			_, _ = g.Wait()
			require.Panics(t, func() { g.WithMaxGoroutines(10) })
		})
	})

	t.Run("behaves the same as ErrorGroup", func(t *testing.T) {
		t.Parallel()
		bgctx := context.Background()
		t.Run("wait returns no error if no errors", func(t *testing.T) {
			t.Parallel()
			g := pool.NewWithResults[int]().WithContext(bgctx)
			g.Go(func(context.Context) (int, error) { return 0, nil })
			res, err := g.Wait()
			require.Len(t, res, 1)
			require.NoError(t, err)
		})

		t.Run("wait error if func returns error", func(t *testing.T) {
			t.Parallel()
			g := pool.NewWithResults[int]().WithContext(bgctx)
			g.Go(func(context.Context) (int, error) { return 0, err1 })
			res, err := g.Wait()
			require.Len(t, res, 0)
			require.ErrorIs(t, err, err1)
		})

		t.Run("wait error is all returned errors", func(t *testing.T) {
			t.Parallel()
			g := pool.NewWithResults[int]().WithErrors().WithContext(bgctx)
			g.Go(func(context.Context) (int, error) { return 0, err1 })
			g.Go(func(context.Context) (int, error) { return 0, nil })
			g.Go(func(context.Context) (int, error) { return 0, err2 })
			res, err := g.Wait()
			require.Len(t, res, 1)
			require.ErrorIs(t, err, err1)
			require.ErrorIs(t, err, err2)
		})
	})

	t.Run("context cancel propagates", func(t *testing.T) {
		t.Parallel()
		ctx, cancel := context.WithCancel(context.Background())
		g := pool.NewWithResults[int]().WithContext(ctx)
		g.Go(func(ctx context.Context) (int, error) {
			<-ctx.Done()
			return 0, ctx.Err()
		})
		cancel()
		res, err := g.Wait()
		require.Len(t, res, 0)
		require.ErrorIs(t, err, context.Canceled)
	})

	t.Run("WithCancelOnError", func(t *testing.T) {
		t.Parallel()
		g := pool.NewWithResults[int]().WithContext(context.Background()).WithCancelOnError()
		g.Go(func(ctx context.Context) (int, error) {
			<-ctx.Done()
			return 0, ctx.Err()
		})
		g.Go(func(ctx context.Context) (int, error) {
			return 0, err1
		})
		res, err := g.Wait()
		require.Len(t, res, 0)
		require.ErrorIs(t, err, context.Canceled)
		require.ErrorIs(t, err, err1)
	})

	t.Run("WithFailFast", func(t *testing.T) {
		t.Parallel()
		p := pool.NewWithResults[int]().WithContext(context.Background()).WithFailFast()
		p.Go(func(ctx context.Context) (int, error) {
			return 0, err1
		})
		p.Go(func(ctx context.Context) (int, error) {
			<-ctx.Done()
			return 1, ctx.Err()
		})
		results, err := p.Wait()
		require.ErrorIs(t, err, err1)
		require.NotErrorIs(t, err, context.Canceled)
		require.Empty(t, results)
	})

	t.Run("WithCancelOnError and panic", func(t *testing.T) {
		t.Parallel()
		p := pool.NewWithResults[int]().
			WithContext(context.Background()).
			WithCancelOnError()
		var cancelledTasks atomic.Int64
		p.Go(func(ctx context.Context) (int, error) {
			<-ctx.Done()
			cancelledTasks.Add(1)
			return 0, ctx.Err()
		})
		p.Go(func(ctx context.Context) (int, error) {
			<-ctx.Done()
			cancelledTasks.Add(1)
			return 0, ctx.Err()
		})
		p.Go(func(ctx context.Context) (int, error) {
			panic("abort!")
		})
		assert.Panics(t, func() { _, _ = p.Wait() })
		assert.EqualValues(t, 2, cancelledTasks.Load())
	})

	t.Run("no WithCancelOnError", func(t *testing.T) {
		t.Parallel()
		g := pool.NewWithResults[int]().WithContext(context.Background())
		g.Go(func(ctx context.Context) (int, error) {
			select {
			case <-ctx.Done():
				return 0, ctx.Err()
			case <-time.After(10 * time.Millisecond):
				return 0, nil
			}
		})
		g.Go(func(ctx context.Context) (int, error) {
			return 0, err1
		})
		res, err := g.Wait()
		require.Len(t, res, 1)
		require.NotErrorIs(t, err, context.Canceled)
		require.ErrorIs(t, err, err1)
	})

	t.Run("WithCollectErrored", func(t *testing.T) {
		t.Parallel()
		g := pool.NewWithResults[int]().WithContext(context.Background()).WithCollectErrored()
		g.Go(func(context.Context) (int, error) { return 0, err1 })
		res, err := g.Wait()
		require.Len(t, res, 1) // errored value is collected
		require.ErrorIs(t, err, err1)
	})

	t.Run("WithFirstError", func(t *testing.T) {
		t.Parallel()
		g := pool.NewWithResults[int]().WithContext(context.Background()).WithFirstError()
		sync := make(chan struct{})
		g.Go(func(ctx context.Context) (int, error) {
			defer close(sync)
			return 0, err1
		})
		g.Go(func(ctx context.Context) (int, error) {
			// This test has a race condition. After the first goroutine
			// completes, this goroutine is woken up because sync is closed.
			// However, this goroutine might be woken up before the error from
			// the first goroutine is registered. To prevent that, we sleep for
			// another 10 milliseconds, giving the other goroutine time to return
			// and register its error before this goroutine returns its error.
			<-sync
			time.Sleep(10 * time.Millisecond)
			return 0, err2
		})
		res, err := g.Wait()
		require.Len(t, res, 0)
		require.ErrorIs(t, err, err1)
		require.NotErrorIs(t, err, err2)
	})

	t.Run("limit", func(t *testing.T) {
		t.Parallel()
		for _, maxConcurrency := range []int{1, 10, 100} {
			t.Run(strconv.Itoa(maxConcurrency), func(t *testing.T) {
				maxConcurrency := maxConcurrency // copy

				t.Parallel()
				ctx := context.Background()
				g := pool.NewWithResults[int]().WithContext(ctx).WithMaxGoroutines(maxConcurrency)

				var currentConcurrent atomic.Int64
				taskCount := maxConcurrency * 10
				expected := make([]int, taskCount)
				for i := 0; i < taskCount; i++ {
					i := i
					expected[i] = i
					g.Go(func(context.Context) (int, error) {
						cur := currentConcurrent.Add(1)
						if cur > int64(maxConcurrency) {
							return 0, fmt.Errorf("expected no more than %d concurrent goroutines", maxConcurrency)
						}
						time.Sleep(time.Millisecond)
						currentConcurrent.Add(-1)
						return i, nil
					})
				}
				res, err := g.Wait()
				require.Equal(t, expected, res)
				require.NoError(t, err)
				require.Equal(t, int64(0), currentConcurrent.Load())
			})
		}
	})

	t.Run("reuse", func(t *testing.T) {
		// Test for https://github.com/sourcegraph/conc/issues/128
		p := pool.NewWithResults[int]().WithContext(context.Background())

		p.Go(func(context.Context) (int, error) { return 1, err1 })
		results1, errs1 := p.Wait()
		require.Empty(t, results1)
		require.ErrorIs(t, errs1, err1)

		p.Go(func(context.Context) (int, error) { return 2, err2 })
		results2, errs2 := p.Wait()
		require.Empty(t, results2)
		require.ErrorIs(t, errs2, err2)
		require.NotErrorIs(t, errs2, err1)
	})
}
