// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package packet

import (
	"bytes"
	"crypto/rand"
	"crypto/sha1"
	"encoding/hex"
	goerrors "errors"
	"io"
	"testing"

	"github.com/ProtonMail/go-crypto/openpgp/errors"
)

// TestReader wraps a []byte and returns reads of a specific length.
type testReader struct {
	data   []byte
	stride int
}

func (t *testReader) Read(buf []byte) (n int, err error) {
	n = t.stride
	if n > len(t.data) {
		n = len(t.data)
	}
	if n > len(buf) {
		n = len(buf)
	}

	copy(buf[:n], t.data)
	t.data = t.data[n:]

	if len(t.data) == 0 {
		err = io.EOF
	}

	return
}

const mdcPlaintextHex = "cb1362000000000048656c6c6f2c20776f726c6421d314c23d643f478a9a2098811fcb191e7b24b80966a1"

func TestMDCReader(t *testing.T) {
	mdcPlaintext, _ := hex.DecodeString(mdcPlaintextHex)
	for stride := 1; stride < len(mdcPlaintext)/2; stride++ {
		r := &testReader{data: mdcPlaintext, stride: stride}
		mdcReader := &seMDCReader{in: r, h: sha1.New()}
		body, err := io.ReadAll(mdcReader)
		if err != nil {
			t.Errorf("stride: %d, error: %s", stride, err)
			continue
		}
		if !bytes.Equal(body, mdcPlaintext[:len(mdcPlaintext)-22]) {
			t.Errorf("stride: %d: bad contents %x", stride, body)
			continue
		}

		err = mdcReader.Close()
		if err != nil {
			t.Errorf("stride: %d, error on Close: %s", stride, err)
		}
	}

	mdcPlaintext[15] ^= 80

	r := &testReader{data: mdcPlaintext, stride: 2}
	mdcReader := &seMDCReader{in: r, h: sha1.New()}
	_, err := io.ReadAll(mdcReader)
	if err != nil {
		t.Errorf("corruption test, error: %s", err)
		return
	}
	err = mdcReader.Close()
	if err == nil {
		t.Error("corruption: no error")
	} else if !goerrors.Is(err, errors.ErrMDCHashMismatch) {
		t.Errorf("corruption: expected SignatureError, got: %s", err)
	}
}

func TestSerializeMdc(t *testing.T) {
	buf := bytes.NewBuffer(nil)
	c := CipherAES128
	key := make([]byte, c.KeySize())

	cipherSuite := CipherSuite{
		Cipher: c,
		Mode:   AEADModeOCB,
	}

	w, err := SerializeSymmetricallyEncrypted(buf, c, false, cipherSuite, key, nil)
	if err != nil {
		t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
		return
	}

	contents := []byte("hello world\n")

	if _, err := w.Write(contents); err != nil {
		t.Error(err)
		return
	}
	w.Close()

	p, err := Read(buf)
	if err != nil {
		t.Errorf("error from Read: %s", err)
		return
	}

	se, ok := p.(*SymmetricallyEncrypted)
	if !ok {
		t.Errorf("didn't read a *SymmetricallyEncrypted")
		return
	}

	r, err := se.Decrypt(c, key)
	if err != nil {
		t.Errorf("error from Decrypt: %s", err)
		return
	}

	contentsCopy := bytes.NewBuffer(nil)
	_, err = io.Copy(contentsCopy, r)
	if err != nil {
		t.Errorf("error from io.Copy: %s", err)
		return
	}
	if !bytes.Equal(contentsCopy.Bytes(), contents) {
		t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
	}
}

const aeadHexKey = "1936fc8568980274bb900d8319360c77"
const aeadHexSeipd = "d26902070306fcb94490bcb98bbdc9d106c6090266940f72e89edc21b5596b1576b101ed0f9ffc6fc6d65bbfd24dcd0790966e6d1e85a30053784cb1d8b6a0699ef12155a7b2ad6258531b57651fd7777912fa95e35d9b40216f69a4c248db28ff4331f1632907399e6ff9"
const aeadHexPlainText = "cb1362000000000048656c6c6f2c20776f726c6421d50e1ce2269a9eddef81032172b7ed7c"
const aeadExpectedSalt = "fcb94490bcb98bbdc9d106c6090266940f72e89edc21b5596b1576b101ed0f9f"

func TestAeadRfcVector(t *testing.T) {
	key, err := hex.DecodeString(aeadHexKey)
	if err != nil {
		t.Errorf("error in decoding key: %s", err)
	}

	packet, err := hex.DecodeString(aeadHexSeipd)
	if err != nil {
		t.Errorf("error in decoding packet: %s", err)
	}

	plainText, err := hex.DecodeString(aeadHexPlainText)
	if err != nil {
		t.Errorf("error in decoding plaintext: %s", err)
	}

	expectedSalt, err := hex.DecodeString(aeadExpectedSalt)
	if err != nil {
		t.Errorf("error in decoding salt: %s", err)
	}

	buf := bytes.NewBuffer(packet)
	p, err := Read(buf)
	if err != nil {
		t.Errorf("error from Read: %s", err)
		return
	}

	se, ok := p.(*SymmetricallyEncrypted)
	if !ok {
		t.Errorf("didn't read a *SymmetricallyEncrypted")
		return
	}

	if se.Version != symmetricallyEncryptedVersionAead {
		t.Errorf("found wrong version, want: %d, got: %d", symmetricallyEncryptedVersionAead, se.Version)
	}

	if se.Cipher != CipherAES128 {
		t.Errorf("found wrong cipher, want: %d, got: %d", CipherAES128, se.Cipher)
	}

	if se.Mode != AEADModeGCM {
		t.Errorf("found wrong mode, want: %d, got: %d", AEADModeGCM, se.Mode)
	}

	if !bytes.Equal(se.Salt[:], expectedSalt) {
		t.Errorf("found wrong salt, want: %x, got: %x", expectedSalt, se.Salt)
	}

	if se.ChunkSizeByte != 0x06 {
		t.Errorf("found wrong chunk size byte, want: %d, got: %d", 0x06, se.ChunkSizeByte)
	}

	aeadReader, err := se.Decrypt(CipherFunction(0), key)
	if err != nil {
		t.Errorf("error from Decrypt: %s", err)
		return
	}

	decrypted, err := io.ReadAll(aeadReader)
	if err != nil {
		t.Errorf("error when reading: %s", err)
		return
	}

	err = aeadReader.Close()
	if err != nil {
		t.Errorf("error when closing reader: %s", err)
		return
	}

	if !bytes.Equal(decrypted, plainText) {
		t.Errorf("contents not equal got: %x want: %x", decrypted, plainText)
	}
}

func TestAeadEncryptDecrypt(t *testing.T) {
	ciphers := map[string]CipherFunction{
		"AES128": CipherAES128,
		"AES192": CipherAES192,
		"AES256": CipherAES256,
	}

	modes := map[string]AEADMode{
		"EAX": AEADModeEAX,
		"OCB": AEADModeOCB,
		"GCM": AEADModeGCM,
	}

	for cipherName, cipher := range ciphers {
		t.Run(cipherName, func(t *testing.T) {
			for modeName, mode := range modes {
				t.Run(modeName, func(t *testing.T) {
					testSerializeAead(t, CipherSuite{Cipher: cipher, Mode: mode})
				})
			}
		})
	}
}

func testSerializeAead(t *testing.T, cipherSuite CipherSuite) {
	buf := bytes.NewBuffer(nil)
	key := make([]byte, cipherSuite.Cipher.KeySize())
	_, _ = rand.Read(key)

	w, err := SerializeSymmetricallyEncrypted(buf, CipherFunction(0), true, cipherSuite, key, &Config{AEADConfig: &AEADConfig{}})
	if err != nil {
		t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
		return
	}

	contents := []byte("hello world\n")

	w.Write(contents)
	w.Close()

	p, err := Read(buf)
	if err != nil {
		t.Errorf("error from Read: %s", err)
		return
	}

	se, ok := p.(*SymmetricallyEncrypted)
	if !ok {
		t.Errorf("didn't read a *SymmetricallyEncrypted")
		return
	}

	if se.Version != symmetricallyEncryptedVersionAead {
		t.Errorf("found wrong version, want: %d, got: %d", symmetricallyEncryptedVersionAead, se.Version)
	}

	if se.Cipher != cipherSuite.Cipher {
		t.Errorf("found wrong cipher, want: %d, got: %d", cipherSuite.Cipher, se.Cipher)
	}

	if se.Mode != cipherSuite.Mode {
		t.Errorf("found wrong mode, want: %d, got: %d", cipherSuite.Mode, se.Mode)
	}

	r, err := se.Decrypt(CipherFunction(0), key)
	if err != nil {
		t.Errorf("error from Decrypt: %s", err)
		return
	}

	contentsCopy := bytes.NewBuffer(nil)
	_, err = io.Copy(contentsCopy, r)
	if err != nil {
		t.Errorf("error from io.Copy: %s", err)
		return
	}
	if !bytes.Equal(contentsCopy.Bytes(), contents) {
		t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
	}
}
