package kong

import (
	"errors"
	"fmt"
	"os"
	"reflect"
	"sort"
	"strconv"
	"strings"
)

// Path records the nodes and parsed values from the current command-line.
type Path struct {
	Parent *Node

	// One of these will be non-nil.
	App        *Application
	Positional *Positional
	Flag       *Flag
	Argument   *Argument
	Command    *Command

	// Flags added by this node.
	Flags []*Flag

	// True if this Path element was created as the result of a resolver.
	Resolved bool

	// Remaining tokens after this node
	remainder []Token
}

// Node returns the Node associated with this Path, or nil if Path is a non-Node.
func (p *Path) Node() *Node {
	switch {
	case p.App != nil:
		return p.App.Node

	case p.Argument != nil:
		return p.Argument

	case p.Command != nil:
		return p.Command
	}
	return nil
}

// Visitable returns the Visitable for this path element.
func (p *Path) Visitable() Visitable {
	switch {
	case p.App != nil:
		return p.App

	case p.Argument != nil:
		return p.Argument

	case p.Command != nil:
		return p.Command

	case p.Flag != nil:
		return p.Flag

	case p.Positional != nil:
		return p.Positional
	}
	return nil
}

// Remainder returns the remaining unparsed args after this Path element.
func (p *Path) Remainder() []string {
	args := []string{}
	for _, token := range p.remainder {
		args = append(args, token.String())
	}
	return args
}

// Context contains the current parse context.
type Context struct {
	*Kong
	// A trace through parsed nodes.
	Path []*Path
	// Original command-line arguments.
	Args []string
	// Error that occurred during trace, if any.
	Error error

	values    map[*Value]reflect.Value // Temporary values during tracing.
	bindings  bindings
	resolvers []Resolver // Extra context-specific resolvers.
	scan      *Scanner
}

// Trace path of "args" through the grammar tree.
//
// The returned Context will include a Path of all commands, arguments, positionals and flags.
//
// This just constructs a new trace. To fully apply the trace you must call Reset(), Resolve(),
// Validate() and Apply().
func Trace(k *Kong, args []string) (*Context, error) {
	s := Scan(args...).AllowHyphenPrefixedParameters(k.allowHyphenated)
	c := &Context{
		Kong: k,
		Args: args,
		Path: []*Path{
			{App: k.Model, Flags: k.Model.Flags, remainder: s.PeekAll()},
		},
		values:   map[*Value]reflect.Value{},
		scan:     s,
		bindings: bindings{},
	}
	c.Error = c.trace(c.Model.Node)
	return c, nil
}

// Bind adds bindings to the Context.
func (c *Context) Bind(args ...any) {
	c.bindings.add(args...)
}

// BindTo adds a binding to the Context.
//
// This will typically have to be called like so:
//
//	BindTo(impl, (*MyInterface)(nil))
func (c *Context) BindTo(impl, iface any) {
	c.bindings.addTo(impl, iface)
}

// BindToProvider allows binding of provider functions.
//
// This is useful when the Run() function of different commands require different values that may
// not all be initialisable from the main() function.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
// where ... will be recursively injected with bound values.
func (c *Context) BindToProvider(provider any) error {
	return c.bindings.addProvider(provider, false /* singleton */)
}

// BindSingletonProvider allows binding of provider functions.
// The provider will be called once and the result cached.
//
// "provider" must be a function with the signature func(...) (T, error) or func(...) T,
// where ... will be recursively injected with bound values.
func (c *Context) BindSingletonProvider(provider any) error {
	return c.bindings.addProvider(provider, true /* singleton */)
}

// Value returns the value for a particular path element.
func (c *Context) Value(path *Path) reflect.Value {
	switch {
	case path.Positional != nil:
		return c.values[path.Positional]
	case path.Flag != nil:
		return c.values[path.Flag.Value]
	case path.Argument != nil:
		return c.values[path.Argument.Argument]
	}
	panic("can only retrieve value for flag, argument or positional")
}

// Selected command or argument.
func (c *Context) Selected() *Node {
	var selected *Node
	for _, path := range c.Path {
		switch {
		case path.Command != nil:
			selected = path.Command
		case path.Argument != nil:
			selected = path.Argument
		}
	}
	return selected
}

// Empty returns true if there were no arguments provided.
func (c *Context) Empty() bool {
	for _, path := range c.Path {
		if !path.Resolved && path.App == nil {
			return false
		}
	}
	return true
}

// Validate the current context.
func (c *Context) Validate() error { //nolint: gocyclo
	err := Visit(c.Model, func(node Visitable, next Next) error {
		switch node := node.(type) {
		case *Value:
			ok := atLeastOneEnvSet(node.Tag.Envs)
			if node.Enum != "" && (!node.Required || node.HasDefault || (len(node.Tag.Envs) != 0 && ok)) {
				if err := checkEnum(node, node.Target); err != nil {
					return err
				}
			}

		case *Flag:
			ok := atLeastOneEnvSet(node.Tag.Envs)
			if node.Enum != "" && (!node.Required || node.HasDefault || (len(node.Tag.Envs) != 0 && ok)) {
				if err := checkEnum(node.Value, node.Target); err != nil {
					return err
				}
			}
		}
		return next(nil)
	})
	if err != nil {
		return err
	}
	for _, el := range c.Path {
		var (
			value reflect.Value
			desc  string
		)
		switch node := el.Visitable().(type) {
		case *Value:
			value = node.Target
			desc = node.ShortSummary()

		case *Flag:
			value = node.Target
			desc = node.ShortSummary()

		case *Application:
			value = node.Target
			desc = ""

		case *Node:
			value = node.Target
			desc = node.Path()
		}
		if validate := isValidatable(value); validate != nil {
			if err := validate.Validate(c); err != nil {
				if desc != "" {
					return fmt.Errorf("%s: %w", desc, err)
				}
				return err
			}
		}
	}
	for _, resolver := range c.combineResolvers() {
		if err := resolver.Validate(c.Model); err != nil {
			return err
		}
	}
	for _, path := range c.Path {
		var value *Value
		switch {
		case path.Flag != nil:
			value = path.Flag.Value

		case path.Positional != nil:
			value = path.Positional
		}
		if value != nil && value.Tag.Enum != "" {
			if err := checkEnum(value, value.Target); err != nil {
				return err
			}
		}
		if err := checkMissingFlags(path.Flags); err != nil {
			return err
		}
	}
	// Check the terminal node.
	node := c.Selected()
	if node == nil {
		node = c.Model.Node
	}

	// Find deepest positional argument so we can check if all required positionals have been provided.
	positionals := 0
	for _, path := range c.Path {
		if path.Positional != nil {
			positionals = path.Positional.Position + 1
		}
	}

	if err := checkMissingChildren(node); err != nil {
		return err
	}
	if err := checkMissingPositionals(positionals, node.Positional); err != nil {
		return err
	}
	if err := checkXorDuplicatedAndAndMissing(c.Path); err != nil {
		return err
	}

	if node.Type == ArgumentNode {
		value := node.Argument
		if value.Required && !value.Set {
			return fmt.Errorf("%s is required", node.Summary())
		}
	}
	return nil
}

// Flags returns the accumulated available flags.
func (c *Context) Flags() (flags []*Flag) {
	for _, trace := range c.Path {
		flags = append(flags, trace.Flags...)
	}
	return
}

// Command returns the full command path.
func (c *Context) Command() string {
	command := []string{}
	for _, trace := range c.Path {
		switch {
		case trace.Positional != nil:
			command = append(command, "<"+trace.Positional.Name+">")

		case trace.Argument != nil:
			command = append(command, "<"+trace.Argument.Name+">")

		case trace.Command != nil:
			command = append(command, trace.Command.Name)
		}
	}
	return strings.Join(command, " ")
}

// AddResolver adds a context-specific resolver.
//
// This is most useful in the BeforeResolve() hook.
func (c *Context) AddResolver(resolver Resolver) {
	c.resolvers = append(c.resolvers, resolver)
}

// FlagValue returns the set value of a flag if it was encountered and exists, or its default value.
func (c *Context) FlagValue(flag *Flag) any {
	for _, trace := range c.Path {
		if trace.Flag == flag {
			v, ok := c.values[trace.Flag.Value]
			if !ok {
				break
			}
			return v.Interface()
		}
	}
	if flag.Target.IsValid() {
		return flag.Target.Interface()
	}
	return flag.DefaultValue.Interface()
}

// Reset recursively resets values to defaults (as specified in the grammar) or the zero value.
func (c *Context) Reset() error {
	return Visit(c.Model.Node, func(node Visitable, next Next) error {
		if value, ok := node.(*Value); ok {
			return next(value.Reset())
		}
		return next(nil)
	})
}

func (c *Context) endParsing() {
	args := []string{}
	for {
		token := c.scan.Pop()
		if token.Type == EOLToken {
			break
		}
		args = append(args, token.String())
	}
	// Note: tokens must be pushed in reverse order.
	for i := range args {
		c.scan.PushTyped(args[len(args)-1-i], PositionalArgumentToken)
	}
}

//nolint:maintidx
func (c *Context) trace(node *Node) (err error) { //nolint: gocyclo
	positional := 0
	node.Active = true

	flags := []*Flag{}
	flagNode := node
	if node.DefaultCmd != nil && node.DefaultCmd.Tag.Default == "withargs" {
		// Add flags of the default command if the current node has one
		// and that default command allows args / flags without explicitly
		// naming the command on the CLI.
		flagNode = node.DefaultCmd
	}
	for _, group := range flagNode.AllFlags(false) {
		flags = append(flags, group...)
	}

	if node.Passthrough {
		c.endParsing()
	}

	for !c.scan.Peek().IsEOL() {
		token := c.scan.Peek()
		switch token.Type {
		case UntypedToken:
			switch v := token.Value.(type) {
			case string:

				switch {
				case v == "-":
					fallthrough
				default: //nolint
					c.scan.Pop()
					c.scan.PushTyped(token.Value, PositionalArgumentToken)

				// Indicates end of parsing. All remaining arguments are treated as positional arguments only.
				case v == "--":
					c.endParsing()

					// Pop the -- token unless the next positional argument accepts passthrough arguments.
					if !(positional < len(node.Positional) && node.Positional[positional].Passthrough) {
						c.scan.Pop()
					}

				// Long flag.
				case strings.HasPrefix(v, "--"):
					c.scan.Pop()
					// Parse it and push the tokens.
					parts := strings.SplitN(v[2:], "=", 2)
					if len(parts) > 1 {
						c.scan.PushTyped(parts[1], FlagValueToken)
					}
					c.scan.PushTyped(parts[0], FlagToken)

				// Short flag.
				case strings.HasPrefix(v, "-"):
					c.scan.Pop()
					// Note: tokens must be pushed in reverse order.
					if tail := v[2:]; tail != "" {
						c.scan.PushTyped(tail, ShortFlagTailToken)
					}
					c.scan.PushTyped(v[1:2], ShortFlagToken)
				}
			default:
				c.scan.Pop()
				c.scan.PushTyped(token.Value, PositionalArgumentToken)
			}

		case ShortFlagTailToken:
			c.scan.Pop()
			// Note: tokens must be pushed in reverse order.
			if tail := token.String()[1:]; tail != "" {
				c.scan.PushTyped(tail, ShortFlagTailToken)
			}
			c.scan.PushTyped(token.String()[0:1], ShortFlagToken)

		case FlagToken:
			if err := c.parseFlag(flags, token.String()); err != nil {
				if isUnknownFlagError(err) && positional < len(node.Positional) && node.Positional[positional].PassthroughMode == PassThroughModeAll {
					c.scan.Pop()
					c.scan.PushTyped(token.String(), PositionalArgumentToken)
				} else {
					return err
				}
			}

		case ShortFlagToken:
			if err := c.parseFlag(flags, token.String()); err != nil {
				if isUnknownFlagError(err) && positional < len(node.Positional) && node.Positional[positional].PassthroughMode == PassThroughModeAll {
					c.scan.Pop()
					c.scan.PushTyped(token.String(), PositionalArgumentToken)
				} else {
					return err
				}
			}

		case FlagValueToken:
			return fmt.Errorf("unexpected flag argument %q", token.Value)

		case PositionalArgumentToken:
			candidates := []string{}

			// Ensure we've consumed all positional arguments.
			if positional < len(node.Positional) {
				arg := node.Positional[positional]

				if arg.Passthrough {
					c.endParsing()
				}

				arg.Active = true
				err := arg.Parse(c.scan, c.getValue(arg))
				if err != nil {
					return err
				}
				c.Path = append(c.Path, &Path{
					Parent:     node,
					Positional: arg,
					remainder:  c.scan.PeekAll(),
				})
				positional++
				break
			}

			// Assign token value to a branch name if tagged as an alias
			// An alias will be ignored in the case of an existing command
			cmds := make(map[string]bool)
			for _, branch := range node.Children {
				if branch.Type == CommandNode {
					cmds[branch.Name] = true
				}
			}
			for _, branch := range node.Children {
				for _, a := range branch.Aliases {
					_, ok := cmds[a]
					if token.Value == a && !ok {
						token.Value = branch.Name
						break
					}
				}
			}

			// After positional arguments have been consumed, check commands next...
			for _, branch := range node.Children {
				if branch.Type == CommandNode && !branch.Hidden {
					candidates = append(candidates, branch.Name)
				}
				if branch.Type == CommandNode && branch.Name == token.Value {
					c.scan.Pop()
					c.Path = append(c.Path, &Path{
						Parent:    node,
						Command:   branch,
						Flags:     branch.Flags,
						remainder: c.scan.PeekAll(),
					})
					return c.trace(branch)
				}
			}

			// Finally, check arguments.
			for _, branch := range node.Children {
				if branch.Type == ArgumentNode {
					arg := branch.Argument
					if err := arg.Parse(c.scan, c.getValue(arg)); err == nil {
						c.Path = append(c.Path, &Path{
							Parent:    node,
							Argument:  branch,
							Flags:     branch.Flags,
							remainder: c.scan.PeekAll(),
						})
						return c.trace(branch)
					}
				}
			}

			// If there is a default command that allows args and nothing else
			// matches, take the branch of the default command
			if node.DefaultCmd != nil && node.DefaultCmd.Tag.Default == "withargs" {
				c.Path = append(c.Path, &Path{
					Parent:    node,
					Command:   node.DefaultCmd,
					Flags:     node.DefaultCmd.Flags,
					remainder: c.scan.PeekAll(),
				})
				return c.trace(node.DefaultCmd)
			}

			return findPotentialCandidates(token.String(), candidates, "unexpected argument %s", token)
		default:
			return fmt.Errorf("unexpected token %s", token)
		}
	}
	return c.maybeSelectDefault(flags, node)
}

// IgnoreDefault can be implemented by flags that want to be applied before any default commands.
type IgnoreDefault interface {
	IgnoreDefault()
}

// End of the line, check for a default command, but only if we're not displaying help,
// otherwise we'd only ever display the help for the default command.
func (c *Context) maybeSelectDefault(flags []*Flag, node *Node) error {
	for _, flag := range flags {
		if _, ok := flag.Target.Interface().(IgnoreDefault); ok && flag.Set {
			return nil
		}
	}
	if node.DefaultCmd != nil {
		c.Path = append(c.Path, &Path{
			Parent:    node.DefaultCmd,
			Command:   node.DefaultCmd,
			Flags:     node.DefaultCmd.Flags,
			remainder: c.scan.PeekAll(),
		})
	}
	return nil
}

// Resolve walks through the traced path, applying resolvers to any unset flags.
func (c *Context) Resolve() error {
	resolvers := c.combineResolvers()
	if len(resolvers) == 0 {
		return nil
	}

	inserted := []*Path{}
	for _, path := range c.Path {
		for _, flag := range path.Flags {
			// Flag has already been set on the command-line.
			if _, ok := c.values[flag.Value]; ok {
				continue
			}

			// Pick the last resolved value.
			var selected any
			for _, resolver := range resolvers {
				s, err := resolver.Resolve(c, path, flag)
				if err != nil {
					return fmt.Errorf("%s: %w", flag.ShortSummary(), err)
				}
				if s == nil {
					continue
				}
				selected = s
			}

			if selected == nil {
				continue
			}

			scan := Scan().PushTyped(selected, FlagValueToken)
			delete(c.values, flag.Value)
			err := flag.Parse(scan, c.getValue(flag.Value))
			if err != nil {
				return err
			}
			inserted = append(inserted, &Path{
				Flag:      flag,
				Resolved:  true,
				remainder: c.scan.PeekAll(),
			})
		}
	}
	c.Path = append(c.Path, inserted...)
	return nil
}

// Combine application-level resolvers and context resolvers.
func (c *Context) combineResolvers() []Resolver {
	resolvers := []Resolver{}
	resolvers = append(resolvers, c.Kong.resolvers...)
	resolvers = append(resolvers, c.resolvers...)
	return resolvers
}

func (c *Context) getValue(value *Value) reflect.Value {
	v, ok := c.values[value]
	if !ok {
		v = reflect.New(value.Target.Type()).Elem()
		switch v.Kind() {
		case reflect.Ptr:
			v.Set(reflect.New(v.Type().Elem()))
		case reflect.Slice:
			v.Set(reflect.MakeSlice(v.Type(), 0, 0))
		case reflect.Map:
			v.Set(reflect.MakeMap(v.Type()))
		default:
		}
		c.values[value] = v
	}
	return v
}

// ApplyDefaults if they are not already set.
func (c *Context) ApplyDefaults() error {
	return Visit(c.Model.Node, func(node Visitable, next Next) error {
		var value *Value
		switch node := node.(type) {
		case *Flag:
			value = node.Value
		case *Node:
			value = node.Argument
		case *Value:
			value = node
		default:
		}
		if value != nil {
			if err := value.ApplyDefault(); err != nil {
				return err
			}
		}
		return next(nil)
	})
}

// Apply traced context to the target grammar.
func (c *Context) Apply() (string, error) {
	path := []string{}

	for _, trace := range c.Path {
		var value *Value
		switch {
		case trace.App != nil:
		case trace.Argument != nil:
			path = append(path, "<"+trace.Argument.Name+">")
			value = trace.Argument.Argument
		case trace.Command != nil:
			path = append(path, trace.Command.Name)
		case trace.Flag != nil:
			value = trace.Flag.Value
		case trace.Positional != nil:
			path = append(path, "<"+trace.Positional.Name+">")
			value = trace.Positional
		default:
			panic("unsupported path ?!")
		}
		if value != nil {
			value.Apply(c.getValue(value))
		}
	}

	return strings.Join(path, " "), nil
}

func flipBoolValue(value reflect.Value) error {
	if value.Kind() == reflect.Bool {
		value.SetBool(!value.Bool())
		return nil
	}

	if value.Kind() == reflect.Ptr {
		if !value.IsNil() {
			return flipBoolValue(value.Elem())
		}
		return nil
	}

	return fmt.Errorf("cannot negate a value of %s", value.Type().String())
}

func (c *Context) parseFlag(flags []*Flag, match string) (err error) {
	candidates := []string{}

	for _, flag := range flags {
		long := "--" + flag.Name
		matched := long == match
		candidates = append(candidates, long)
		if flag.Short != 0 {
			short := "-" + string(flag.Short)
			matched = matched || (short == match)
			candidates = append(candidates, short)
		}
		for _, alias := range flag.Aliases {
			alias = "--" + alias
			matched = matched || (alias == match)
			candidates = append(candidates, alias)
		}

		neg := negatableFlagName(flag.Name, flag.Tag.Negatable)
		if !matched && match != neg {
			continue
		}
		// Found a matching flag.
		c.scan.Pop()
		if match == neg && flag.Tag.Negatable != "" {
			flag.Negated = true
		}
		err := flag.Parse(c.scan, c.getValue(flag.Value))
		if err != nil {
			var expected *expectedError
			if errors.As(err, &expected) && expected.token.InferredType().IsAny(FlagToken, ShortFlagToken) {
				return fmt.Errorf("%s; perhaps try %s=%q?", err.Error(), flag.ShortSummary(), expected.token)
			}
			return err
		}
		if flag.Negated {
			value := c.getValue(flag.Value)
			err := flipBoolValue(value)
			if err != nil {
				return err
			}
			flag.Value.Apply(value)
		}
		c.Path = append(c.Path, &Path{
			Flag:      flag,
			remainder: c.scan.PeekAll(),
		})
		return nil
	}
	return &unknownFlagError{Cause: findPotentialCandidates(match, candidates, "unknown flag %s", match)}
}

func isUnknownFlagError(err error) bool {
	var unknown *unknownFlagError
	return errors.As(err, &unknown)
}

type unknownFlagError struct{ Cause error }

func (e *unknownFlagError) Unwrap() error { return e.Cause }
func (e *unknownFlagError) Error() string { return e.Cause.Error() }

// Call an arbitrary function filling arguments with bound values.
func (c *Context) Call(fn any, binds ...any) (out []any, err error) {
	fv := reflect.ValueOf(fn)
	bindings := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings)
	return callAnyFunction(fv, bindings)
}

// RunNode calls the Run() method on an arbitrary node.
//
// This is useful in conjunction with Visit(), for dynamically running commands.
//
// Any passed values will be bindable to arguments of the target Run() method. Additionally,
// all parent nodes in the command structure will be bound.
func (c *Context) RunNode(node *Node, binds ...any) (err error) {
	type targetMethod struct {
		node   *Node
		method reflect.Value
		binds  bindings
	}
	methodBinds := c.Kong.bindings.clone().add(binds...).add(c).merge(c.bindings)
	methods := []targetMethod{}
	for i := 0; node != nil; i, node = i+1, node.Parent {
		method := getMethod(node.Target, "Run")
		methodBinds = methodBinds.clone()
		for p := node; p != nil; p = p.Parent {
			methodBinds = methodBinds.add(p.Target.Addr().Interface())
			// Try value and pointer to value.
			for _, p := range []reflect.Value{p.Target, p.Target.Addr()} {
				t := p.Type()
				for i := 0; i < p.NumMethod(); i++ {
					methodt := t.Method(i)
					if strings.HasPrefix(methodt.Name, "Provide") {
						method := p.Method(i)
						if err := methodBinds.addProvider(method.Interface(), false /* singleton */); err != nil {
							return fmt.Errorf("%s.%s: %w", t.Name(), methodt.Name, err)
						}
					}
				}
			}
		}
		if method.IsValid() {
			methods = append(methods, targetMethod{node, method, methodBinds})
		}
	}
	if len(methods) == 0 {
		return fmt.Errorf("no Run() method found in hierarchy of %s", c.Selected().Summary())
	}
	for _, method := range methods {
		if err = callFunction(method.method, method.binds); err != nil {
			return err
		}
	}
	return nil
}

// Run executes the Run() method on the selected command, which must exist.
//
// Any passed values will be bindable to arguments of the target Run() method. Additionally,
// all parent nodes in the command structure will be bound.
func (c *Context) Run(binds ...any) (err error) {
	node := c.Selected()
	if node == nil {
		if len(c.Path) == 0 {
			return fmt.Errorf("no command selected")
		}
		selected := c.Path[0].Node()
		if selected.Type == ApplicationNode {
			method := getMethod(selected.Target, "Run")
			if method.IsValid() {
				node = selected
			}
		}

		if node == nil {
			return fmt.Errorf("no command selected")
		}
	}
	runErr := c.RunNode(node, binds...)
	err = c.Kong.applyHook(c, "AfterRun")
	return errors.Join(runErr, err)
}

// PrintUsage to Kong's stdout.
//
// If summary is true, a summarised version of the help will be output.
func (c *Context) PrintUsage(summary bool) error {
	options := c.helpOptions
	options.Summary = summary
	return c.printHelp(options)
}

func (c *Context) printHelp(options HelpOptions) error {
	options.ValueFormatter = c.Kong.helpFormatter
	return c.help(options, c)
}

func checkMissingFlags(flags []*Flag) error {
	xorGroupSet := map[string]bool{}
	xorGroup := map[string][]string{}
	andGroupSet := map[string]bool{}
	andGroup := map[string][]string{}
	missing := []string{}
	andGroupRequired := getRequiredAndGroupMap(flags)
	for _, flag := range flags {
		for _, and := range flag.And {
			flag.Required = andGroupRequired[and]
		}
		if flag.Set {
			for _, xor := range flag.Xor {
				xorGroupSet[xor] = true
			}
			for _, and := range flag.And {
				andGroupSet[and] = true
			}
		}
		if !flag.Required || flag.Set {
			continue
		}
		if len(flag.Xor) > 0 || len(flag.And) > 0 {
			for _, xor := range flag.Xor {
				if xorGroupSet[xor] {
					continue
				}
				xorGroup[xor] = append(xorGroup[xor], flag.Summary())
			}
			for _, and := range flag.And {
				andGroup[and] = append(andGroup[and], flag.Summary())
			}
		} else {
			missing = append(missing, flag.Summary())
		}
	}
	for xor, flags := range xorGroup {
		if !xorGroupSet[xor] && len(flags) > 1 {
			missing = append(missing, strings.Join(flags, " or "))
		}
	}
	for _, flags := range andGroup {
		if len(flags) > 1 {
			missing = append(missing, strings.Join(flags, " and "))
		}
	}

	if len(missing) == 0 {
		return nil
	}

	sort.Strings(missing)

	return fmt.Errorf("missing flags: %s", strings.Join(missing, ", "))
}

func getRequiredAndGroupMap(flags []*Flag) map[string]bool {
	andGroupRequired := map[string]bool{}
	for _, flag := range flags {
		for _, and := range flag.And {
			if flag.Required {
				andGroupRequired[and] = true
			}
		}
	}
	return andGroupRequired
}

func checkMissingChildren(node *Node) error {
	missing := []string{}

	missingArgs := []string{}
	for _, arg := range node.Positional {
		if arg.Required && !arg.Set {
			missingArgs = append(missingArgs, arg.Summary())
		}
	}
	if len(missingArgs) > 0 {
		missing = append(missing, strconv.Quote(strings.Join(missingArgs, " ")))
	}

	for _, child := range node.Children {
		if child.Hidden {
			continue
		}
		if child.Argument != nil {
			if !child.Argument.Required {
				continue
			}
			missing = append(missing, strconv.Quote(child.Summary()))
		} else {
			missing = append(missing, strconv.Quote(child.Name))
		}
	}
	if len(missing) == 0 {
		return nil
	}

	if len(missing) > 5 {
		missing = append(missing[:5], "...")
	}
	if len(missing) == 1 {
		return fmt.Errorf("expected %s", missing[0])
	}
	return fmt.Errorf("expected one of %s", strings.Join(missing, ", "))
}

// If we're missing any positionals and they're required, return an error.
func checkMissingPositionals(positional int, values []*Value) error {
	// All the positionals are in.
	if positional >= len(values) {
		return nil
	}

	// We're low on supplied positionals, but the missing one is optional.
	if !values[positional].Required {
		return nil
	}

	missing := []string{}
	for ; positional < len(values); positional++ {
		arg := values[positional]
		// TODO(aat): Fix hardcoding of these env checks all over the place :\
		if len(arg.Tag.Envs) != 0 {
			if atLeastOneEnvSet(arg.Tag.Envs) {
				continue
			}
		}
		missing = append(missing, "<"+arg.Name+">")
	}
	if len(missing) == 0 {
		return nil
	}
	return fmt.Errorf("missing positional arguments %s", strings.Join(missing, " "))
}

func checkEnum(value *Value, target reflect.Value) error {
	switch target.Kind() {
	case reflect.Slice, reflect.Array:
		for i := 0; i < target.Len(); i++ {
			if err := checkEnum(value, target.Index(i)); err != nil {
				return err
			}
		}
		return nil

	case reflect.Map, reflect.Struct:
		return errors.New("enum can only be applied to a slice or value")

	case reflect.Ptr:
		if target.IsNil() {
			return nil
		}
		return checkEnum(value, target.Elem())
	default:
		enumSlice := value.EnumSlice()
		v := fmt.Sprintf("%v", target)
		enums := []string{}
		for _, enum := range enumSlice {
			if enum == v {
				return nil
			}
			enums = append(enums, fmt.Sprintf("%q", enum))
		}
		return fmt.Errorf("%s must be one of %s but got %q", value.ShortSummary(), strings.Join(enums, ","), fmt.Sprintf("%v", target.Interface()))
	}
}

func checkPassthroughArg(target reflect.Value) bool {
	typ := target.Type()
	switch typ.Kind() {
	case reflect.Slice:
		return typ.Elem().Kind() == reflect.String
	default:
		return false
	}
}

func checkXorDuplicatedAndAndMissing(paths []*Path) error {
	errs := []string{}
	if err := checkXorDuplicates(paths); err != nil {
		errs = append(errs, err.Error())
	}
	if err := checkAndMissing(paths); err != nil {
		errs = append(errs, err.Error())
	}
	if len(errs) > 0 {
		return errors.New(strings.Join(errs, ", "))
	}
	return nil
}

func checkXorDuplicates(paths []*Path) error {
	for _, path := range paths {
		seen := map[string]*Flag{}
		for _, flag := range path.Flags {
			if !flag.Set {
				continue
			}
			for _, xor := range flag.Xor {
				if seen[xor] != nil {
					return fmt.Errorf("--%s and --%s can't be used together", seen[xor].Name, flag.Name)
				}
				seen[xor] = flag
			}
		}
	}
	return nil
}

func checkAndMissing(paths []*Path) error {
	for _, path := range paths {
		missingMsgs := []string{}
		andGroups := map[string][]*Flag{}
		for _, flag := range path.Flags {
			for _, and := range flag.And {
				andGroups[and] = append(andGroups[and], flag)
			}
		}
		for _, flags := range andGroups {
			oneSet := false
			notSet := []*Flag{}
			flagNames := []string{}
			for _, flag := range flags {
				flagNames = append(flagNames, flag.Name)
				if flag.Set {
					oneSet = true
				} else {
					notSet = append(notSet, flag)
				}
			}
			if len(notSet) > 0 && oneSet {
				missingMsgs = append(missingMsgs, fmt.Sprintf("--%s must be used together", strings.Join(flagNames, " and --")))
			}
		}
		if len(missingMsgs) > 0 {
			return fmt.Errorf("%s", strings.Join(missingMsgs, ", "))
		}
	}
	return nil
}

func findPotentialCandidates(needle string, haystack []string, format string, args ...any) error {
	if len(haystack) == 0 {
		return fmt.Errorf(format, args...)
	}
	closestCandidates := []string{}
	for _, candidate := range haystack {
		if strings.HasPrefix(candidate, needle) || levenshtein(candidate, needle) <= 2 {
			closestCandidates = append(closestCandidates, fmt.Sprintf("%q", candidate))
		}
	}
	prefix := fmt.Sprintf(format, args...)
	if len(closestCandidates) == 1 {
		return fmt.Errorf("%s, did you mean %s?", prefix, closestCandidates[0])
	} else if len(closestCandidates) > 1 {
		return fmt.Errorf("%s, did you mean one of %s?", prefix, strings.Join(closestCandidates, ", "))
	}
	return fmt.Errorf("%s", prefix)
}

type validatable interface{ Validate() error }
type extendedValidatable interface {
	Validate(kctx *Context) error
}

// Proxy a validatable function to the extendedValidatable interface
type validatableFunc func() error

func (f validatableFunc) Validate(kctx *Context) error { return f() }

func isValidatable(v reflect.Value) extendedValidatable {
	if !v.IsValid() || (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice || v.Kind() == reflect.Map) && v.IsNil() {
		return nil
	}
	if validate, ok := v.Interface().(validatable); ok {
		return validatableFunc(validate.Validate)
	}
	if validate, ok := v.Interface().(extendedValidatable); ok {
		return validate
	}
	if v.CanAddr() {
		return isValidatable(v.Addr())
	}
	return nil
}

func atLeastOneEnvSet(envs []string) bool {
	for _, env := range envs {
		if _, ok := os.LookupEnv(env); ok {
			return true
		}
	}
	return false
}
