package storage

import (
	"fmt"
	"strings"

	"github.com/containers/storage/pkg/idtools"
	"github.com/containers/storage/types"
	"github.com/google/go-intervals/intervalset"
)

// idSet represents a set of integer IDs. It is stored as an ordered set of intervals.
type idSet struct {
	set *intervalset.ImmutableSet
}

func newIDSet(intervals []interval) *idSet {
	s := intervalset.Empty()
	for _, i := range intervals {
		s.Add(intervalset.NewSet([]intervalset.Interval{i}))
	}
	return &idSet{set: s.ImmutableSet()}
}

// getHostIDs returns all the host ids in the id map.
func getHostIDs(idMaps []idtools.IDMap) *idSet {
	var intervals []interval
	for _, m := range idMaps {
		intervals = append(intervals, interval{start: m.HostID, end: m.HostID + m.Size})
	}
	return newIDSet(intervals)
}

// getContainerIDs returns all the container ids in the id map.
func getContainerIDs(idMaps []idtools.IDMap) *idSet {
	var intervals []interval
	for _, m := range idMaps {
		intervals = append(intervals, interval{start: m.ContainerID, end: m.ContainerID + m.Size})
	}
	return newIDSet(intervals)
}

// subtract returns the subtraction of `s` and `t`. `s` and `t` are unchanged.
func (s *idSet) subtract(t *idSet) *idSet {
	if s == nil || t == nil {
		return s
	}
	return &idSet{set: s.set.Sub(t.set)}
}

// union returns the union of `s` and `t`. `s` and `t` are unchanged.
func (s *idSet) union(t *idSet) *idSet {
	if s == nil {
		return t
	} else if t == nil {
		return s
	}
	return &idSet{set: s.set.Union(t.set)}
}

// Methods to iterate over the intervals of the idSet. intervalset doesn't provide one :-(

// iterator to idSet. Returns nil if iteration finishes.
type iteratorFn func() *interval

// cancelFn must be called exactly once unless iteratorFn returns nil, otherwise go routine might
// leak.
type cancelFn func()

func (s *idSet) iterator() (iteratorFn, cancelFn) {
	if s == nil {
		return func() *interval { return nil }, func() {}
	}
	cancelCh := make(chan byte)
	dataCh := make(chan interval)
	go func() {
		s.set.Intervals(func(ii intervalset.Interval) bool {
			select {
			case <-cancelCh:
				return false
			case dataCh <- ii.(interval):
				return true
			}
		})
		close(dataCh)
	}()
	iterator := func() *interval {
		i, ok := <-dataCh
		if !ok {
			return nil
		}
		return &i
	}
	return iterator, func() { close(cancelCh) }
}

// size returns the total number of ids in the ID set.
func (s *idSet) size() int {
	var size int
	iterator, cancel := s.iterator()
	defer cancel()
	for i := iterator(); i != nil; i = iterator() {
		size += i.length()
	}
	return size
}

// findAvailable finds the `n` ids from `s`.
func (s *idSet) findAvailable(n int) (*idSet, error) {
	var intervals []intervalset.Interval
	iterator, cancel := s.iterator()
	defer cancel()
	for i := iterator(); n > 0 && i != nil; i = iterator() {
		i.end = min(i.end, i.start+n)
		intervals = append(intervals, *i)
		n -= i.length()
	}
	if n > 0 {
		return nil, types.ErrNoAvailableIDs
	}
	return &idSet{set: intervalset.NewImmutableSet(intervals)}, nil
}

// zip creates an id map from `s` (host ids) and container ids.
func (s *idSet) zip(container *idSet) []idtools.IDMap {
	hostIterator, hostCancel := s.iterator()
	defer hostCancel()
	containerIterator, containerCancel := container.iterator()
	defer containerCancel()
	var out []idtools.IDMap
	for h, c := hostIterator(), containerIterator(); h != nil && c != nil; {
		if n := min(h.length(), c.length()); n > 0 {
			out = append(out, idtools.IDMap{
				ContainerID: c.start,
				HostID:      h.start,
				Size:        n,
			})
			h.start += n
			c.start += n
		}
		if h.IsZero() {
			h = hostIterator()
		}
		if c.IsZero() {
			c = containerIterator()
		}
	}
	return out
}

// interval represents an interval of integers [start, end). Note it is allowed to have
// start >= end, in which case it is treated as an empty interval. It implements interface
// intervalset.Interval.
type interval struct {
	// Start of the interval (inclusive).
	start int
	// End of the interval (exclusive).
	end int
}

func (i interval) length() int {
	return max(0, i.end-i.start)
}

func (i interval) Intersect(other intervalset.Interval) intervalset.Interval {
	j := other.(interval)
	return interval{start: max(i.start, j.start), end: min(i.end, j.end)}
}

func (i interval) Before(other intervalset.Interval) bool {
	j := other.(interval)
	return !i.IsZero() && !j.IsZero() && i.end < j.start
}

func (i interval) IsZero() bool {
	return i.length() <= 0
}

func (i interval) Bisect(other intervalset.Interval) (intervalset.Interval, intervalset.Interval) {
	j := other.(interval)
	if j.IsZero() {
		return i, interval{}
	}
	// Subtracting [j.start, j.end) is equivalent to the union of intersecting (-inf, j.start) and
	// [j.end, +inf).
	left := interval{start: i.start, end: min(i.end, j.start)}
	right := interval{start: max(i.start, j.end), end: i.end}
	return left, right
}

func (i interval) Adjoin(other intervalset.Interval) intervalset.Interval {
	j := other.(interval)
	if !i.IsZero() && !j.IsZero() && (i.end == j.start || j.end == i.start) {
		return interval{start: min(i.start, j.start), end: max(i.end, j.end)}
	}
	return interval{}
}

func (i interval) Encompass(other intervalset.Interval) intervalset.Interval {
	j := other.(interval)
	switch {
	case i.IsZero():
		return j
	case j.IsZero():
		return i
	default:
		return interval{start: min(i.start, j.start), end: max(i.end, j.end)}
	}
}

func hasOverlappingRanges(mappings []idtools.IDMap) error {
	hostIntervals := intervalset.Empty()
	containerIntervals := intervalset.Empty()

	var conflicts []string

	for _, m := range mappings {
		c := interval{start: m.ContainerID, end: m.ContainerID + m.Size}
		h := interval{start: m.HostID, end: m.HostID + m.Size}

		added := false
		overlaps := false

		containerIntervals.IntervalsBetween(c, func(x intervalset.Interval) bool {
			overlaps = true
			return false
		})
		if overlaps {
			conflicts = append(conflicts, fmt.Sprintf("%v:%v:%v", m.ContainerID, m.HostID, m.Size))
			added = true
		}
		containerIntervals.Add(intervalset.NewSet([]intervalset.Interval{c}))

		hostIntervals.IntervalsBetween(h, func(x intervalset.Interval) bool {
			overlaps = true
			return false
		})
		if overlaps && !added {
			conflicts = append(conflicts, fmt.Sprintf("%v:%v:%v", m.ContainerID, m.HostID, m.Size))
		}
		hostIntervals.Add(intervalset.NewSet([]intervalset.Interval{h}))
	}

	if conflicts != nil {
		if len(conflicts) == 1 {
			return fmt.Errorf("the specified UID and/or GID mapping %s conflicts with other mappings: %w", conflicts[0], ErrInvalidMappings)
		}
		return fmt.Errorf("the specified UID and/or GID mappings %s conflict with other mappings: %w", strings.Join(conflicts, ", "), ErrInvalidMappings)
	}
	return nil
}
