package dbus

import (
	"context"
	"encoding/binary"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"sync"
	"testing"
	"time"
)

func TestSessionBus(t *testing.T) {
	oldConn, err := SessionBus()
	if err != nil {
		t.Error(err)
	}
	if err = oldConn.Close(); err != nil {
		t.Fatal(err)
	}
	if oldConn.Connected() {
		t.Fatal("Should be closed")
	}
	newConn, err := SessionBus()
	if err != nil {
		t.Error(err)
	}
	if newConn == oldConn {
		t.Fatal("Should get a new connection")
	}
}

func TestSystemBus(t *testing.T) {
	oldConn, err := SystemBus()
	if err != nil {
		t.Error(err)
	}
	if err = oldConn.Close(); err != nil {
		t.Fatal(err)
	}
	if oldConn.Connected() {
		t.Fatal("Should be closed")
	}
	newConn, err := SystemBus()
	if err != nil {
		t.Error(err)
	}
	if newConn == oldConn {
		t.Fatal("Should get a new connection")
	}
}

func TestConnectSessionBus(t *testing.T) {
	conn, err := ConnectSessionBus()
	if err != nil {
		t.Fatal(err)
	}
	if err = conn.Close(); err != nil {
		t.Fatal(err)
	}
	if conn.Connected() {
		t.Fatal("Should be closed")
	}
}

func TestConnectSystemBus(t *testing.T) {
	conn, err := ConnectSystemBus()
	if err != nil {
		t.Fatal(err)
	}
	if err = conn.Close(); err != nil {
		t.Fatal(err)
	}
	if conn.Connected() {
		t.Fatal("Should be closed")
	}
}

func TestSend(t *testing.T) {
	bus, err := ConnectSessionBus()
	if err != nil {
		t.Fatal(err)
	}
	defer bus.Close()

	ch := make(chan *Call, 1)
	msg := &Message{
		Type:  TypeMethodCall,
		Flags: 0,
		Headers: map[HeaderField]Variant{
			FieldDestination: MakeVariant(bus.Names()[0]),
			FieldPath:        MakeVariant(ObjectPath("/org/freedesktop/DBus")),
			FieldInterface:   MakeVariant("org.freedesktop.DBus.Peer"),
			FieldMember:      MakeVariant("Ping"),
		},
	}
	call := bus.Send(msg, ch)
	<-ch
	if call.Err != nil {
		t.Error(call.Err)
	}
}

func TestFlagNoReplyExpectedSend(t *testing.T) {
	bus, err := ConnectSessionBus()
	if err != nil {
		t.Fatal(err)
	}
	defer bus.Close()

	done := make(chan struct{})
	go func() {
		bus.BusObject().Call("org.freedesktop.DBus.ListNames", FlagNoReplyExpected)
		close(done)
	}()
	select {
	case <-done:
	case <-time.After(1 * time.Second):
		t.Error("Failed to announce that the call was done")
	}
}

func TestRemoveSignal(t *testing.T) {
	bus, err := NewConn(nil)
	if err != nil {
		t.Error(err)
	}
	signals := bus.signalHandler.(*defaultSignalHandler).signals
	ch := make(chan *Signal)
	ch2 := make(chan *Signal)
	for _, ch := range []chan *Signal{ch, ch2, ch, ch2, ch2, ch} {
		bus.Signal(ch)
	}
	signals = bus.signalHandler.(*defaultSignalHandler).signals
	if len(signals) != 6 {
		t.Errorf("remove signal: signals length not equal: got '%d', want '6'", len(signals))
	}
	bus.RemoveSignal(ch)
	signals = bus.signalHandler.(*defaultSignalHandler).signals
	if len(signals) != 3 {
		t.Errorf("remove signal: signals length not equal: got '%d', want '3'", len(signals))
	}
	signals = bus.signalHandler.(*defaultSignalHandler).signals
	for _, scd := range signals {
		if scd.ch != ch2 {
			t.Errorf("remove signal: removed signal present: got '%v', want '%v'", scd.ch, ch2)
		}
	}
}

type rwc struct {
	io.Reader
	io.Writer
}

func (rwc) Close() error { return nil }

type fakeAuth struct {
}

func (fakeAuth) FirstData() (name, resp []byte, status AuthStatus) {
	return []byte("name"), []byte("resp"), AuthOk
}

func (fakeAuth) HandleData(data []byte) (resp []byte, status AuthStatus) {
	return nil, AuthOk
}

func TestCloseBeforeSignal(t *testing.T) {
	reader, pipewriter := io.Pipe()
	defer pipewriter.Close()
	defer reader.Close()

	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
	if err != nil {
		t.Fatal(err)
	}
	// give ch a buffer so sends won't block
	ch := make(chan *Signal, 1)
	bus.Signal(ch)

	go func() {
		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
		if err != nil {
			t.Errorf("error writing to pipe: %v", err)
		}
	}()

	err = bus.Auth([]Auth{fakeAuth{}})
	if err != nil {
		t.Fatal(err)
	}

	err = bus.Close()
	if err != nil {
		t.Fatal(err)
	}

	msg := &Message{
		Type: TypeSignal,
		Headers: map[HeaderField]Variant{
			FieldInterface: MakeVariant("foo.bar"),
			FieldMember:    MakeVariant("bar"),
			FieldPath:      MakeVariant(ObjectPath("/baz")),
		},
	}
	err = msg.EncodeTo(pipewriter, binary.LittleEndian)
	if err != nil {
		t.Fatal(err)
	}
}

func TestCloseChannelAfterRemoveSignal(t *testing.T) {
	bus, err := NewConn(nil)
	if err != nil {
		t.Fatal(err)
	}

	// Add an unbuffered signal channel
	ch := make(chan *Signal)
	bus.Signal(ch)

	// Send a signal
	msg := &Message{
		Type: TypeSignal,
		Headers: map[HeaderField]Variant{
			FieldInterface: MakeVariant("foo.bar"),
			FieldMember:    MakeVariant("bar"),
			FieldPath:      MakeVariant(ObjectPath("/baz")),
		},
	}
	bus.handleSignal(Sequence(1), msg)

	// Remove and close the signal channel
	bus.RemoveSignal(ch)
	close(ch)
}

func TestAddAndRemoveMatchSignalContext(t *testing.T) {
	conn, err := ConnectSessionBus()
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	sigc := make(chan *Signal, 1)
	conn.Signal(sigc)

	ctx, cancel := context.WithCancel(context.Background())
	cancel()
	// try to subscribe to a made up signal with an already canceled context
	if err = conn.AddMatchSignalContext(
		ctx,
		WithMatchInterface("org.test"),
		WithMatchMember("CtxTest"),
	); err == nil {
		t.Fatal("call on canceled context did not fail")
	}

	// subscribe to the signal with background context
	if err = conn.AddMatchSignalContext(
		context.Background(),
		WithMatchInterface("org.test"),
		WithMatchMember("CtxTest"),
	); err != nil {
		t.Fatal(err)
	}

	// try to unsubscribe with an already canceled context
	if err = conn.RemoveMatchSignalContext(
		ctx,
		WithMatchInterface("org.test"),
		WithMatchMember("CtxTest"),
	); err == nil {
		t.Fatal("call on canceled context did not fail")
	}

	// check that signal is still delivered
	if err = conn.Emit("/", "org.test.CtxTest"); err != nil {
		t.Fatal(err)
	}
	if sig := waitSignal(sigc, "org.test.CtxTest", time.Second); sig == nil {
		t.Fatal("signal receive timed out")
	}

	// unsubscribe from the signal
	if err = conn.RemoveMatchSignalContext(
		context.Background(),
		WithMatchInterface("org.test"),
		WithMatchMember("CtxTest"),
	); err != nil {
		t.Fatal(err)
	}
	if err = conn.Emit("/", "org.test.CtxTest"); err != nil {
		t.Fatal(err)
	}
	if sig := waitSignal(sigc, "org.test.CtxTest", time.Second); sig != nil {
		t.Fatalf("unsubscribed from %q signal, but received %#v", "org.test.CtxTest", sig)
	}
}

func TestAddAndRemoveMatchSignal(t *testing.T) {
	conn, err := ConnectSessionBus()
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	sigc := make(chan *Signal, 1)
	conn.Signal(sigc)

	// subscribe to a made up signal name and emit one of the type
	if err = conn.AddMatchSignal(
		WithMatchInterface("org.test"),
		WithMatchMember("Test"),
	); err != nil {
		t.Fatal(err)
	}
	if err = conn.Emit("/", "org.test.Test"); err != nil {
		t.Fatal(err)
	}
	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig == nil {
		t.Fatal("signal receive timed out")
	}

	// unsubscribe from the signal and check that is not delivered anymore
	if err = conn.RemoveMatchSignal(
		WithMatchInterface("org.test"),
		WithMatchMember("Test"),
	); err != nil {
		t.Fatal(err)
	}
	if err = conn.Emit("/", "org.test.Test"); err != nil {
		t.Fatal(err)
	}
	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig != nil {
		t.Fatalf("unsubscribed from %q signal, but received %#v", "org.test.Test", sig)
	}
}

func waitSignal(sigc <-chan *Signal, name string, timeout time.Duration) *Signal {
	for {
		select {
		case sig := <-sigc:
			if sig.Name == name {
				return sig
			}
		case <-time.After(timeout):
			return nil
		}
	}
}

const (
	SCPPInterface         = "org.godbus.DBus.StatefulTest"
	SCPPPath              = "/org/godbus/DBus/StatefulTest"
	SCPPChangedSignalName = "Changed"
	SCPPStateMethodName   = "State"
)

func TestStateCachingProxyPattern(t *testing.T) {
	srv, err := ConnectSessionBus()
	defer srv.Close()
	if err != nil {
		t.Fatal(err)
	}

	conn, err := ConnectSessionBus(WithSignalHandler(NewSequentialSignalHandler()))
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()

	serviceName := srv.Names()[0]
	// message channel should have at least some buffering, to make sure Eavesdrop does not
	// drop the message if nobody is currently trying to read from the channel.
	messages := make(chan *Message, 1)
	srv.Eavesdrop(messages)

	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
	defer cancel()

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		if err := serverProcess(ctx, srv, messages, t); err != nil {
			t.Errorf("error in server process: %v", err)
			cancel()
		}
	}()
	go func() {
		defer wg.Done()
		if err := clientProcess(ctx, conn, serviceName, t); err != nil {
			t.Errorf("error in client process: %v", err)
		}
		// Cancel the server process.
		cancel()
	}()
	wg.Wait()
}

func clientProcess(ctx context.Context, conn *Conn, serviceName string, t *testing.T) error {
	// Subscribe to state changes on the remote object
	if err := conn.AddMatchSignal(
		WithMatchInterface(SCPPInterface),
		WithMatchMember(SCPPChangedSignalName),
	); err != nil {
		return err
	}
	channel := make(chan *Signal)
	conn.Signal(channel)
	t.Log("Subscribed to signals")

	// Simulate unfavourable OS scheduling leading to a delay between subscription
	// and querying for the current state.
	time.Sleep(30 * time.Millisecond)

	// Call .State() on the remote object to get its current state and store in observedStates[0].
	obj := conn.Object(serviceName, SCPPPath)
	observedStates := make([]uint64, 1)
	call := obj.CallWithContext(ctx, SCPPInterface+"."+SCPPStateMethodName, 0)
	if err := call.Store(&observedStates[0]); err != nil {
		return err
	}
	t.Logf("Queried current state, got %v", observedStates[0])

	// Populate observedStates[1...49] based on the state change signals,
	// ignoring signals with a sequence number less than call.ResponseSequence so that we ignore past signals.
	signalsProcessed := 0
readSignals:
	for {
		select {
		case signal := <-channel:
			signalsProcessed++
			if signal.Name == SCPPInterface+"."+SCPPChangedSignalName && signal.Sequence > call.ResponseSequence {
				observedState := signal.Body[0].(uint64)
				observedStates = append(observedStates, observedState)
				// Observing at least 50 states gives us low probability that we received a contiguous subsequence of states 'by accident'
				if len(observedStates) >= 50 {
					break readSignals
				}
			}
		case <-ctx.Done():
			t.Logf("Context cancelled, client processed %v signals", signalsProcessed)
			return ctx.Err()
		}
	}
	t.Logf("client processed %v signals", signalsProcessed)

	// Expect that we begun observing at least a few states in. This ensures the server was already emitting signals
	// and makes it likely we simulated our race condition.
	if observedStates[0] < 10 {
		return fmt.Errorf("expected first state to be at least 10, got %v", observedStates[0])
	}

	t.Logf("Observed states: %v", observedStates)

	// The observable states of the remote object were [1 ... (infinity)] during this test.
	// This loop is intended to assert that our observed states are a contiguous subgrange [n ... n+49] for some n, i.e.
	// that we received a contiguous subsequence of the states of the remote object. For each run of the test, n
	// may be slightly different due to scheduling effects.
	for i := 0; i < len(observedStates); i++ {
		expectedState := observedStates[0] + uint64(i)
		if observedStates[i] != expectedState {
			return fmt.Errorf("expected observed state %v to be %v, got %v", i, expectedState, observedStates[i])
		}
	}
	return nil
}

func serverProcess(ctx context.Context, srv *Conn, messages <-chan *Message, t *testing.T) error {
	state := uint64(0)

process:
	for {
		select {
		case msg, ok := <-messages:
			if !ok {
				t.Log("Message channel closed")
				// Message channel closed.
				break process
			}
			if msg.IsValid() != nil {
				t.Log("Got invalid message, discarding")
				continue process
			}
			name := msg.Headers[FieldMember].value.(string)
			ifname := msg.Headers[FieldInterface].value.(string)
			if ifname == SCPPInterface && name == SCPPStateMethodName {
				t.Logf("Processing reply to .State(), returning state = %v", state)
				reply := new(Message)
				reply.Type = TypeMethodReply
				reply.Headers = make(map[HeaderField]Variant)
				reply.Headers[FieldDestination] = msg.Headers[FieldSender]
				reply.Headers[FieldReplySerial] = MakeVariant(msg.serial)
				reply.Body = make([]interface{}, 1)
				reply.Body[0] = state
				reply.Headers[FieldSignature] = MakeVariant(SignatureOf(reply.Body...))
				srv.sendMessageAndIfClosed(reply, nil)
			}
		case <-ctx.Done():
			t.Logf("Context cancelled, server emitted %v signals", state)
			return nil
		default:
			state++
			if err := srv.Emit(SCPPPath, SCPPInterface+"."+SCPPChangedSignalName, state); err != nil {
				return err
			}
		}
	}
	return nil
}

type server struct{}

func (server) Double(i int64) (int64, *Error) {
	return 2 * i, nil
}

func BenchmarkCall(b *testing.B) {
	b.StopTimer()
	b.ReportAllocs()
	var s string
	bus, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer bus.Close()

	name := bus.Names()[0]
	obj := bus.BusObject()
	b.StartTimer()
	for i := 0; i < b.N; i++ {
		err := obj.Call("org.freedesktop.DBus.GetNameOwner", 0, name).Store(&s)
		if err != nil {
			b.Fatal(err)
		}
		if s != name {
			b.Errorf("got %s, wanted %s", s, name)
		}
	}
}

func BenchmarkCallAsync(b *testing.B) {
	b.StopTimer()
	b.ReportAllocs()
	bus, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer bus.Close()

	name := bus.Names()[0]
	obj := bus.BusObject()
	c := make(chan *Call, 50)
	done := make(chan struct{})
	go func() {
		for i := 0; i < b.N; i++ {
			v := <-c
			if v.Err != nil {
				b.Error(v.Err)
			}
			s := v.Body[0].(string)
			if s != name {
				b.Errorf("got %s, wanted %s", s, name)
			}
		}
		close(done)
	}()
	b.StartTimer()
	for i := 0; i < b.N; i++ {
		obj.Go("org.freedesktop.DBus.GetNameOwner", 0, c, name)
	}
	<-done
}

func BenchmarkServe(b *testing.B) {
	b.StopTimer()
	srv, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer srv.Close()

	cli, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer cli.Close()

	benchmarkServe(b, srv, cli)
}

func BenchmarkServeAsync(b *testing.B) {
	b.StopTimer()
	srv, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer srv.Close()

	cli, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer cli.Close()

	benchmarkServeAsync(b, srv, cli)
}

func BenchmarkServeSameConn(b *testing.B) {
	b.StopTimer()
	bus, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer bus.Close()

	benchmarkServe(b, bus, bus)
}

func BenchmarkServeSameConnAsync(b *testing.B) {
	b.StopTimer()
	bus, err := ConnectSessionBus()
	if err != nil {
		b.Fatal(err)
	}
	defer bus.Close()

	benchmarkServeAsync(b, bus, bus)
}

func benchmarkServe(b *testing.B, srv, cli *Conn) {
	var r int64
	var err error
	dest := srv.Names()[0]
	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
	b.StartTimer()
	for i := 0; i < b.N; i++ {
		err = obj.Call("org.guelfey.DBus.Test.Double", 0, int64(i)).Store(&r)
		if err != nil {
			b.Fatal(err)
		}
		if r != 2*int64(i) {
			b.Errorf("got %d, wanted %d", r, 2*int64(i))
		}
	}
}

func benchmarkServeAsync(b *testing.B, srv, cli *Conn) {
	dest := srv.Names()[0]
	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
	c := make(chan *Call, 50)
	done := make(chan struct{})
	go func() {
		for i := 0; i < b.N; i++ {
			v := <-c
			if v.Err != nil {
				b.Fatal(v.Err)
			}
			i, r := v.Args[0].(int64), v.Body[0].(int64)
			if 2*i != r {
				b.Errorf("got %d, wanted %d", r, 2*i)
			}
		}
		close(done)
	}()
	b.StartTimer()
	for i := 0; i < b.N; i++ {
		obj.Go("org.guelfey.DBus.Test.Double", 0, c, int64(i))
	}
	<-done
}

func TestGetKey(t *testing.T) {
	keys := "host=1.2.3.4,port=5678,family=ipv4"
	if host := getKey(keys, "host"); host != "1.2.3.4" {
		t.Error(`Expected "1.2.3.4", got`, host)
	}
	if port := getKey(keys, "port"); port != "5678" {
		t.Error(`Expected "5678", got`, port)
	}
	if family := getKey(keys, "family"); family != "ipv4" {
		t.Error(`Expected "ipv4", got`, family)
	}
}

func TestInterceptors(t *testing.T) {
	conn, err := ConnectSessionBus(
		WithIncomingInterceptor(func(msg *Message) {
			log.Println("<", msg)
		}),
		WithOutgoingInterceptor(func(msg *Message) {
			log.Println(">", msg)
		}),
	)
	if err != nil {
		t.Fatal(err)
	}
	defer conn.Close()
}

func TestCloseCancelsConnectionContext(t *testing.T) {
	bus, err := ConnectSessionBus()
	if err != nil {
		t.Fatal(err)
	}
	defer bus.Close()

	// The context is not done at this point
	ctx := bus.Context()
	select {
	case <-ctx.Done():
		t.Fatal("context should not be done")
	default:
	}

	err = bus.Close()
	if err != nil {
		t.Fatal(err)
	}
	select {
	case <-ctx.Done():
		// expected
	case <-time.After(5 * time.Second):
		t.Fatal("context is not done after connection closed")
	}
}

func TestDisconnectCancelsConnectionContext(t *testing.T) {
	reader, pipewriter := io.Pipe()
	defer pipewriter.Close()
	defer reader.Close()

	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
	if err != nil {
		t.Fatal(err)
	}

	go func() {
		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
		if err != nil {
			t.Errorf("error writing to pipe: %v", err)
		}
	}()
	err = bus.Auth([]Auth{fakeAuth{}})
	if err != nil {
		t.Fatal(err)
	}

	ctx := bus.Context()

	pipewriter.Close()
	select {
	case <-ctx.Done():
		// expected
	case <-time.After(5 * time.Second):
		t.Fatal("context is not done after connection closed")
	}
}

func TestCancellingContextClosesConnection(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	reader, pipewriter := io.Pipe()
	defer pipewriter.Close()
	defer reader.Close()

	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard}, WithContext(ctx))
	if err != nil {
		t.Fatal(err)
	}

	go func() {
		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
		if err != nil {
			t.Errorf("error writing to pipe: %v", err)
		}
	}()
	err = bus.Auth([]Auth{fakeAuth{}})
	if err != nil {
		t.Fatal(err)
	}

	// Cancel the connection's parent context and give time for
	// other goroutines to schedule.
	cancel()
	time.Sleep(50 * time.Millisecond)

	err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store()
	if err != ErrClosed {
		t.Errorf("expected connection to be closed, but got: %v", err)
	}
}

// TestTimeoutContextClosesConnection checks that a Conn instance is closed after
// the passed context's deadline is missed.
// The test also checks that there's no data race between Conn creation and its
// automatic closing.
func TestTimeoutContextClosesConnection(t *testing.T) {
	ctx, cancel := context.WithTimeout(context.Background(), 0)
	defer cancel()

	bus, err := NewConn(rwc{}, WithContext(ctx))
	if err != nil {
		t.Fatal(err)
	}

	// wait until the connection is actually closed
	time.Sleep(50 * time.Millisecond)

	err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store()
	if err != ErrClosed {
		t.Errorf("expected connection to be closed, but got: %v", err)
	}
}
