package fp448

import (
	"crypto/rand"
	"fmt"
	"math/big"
	"testing"

	"github.com/cloudflare/circl/internal/conv"
	"github.com/cloudflare/circl/internal/test"
)

func testCmov(t *testing.T, f func(x, y *Elt, n uint)) {
	const numTests = 1 << 9
	var x, y Elt
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])
		b := uint(y[0] & 0x1)
		want := conv.BytesLe2BigInt(x[:])
		if b != 0 {
			want = conv.BytesLe2BigInt(y[:])
		}

		f(&x, &y, b)
		got := conv.BytesLe2BigInt(x[:])

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x, y, b)
		}
	}
}

func testCswap(t *testing.T, f func(x, y *Elt, n uint)) {
	const numTests = 1 << 9
	var x, y Elt
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])
		b := uint(y[0] & 0x1)
		want0 := conv.BytesLe2BigInt(x[:])
		want1 := conv.BytesLe2BigInt(y[:])
		if b != 0 {
			want0 = conv.BytesLe2BigInt(y[:])
			want1 = conv.BytesLe2BigInt(x[:])
		}

		f(&x, &y, b)
		got0 := conv.BytesLe2BigInt(x[:])
		got1 := conv.BytesLe2BigInt(y[:])

		if got0.Cmp(want0) != 0 {
			test.ReportError(t, got0, want0, x, y, b)
		}
		if got1.Cmp(want1) != 0 {
			test.ReportError(t, got1, want1, x, y, b)
		}
	}
}

func testAdd(t *testing.T, f func(z, x, y *Elt)) {
	const numTests = 1 << 9
	var x, y, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])
		f(&z, &x, &y)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
		want := xx.Add(xx, yy).Mod(xx, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x, y)
		}
	}
}

func testSub(t *testing.T, f func(z, x, y *Elt)) {
	const numTests = 1 << 9
	var x, y, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])
		f(&z, &x, &y)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
		want := xx.Sub(xx, yy).Mod(xx, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x, y)
		}
	}
}

func testAddSub(t *testing.T, f func(x, y *Elt)) {
	const numTests = 1 << 9
	var x, y Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	want0, want1 := big.NewInt(0), big.NewInt(0)
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])
		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
		want0.Add(xx, yy).Mod(want0, p)
		want1.Sub(xx, yy).Mod(want1, p)

		f(&x, &y)
		Modp(&x)
		Modp(&y)
		got0 := conv.BytesLe2BigInt(x[:])
		got1 := conv.BytesLe2BigInt(y[:])

		if got0.Cmp(want0) != 0 {
			test.ReportError(t, got0, want0, x, y)
		}
		if got1.Cmp(want1) != 0 {
			test.ReportError(t, got1, want1, x, y)
		}
	}
}

func testMul(t *testing.T, f func(z, x, y *Elt)) {
	const numTests = 1 << 9
	var x, y, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])
		f(&z, &x, &y)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		xx, yy := conv.BytesLe2BigInt(x[:]), conv.BytesLe2BigInt(y[:])
		want := xx.Mul(xx, yy).Mod(xx, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x, y)
		}
	}
}

func testSqr(t *testing.T, f func(z, x *Elt)) {
	const numTests = 1 << 9
	var x, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		f(&z, &x)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		xx := conv.BytesLe2BigInt(x[:])
		want := xx.Mul(xx, xx).Mod(xx, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x)
		}
	}
}

func TestModp(t *testing.T) {
	const numTests = 1 << 9
	var x Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	two256 := big.NewInt(1)
	two256.Lsh(two256, 256)
	want := new(big.Int)
	for i := 0; i < numTests; i++ {
		bigX, _ := rand.Int(rand.Reader, two256)
		bigX.Add(bigX, p).Mod(bigX, two256)
		conv.BigInt2BytesLe(x[:], bigX)

		Modp(&x)
		got := conv.BytesLe2BigInt(x[:])

		want.Mod(bigX, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, bigX)
		}
	}
}

func TestIsZero(t *testing.T) {
	var x Elt
	got := IsZero(&x)
	want := true
	if got != want {
		test.ReportError(t, got, want, x)
	}

	SetOne(&x)
	got = IsZero(&x)
	want = false
	if got != want {
		test.ReportError(t, got, want, x)
	}

	x = P()
	got = IsZero(&x)
	want = true
	if got != want {
		test.ReportError(t, got, want, x)
	}
}

func TestToBytes(t *testing.T) {
	const numTests = 1 << 9
	var x Elt
	var got, want [Size]byte
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		err := ToBytes(got[:], &x)
		conv.BigInt2BytesLe(want[:], conv.BytesLe2BigInt(x[:]))

		if err != nil || got != want {
			test.ReportError(t, got, want, x)
		}
	}
	var largeSlice [Size + 1]byte
	err := ToBytes(largeSlice[:], &x)
	if err == nil {
		test.ReportError(t, got, want, largeSlice)
	}
}

func TestString(t *testing.T) {
	const numTests = 1 << 9
	var x Elt
	var bigX big.Int
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		got, _ := bigX.SetString(fmt.Sprint(x), 0)
		want := conv.BytesLe2BigInt(x[:])

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x)
		}
	}
}

func TestNeg(t *testing.T) {
	const numTests = 1 << 9
	var x, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		Neg(&z, &x)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		bigX := conv.BytesLe2BigInt(x[:])
		want := bigX.Neg(bigX).Mod(bigX, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, bigX)
		}
	}
}

func TestInv(t *testing.T) {
	const numTests = 1 << 9
	var x, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		Inv(&z, &x)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		xx := conv.BytesLe2BigInt(x[:])
		want := xx.ModInverse(xx, p)

		if got.Cmp(want) != 0 {
			test.ReportError(t, got, want, x)
		}
	}
}

func TestInvSqrt(t *testing.T) {
	const numTests = 1 << 9
	var x, y, z Elt
	prime := P()
	p := conv.BytesLe2BigInt(prime[:])
	exp := big.NewInt(1)
	exp.Add(p, exp).Rsh(exp, 2)
	var frac, root, sqRoot big.Int
	var wantQR bool
	var want *big.Int
	for i := 0; i < numTests; i++ {
		_, _ = rand.Read(x[:])
		_, _ = rand.Read(y[:])

		gotQR := InvSqrt(&z, &x, &y)
		Modp(&z)
		got := conv.BytesLe2BigInt(z[:])

		xx := conv.BytesLe2BigInt(x[:])
		yy := conv.BytesLe2BigInt(y[:])
		frac.ModInverse(yy, p).Mul(&frac, xx).Mod(&frac, p)
		root.Exp(&frac, exp, p)
		sqRoot.Mul(&root, &root).Mod(&sqRoot, p)

		if sqRoot.Cmp(&frac) == 0 {
			want = &root
			wantQR = true
		} else {
			want = big.NewInt(0)
			wantQR = false
		}

		if wantQR {
			if gotQR != wantQR || got.Cmp(want) != 0 {
				test.ReportError(t, got, want, x, y)
			}
		} else {
			if gotQR != wantQR {
				test.ReportError(t, gotQR, wantQR, x, y)
			}
		}
	}
}

func TestGeneric(t *testing.T) {
	t.Run("Cmov", func(t *testing.T) { testCmov(t, cmovGeneric) })
	t.Run("Cswap", func(t *testing.T) { testCswap(t, cswapGeneric) })
	t.Run("Add", func(t *testing.T) { testAdd(t, addGeneric) })
	t.Run("Sub", func(t *testing.T) { testSub(t, subGeneric) })
	t.Run("AddSub", func(t *testing.T) { testAddSub(t, addsubGeneric) })
	t.Run("Mul", func(t *testing.T) { testMul(t, mulGeneric) })
	t.Run("Sqr", func(t *testing.T) { testSqr(t, sqrGeneric) })
}

func TestNative(t *testing.T) {
	t.Run("Cmov", func(t *testing.T) { testCmov(t, Cmov) })
	t.Run("Cswap", func(t *testing.T) { testCswap(t, Cswap) })
	t.Run("Add", func(t *testing.T) { testAdd(t, Add) })
	t.Run("Sub", func(t *testing.T) { testSub(t, Sub) })
	t.Run("AddSub", func(t *testing.T) { testAddSub(t, AddSub) })
	t.Run("Mul", func(t *testing.T) { testMul(t, Mul) })
	t.Run("Sqr", func(t *testing.T) { testSqr(t, Sqr) })
}

func BenchmarkFp(b *testing.B) {
	var x, y, z Elt
	_, _ = rand.Read(x[:])
	_, _ = rand.Read(y[:])
	_, _ = rand.Read(z[:])
	b.Run("Add", func(b *testing.B) {
		for i := 0; i < b.N; i++ {
			Add(&x, &y, &z)
		}
	})
	b.Run("Sub", func(b *testing.B) {
		for i := 0; i < b.N; i++ {
			Sub(&x, &y, &z)
		}
	})
	b.Run("Mul", func(b *testing.B) {
		for i := 0; i < b.N; i++ {
			Mul(&x, &y, &z)
		}
	})
	b.Run("Sqr", func(b *testing.B) {
		for i := 0; i < b.N; i++ {
			Sqr(&x, &y)
		}
	})
	b.Run("Inv", func(b *testing.B) {
		for i := 0; i < b.N; i++ {
			Inv(&x, &y)
		}
	})
	b.Run("InvSqrt", func(b *testing.B) {
		for i := 0; i < b.N; i++ {
			_ = InvSqrt(&z, &x, &y)
		}
	})
}
