// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package vnet

import (
	"math"
	"sync"
	"time"

	"github.com/pion/logging"
)

const (
	// Bit is a single bit.
	Bit = 1
	// KBit is a kilobit.
	KBit = 1000 * Bit
	// MBit is a Megabit.
	MBit = 1000 * KBit
)

// TokenBucketFilter implements a token bucket rate limit algorithm.
type TokenBucketFilter struct {
	NIC
	currentTokensInBucket float64
	c                     chan Chunk
	queue                 *chunkQueue
	queueSize             int // in bytes

	mutex             sync.Mutex
	rate              int
	maxBurst          int
	minRefillDuration time.Duration

	wg   sync.WaitGroup
	done chan struct{}

	log logging.LeveledLogger
}

// TBFOption is the option type to configure a TokenBucketFilter.
type TBFOption func(*TokenBucketFilter) TBFOption

// TBFQueueSizeInBytes sets the max number of bytes waiting in the queue. Can
// only be set in constructor before using the TBF.
func TBFQueueSizeInBytes(bytes int) TBFOption {
	return func(t *TokenBucketFilter) TBFOption {
		prev := t.queueSize
		t.queueSize = bytes

		return TBFQueueSizeInBytes(prev)
	}
}

// TBFRate sets the bit rate of a TokenBucketFilter.
func TBFRate(rate int) TBFOption {
	return func(t *TokenBucketFilter) TBFOption {
		t.mutex.Lock()
		defer t.mutex.Unlock()
		previous := t.rate
		t.rate = rate

		return TBFRate(previous)
	}
}

// TBFMaxBurst sets the bucket size of the token bucket filter. This is the
// maximum size that can instantly leave the filter, if the bucket is full.
func TBFMaxBurst(size int) TBFOption {
	return func(t *TokenBucketFilter) TBFOption {
		t.mutex.Lock()
		defer t.mutex.Unlock()
		previous := t.maxBurst
		t.maxBurst = size

		return TBFMaxBurst(previous)
	}
}

// Set updates a setting on the token bucket filter.
func (t *TokenBucketFilter) Set(opts ...TBFOption) (previous TBFOption) {
	for _, opt := range opts {
		previous = opt(t)
	}

	return previous
}

// NewTokenBucketFilter creates and starts a new TokenBucketFilter.
func NewTokenBucketFilter(n NIC, opts ...TBFOption) (*TokenBucketFilter, error) {
	tbf := &TokenBucketFilter{
		NIC:                   n,
		currentTokensInBucket: 0,
		c:                     make(chan Chunk),
		queue:                 nil,
		queueSize:             50000,
		mutex:                 sync.Mutex{},
		rate:                  1 * MBit,
		maxBurst:              8 * KBit,
		minRefillDuration:     100 * time.Millisecond,
		wg:                    sync.WaitGroup{},
		done:                  make(chan struct{}),
		log:                   logging.NewDefaultLoggerFactory().NewLogger("tbf"),
	}
	tbf.Set(opts...)
	tbf.queue = newChunkQueue(0, tbf.queueSize)
	tbf.wg.Add(1)
	go tbf.run()

	return tbf, nil
}

func (t *TokenBucketFilter) onInboundChunk(c Chunk) {
	select {
	case t.c <- c:
	case <-t.done:
	}
}

func (t *TokenBucketFilter) run() {
	defer t.wg.Done()

	t.refillTokens(t.minRefillDuration)
	lastRefill := time.Now()

	for {
		select {
		case <-t.done:
			t.drainQueue()

			return
		case chunk := <-t.c:
			if time.Since(lastRefill) > t.minRefillDuration {
				t.refillTokens(time.Since(lastRefill))
				lastRefill = time.Now()
			}
			t.queue.push(chunk)
			t.drainQueue()
		}
	}
}

func (t *TokenBucketFilter) refillTokens(dt time.Duration) {
	t.mutex.Lock()
	defer t.mutex.Unlock()
	m := 1000.0 / float64(dt.Milliseconds())
	add := (float64(t.rate) / m) / 8.0
	t.currentTokensInBucket = math.Min(float64(t.maxBurst), t.currentTokensInBucket+add)
	t.log.Tracef(
		"add=(%v / %v) / 8 = %v, currentTokensInBucket=%v, maxBurst=%v",
		t.rate,
		m,
		add,
		t.currentTokensInBucket,
		t.maxBurst,
	)
}

func (t *TokenBucketFilter) drainQueue() {
	for {
		next := t.queue.peek()
		if next == nil {
			break
		}
		tokens := float64(len(next.UserData()))
		if t.currentTokensInBucket < tokens {
			t.log.Tracef("currentTokensInBucket=%v, tokens=%v, stop drain", t.currentTokensInBucket, tokens)

			break
		}
		t.log.Tracef("currentTokensInBucket=%v, tokens=%v, pop chunk", t.currentTokensInBucket, tokens)
		t.queue.pop()
		t.NIC.onInboundChunk(next)
		t.currentTokensInBucket -= tokens
	}
}

// Close closes and stops the token bucket filter queue.
func (t *TokenBucketFilter) Close() error {
	close(t.done)
	t.wg.Wait()

	return nil
}
