package group_test

import (
	"bytes"
	"crypto/rand"
	"fmt"
	"testing"

	"github.com/cloudflare/circl/group"
	"github.com/cloudflare/circl/internal/test"
)

var allGroups = []group.Group{
	group.P256,
	group.P384,
	group.P521,
	group.Ristretto255,
}

func TestGroup(t *testing.T) {
	const testTimes = 1 << 7
	for _, g := range allGroups {
		n := g.(fmt.Stringer).String()
		t.Run(n+"/Add", func(tt *testing.T) { testAdd(tt, testTimes, g) })
		t.Run(n+"/Neg", func(tt *testing.T) { testNeg(tt, testTimes, g) })
		t.Run(n+"/Mul", func(tt *testing.T) { testMul(tt, testTimes, g) })
		t.Run(n+"/MulGen", func(tt *testing.T) { testMulGen(tt, testTimes, g) })
		t.Run(n+"/CMov", func(tt *testing.T) { testCMov(tt, testTimes, g) })
		t.Run(n+"/CSelect", func(tt *testing.T) { testCSelect(tt, testTimes, g) })
		t.Run(n+"/Order", func(tt *testing.T) { testOrder(tt, testTimes, g) })
		t.Run(n+"/Marshal", func(tt *testing.T) { testMarshal(tt, testTimes, g) })
		t.Run(n+"/Scalar", func(tt *testing.T) { testScalar(tt, testTimes, g) })
	}
}

func testAdd(t *testing.T, testTimes int, g group.Group) {
	Q := g.NewElement()
	for i := 0; i < testTimes; i++ {
		P := g.RandomElement(rand.Reader)

		got := Q.Dbl(P).Dbl(Q).Dbl(Q).Dbl(Q) // Q = 16P

		R := g.Identity()
		for j := 0; j < 16; j++ {
			R.Add(R, P)
		}
		want := R // R = 16P = P+P...+P
		if !got.IsEqual(want) {
			test.ReportError(t, got, want, P)
		}
	}
}

func testNeg(t *testing.T, testTimes int, g group.Group) {
	Q := g.NewElement()
	for i := 0; i < testTimes; i++ {
		P := g.RandomElement(rand.Reader)
		Q.Neg(P)
		Q.Add(Q, P)
		got := Q.IsIdentity()
		want := true
		if got != want {
			test.ReportError(t, got, want, P)
		}
	}
}

func testMul(t *testing.T, testTimes int, g group.Group) {
	Q := g.NewElement()
	kInv := g.NewScalar()
	for i := 0; i < testTimes; i++ {
		P := g.RandomElement(rand.Reader)
		k := g.RandomScalar(rand.Reader)
		kInv.Inv(k)

		Q.Mul(P, k)
		Q.Mul(Q, kInv)

		got := P
		want := Q
		if !got.IsEqual(want) {
			test.ReportError(t, got, want, P, k)
		}
	}
}

func testMulGen(t *testing.T, testTimes int, g group.Group) {
	G := g.Generator()
	P := g.NewElement()
	Q := g.NewElement()
	for i := 0; i < testTimes; i++ {
		k := g.RandomScalar(rand.Reader)

		P.Mul(G, k)
		Q.MulGen(k)

		got := P
		want := Q
		if !got.IsEqual(want) {
			test.ReportError(t, got, want, P, k)
		}
	}
}

func testCMov(t *testing.T, testTimes int, g group.Group) {
	P := g.RandomElement(rand.Reader)
	Q := g.RandomElement(rand.Reader)

	err := test.CheckPanic(func() { P.CMov(0, Q) })
	test.CheckIsErr(t, err, "shouldn't fail with 0")
	err = test.CheckPanic(func() { P.CMov(1, Q) })
	test.CheckIsErr(t, err, "shouldn't fail with 1")
	err = test.CheckPanic(func() { P.CMov(2, Q) })
	test.CheckNoErr(t, err, "should fail with dif 0,1")

	for i := 0; i < testTimes; i++ {
		P = g.RandomElement(rand.Reader)
		Q = g.RandomElement(rand.Reader)

		want := P.Copy()
		got := P.CMov(0, Q)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}

		want = Q.Copy()
		got = P.CMov(1, Q)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}
	}
}

func testCSelect(t *testing.T, testTimes int, g group.Group) {
	P := g.RandomElement(rand.Reader)
	Q := g.RandomElement(rand.Reader)
	R := g.RandomElement(rand.Reader)

	err := test.CheckPanic(func() { P.CSelect(0, Q, R) })
	test.CheckIsErr(t, err, "shouldn't fail with 0")
	err = test.CheckPanic(func() { P.CSelect(1, Q, R) })
	test.CheckIsErr(t, err, "shouldn't fail with 1")
	err = test.CheckPanic(func() { P.CSelect(2, Q, R) })
	test.CheckNoErr(t, err, "should fail with dif 0,1")

	for i := 0; i < testTimes; i++ {
		P = g.RandomElement(rand.Reader)
		Q = g.RandomElement(rand.Reader)
		R = g.RandomElement(rand.Reader)

		want := R.Copy()
		got := P.CSelect(0, Q, R)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}

		want = Q.Copy()
		got = P.CSelect(1, Q, R)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}
	}
}

func testOrder(t *testing.T, testTimes int, g group.Group) {
	I := g.Identity()
	Q := g.NewElement()
	minusOne := g.NewScalar().SetUint64(1)
	minusOne.Neg(minusOne)
	for i := 0; i < testTimes; i++ {
		P := g.RandomElement(rand.Reader)

		Q.Mul(P, minusOne)
		got := Q.Add(Q, P)
		want := I
		if !got.IsEqual(want) {
			test.ReportError(t, got, want, P)
		}
	}
}

func isZero(b []byte) bool {
	for i := 0; i < len(b); i++ {
		if b[i] != 0x00 {
			return false
		}
	}
	return true
}

func testMarshal(t *testing.T, testTimes int, g group.Group) {
	params := g.Params()
	I := g.Identity()
	got, err := I.MarshalBinary()
	test.CheckNoErr(t, err, "error on MarshalBinary")
	if !isZero(got) {
		test.ReportError(t, got, "Non-zero identity")
	}
	if l := uint(len(got)); !(l == 1 || l == params.ElementLength) {
		test.ReportError(t, l, params.ElementLength)
	}
	got, err = I.MarshalBinaryCompress()
	test.CheckNoErr(t, err, "error on MarshalBinaryCompress")
	if !isZero(got) {
		test.ReportError(t, got, "Non-zero identity")
	}
	if l := uint(len(got)); !(l == 1 || l == params.CompressedElementLength) {
		test.ReportError(t, l, params.CompressedElementLength)
	}
	II := g.NewElement()
	err = II.UnmarshalBinary(got)
	if err != nil || !I.IsEqual(II) {
		test.ReportError(t, I, II)
	}

	got1 := g.NewElement()
	got2 := g.NewElement()
	for i := 0; i < testTimes; i++ {
		x := g.RandomElement(rand.Reader)
		enc1, err1 := x.MarshalBinary()
		enc2, err2 := x.MarshalBinaryCompress()
		test.CheckNoErr(t, err1, "error on marshalling")
		test.CheckNoErr(t, err2, "error on marshalling compress")

		err1 = got1.UnmarshalBinary(enc1)
		err2 = got2.UnmarshalBinary(enc2)
		test.CheckNoErr(t, err1, "error on unmarshalling")
		test.CheckNoErr(t, err2, "error on unmarshalling compress")
		if !x.IsEqual(got1) {
			test.ReportError(t, got1, x)
		}
		if !x.IsEqual(got2) {
			test.ReportError(t, got2, x)
		}
		if l := uint(len(enc1)); l != params.ElementLength {
			test.ReportError(t, l, params.ElementLength)
		}
		if l := uint(len(enc2)); l != params.CompressedElementLength {
			test.ReportError(t, l, params.CompressedElementLength)
		}
	}
}

func testScalar(t *testing.T, testTimes int, g group.Group) {
	a := g.RandomScalar(rand.Reader)
	b := g.RandomScalar(rand.Reader)
	c := g.NewScalar()
	d := g.NewScalar()
	e := g.NewScalar()
	f := g.NewScalar()
	one := g.NewScalar()
	one.SetUint64(1)
	params := g.Params()

	err := test.CheckPanic(func() { a.CMov(0, b) })
	test.CheckIsErr(t, err, "shouldn't fail with 0")
	err = test.CheckPanic(func() { a.CMov(1, b) })
	test.CheckIsErr(t, err, "shouldn't fail with 1")
	err = test.CheckPanic(func() { a.CMov(2, b) })
	test.CheckNoErr(t, err, "should fail with dif 0,1")

	err = test.CheckPanic(func() { a.CSelect(0, b, c) })
	test.CheckIsErr(t, err, "shouldn't fail with 0")
	err = test.CheckPanic(func() { a.CSelect(1, b, c) })
	test.CheckIsErr(t, err, "shouldn't fail with 1")
	err = test.CheckPanic(func() { a.CSelect(2, b, c) })
	test.CheckNoErr(t, err, "should fail with dif 0,1")

	for i := 0; i < testTimes; i++ {
		a = g.RandomScalar(rand.Reader)
		b = g.RandomScalar(rand.Reader)
		c.Add(a, b)
		d.Sub(a, b)
		e.Mul(c, d)
		e.Add(e, one)

		c.Mul(a, a)
		d.Mul(b, b)
		d.Neg(d)
		f.Add(c, d)
		f.Add(f, one)
		enc1, err1 := e.MarshalBinary()
		enc2, err2 := f.MarshalBinary()
		if err1 != nil || err2 != nil || !bytes.Equal(enc1, enc2) {
			test.ReportError(t, enc1, enc2, a, b)
		}
		if l := uint(len(enc1)); l != params.ScalarLength {
			test.ReportError(t, l, params.ScalarLength)
		}

		want := c.Copy()
		got := c.CMov(0, a)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}

		want = b.Copy()
		got = d.CMov(1, b)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}

		want = b.Copy()
		got = e.CSelect(0, a, b)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}

		want = a.Copy()
		got = f.CSelect(1, a, b)
		if !got.IsEqual(want) {
			test.ReportError(t, got, want)
		}
	}

	c.Inv(a)
	c.Mul(c, a)
	c.Sub(c, one)
	if !c.IsZero() {
		test.ReportError(t, c, one, a)
	}
}

func BenchmarkElement(b *testing.B) {
	for _, g := range allGroups {
		x := g.RandomElement(rand.Reader)
		y := g.RandomElement(rand.Reader)
		n := g.RandomScalar(rand.Reader)
		name := g.(fmt.Stringer).String()
		b.Run(name+"/Add", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				x.Add(x, y)
			}
		})
		b.Run(name+"/Dbl", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				x.Dbl(x)
			}
		})
		b.Run(name+"/Mul", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				y.Mul(x, n)
			}
		})
		b.Run(name+"/MulGen", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				x.MulGen(n)
			}
		})
	}
}

func BenchmarkScalar(b *testing.B) {
	for _, g := range allGroups {
		x := g.RandomScalar(rand.Reader)
		y := g.RandomScalar(rand.Reader)
		name := g.(fmt.Stringer).String()
		b.Run(name+"/Add", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				x.Add(x, y)
			}
		})
		b.Run(name+"/Mul", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				x.Mul(x, y)
			}
		})
		b.Run(name+"/Inv", func(b *testing.B) {
			for i := 0; i < b.N; i++ {
				y.Inv(x)
			}
		})
	}
}
