// Copyright 2015 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// +build go1.8

package config

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net"
	"net/http"
	"net/http/httptest"
	"net/url"
	"os"
	"path/filepath"
	"reflect"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	yaml "gopkg.in/yaml.v2"
)

const (
	TLSCAChainPath        = "testdata/tls-ca-chain.pem"
	ServerCertificatePath = "testdata/server.crt"
	ServerKeyPath         = "testdata/server.key"
	ClientCertificatePath = "testdata/client.crt"
	ClientKeyNoPassPath   = "testdata/client-no-pass.key"
	InvalidCA             = "testdata/client-no-pass.key"
	WrongClientCertPath   = "testdata/self-signed-client.crt"
	WrongClientKeyPath    = "testdata/self-signed-client.key"
	EmptyFile             = "testdata/empty"
	MissingCA             = "missing/ca.crt"
	MissingCert           = "missing/cert.crt"
	MissingKey            = "missing/secret.key"

	ExpectedMessage                   = "I'm here to serve you!!!"
	ExpectedError                     = "expected error"
	AuthorizationCredentials          = "theanswertothegreatquestionoflifetheuniverseandeverythingisfortytwo"
	AuthorizationCredentialsFile      = "testdata/bearer.token"
	AuthorizationType                 = "APIKEY"
	BearerToken                       = AuthorizationCredentials
	BearerTokenFile                   = AuthorizationCredentialsFile
	MissingBearerTokenFile            = "missing/bearer.token"
	ExpectedBearer                    = "Bearer " + BearerToken
	ExpectedAuthenticationCredentials = AuthorizationType + " " + BearerToken
	ExpectedUsername                  = "arthurdent"
	ExpectedPassword                  = "42"
	ExpectedAccessToken               = "12345"
)

var invalidHTTPClientConfigs = []struct {
	httpClientConfigFile string
	errMsg               string
}{
	{
		httpClientConfigFile: "testdata/http.conf.bearer-token-and-file-set.bad.yml",
		errMsg:               "at most one of bearer_token & bearer_token_file must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.empty.bad.yml",
		errMsg:               "at most one of basic_auth, oauth2, bearer_token & bearer_token_file must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.basic-auth.too-much.bad.yaml",
		errMsg:               "at most one of basic_auth password & password_file must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.mix-bearer-and-creds.bad.yaml",
		errMsg:               "authorization is not compatible with bearer_token & bearer_token_file",
	},
	{
		httpClientConfigFile: "testdata/http.conf.auth-creds-and-file-set.too-much.bad.yaml",
		errMsg:               "at most one of authorization credentials & credentials_file must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.basic-auth-and-auth-creds.too-much.bad.yaml",
		errMsg:               "at most one of basic_auth, oauth2 & authorization must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.basic-auth-and-oauth2.too-much.bad.yaml",
		errMsg:               "at most one of basic_auth, oauth2 & authorization must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.auth-creds-no-basic.bad.yaml",
		errMsg:               `authorization type cannot be set to "basic", use "basic_auth" instead`,
	},
	{
		httpClientConfigFile: "testdata/http.conf.oauth2-secret-and-file-set.bad.yml",
		errMsg:               "at most one of oauth2 client_secret & client_secret_file must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.oauth2-no-client-id.bad.yaml",
		errMsg:               "oauth2 client_id must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.oauth2-no-client-secret.bad.yaml",
		errMsg:               "either oauth2 client_secret or client_secret_file must be configured",
	},
	{
		httpClientConfigFile: "testdata/http.conf.oauth2-no-token-url.bad.yaml",
		errMsg:               "oauth2 token_url must be configured",
	},
}

func newTestServer(handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, error) {
	testServer := httptest.NewUnstartedServer(http.HandlerFunc(handler))

	tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath)
	if err != nil {
		return nil, fmt.Errorf("Can't read %s", TLSCAChainPath)
	}
	serverCertificate, err := tls.LoadX509KeyPair(ServerCertificatePath, ServerKeyPath)
	if err != nil {
		return nil, fmt.Errorf("Can't load X509 key pair %s - %s", ServerCertificatePath, ServerKeyPath)
	}

	rootCAs := x509.NewCertPool()
	rootCAs.AppendCertsFromPEM(tlsCAChain)

	testServer.TLS = &tls.Config{
		Certificates: make([]tls.Certificate, 1),
		RootCAs:      rootCAs,
		ClientAuth:   tls.RequireAndVerifyClientCert,
		ClientCAs:    rootCAs}
	testServer.TLS.Certificates[0] = serverCertificate

	testServer.StartTLS()

	return testServer, nil
}

func TestNewClientFromConfig(t *testing.T) {
	var newClientValidConfig = []struct {
		clientConfig HTTPClientConfig
		handler      func(w http.ResponseWriter, r *http.Request)
	}{
		{
			clientConfig: HTTPClientConfig{
				TLSConfig: TLSConfig{
					CAFile:             "",
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: true},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				fmt.Fprint(w, ExpectedMessage)
			},
		}, {
			clientConfig: HTTPClientConfig{
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				fmt.Fprint(w, ExpectedMessage)
			},
		}, {
			clientConfig: HTTPClientConfig{
				BearerToken: BearerToken,
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				bearer := r.Header.Get("Authorization")
				if bearer != ExpectedBearer {
					fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
						ExpectedBearer, bearer)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				BearerTokenFile: BearerTokenFile,
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				bearer := r.Header.Get("Authorization")
				if bearer != ExpectedBearer {
					fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
						ExpectedBearer, bearer)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				Authorization: &Authorization{Credentials: BearerToken},
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				bearer := r.Header.Get("Authorization")
				if bearer != ExpectedBearer {
					fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
						ExpectedBearer, bearer)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				Authorization: &Authorization{CredentialsFile: AuthorizationCredentialsFile, Type: AuthorizationType},
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				bearer := r.Header.Get("Authorization")
				if bearer != ExpectedAuthenticationCredentials {
					fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
						ExpectedAuthenticationCredentials, bearer)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				Authorization: &Authorization{
					Credentials: AuthorizationCredentials,
					Type:        AuthorizationType,
				},
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				bearer := r.Header.Get("Authorization")
				if bearer != ExpectedAuthenticationCredentials {
					fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
						ExpectedAuthenticationCredentials, bearer)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				Authorization: &Authorization{
					CredentialsFile: BearerTokenFile,
				},
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				bearer := r.Header.Get("Authorization")
				if bearer != ExpectedBearer {
					fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
						ExpectedBearer, bearer)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				BasicAuth: &BasicAuth{
					Username: ExpectedUsername,
					Password: ExpectedPassword,
				},
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				username, password, ok := r.BasicAuth()
				if !ok {
					fmt.Fprintf(w, "The Authorization header wasn't set")
				} else if ExpectedUsername != username {
					fmt.Fprintf(w, "The expected username (%s) differs from the obtained username (%s).", ExpectedUsername, username)
				} else if ExpectedPassword != password {
					fmt.Fprintf(w, "The expected password (%s) differs from the obtained password (%s).", ExpectedPassword, password)
				} else {
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				FollowRedirects: true,
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				switch r.URL.Path {
				case "/redirected":
					fmt.Fprint(w, ExpectedMessage)
				default:
					w.Header().Set("Location", "/redirected")
					w.WriteHeader(http.StatusFound)
					fmt.Fprint(w, "It should follow the redirect.")
				}
			},
		}, {
			clientConfig: HTTPClientConfig{
				FollowRedirects: false,
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				switch r.URL.Path {
				case "/redirected":
					fmt.Fprint(w, "The redirection was followed.")
				default:
					w.Header().Set("Location", "/redirected")
					w.WriteHeader(http.StatusFound)
					fmt.Fprint(w, ExpectedMessage)
				}
			},
		},
		{
			clientConfig: HTTPClientConfig{
				OAuth2: &OAuth2{
					ClientID:     "ExpectedUsername",
					ClientSecret: "ExpectedPassword",
					TLSConfig: TLSConfig{
						CAFile:             TLSCAChainPath,
						CertFile:           ClientCertificatePath,
						KeyFile:            ClientKeyNoPassPath,
						ServerName:         "",
						InsecureSkipVerify: false},
				},
				TLSConfig: TLSConfig{
					CAFile:             TLSCAChainPath,
					CertFile:           ClientCertificatePath,
					KeyFile:            ClientKeyNoPassPath,
					ServerName:         "",
					InsecureSkipVerify: false},
			},
			handler: func(w http.ResponseWriter, r *http.Request) {
				switch r.URL.Path {
				case "/token":
					res, _ := json.Marshal(oauth2TestServerResponse{
						AccessToken: ExpectedAccessToken,
						TokenType:   "Bearer",
					})
					w.Header().Add("Content-Type", "application/json")
					_, _ = w.Write(res)

				default:
					authorization := r.Header.Get("Authorization")
					if authorization != "Bearer "+ExpectedAccessToken {
						fmt.Fprintf(w, "Expected Authorization header %q, got %q", "Bearer "+ExpectedAccessToken, authorization)
					} else {
						fmt.Fprint(w, ExpectedMessage)
					}
				}
			},
		},
	}

	for _, validConfig := range newClientValidConfig {
		testServer, err := newTestServer(validConfig.handler)
		if err != nil {
			t.Fatal(err.Error())
		}
		defer testServer.Close()

		if validConfig.clientConfig.OAuth2 != nil {
			// We don't have access to the test server's URL when configuring the test cases,
			// so it has to be specified here.
			validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
		}

		err = validConfig.clientConfig.Validate()
		if err != nil {
			t.Fatal(err.Error())
		}
		client, err := NewClientFromConfig(validConfig.clientConfig, "test")
		if err != nil {
			t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
			continue
		}

		response, err := client.Get(testServer.URL)
		if err != nil {
			t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
			continue
		}

		message, err := ioutil.ReadAll(response.Body)
		response.Body.Close()
		if err != nil {
			t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig)
			continue
		}

		trimMessage := strings.TrimSpace(string(message))
		if ExpectedMessage != trimMessage {
			t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v",
				ExpectedMessage, trimMessage, validConfig.clientConfig)
		}
	}
}

func TestNewClientFromInvalidConfig(t *testing.T) {
	var newClientInvalidConfig = []struct {
		clientConfig HTTPClientConfig
		errorMsg     string
	}{
		{
			clientConfig: HTTPClientConfig{
				TLSConfig: TLSConfig{
					CAFile:             MissingCA,
					InsecureSkipVerify: true},
			},
			errorMsg: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA),
		},
		{
			clientConfig: HTTPClientConfig{
				TLSConfig: TLSConfig{
					CAFile:             InvalidCA,
					InsecureSkipVerify: true},
			},
			errorMsg: fmt.Sprintf("unable to use specified CA cert %s", InvalidCA),
		},
	}

	for _, invalidConfig := range newClientInvalidConfig {
		client, err := NewClientFromConfig(invalidConfig.clientConfig, "test")
		if client != nil {
			t.Errorf("A client instance was returned instead of nil using this config: %+v", invalidConfig.clientConfig)
		}
		if err == nil {
			t.Errorf("No error was returned using this config: %+v", invalidConfig.clientConfig)
		}
		if !strings.Contains(err.Error(), invalidConfig.errorMsg) {
			t.Errorf("Expected error %q does not contain %q", err.Error(), invalidConfig.errorMsg)
		}
	}
}

func TestCustomDialContextFunc(t *testing.T) {
	dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
		return nil, errors.New(ExpectedError)
	}

	cfg := HTTPClientConfig{}
	client, err := NewClientFromConfig(cfg, "test", WithDialContextFunc(dialFn))
	if err != nil {
		t.Fatalf("Can't create a client from this config: %+v", cfg)
	}

	_, err = client.Get("http://localhost")
	if err == nil || !strings.Contains(err.Error(), ExpectedError) {
		t.Errorf("Expected error %q but got %q", ExpectedError, err)
	}
}

func TestCustomIdleConnTimeout(t *testing.T) {
	timeout := time.Second * 5

	cfg := HTTPClientConfig{}
	rt, err := NewRoundTripperFromConfig(cfg, "test", WithIdleConnTimeout(timeout))
	if err != nil {
		t.Fatalf("Can't create a round-tripper from this config: %+v", cfg)
	}

	transport, ok := rt.(*http.Transport)
	if !ok {
		t.Fatalf("Unexpected transport: %+v", transport)
	}

	if transport.IdleConnTimeout != timeout {
		t.Fatalf("Unexpected idle connection timeout: %+v", timeout)
	}
}

func TestMissingBearerAuthFile(t *testing.T) {
	cfg := HTTPClientConfig{
		BearerTokenFile: MissingBearerTokenFile,
		TLSConfig: TLSConfig{
			CAFile:             TLSCAChainPath,
			CertFile:           ClientCertificatePath,
			KeyFile:            ClientKeyNoPassPath,
			ServerName:         "",
			InsecureSkipVerify: false},
	}
	handler := func(w http.ResponseWriter, r *http.Request) {
		bearer := r.Header.Get("Authorization")
		if bearer != ExpectedBearer {
			fmt.Fprintf(w, "The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
				ExpectedBearer, bearer)
		} else {
			fmt.Fprint(w, ExpectedMessage)
		}
	}

	testServer, err := newTestServer(handler)
	if err != nil {
		t.Fatal(err.Error())
	}
	defer testServer.Close()

	client, err := NewClientFromConfig(cfg, "test")
	if err != nil {
		t.Fatal(err)
	}

	_, err = client.Get(testServer.URL)
	if err == nil {
		t.Fatal("No error is returned here")
	}

	if !strings.Contains(err.Error(), "unable to read authorization credentials file missing/bearer.token: open missing/bearer.token: no such file or directory") {
		t.Fatal("wrong error message being returned")
	}
}

func TestBearerAuthRoundTripper(t *testing.T) {
	const (
		newBearerToken = "goodbyeandthankyouforthefish"
	)

	fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) {
		bearer := req.Header.Get("Authorization")
		if bearer != ExpectedBearer {
			t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
				ExpectedBearer, bearer)
		}
	}, nil, nil)

	// Normal flow.
	bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", BearerToken, fakeRoundTripper)
	request, _ := http.NewRequest("GET", "/hitchhiker", nil)
	request.Header.Set("User-Agent", "Douglas Adams mind")
	_, err := bearerAuthRoundTripper.RoundTrip(request)
	if err != nil {
		t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
	}

	// Should honor already Authorization header set.
	bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", newBearerToken, fakeRoundTripper)
	request, _ = http.NewRequest("GET", "/hitchhiker", nil)
	request.Header.Set("Authorization", ExpectedBearer)
	_, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request)
	if err != nil {
		t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
	}
}

func TestBearerAuthFileRoundTripper(t *testing.T) {
	fakeRoundTripper := NewRoundTripCheckRequest(func(req *http.Request) {
		bearer := req.Header.Get("Authorization")
		if bearer != ExpectedBearer {
			t.Errorf("The expected Bearer Authorization (%s) differs from the obtained Bearer Authorization (%s)",
				ExpectedBearer, bearer)
		}
	}, nil, nil)

	// Normal flow.
	bearerAuthRoundTripper := NewAuthorizationCredentialsFileRoundTripper("Bearer", BearerTokenFile, fakeRoundTripper)
	request, _ := http.NewRequest("GET", "/hitchhiker", nil)
	request.Header.Set("User-Agent", "Douglas Adams mind")
	_, err := bearerAuthRoundTripper.RoundTrip(request)
	if err != nil {
		t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
	}

	// Should honor already Authorization header set.
	bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsFileRoundTripper("Bearer", MissingBearerTokenFile, fakeRoundTripper)
	request, _ = http.NewRequest("GET", "/hitchhiker", nil)
	request.Header.Set("Authorization", ExpectedBearer)
	_, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request)
	if err != nil {
		t.Errorf("unexpected error while executing RoundTrip: %s", err.Error())
	}
}

func TestTLSConfig(t *testing.T) {
	configTLSConfig := TLSConfig{
		CAFile:             TLSCAChainPath,
		CertFile:           ClientCertificatePath,
		KeyFile:            ClientKeyNoPassPath,
		ServerName:         "localhost",
		InsecureSkipVerify: false}

	tlsCAChain, err := ioutil.ReadFile(TLSCAChainPath)
	if err != nil {
		t.Fatalf("Can't read the CA certificate chain (%s)",
			TLSCAChainPath)
	}
	rootCAs := x509.NewCertPool()
	rootCAs.AppendCertsFromPEM(tlsCAChain)

	expectedTLSConfig := &tls.Config{
		RootCAs:            rootCAs,
		ServerName:         configTLSConfig.ServerName,
		InsecureSkipVerify: configTLSConfig.InsecureSkipVerify}

	tlsConfig, err := NewTLSConfig(&configTLSConfig)
	if err != nil {
		t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err)
	}

	clientCertificate, err := tls.LoadX509KeyPair(ClientCertificatePath, ClientKeyNoPassPath)
	if err != nil {
		t.Fatalf("Can't load the client key pair ('%s' and '%s'). Reason: %s",
			ClientCertificatePath, ClientKeyNoPassPath, err)
	}
	cert, err := tlsConfig.GetClientCertificate(nil)
	if err != nil {
		t.Fatalf("unexpected error returned by tlsConfig.GetClientCertificate(): %s", err)
	}
	if !reflect.DeepEqual(cert, &clientCertificate) {
		t.Fatalf("Unexpected client certificate result: \n\n%+v\n expected\n\n%+v", cert, clientCertificate)
	}

	// tlsConfig.rootCAs.LazyCerts contains functions getCert() in go 1.16, which are
	// never equal. Compare the Subjects instead.
	if !reflect.DeepEqual(tlsConfig.RootCAs.Subjects(), expectedTLSConfig.RootCAs.Subjects()) {
		t.Fatalf("Unexpected RootCAs result: \n\n%+v\n expected\n\n%+v", tlsConfig.RootCAs.Subjects(), expectedTLSConfig.RootCAs.Subjects())
	}
	tlsConfig.RootCAs = nil
	expectedTLSConfig.RootCAs = nil

	// Non-nil functions are never equal.
	tlsConfig.GetClientCertificate = nil

	if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) {
		t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig)
	}
}

func TestTLSConfigEmpty(t *testing.T) {
	configTLSConfig := TLSConfig{
		InsecureSkipVerify: true,
	}

	expectedTLSConfig := &tls.Config{
		InsecureSkipVerify: configTLSConfig.InsecureSkipVerify,
	}

	tlsConfig, err := NewTLSConfig(&configTLSConfig)
	if err != nil {
		t.Fatalf("Can't create a new TLS Config from a configuration (%s).", err)
	}

	if !reflect.DeepEqual(tlsConfig, expectedTLSConfig) {
		t.Fatalf("Unexpected TLS Config result: \n\n%+v\n expected\n\n%+v", tlsConfig, expectedTLSConfig)
	}
}

func TestTLSConfigInvalidCA(t *testing.T) {
	var invalidTLSConfig = []struct {
		configTLSConfig TLSConfig
		errorMessage    string
	}{
		{
			configTLSConfig: TLSConfig{
				CAFile:             MissingCA,
				CertFile:           "",
				KeyFile:            "",
				ServerName:         "",
				InsecureSkipVerify: false},
			errorMessage: fmt.Sprintf("unable to load specified CA cert %s:", MissingCA),
		}, {
			configTLSConfig: TLSConfig{
				CAFile:             "",
				CertFile:           MissingCert,
				KeyFile:            ClientKeyNoPassPath,
				ServerName:         "",
				InsecureSkipVerify: false},
			errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", MissingCert, ClientKeyNoPassPath),
		}, {
			configTLSConfig: TLSConfig{
				CAFile:             "",
				CertFile:           ClientCertificatePath,
				KeyFile:            MissingKey,
				ServerName:         "",
				InsecureSkipVerify: false},
			errorMessage: fmt.Sprintf("unable to use specified client cert (%s) & key (%s):", ClientCertificatePath, MissingKey),
		},
	}

	for _, anInvalididTLSConfig := range invalidTLSConfig {
		tlsConfig, err := NewTLSConfig(&anInvalididTLSConfig.configTLSConfig)
		if tlsConfig != nil && err == nil {
			t.Errorf("The TLS Config could be created even with this %+v", anInvalididTLSConfig.configTLSConfig)
			continue
		}
		if !strings.Contains(err.Error(), anInvalididTLSConfig.errorMessage) {
			t.Errorf("The expected error should contain %s, but got %s", anInvalididTLSConfig.errorMessage, err)
		}
	}
}

func TestBasicAuthNoPassword(t *testing.T) {
	cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-password.yaml")
	if err != nil {
		t.Fatalf("Error loading HTTP client config: %v", err)
	}
	client, err := NewClientFromConfig(*cfg, "test")
	if err != nil {
		t.Fatalf("Error creating HTTP Client: %v", err)
	}

	rt, ok := client.Transport.(*basicAuthRoundTripper)
	if !ok {
		t.Fatalf("Error casting to basic auth transport, %v", client.Transport)
	}

	if rt.username != "user" {
		t.Errorf("Bad HTTP client username: %s", rt.username)
	}
	if string(rt.password) != "" {
		t.Errorf("Expected empty HTTP client password: %s", rt.password)
	}
	if string(rt.passwordFile) != "" {
		t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile)
	}
}

func TestBasicAuthNoUsername(t *testing.T) {
	cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.no-username.yaml")
	if err != nil {
		t.Fatalf("Error loading HTTP client config: %v", err)
	}
	client, err := NewClientFromConfig(*cfg, "test")
	if err != nil {
		t.Fatalf("Error creating HTTP Client: %v", err)
	}

	rt, ok := client.Transport.(*basicAuthRoundTripper)
	if !ok {
		t.Fatalf("Error casting to basic auth transport, %v", client.Transport)
	}

	if rt.username != "" {
		t.Errorf("Got unexpected username: %s", rt.username)
	}
	if string(rt.password) != "secret" {
		t.Errorf("Unexpected HTTP client password: %s", string(rt.password))
	}
	if string(rt.passwordFile) != "" {
		t.Errorf("Expected empty HTTP client passwordFile: %s", rt.passwordFile)
	}
}

func TestBasicAuthPasswordFile(t *testing.T) {
	cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.basic-auth.good.yaml")
	if err != nil {
		t.Fatalf("Error loading HTTP client config: %v", err)
	}
	client, err := NewClientFromConfig(*cfg, "test")
	if err != nil {
		t.Fatalf("Error creating HTTP Client: %v", err)
	}

	rt, ok := client.Transport.(*basicAuthRoundTripper)
	if !ok {
		t.Fatalf("Error casting to basic auth transport, %v", client.Transport)
	}

	if rt.username != "user" {
		t.Errorf("Bad HTTP client username: %s", rt.username)
	}
	if string(rt.password) != "" {
		t.Errorf("Bad HTTP client password: %s", rt.password)
	}
	if string(rt.passwordFile) != "testdata/basic-auth-password" {
		t.Errorf("Bad HTTP client passwordFile: %s", rt.passwordFile)
	}
}

func getCertificateBlobs(t *testing.T) map[string][]byte {
	files := []string{
		TLSCAChainPath,
		ClientCertificatePath,
		ClientKeyNoPassPath,
		ServerCertificatePath,
		ServerKeyPath,
		WrongClientCertPath,
		WrongClientKeyPath,
		EmptyFile,
	}
	bs := make(map[string][]byte, len(files)+1)
	for _, f := range files {
		b, err := ioutil.ReadFile(f)
		if err != nil {
			t.Fatal(err)
		}
		bs[f] = b
	}

	return bs
}

func writeCertificate(bs map[string][]byte, src string, dst string) {
	b, ok := bs[src]
	if !ok {
		panic(fmt.Sprintf("Couldn't find %q in bs", src))
	}
	if err := ioutil.WriteFile(dst, b, 0664); err != nil {
		panic(err)
	}
}

func TestTLSRoundTripper(t *testing.T) {
	bs := getCertificateBlobs(t)

	tmpDir, err := ioutil.TempDir("", "tlsroundtripper")
	if err != nil {
		t.Fatal("Failed to create tmp dir", err)
	}
	defer os.RemoveAll(tmpDir)

	ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")

	handler := func(w http.ResponseWriter, r *http.Request) {
		fmt.Fprint(w, ExpectedMessage)
	}
	testServer, err := newTestServer(handler)
	if err != nil {
		t.Fatal(err.Error())
	}
	defer testServer.Close()

	testCases := []struct {
		ca   string
		cert string
		key  string

		errMsg string
	}{
		{
			// Valid certs.
			ca:   TLSCAChainPath,
			cert: ClientCertificatePath,
			key:  ClientKeyNoPassPath,
		},
		{
			// CA not matching.
			ca:   ClientCertificatePath,
			cert: ClientCertificatePath,
			key:  ClientKeyNoPassPath,

			errMsg: "certificate signed by unknown authority",
		},
		{
			// Invalid client cert+key.
			ca:   TLSCAChainPath,
			cert: WrongClientCertPath,
			key:  WrongClientKeyPath,

			errMsg: "remote error: tls",
		},
		{
			// CA file empty
			ca:   EmptyFile,
			cert: ClientCertificatePath,
			key:  ClientKeyNoPassPath,

			errMsg: "unable to use specified CA cert",
		},
		{
			// cert file empty
			ca:   TLSCAChainPath,
			cert: EmptyFile,
			key:  ClientKeyNoPassPath,

			errMsg: "failed to find any PEM data in certificate input",
		},
		{
			// key file empty
			ca:   TLSCAChainPath,
			cert: ClientCertificatePath,
			key:  EmptyFile,

			errMsg: "failed to find any PEM data in key input",
		},
		{
			// Valid certs again.
			ca:   TLSCAChainPath,
			cert: ClientCertificatePath,
			key:  ClientKeyNoPassPath,
		},
	}

	cfg := HTTPClientConfig{
		TLSConfig: TLSConfig{
			CAFile:             ca,
			CertFile:           cert,
			KeyFile:            key,
			InsecureSkipVerify: false},
	}

	var c *http.Client
	for i, tc := range testCases {
		tc := tc
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			writeCertificate(bs, tc.ca, ca)
			writeCertificate(bs, tc.cert, cert)
			writeCertificate(bs, tc.key, key)
			if c == nil {
				c, err = NewClientFromConfig(cfg, "test")
				if err != nil {
					t.Fatalf("Error creating HTTP Client: %v", err)
				}
			}

			req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
			if err != nil {
				t.Fatalf("Error creating HTTP request: %v", err)
			}
			r, err := c.Do(req)
			if len(tc.errMsg) > 0 {
				if err == nil {
					r.Body.Close()
					t.Fatalf("Could connect to the test server.")
				}
				if !strings.Contains(err.Error(), tc.errMsg) {
					t.Fatalf("Expected error message to contain %q, got %q", tc.errMsg, err)
				}
				return
			}

			if err != nil {
				t.Fatalf("Can't connect to the test server")
			}

			b, err := ioutil.ReadAll(r.Body)
			r.Body.Close()
			if err != nil {
				t.Errorf("Can't read the server response body")
			}

			got := strings.TrimSpace(string(b))
			if ExpectedMessage != got {
				t.Errorf("The expected message %q differs from the obtained message %q", ExpectedMessage, got)
			}
		})
	}
}

func TestTLSRoundTripperRaces(t *testing.T) {
	bs := getCertificateBlobs(t)

	tmpDir, err := ioutil.TempDir("", "tlsroundtripper")
	if err != nil {
		t.Fatal("Failed to create tmp dir", err)
	}
	defer os.RemoveAll(tmpDir)

	ca, cert, key := filepath.Join(tmpDir, "ca"), filepath.Join(tmpDir, "cert"), filepath.Join(tmpDir, "key")

	handler := func(w http.ResponseWriter, r *http.Request) {
		fmt.Fprint(w, ExpectedMessage)
	}
	testServer, err := newTestServer(handler)
	if err != nil {
		t.Fatal(err.Error())
	}
	defer testServer.Close()

	cfg := HTTPClientConfig{
		TLSConfig: TLSConfig{
			CAFile:             ca,
			CertFile:           cert,
			KeyFile:            key,
			InsecureSkipVerify: false},
	}

	var c *http.Client
	writeCertificate(bs, TLSCAChainPath, ca)
	writeCertificate(bs, ClientCertificatePath, cert)
	writeCertificate(bs, ClientKeyNoPassPath, key)
	c, err = NewClientFromConfig(cfg, "test")
	if err != nil {
		t.Fatalf("Error creating HTTP Client: %v", err)
	}

	var wg sync.WaitGroup
	ch := make(chan struct{})
	var total, ok int64
	// Spawn 10 Go routines polling the server concurrently.
	for i := 0; i < 10; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for {
				select {
				case <-ch:
					return
				default:
					atomic.AddInt64(&total, 1)
					r, err := c.Get(testServer.URL)
					if err == nil {
						r.Body.Close()
						atomic.AddInt64(&ok, 1)
					}
				}
			}
		}()
	}

	// Change the CA file every 10ms for 1 second.
	wg.Add(1)
	go func() {
		defer wg.Done()
		i := 0
		for {
			tick := time.NewTicker(10 * time.Millisecond)
			<-tick.C
			if i%2 == 0 {
				writeCertificate(bs, ClientCertificatePath, ca)
			} else {
				writeCertificate(bs, TLSCAChainPath, ca)
			}
			i++
			if i > 100 {
				close(ch)
				return
			}
		}
	}()

	wg.Wait()
	if ok == total {
		t.Fatalf("Expecting some requests to fail but got %d/%d successful requests", ok, total)
	}
}

func TestHideHTTPClientConfigSecrets(t *testing.T) {
	c, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml")
	if err != nil {
		t.Errorf("Error parsing %s: %s", "testdata/http.conf.good.yml", err)
	}

	// String method must not reveal authentication credentials.
	s := c.String()
	if strings.Contains(s, "mysecret") {
		t.Fatal("http client config's String method reveals authentication credentials.")
	}
}

func TestDefaultFollowRedirect(t *testing.T) {
	cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml")
	if err != nil {
		t.Errorf("Error loading HTTP client config: %v", err)
	}
	if !cfg.FollowRedirects {
		t.Errorf("follow_redirects should be true")
	}
}

func TestValidateHTTPConfig(t *testing.T) {
	cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml")
	if err != nil {
		t.Errorf("Error loading HTTP client config: %v", err)
	}
	err = cfg.Validate()
	if err != nil {
		t.Fatalf("Error validating %s: %s", "testdata/http.conf.good.yml", err)
	}
}

func TestInvalidHTTPConfigs(t *testing.T) {
	for _, ee := range invalidHTTPClientConfigs {
		_, _, err := LoadHTTPConfigFile(ee.httpClientConfigFile)
		if err == nil {
			t.Error("Expected error with config but got none")
			continue
		}
		if !strings.Contains(err.Error(), ee.errMsg) {
			t.Errorf("Expected error for invalid HTTP client configuration to contain %q but got: %s", ee.errMsg, err)
		}
	}
}

// LoadHTTPConfig parses the YAML input s into a HTTPClientConfig.
func LoadHTTPConfig(s string) (*HTTPClientConfig, error) {
	cfg := &HTTPClientConfig{}
	err := yaml.UnmarshalStrict([]byte(s), cfg)
	if err != nil {
		return nil, err
	}
	return cfg, nil
}

// LoadHTTPConfigFile parses the given YAML file into a HTTPClientConfig.
func LoadHTTPConfigFile(filename string) (*HTTPClientConfig, []byte, error) {
	content, err := ioutil.ReadFile(filename)
	if err != nil {
		return nil, nil, err
	}
	cfg, err := LoadHTTPConfig(string(content))
	if err != nil {
		return nil, nil, err
	}
	return cfg, content, nil
}

type roundTrip struct {
	theResponse *http.Response
	theError    error
}

func (rt *roundTrip) RoundTrip(r *http.Request) (*http.Response, error) {
	return rt.theResponse, rt.theError
}

type roundTripCheckRequest struct {
	checkRequest func(*http.Request)
	roundTrip
}

func (rt *roundTripCheckRequest) RoundTrip(r *http.Request) (*http.Response, error) {
	rt.checkRequest(r)
	return rt.theResponse, rt.theError
}

// NewRoundTripCheckRequest creates a new instance of a type that implements http.RoundTripper,
// which before returning theResponse and theError, executes checkRequest against a http.Request.
func NewRoundTripCheckRequest(checkRequest func(*http.Request), theResponse *http.Response, theError error) http.RoundTripper {
	return &roundTripCheckRequest{
		checkRequest: checkRequest,
		roundTrip: roundTrip{
			theResponse: theResponse,
			theError:    theError}}
}

type oauth2TestServerResponse struct {
	AccessToken string `json:"access_token"`
	TokenType   string `json:"token_type"`
}

func TestOAuth2(t *testing.T) {
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		res, _ := json.Marshal(oauth2TestServerResponse{
			AccessToken: "12345",
			TokenType:   "Bearer",
		})
		w.Header().Add("Content-Type", "application/json")
		_, _ = w.Write(res)
	}))
	defer ts.Close()

	var yamlConfig = fmt.Sprintf(`
client_id: 1
client_secret: 2
scopes:
 - A
 - B
token_url: %s
endpoint_params:
 hi: hello
`, ts.URL)
	expectedConfig := OAuth2{
		ClientID:       "1",
		ClientSecret:   "2",
		Scopes:         []string{"A", "B"},
		EndpointParams: map[string]string{"hi": "hello"},
		TokenURL:       ts.URL,
	}

	var unmarshalledConfig OAuth2
	err := yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig)
	if err != nil {
		t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
	}
	if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
		t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
	}

	rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)

	client := http.Client{
		Transport: rt,
	}
	resp, _ := client.Get(ts.URL)

	authorization := resp.Request.Header.Get("Authorization")
	if authorization != "Bearer 12345" {
		t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
	}
}

func TestOAuth2WithFile(t *testing.T) {
	var expectedAuth *string
	var previousAuth string
	tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		auth := r.Header.Get("Authorization")
		if auth != *expectedAuth {
			t.Fatalf("bad auth, expected %s, got %s", *expectedAuth, auth)
		}
		if auth == previousAuth {
			t.Fatal("token endpoint called twice")
		}
		previousAuth = auth
		res, _ := json.Marshal(oauth2TestServerResponse{
			AccessToken: "12345",
			TokenType:   "Bearer",
		})
		w.Header().Add("Content-Type", "application/json")
		_, _ = w.Write(res)
	}))
	defer tokenTS.Close()
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		auth := r.Header.Get("Authorization")
		if auth != "Bearer 12345" {
			t.Fatalf("bad auth, expected %s, got %s", "Bearer 12345", auth)
		}
		fmt.Fprintln(w, "Hello, client")
	}))
	defer ts.Close()

	secretFile, err := ioutil.TempFile("", "oauth2_secret")
	if err != nil {
		t.Fatal(err)
	}
	defer os.Remove(secretFile.Name())

	var yamlConfig = fmt.Sprintf(`
client_id: 1
client_secret_file: %s
scopes:
 - A
 - B
token_url: %s
endpoint_params:
 hi: hello
`, secretFile.Name(), tokenTS.URL)
	expectedConfig := OAuth2{
		ClientID:         "1",
		ClientSecretFile: secretFile.Name(),
		Scopes:           []string{"A", "B"},
		EndpointParams:   map[string]string{"hi": "hello"},
		TokenURL:         tokenTS.URL,
	}

	var unmarshalledConfig OAuth2
	err = yaml.Unmarshal([]byte(yamlConfig), &unmarshalledConfig)
	if err != nil {
		t.Fatalf("Expected no error unmarshalling yaml, got %v", err)
	}
	if !reflect.DeepEqual(unmarshalledConfig, expectedConfig) {
		t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
	}

	rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)

	client := http.Client{
		Transport: rt,
	}

	tk := "Basic MToxMjM0NTY="
	expectedAuth = &tk
	if _, err := secretFile.Write([]byte("123456")); err != nil {
		t.Fatal(err)
	}
	resp, err := client.Get(ts.URL)
	if err != nil {
		t.Fatal(err)
	}

	authorization := resp.Request.Header.Get("Authorization")
	if authorization != "Bearer 12345" {
		t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
	}

	// Making a second request with the same file content should not re-call the token API.
	resp, err = client.Get(ts.URL)
	if err != nil {
		t.Fatal(err)
	}

	tk = "Basic MToxMjM0NTY3"
	expectedAuth = &tk
	if _, err := secretFile.Write([]byte("7")); err != nil {
		t.Fatal(err)
	}

	_, err = client.Get(ts.URL)
	if err != nil {
		t.Fatal(err)
	}

	// Making a second request with the same file content should not re-call the token API.
	_, err = client.Get(ts.URL)
	if err != nil {
		t.Fatal(err)
	}

	authorization = resp.Request.Header.Get("Authorization")
	if authorization != "Bearer 12345" {
		t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
	}
}

func TestMarshalURL(t *testing.T) {
	urlp, err := url.Parse("http://example.com/")
	if err != nil {
		t.Fatal(err)
	}
	u := &URL{urlp}

	c, err := json.Marshal(u)
	if err != nil {
		t.Fatal(err)
	}
	if string(c) != "\"http://example.com/\"" {
		t.Fatalf("URL not properly marshaled in JSON got '%s'", string(c))
	}

	c, err = yaml.Marshal(u)
	if err != nil {
		t.Fatal(err)
	}
	if string(c) != "http://example.com/\n" {
		t.Fatalf("URL not properly marshaled in YAML got '%s'", string(c))
	}
}

func TestMarshalURLWrapperWithNilValue(t *testing.T) {
	u := &URL{}

	c, err := json.Marshal(u)
	if err != nil {
		t.Fatal(err)
	}
	if string(c) != "null" {
		t.Fatalf("URL with nil value not properly marshaled into JSON, got %q", c)
	}

	c, err = yaml.Marshal(u)
	if err != nil {
		t.Fatal(err)
	}
	if string(c) != "null\n" {
		t.Fatalf("URL with nil value not properly marshaled into JSON, got %q", c)
	}
}

func TestUnmarshalNullURL(t *testing.T) {
	b := []byte(`null`)

	{
		var u URL
		err := json.Unmarshal(b, &u)
		if err != nil {
			t.Fatal(err)
		}
		if !isEmptyNonNilURL(u.URL) {
			t.Fatalf("`null` literal not properly unmarshaled from JSON as URL, got %#v", u.URL)
		}
	}

	{
		var u URL
		err := yaml.Unmarshal(b, &u)
		if err != nil {
			t.Fatal(err)
		}
		if u.URL != nil { // UnmarshalYAML is not called when parsing null literal.
			t.Fatalf("`null` literal not properly unmarshaled from YAML as URL, got %#v", u.URL)
		}
	}
}

func TestUnmarshalEmptyURL(t *testing.T) {
	b := []byte(`""`)

	{
		var u URL
		err := json.Unmarshal(b, &u)
		if err != nil {
			t.Fatal(err)
		}
		if !isEmptyNonNilURL(u.URL) {
			t.Fatalf("empty string not properly unmarshaled from JSON as URL, got %#v", u.URL)
		}
	}

	{
		var u URL
		err := yaml.Unmarshal(b, &u)
		if err != nil {
			t.Fatal(err)
		}
		if !isEmptyNonNilURL(u.URL) {
			t.Fatalf("empty string not properly unmarshaled from YAML as URL, got %#v", u.URL)
		}
	}
}

// checks if u equals to &url.URL{}
func isEmptyNonNilURL(u *url.URL) bool {
	return u != nil && *u == url.URL{}
}

func TestUnmarshalURL(t *testing.T) {
	b := []byte(`"http://example.com/a b"`)
	var u URL

	err := json.Unmarshal(b, &u)
	if err != nil {
		t.Fatal(err)
	}
	if u.String() != "http://example.com/a%20b" {
		t.Fatalf("URL not properly unmarshaled in JSON, got '%s'", u.String())
	}

	err = yaml.Unmarshal(b, &u)
	if err != nil {
		t.Fatal(err)
	}
	if u.String() != "http://example.com/a%20b" {
		t.Fatalf("URL not properly unmarshaled in YAML, got '%s'", u.String())
	}
}

func TestMarshalURLWithSecret(t *testing.T) {
	var u URL
	err := yaml.Unmarshal([]byte("http://foo:bar@example.com"), &u)
	if err != nil {
		t.Fatal(err)
	}

	b, err := yaml.Marshal(u)
	if err != nil {
		t.Fatal(err)
	}
	if strings.TrimSpace(string(b)) != "http://foo:xxxxx@example.com" {
		t.Fatalf("URL not properly marshaled in YAML, got '%s'", string(b))
	}
}
