// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package eg

// This file defines the AST rewriting pass.
// Most of it was plundered directly from
// $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).

import (
	"fmt"
	"go/ast"
	"go/token"
	"go/types"
	"os"
	"reflect"
	"sort"
	"strconv"
	"strings"

	"golang.org/x/tools/go/ast/astutil"
)

// transformItem takes a reflect.Value representing a variable of type ast.Node
// transforms its child elements recursively with apply, and then transforms the
// actual element if it contains an expression.
func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
	// don't bother if val is invalid to start with
	if !rv.IsValid() {
		return reflect.Value{}, false, nil
	}

	rv, changed, newEnv := tr.apply(tr.transformItem, rv)

	e := rvToExpr(rv)
	if e == nil {
		return rv, changed, newEnv
	}

	savedEnv := tr.env
	tr.env = make(map[string]ast.Expr) // inefficient!  Use a slice of k/v pairs

	if tr.matchExpr(tr.before, e) {
		if tr.verbose {
			fmt.Fprintf(os.Stderr, "%s matches %s",
				astString(tr.fset, tr.before), astString(tr.fset, e))
			if len(tr.env) > 0 {
				fmt.Fprintf(os.Stderr, " with:")
				for name, ast := range tr.env {
					fmt.Fprintf(os.Stderr, " %s->%s",
						name, astString(tr.fset, ast))
				}
			}
			fmt.Fprintf(os.Stderr, "\n")
		}
		tr.nsubsts++

		// Clone the replacement tree, performing parameter substitution.
		// We update all positions to n.Pos() to aid comment placement.
		rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
			reflect.ValueOf(e.Pos()))
		changed = true
		newEnv = tr.env
	}
	tr.env = savedEnv

	return rv, changed, newEnv
}

// Transform applies the transformation to the specified parsed file,
// whose type information is supplied in info, and returns the number
// of replacements that were made.
//
// It mutates the AST in place (the identity of the root node is
// unchanged), and may add nodes for which no type information is
// available in info.
//
// Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
	if !tr.seenInfos[info] {
		tr.seenInfos[info] = true
		mergeTypeInfo(tr.info, info)
	}
	tr.currentPkg = pkg
	tr.nsubsts = 0

	if tr.verbose {
		fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
		fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
		fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
	}

	o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
	if changed {
		panic("BUG")
	}
	file2 := o.Interface().(*ast.File)

	// By construction, the root node is unchanged.
	if file != file2 {
		panic("BUG")
	}

	// Add any necessary imports.
	// TODO(adonovan): remove no-longer needed imports too.
	if tr.nsubsts > 0 {
		pkgs := make(map[string]*types.Package)
		for obj := range tr.importedObjs {
			pkgs[obj.Pkg().Path()] = obj.Pkg()
		}

		for _, imp := range file.Imports {
			path, _ := strconv.Unquote(imp.Path.Value)
			delete(pkgs, path)
		}
		delete(pkgs, pkg.Path()) // don't import self

		// NB: AddImport may completely replace the AST!
		// It thus renders info and tr.info no longer relevant to file.
		var paths []string
		for path := range pkgs {
			paths = append(paths, path)
		}
		sort.Strings(paths)
		for _, path := range paths {
			astutil.AddImport(tr.fset, file, path)
		}
	}

	tr.currentPkg = nil

	return tr.nsubsts
}

// setValue is a wrapper for x.SetValue(y); it protects
// the caller from panics if x cannot be changed to y.
func setValue(x, y reflect.Value) {
	// don't bother if y is invalid to start with
	if !y.IsValid() {
		return
	}
	defer func() {
		if x := recover(); x != nil {
			if s, ok := x.(string); ok &&
				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
				// x cannot be set to y - ignore this rewrite
				return
			}
			panic(x)
		}
	}()
	x.Set(y)
}

// Values/types for special cases.
var (
	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))

	identType        = reflect.TypeFor[*ast.Ident]()
	selectorExprType = reflect.TypeFor[*ast.SelectorExpr]()
	objectPtrType    = reflect.TypeFor[*ast.Object]()
	statementType    = reflect.TypeFor[ast.Stmt]()
	positionType     = reflect.TypeFor[token.Pos]()
	scopePtrType     = reflect.TypeFor[*ast.Scope]()
)

// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
// f takes a reflect.Value representing the variable to modify of type ast.Node.
// It returns a reflect.Value containing the transformed value of type ast.Node,
// whether any change was made, and a map of identifiers to ast.Expr (so we can
// do contextually correct substitutions in the parent statements).
func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
	if !val.IsValid() {
		return reflect.Value{}, false, nil
	}

	// *ast.Objects introduce cycles and are likely incorrect after
	// rewrite; don't follow them but replace with nil instead
	if val.Type() == objectPtrType {
		return objectPtrNil, false, nil
	}

	// similarly for scopes: they are likely incorrect after a rewrite;
	// replace them with nil
	if val.Type() == scopePtrType {
		return scopePtrNil, false, nil
	}

	switch v := reflect.Indirect(val); v.Kind() {
	case reflect.Slice:
		// no possible rewriting of statements.
		if v.Type().Elem() != statementType {
			changed := false
			var envp map[string]ast.Expr
			for i := 0; i < v.Len(); i++ {
				e := v.Index(i)
				o, localchanged, env := f(e)
				if localchanged {
					changed = true
					// we clobber envp here,
					// which means if we have two successive
					// replacements inside the same statement
					// we will only generate the setup for one of them.
					envp = env
				}
				setValue(e, o)
			}
			return val, changed, envp
		}

		// statements are rewritten.
		var out []ast.Stmt
		for i := 0; i < v.Len(); i++ {
			e := v.Index(i)
			o, changed, env := f(e)
			if changed {
				for _, s := range tr.afterStmts {
					t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
					out = append(out, t.(ast.Stmt))
				}
			}
			setValue(e, o)
			out = append(out, e.Interface().(ast.Stmt))
		}
		return reflect.ValueOf(out), false, nil
	case reflect.Struct:
		changed := false
		var envp map[string]ast.Expr
		for i := 0; i < v.NumField(); i++ {
			e := v.Field(i)
			o, localchanged, env := f(e)
			if localchanged {
				changed = true
				envp = env
			}
			setValue(e, o)
		}
		return val, changed, envp
	case reflect.Interface:
		e := v.Elem()
		o, changed, env := f(e)
		setValue(v, o)
		return val, changed, env
	}
	return val, false, nil
}

// subst returns a copy of (replacement) pattern with values from env
// substituted in place of wildcards and pos used as the position of
// tokens from the pattern. If env == nil, subst returns a copy of
// pattern and doesn't change the line number information.
func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
	if !pattern.IsValid() {
		return reflect.Value{}
	}

	// *ast.Objects introduce cycles and are likely incorrect after
	// rewrite; don't follow them but replace with nil instead
	if pattern.Type() == objectPtrType {
		return objectPtrNil
	}

	// similarly for scopes: they are likely incorrect after a rewrite;
	// replace them with nil
	if pattern.Type() == scopePtrType {
		return scopePtrNil
	}

	// Wildcard gets replaced with map value.
	if env != nil && pattern.Type() == identType {
		id := pattern.Interface().(*ast.Ident)
		if old, ok := env[id.Name]; ok {
			return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
		}
	}

	// Emit qualified identifiers in the pattern by appropriate
	// (possibly qualified) identifier in the input.
	//
	// The template cannot contain dot imports, so all identifiers
	// for imported objects are explicitly qualified.
	//
	// We assume (unsoundly) that there are no dot or named
	// imports in the input code, nor are any imported package
	// names shadowed, so the usual normal qualified identifier
	// syntax may be used.
	// TODO(adonovan): fix: avoid this assumption.
	//
	// A refactoring may be applied to a package referenced by the
	// template.  Objects belonging to the current package are
	// denoted by unqualified identifiers.
	//
	if tr.importedObjs != nil && pattern.Type() == selectorExprType {
		obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
		if obj != nil {
			if sel, ok := tr.importedObjs[obj]; ok {
				var id ast.Expr
				if obj.Pkg() == tr.currentPkg {
					id = sel.Sel // unqualified
				} else {
					id = sel // pkg-qualified
				}

				// Return a clone of id.
				saved := tr.importedObjs
				tr.importedObjs = nil // break cycle
				r := tr.subst(nil, reflect.ValueOf(id), pos)
				tr.importedObjs = saved
				return r
			}
		}
	}

	if pos.IsValid() && pattern.Type() == positionType {
		// use new position only if old position was valid in the first place
		if old := pattern.Interface().(token.Pos); !old.IsValid() {
			return pattern
		}
		return pos
	}

	// Otherwise copy.
	switch p := pattern; p.Kind() {
	case reflect.Slice:
		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
		for i := 0; i < p.Len(); i++ {
			v.Index(i).Set(tr.subst(env, p.Index(i), pos))
		}
		return v

	case reflect.Struct:
		v := reflect.New(p.Type()).Elem()
		for i := 0; i < p.NumField(); i++ {
			v.Field(i).Set(tr.subst(env, p.Field(i), pos))
		}
		return v

	case reflect.Pointer:
		v := reflect.New(p.Type()).Elem()
		if elem := p.Elem(); elem.IsValid() {
			v.Set(tr.subst(env, elem, pos).Addr())
		}

		// Duplicate type information for duplicated ast.Expr.
		// All ast.Node implementations are *structs,
		// so this case catches them all.
		if e := rvToExpr(v); e != nil {
			updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
		}
		return v

	case reflect.Interface:
		v := reflect.New(p.Type()).Elem()
		if elem := p.Elem(); elem.IsValid() {
			v.Set(tr.subst(env, elem, pos))
		}
		return v
	}

	return pattern
}

// -- utilities -------------------------------------------------------

func rvToExpr(rv reflect.Value) ast.Expr {
	if rv.CanInterface() {
		if e, ok := rv.Interface().(ast.Expr); ok {
			return e
		}
	}
	return nil
}

// updateTypeInfo duplicates type information for the existing AST old
// so that it also applies to duplicated AST new.
func updateTypeInfo(info *types.Info, new, old ast.Expr) {
	switch new := new.(type) {
	case *ast.Ident:
		orig := old.(*ast.Ident)
		if obj, ok := info.Defs[orig]; ok {
			info.Defs[new] = obj
		}
		if obj, ok := info.Uses[orig]; ok {
			info.Uses[new] = obj
		}

	case *ast.SelectorExpr:
		orig := old.(*ast.SelectorExpr)
		if sel, ok := info.Selections[orig]; ok {
			info.Selections[new] = sel
		}
	}

	if tv, ok := info.Types[old]; ok {
		info.Types[new] = tv
	}
}
