// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build generate

//go:generate go run gen_fallback_bundle.go

package main

import (
	"bytes"
	"crypto/sha256"
	"flag"
	"fmt"
	"go/format"
	"io"
	"log"
	"mime"
	"net/http"
	"os"
	"sort"
	"time"

	"golang.org/x/crypto/x509roots/nss"
)

const tmpl = `// Code generated by gen_fallback_bundle.go; DO NOT EDIT.

package bundle

var unparsedCertificates = []unparsedCertificate{
`

var (
	certDataURL  = flag.String("certdata-url", "https://hg.mozilla.org/mozilla-central/raw-file/tip/security/nss/lib/ckfw/builtins/certdata.txt", "URL to the raw certdata.txt file to parse (certdata-path overrides this, if provided)")
	certDataPath = flag.String("certdata-path", "", "Path to the NSS certdata.txt file to parse (this overrides certdata-url, if provided)")
	output       = flag.String("output", "fallback/bundle/bundle.go", "Path to file to write output to")
	derOutput    = flag.String("deroutput", "fallback/bundle/bundle.der", "Path to file to write output to (DER certificate bundle)")
)

func main() {
	flag.Parse()

	var certdata io.Reader

	if *certDataPath != "" {
		f, err := os.Open(*certDataPath)
		if err != nil {
			log.Fatalf("unable to open %q: %s", *certDataPath, err)
		}
		defer f.Close()
		certdata = f
	} else {
		resp, err := http.Get(*certDataURL)
		if err != nil {
			log.Fatalf("failed to request %q: %s", *certDataURL, err)
		}
		defer resp.Body.Close()
		if resp.StatusCode != http.StatusOK {
			body, _ := io.ReadAll(io.LimitReader(resp.Body, 4<<10))
			log.Fatalf("got non-200 OK status code: %v body: %q", resp.Status, body)
		} else if ct, want := resp.Header.Get("Content-Type"), `text/plain; charset="UTF-8"`; ct != want {
			if mediaType, _, err := mime.ParseMediaType(ct); err != nil {
				log.Fatalf("bad Content-Type header %q: %v", ct, err)
			} else if mediaType != "text/plain" {
				log.Fatalf("got media type %q, want %q", mediaType, "text/plain")
			}
		}
		certdata = resp.Body
	}

	certs, err := nss.Parse(certdata)
	if err != nil {
		log.Fatalf("failed to parse %q: %s", *certDataPath, err)
	}

	if len(certs) == 0 {
		log.Fatal("certdata.txt appears to contain zero roots")
	}

	sort.Slice(certs, func(i, j int) bool {
		// Sort based on the stringified subject (which may not be unique), and
		// break any ties by just sorting on the raw DER (which will be unique,
		// but is expensive). This should produce a stable sorting, which should
		// be mostly readable by a human looking for a specific root or set of
		// roots.
		subjI, subjJ := certs[i].X509.Subject.String(), certs[j].X509.Subject.String()
		if subjI == subjJ {
			return string(certs[i].X509.Raw) < string(certs[j].X509.Raw)
		}
		return subjI < subjJ
	})

	rawCertsData := new(bytes.Buffer)
	goSrcOut := new(bytes.Buffer)
	goSrcOut.WriteString(tmpl)
	for _, c := range certs {
		var constraints []string
		var skip bool
		for _, constraint := range c.Constraints {
			switch t := constraint.(type) {
			case nss.DistrustAfter:
				constraints = append(constraints, fmt.Sprintf("distrustAfter: \"%s\",", time.Time(t).Format(time.RFC3339)))
			default:
				// If we encounter any constraints we don't support, skip the certificate.
				skip = true
				break
			}
		}
		if skip {
			continue
		}

		off := rawCertsData.Len()
		rawCertsData.Write(c.X509.Raw)

		fmt.Fprintf(goSrcOut, "{\ncn: %q,\nsha256Hash: \"%x\",\ncertStartOff: %v,\ncertLength: %v,\n",
			c.X509.Subject.String(),
			sha256.Sum256(c.X509.Raw),
			off,
			len(c.X509.Raw),
		)
		for _, constraint := range constraints {
			fmt.Fprintln(goSrcOut, constraint)
		}
		fmt.Fprintln(goSrcOut, "},")
	}
	fmt.Fprintln(goSrcOut, "}")

	formatted, err := format.Source(goSrcOut.Bytes())
	if err != nil {
		log.Fatalf("failed to format source: %s", err)
	}

	if err := os.WriteFile(*output, formatted, 0644); err != nil {
		log.Fatalf("failed to write to %q: %s", *output, err)
	}

	if err := os.WriteFile(*derOutput, rawCertsData.Bytes(), 0644); err != nil {
		log.Fatalf("failed to write to %q: %s", *output, err)
	}
}
