package group

import (
	"crypto"
	"crypto/elliptic"
	_ "crypto/sha256"
	_ "crypto/sha512"
	"crypto/subtle"
	"fmt"
	"io"
	"math/big"

	"github.com/cloudflare/circl/ecc/p384"
	"github.com/cloudflare/circl/expander"
)

var (
	// P256 is the group generated by P-256 elliptic curve.
	P256 Group = wG{elliptic.P256()}
	// P384 is the group generated by P-384 elliptic curve.
	P384 Group = wG{p384.P384()}
	// P521 is the group generated by P-521 elliptic curve.
	P521 Group = wG{elliptic.P521()}
)

type wG struct {
	c elliptic.Curve
}

func (g wG) String() string      { return g.c.Params().Name }
func (g wG) NewElement() Element { return g.zeroElement() }
func (g wG) NewScalar() Scalar   { return g.zeroScalar() }
func (g wG) Identity() Element   { return g.zeroElement() }
func (g wG) zeroScalar() *wScl   { return &wScl{g, make([]byte, (g.c.Params().BitSize+7)/8)} }
func (g wG) zeroElement() *wElt  { return &wElt{g, new(big.Int), new(big.Int)} }
func (g wG) Generator() Element  { return &wElt{g, g.c.Params().Gx, g.c.Params().Gy} }
func (g wG) RandomElement(rd io.Reader) Element {
	b := make([]byte, (g.c.Params().BitSize+7)/8)
	if n, err := io.ReadFull(rd, b); err != nil || n != len(b) {
		panic(err)
	}
	return g.HashToElement(b, nil)
}

func (g wG) RandomScalar(rd io.Reader) Scalar {
	b := make([]byte, (g.c.Params().BitSize+7)/8)
	if n, err := io.ReadFull(rd, b); err != nil || n != len(b) {
		panic(err)
	}
	return g.HashToScalar(b, nil)
}

func (g wG) RandomNonZeroScalar(rd io.Reader) Scalar {
	zero := g.zeroScalar()
	for {
		s := g.RandomScalar(rd)
		if !s.IsEqual(zero) {
			return s
		}
	}
}

func (g wG) cvtElt(e Element) *wElt {
	if e == nil {
		return g.zeroElement()
	}
	ee, ok := e.(*wElt)
	if !ok || g.c.Params().BitSize != ee.c.Params().BitSize {
		panic(ErrType)
	}
	return ee
}

func (g wG) cvtScl(s Scalar) *wScl {
	if s == nil {
		return g.zeroScalar()
	}
	ss, ok := s.(*wScl)
	if !ok || g.c.Params().BitSize != ss.c.Params().BitSize {
		panic(ErrType)
	}
	return ss
}

func (g wG) Params() *Params {
	fieldLen := uint((g.c.Params().BitSize + 7) / 8)
	return &Params{
		ElementLength:           1 + 2*fieldLen,
		CompressedElementLength: 1 + fieldLen,
		ScalarLength:            fieldLen,
	}
}

func (g wG) HashToElementNonUniform(b, dst []byte) Element {
	var u [1]big.Int
	mapping, h, L := g.mapToCurveParams()
	xmd := expander.NewExpanderMD(h, dst)
	HashToField(u[:], b, xmd, g.c.Params().P, L)
	return mapping(&u[0])
}

func (g wG) HashToElement(b, dst []byte) Element {
	var u [2]big.Int
	mapping, h, L := g.mapToCurveParams()
	xmd := expander.NewExpanderMD(h, dst)
	HashToField(u[:], b, xmd, g.c.Params().P, L)
	Q0 := mapping(&u[0])
	Q1 := mapping(&u[1])
	return Q0.Add(Q0, Q1)
}

func (g wG) HashToScalar(b, dst []byte) Scalar {
	var u [1]big.Int
	_, h, L := g.mapToCurveParams()
	xmd := expander.NewExpanderMD(h, dst)
	HashToField(u[:], b, xmd, g.c.Params().N, L)
	s := g.NewScalar().(*wScl)
	s.fromBig(&u[0])
	return s
}

type wElt struct {
	wG
	x, y *big.Int
}

func (e *wElt) Group() Group     { return e.wG }
func (e *wElt) String() string   { return fmt.Sprintf("x: 0x%v\ny: 0x%v", e.x.Text(16), e.y.Text(16)) }
func (e *wElt) IsIdentity() bool { return e.x.Sign() == 0 && e.y.Sign() == 0 }
func (e *wElt) IsEqual(o Element) bool {
	oo := e.cvtElt(o)
	return e.x.Cmp(oo.x) == 0 && e.y.Cmp(oo.y) == 0
}

func (e *wElt) Set(a Element) Element {
	aa := e.cvtElt(a)
	e.x.Set(aa.x)
	e.y.Set(aa.y)
	return e
}

func (e *wElt) Copy() Element { return e.wG.zeroElement().Set(e) }

func (e *wElt) CMov(v int, a Element) Element {
	if !(v == 0 || v == 1) {
		panic(ErrSelector)
	}
	aa := e.cvtElt(a)
	l := (e.wG.c.Params().BitSize + 7) / 8
	bufE := make([]byte, l)
	bufA := make([]byte, l)
	e.x.FillBytes(bufE)
	aa.x.FillBytes(bufA)
	subtle.ConstantTimeCopy(v, bufE, bufA)
	e.x.SetBytes(bufE)

	e.y.FillBytes(bufE)
	aa.y.FillBytes(bufA)
	subtle.ConstantTimeCopy(v, bufE, bufA)
	e.y.SetBytes(bufE)

	return e
}

func (e *wElt) CSelect(v int, a Element, b Element) Element {
	if !(v == 0 || v == 1) {
		panic(ErrSelector)
	}
	aa, bb := e.cvtElt(a), e.cvtElt(b)
	l := (e.wG.c.Params().BitSize + 7) / 8
	bufE := make([]byte, l)
	bufA := make([]byte, l)
	bufB := make([]byte, l)

	e.x.FillBytes(bufE)
	aa.x.FillBytes(bufA)
	bb.x.FillBytes(bufB)
	for i := range bufE {
		bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i])))
	}
	e.x.SetBytes(bufE)

	e.y.FillBytes(bufE)
	aa.y.FillBytes(bufA)
	bb.y.FillBytes(bufB)
	for i := range bufE {
		bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i])))
	}
	e.y.SetBytes(bufE)

	return e
}

func (e *wElt) Add(a, b Element) Element {
	aa, bb := e.cvtElt(a), e.cvtElt(b)
	e.x, e.y = e.c.Add(aa.x, aa.y, bb.x, bb.y)
	return e
}

func (e *wElt) Dbl(a Element) Element {
	aa := e.cvtElt(a)
	e.x, e.y = e.c.Double(aa.x, aa.y)
	return e
}

func (e *wElt) Neg(a Element) Element {
	aa := e.cvtElt(a)
	e.x.Set(aa.x)
	e.y.Neg(aa.y).Mod(e.y, e.c.Params().P)
	return e
}

func (e *wElt) Mul(a Element, s Scalar) Element {
	aa, ss := e.cvtElt(a), e.cvtScl(s)
	e.x, e.y = e.c.ScalarMult(aa.x, aa.y, ss.k)
	return e
}

func (e *wElt) MulGen(s Scalar) Element {
	ss := e.cvtScl(s)
	e.x, e.y = e.c.ScalarBaseMult(ss.k)
	return e
}

func (e *wElt) MarshalBinary() ([]byte, error) {
	if e.IsIdentity() {
		return []byte{0x0}, nil
	}
	e.x.Mod(e.x, e.c.Params().P)
	e.y.Mod(e.y, e.c.Params().P)
	return elliptic.Marshal(e.wG.c, e.x, e.y), nil
}

func (e *wElt) MarshalBinaryCompress() ([]byte, error) {
	if e.IsIdentity() {
		return []byte{0x0}, nil
	}
	e.x.Mod(e.x, e.c.Params().P)
	e.y.Mod(e.y, e.c.Params().P)
	return elliptic.MarshalCompressed(e.wG.c, e.x, e.y), nil
}

func (e *wElt) UnmarshalBinary(b []byte) error {
	byteLen := (e.c.Params().BitSize + 7) / 8
	l := len(b)
	switch {
	case l == 1 && b[0] == 0x00: // point at infinity
		e.x.SetInt64(0)
		e.y.SetInt64(0)
	case l == 1+byteLen && (b[0] == 0x02 || b[0] == 0x03): // compressed
		x, y := elliptic.UnmarshalCompressed(e.wG.c, b)
		if x == nil {
			return ErrUnmarshal
		}
		e.x, e.y = x, y
	case l == 1+2*byteLen && b[0] == 0x04: // uncompressed
		x, y := elliptic.Unmarshal(e.wG.c, b)
		if x == nil {
			return ErrUnmarshal
		}
		e.x, e.y = x, y
	default:
		return ErrUnmarshal
	}
	return nil
}

type wScl struct {
	wG
	k []byte
}

func (s *wScl) Group() Group                { return s.wG }
func (s *wScl) String() string              { return fmt.Sprintf("0x%x", s.k) }
func (s *wScl) SetUint64(n uint64) Scalar   { s.fromBig(new(big.Int).SetUint64(n)); return s }
func (s *wScl) SetBigInt(x *big.Int) Scalar { s.fromBig(x); return s }
func (s *wScl) IsZero() bool {
	return subtle.ConstantTimeCompare(s.k, make([]byte, (s.wG.c.Params().BitSize+7)/8)) == 1
}

func (s *wScl) IsEqual(a Scalar) bool {
	aa := s.cvtScl(a)
	return subtle.ConstantTimeCompare(s.k, aa.k) == 1
}

func (s *wScl) fromBig(b *big.Int) {
	k := new(big.Int).Mod(b, s.c.Params().N)
	if err := s.UnmarshalBinary(k.Bytes()); err != nil {
		panic(err)
	}
}

func (s *wScl) Set(a Scalar) Scalar {
	aa := s.cvtScl(a)
	if err := s.UnmarshalBinary(aa.k); err != nil {
		panic(err)
	}
	return s
}

func (s *wScl) Copy() Scalar { return s.wG.zeroScalar().Set(s) }

func (s *wScl) CMov(v int, a Scalar) Scalar {
	if !(v == 0 || v == 1) {
		panic(ErrSelector)
	}
	aa := s.cvtScl(a)
	subtle.ConstantTimeCopy(v, s.k, aa.k)
	return s
}

func (s *wScl) CSelect(v int, a Scalar, b Scalar) Scalar {
	if !(v == 0 || v == 1) {
		panic(ErrSelector)
	}
	aa, bb := s.cvtScl(a), s.cvtScl(b)
	for i := range s.k {
		s.k[i] = byte(subtle.ConstantTimeSelect(v, int(aa.k[i]), int(bb.k[i])))
	}
	return s
}

func (s *wScl) Add(a, b Scalar) Scalar {
	aa, bb := s.cvtScl(a), s.cvtScl(b)
	r := new(big.Int)
	r.SetBytes(aa.k).Add(r, new(big.Int).SetBytes(bb.k))
	s.fromBig(r)
	return s
}

func (s *wScl) Sub(a, b Scalar) Scalar {
	aa, bb := s.cvtScl(a), s.cvtScl(b)
	r := new(big.Int)
	r.SetBytes(aa.k).Sub(r, new(big.Int).SetBytes(bb.k))
	s.fromBig(r)
	return s
}

func (s *wScl) Mul(a, b Scalar) Scalar {
	aa, bb := s.cvtScl(a), s.cvtScl(b)
	r := new(big.Int)
	r.SetBytes(aa.k).Mul(r, new(big.Int).SetBytes(bb.k))
	s.fromBig(r)
	return s
}

func (s *wScl) Neg(a Scalar) Scalar {
	aa := s.cvtScl(a)
	r := new(big.Int)
	r.SetBytes(aa.k).Neg(r)
	s.fromBig(r)
	return s
}

func (s *wScl) Inv(a Scalar) Scalar {
	aa := s.cvtScl(a)
	r := new(big.Int)
	r.SetBytes(aa.k).ModInverse(r, s.c.Params().N)
	s.fromBig(r)
	return s
}

func (s *wScl) MarshalBinary() (data []byte, err error) {
	data = make([]byte, (s.c.Params().BitSize+7)/8)
	copy(data, s.k)
	return data, nil
}

func (s *wScl) UnmarshalBinary(b []byte) error {
	l := (s.c.Params().BitSize + 7) / 8
	s.k = make([]byte, l)
	copy(s.k[l-len(b):l], b)
	return nil
}

func (g wG) mapToCurveParams() (mapping func(u *big.Int) *wElt, h crypto.Hash, L uint) {
	var Z, C2 big.Int
	switch g.c.Params().BitSize {
	case 256:
		Z.SetInt64(-10)
		C2.SetString("0x78bc71a02d89ec07214623f6d0f955072c7cc05604a5a6e23ffbf67115fa5301", 0)
		h = crypto.SHA256
		L = 48
	case 384:
		Z.SetInt64(-12)
		C2.SetString("0x19877cc1041b7555743c0ae2e3a3e61fb2aaa2e0e87ea557a563d8b598a0940d0a697a9e0b9e92cfaa314f583c9d066", 0)
		h = crypto.SHA384
		L = 72
	case 521:
		Z.SetInt64(-4)
		C2.SetInt64(8)
		h = crypto.SHA512
		L = 98
	default:
		panic("curve not supported")
	}
	return func(u *big.Int) *wElt { return g.sswu3mod4Map(u, &Z, &C2) }, h, L
}

func (g wG) sswu3mod4Map(u *big.Int, Z, C2 *big.Int) *wElt {
	tv1 := new(big.Int)
	tv2 := new(big.Int)
	tv3 := new(big.Int)
	tv4 := new(big.Int)
	xn := new(big.Int)
	xd := new(big.Int)
	x1n := new(big.Int)
	x2n := new(big.Int)
	gx1 := new(big.Int)
	gxd := new(big.Int)
	y1 := new(big.Int)
	y2 := new(big.Int)
	x := new(big.Int)
	y := new(big.Int)

	A := big.NewInt(-3)
	B := g.c.Params().B
	p := g.c.Params().P
	c1 := new(big.Int)
	c1.Sub(p, big.NewInt(3)).Rsh(c1, 2) // 1.  c1 = (q - 3) / 4

	add := func(c, a, b *big.Int) { c.Add(a, b).Mod(c, p) }
	mul := func(c, a, b *big.Int) { c.Mul(a, b).Mod(c, p) }
	sqr := func(c, a *big.Int) { c.Mul(a, a).Mod(c, p) }
	exp := func(c, a, b *big.Int) { c.Exp(a, b, p) }
	sgn := func(a *big.Int) uint { a.Mod(a, p); return a.Bit(0) }
	cmv := func(c, a, b *big.Int, k bool) {
		if k {
			c.Set(b)
		} else {
			c.Set(a)
		}
	}

	sqr(tv1, u)                 // 1.  tv1 = u^2
	mul(tv3, Z, tv1)            // 2.  tv3 = Z * tv1
	sqr(tv2, tv3)               // 3.  tv2 = tv3^2
	add(xd, tv2, tv3)           // 4.   xd = tv2 + tv3
	add(x1n, xd, big.NewInt(1)) // 5.  x1n = xd + 1
	mul(x1n, x1n, B)            // 6.  x1n = x1n * B
	tv4.Neg(A)                  //
	mul(xd, tv4, xd)            // 7.   xd = -A * xd
	e1 := xd.Sign() == 0        // 8.   e1 = xd == 0
	mul(tv4, Z, A)              //
	cmv(xd, xd, tv4, e1)        // 9.   xd = CMOV(xd, Z * A, e1)
	sqr(tv2, xd)                // 10. tv2 = xd^2
	mul(gxd, tv2, xd)           // 11. gxd = tv2 * xd
	mul(tv2, A, tv2)            // 12. tv2 = A * tv2
	sqr(gx1, x1n)               // 13. gx1 = x1n^2
	add(gx1, gx1, tv2)          // 14. gx1 = gx1 + tv2
	mul(gx1, gx1, x1n)          // 15. gx1 = gx1 * x1n
	mul(tv2, B, gxd)            // 16. tv2 = B * gxd
	add(gx1, gx1, tv2)          // 17. gx1 = gx1 + tv2
	sqr(tv4, gxd)               // 18. tv4 = gxd^2
	mul(tv2, gx1, gxd)          // 19. tv2 = gx1 * gxd
	mul(tv4, tv4, tv2)          // 20. tv4 = tv4 * tv2
	exp(y1, tv4, c1)            // 21.  y1 = tv4^c1
	mul(y1, y1, tv2)            // 22.  y1 = y1 * tv2
	mul(x2n, tv3, x1n)          // 23. x2n = tv3 * x1n
	mul(y2, y1, C2)             // 24.  y2 = y1 * c2
	mul(y2, y2, tv1)            // 25.  y2 = y2 * tv1
	mul(y2, y2, u)              // 26.  y2 = y2 * u
	sqr(tv2, y1)                // 27. tv2 = y1^2
	mul(tv2, tv2, gxd)          // 28. tv2 = tv2 * gxd
	e2 := tv2.Cmp(gx1) == 0     // 29.  e2 = tv2 == gx1
	cmv(xn, x2n, x1n, e2)       // 30.  xn = CMOV(x2n, x1n, e2)
	cmv(y, y2, y1, e2)          // 31.   y = CMOV(y2, y1, e2)
	e3 := sgn(u) == sgn(y)      // 32.  e3 = sgn0(u) == sgn0(y)
	tv1.Neg(y)                  //
	cmv(y, tv1, y, e3)          // 33.   y = CMOV(-y, y, e3)
	tv1.ModInverse(xd, p)       //
	mul(x, xn, tv1)             // 34. return (xn, xd, y, 1)
	y.Mod(y, p)
	return &wElt{g, x, y}
}
