Source file src/cmd/compile/internal/noder/quirks.go

     1  // UNREVIEWED
     2  
     3  // Copyright 2021 The Go Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package noder
     8  
     9  import (
    10  	"fmt"
    11  
    12  	"cmd/compile/internal/base"
    13  	"cmd/compile/internal/ir"
    14  	"cmd/compile/internal/syntax"
    15  	"cmd/compile/internal/types2"
    16  	"cmd/internal/src"
    17  )
    18  
    19  // This file defines helper functions useful for satisfying toolstash
    20  // -cmp when compared against the legacy frontend behavior, but can be
    21  // removed after that's no longer a concern.
    22  
    23  // quirksMode controls whether behavior specific to satisfying
    24  // toolstash -cmp is used.
    25  func quirksMode() bool {
    26  	return base.Debug.UnifiedQuirks != 0
    27  }
    28  
    29  // posBasesOf returns all of the position bases in the source files,
    30  // as seen in a straightforward traversal.
    31  //
    32  // This is necessary to ensure position bases (and thus file names)
    33  // get registered in the same order as noder would visit them.
    34  func posBasesOf(noders []*noder) []*syntax.PosBase {
    35  	seen := make(map[*syntax.PosBase]bool)
    36  	var bases []*syntax.PosBase
    37  
    38  	for _, p := range noders {
    39  		syntax.Crawl(p.file, func(n syntax.Node) bool {
    40  			if b := n.Pos().Base(); !seen[b] {
    41  				bases = append(bases, b)
    42  				seen[b] = true
    43  			}
    44  			return false
    45  		})
    46  	}
    47  
    48  	return bases
    49  }
    50  
    51  // importedObjsOf returns the imported objects (i.e., referenced
    52  // objects not declared by curpkg) from the parsed source files, in
    53  // the order that typecheck used to load their definitions.
    54  //
    55  // This is needed because loading the definitions for imported objects
    56  // can also add file names.
    57  func importedObjsOf(curpkg *types2.Package, info *types2.Info, noders []*noder) []types2.Object {
    58  	// This code is complex because it matches the precise order that
    59  	// typecheck recursively and repeatedly traverses the IR. It's meant
    60  	// to be thrown away eventually anyway.
    61  
    62  	seen := make(map[types2.Object]bool)
    63  	var objs []types2.Object
    64  
    65  	var phase int
    66  
    67  	decls := make(map[types2.Object]syntax.Decl)
    68  	assoc := func(decl syntax.Decl, names ...*syntax.Name) {
    69  		for _, name := range names {
    70  			obj, ok := info.Defs[name]
    71  			assert(ok)
    72  			decls[obj] = decl
    73  		}
    74  	}
    75  
    76  	for _, p := range noders {
    77  		syntax.Crawl(p.file, func(n syntax.Node) bool {
    78  			switch n := n.(type) {
    79  			case *syntax.ConstDecl:
    80  				assoc(n, n.NameList...)
    81  			case *syntax.FuncDecl:
    82  				assoc(n, n.Name)
    83  			case *syntax.TypeDecl:
    84  				assoc(n, n.Name)
    85  			case *syntax.VarDecl:
    86  				assoc(n, n.NameList...)
    87  			case *syntax.BlockStmt:
    88  				return true
    89  			}
    90  			return false
    91  		})
    92  	}
    93  
    94  	var visited map[syntax.Decl]bool
    95  
    96  	var resolveDecl func(n syntax.Decl)
    97  	var resolveNode func(n syntax.Node, top bool)
    98  
    99  	resolveDecl = func(n syntax.Decl) {
   100  		if visited[n] {
   101  			return
   102  		}
   103  		visited[n] = true
   104  
   105  		switch n := n.(type) {
   106  		case *syntax.ConstDecl:
   107  			resolveNode(n.Type, true)
   108  			resolveNode(n.Values, true)
   109  
   110  		case *syntax.FuncDecl:
   111  			if n.Recv != nil {
   112  				resolveNode(n.Recv, true)
   113  			}
   114  			resolveNode(n.Type, true)
   115  
   116  		case *syntax.TypeDecl:
   117  			resolveNode(n.Type, true)
   118  
   119  		case *syntax.VarDecl:
   120  			if n.Type != nil {
   121  				resolveNode(n.Type, true)
   122  			} else {
   123  				resolveNode(n.Values, true)
   124  			}
   125  		}
   126  	}
   127  
   128  	resolveObj := func(pos syntax.Pos, obj types2.Object) {
   129  		switch obj.Pkg() {
   130  		case nil:
   131  			// builtin; nothing to do
   132  
   133  		case curpkg:
   134  			if decl, ok := decls[obj]; ok {
   135  				resolveDecl(decl)
   136  			}
   137  
   138  		default:
   139  			if obj.Parent() == obj.Pkg().Scope() && !seen[obj] {
   140  				seen[obj] = true
   141  				objs = append(objs, obj)
   142  			}
   143  		}
   144  	}
   145  
   146  	checkdefat := func(pos syntax.Pos, n *syntax.Name) {
   147  		if n.Value == "_" {
   148  			return
   149  		}
   150  		obj, ok := info.Uses[n]
   151  		if !ok {
   152  			obj, ok = info.Defs[n]
   153  			if !ok {
   154  				return
   155  			}
   156  		}
   157  		if obj == nil {
   158  			return
   159  		}
   160  		resolveObj(pos, obj)
   161  	}
   162  	checkdef := func(n *syntax.Name) { checkdefat(n.Pos(), n) }
   163  
   164  	var later []syntax.Node
   165  
   166  	resolveNode = func(n syntax.Node, top bool) {
   167  		if n == nil {
   168  			return
   169  		}
   170  		syntax.Crawl(n, func(n syntax.Node) bool {
   171  			switch n := n.(type) {
   172  			case *syntax.Name:
   173  				checkdef(n)
   174  
   175  			case *syntax.SelectorExpr:
   176  				if name, ok := n.X.(*syntax.Name); ok {
   177  					if _, isPkg := info.Uses[name].(*types2.PkgName); isPkg {
   178  						checkdefat(n.X.Pos(), n.Sel)
   179  						return true
   180  					}
   181  				}
   182  
   183  			case *syntax.AssignStmt:
   184  				resolveNode(n.Rhs, top)
   185  				resolveNode(n.Lhs, top)
   186  				return true
   187  
   188  			case *syntax.VarDecl:
   189  				resolveNode(n.Values, top)
   190  
   191  			case *syntax.FuncLit:
   192  				if top {
   193  					resolveNode(n.Type, top)
   194  					later = append(later, n.Body)
   195  					return true
   196  				}
   197  
   198  			case *syntax.BlockStmt:
   199  				if phase >= 3 {
   200  					for _, stmt := range n.List {
   201  						resolveNode(stmt, false)
   202  					}
   203  				}
   204  				return true
   205  			}
   206  
   207  			return false
   208  		})
   209  	}
   210  
   211  	for phase = 1; phase <= 5; phase++ {
   212  		visited = map[syntax.Decl]bool{}
   213  
   214  		for _, p := range noders {
   215  			for _, decl := range p.file.DeclList {
   216  				switch decl := decl.(type) {
   217  				case *syntax.ConstDecl:
   218  					resolveDecl(decl)
   219  
   220  				case *syntax.FuncDecl:
   221  					resolveDecl(decl)
   222  					if phase >= 3 && decl.Body != nil {
   223  						resolveNode(decl.Body, true)
   224  					}
   225  
   226  				case *syntax.TypeDecl:
   227  					if !decl.Alias || phase >= 2 {
   228  						resolveDecl(decl)
   229  					}
   230  
   231  				case *syntax.VarDecl:
   232  					if phase >= 2 {
   233  						resolveNode(decl.Values, true)
   234  						resolveDecl(decl)
   235  					}
   236  				}
   237  			}
   238  
   239  			if phase >= 5 {
   240  				syntax.Crawl(p.file, func(n syntax.Node) bool {
   241  					if name, ok := n.(*syntax.Name); ok {
   242  						if obj, ok := info.Uses[name]; ok {
   243  							resolveObj(name.Pos(), obj)
   244  						}
   245  					}
   246  					return false
   247  				})
   248  			}
   249  		}
   250  
   251  		for i := 0; i < len(later); i++ {
   252  			resolveNode(later[i], true)
   253  		}
   254  		later = nil
   255  	}
   256  
   257  	return objs
   258  }
   259  
   260  // typeExprEndPos returns the position that noder would leave base.Pos
   261  // after parsing the given type expression.
   262  func typeExprEndPos(expr0 syntax.Expr) syntax.Pos {
   263  	for {
   264  		switch expr := expr0.(type) {
   265  		case *syntax.Name:
   266  			return expr.Pos()
   267  		case *syntax.SelectorExpr:
   268  			return expr.X.Pos()
   269  
   270  		case *syntax.ParenExpr:
   271  			expr0 = expr.X
   272  
   273  		case *syntax.Operation:
   274  			assert(expr.Op == syntax.Mul)
   275  			assert(expr.Y == nil)
   276  			expr0 = expr.X
   277  
   278  		case *syntax.ArrayType:
   279  			expr0 = expr.Elem
   280  		case *syntax.ChanType:
   281  			expr0 = expr.Elem
   282  		case *syntax.DotsType:
   283  			expr0 = expr.Elem
   284  		case *syntax.MapType:
   285  			expr0 = expr.Value
   286  		case *syntax.SliceType:
   287  			expr0 = expr.Elem
   288  
   289  		case *syntax.StructType:
   290  			return expr.Pos()
   291  
   292  		case *syntax.InterfaceType:
   293  			expr0 = lastFieldType(expr.MethodList)
   294  			if expr0 == nil {
   295  				return expr.Pos()
   296  			}
   297  
   298  		case *syntax.FuncType:
   299  			expr0 = lastFieldType(expr.ResultList)
   300  			if expr0 == nil {
   301  				expr0 = lastFieldType(expr.ParamList)
   302  				if expr0 == nil {
   303  					return expr.Pos()
   304  				}
   305  			}
   306  
   307  		case *syntax.IndexExpr: // explicit type instantiation
   308  			targs := unpackListExpr(expr.Index)
   309  			expr0 = targs[len(targs)-1]
   310  
   311  		default:
   312  			panic(fmt.Sprintf("%s: unexpected type expression %v", expr.Pos(), syntax.String(expr)))
   313  		}
   314  	}
   315  }
   316  
   317  func lastFieldType(fields []*syntax.Field) syntax.Expr {
   318  	if len(fields) == 0 {
   319  		return nil
   320  	}
   321  	return fields[len(fields)-1].Type
   322  }
   323  
   324  // sumPos returns the position that noder.sum would produce for
   325  // constant expression x.
   326  func sumPos(x syntax.Expr) syntax.Pos {
   327  	orig := x
   328  	for {
   329  		switch x1 := x.(type) {
   330  		case *syntax.BasicLit:
   331  			assert(x1.Kind == syntax.StringLit)
   332  			return x1.Pos()
   333  		case *syntax.Operation:
   334  			assert(x1.Op == syntax.Add && x1.Y != nil)
   335  			if r, ok := x1.Y.(*syntax.BasicLit); ok {
   336  				assert(r.Kind == syntax.StringLit)
   337  				x = x1.X
   338  				continue
   339  			}
   340  		}
   341  		return orig.Pos()
   342  	}
   343  }
   344  
   345  // funcParamsEndPos returns the value of base.Pos left by noder after
   346  // processing a function signature.
   347  func funcParamsEndPos(fn *ir.Func) src.XPos {
   348  	sig := fn.Nname.Type()
   349  
   350  	fields := sig.Results().FieldSlice()
   351  	if len(fields) == 0 {
   352  		fields = sig.Params().FieldSlice()
   353  		if len(fields) == 0 {
   354  			fields = sig.Recvs().FieldSlice()
   355  			if len(fields) == 0 {
   356  				if fn.OClosure != nil {
   357  					return fn.Nname.Ntype.Pos()
   358  				}
   359  				return fn.Pos()
   360  			}
   361  		}
   362  	}
   363  
   364  	return fields[len(fields)-1].Pos
   365  }
   366  
   367  type dupTypes struct {
   368  	origs map[types2.Type]types2.Type
   369  }
   370  
   371  func (d *dupTypes) orig(t types2.Type) types2.Type {
   372  	if orig, ok := d.origs[t]; ok {
   373  		return orig
   374  	}
   375  	return t
   376  }
   377  
   378  func (d *dupTypes) add(t, orig types2.Type) {
   379  	if t == orig {
   380  		return
   381  	}
   382  
   383  	if d.origs == nil {
   384  		d.origs = make(map[types2.Type]types2.Type)
   385  	}
   386  	assert(d.origs[t] == nil)
   387  	d.origs[t] = orig
   388  
   389  	switch t := t.(type) {
   390  	case *types2.Pointer:
   391  		orig := orig.(*types2.Pointer)
   392  		d.add(t.Elem(), orig.Elem())
   393  
   394  	case *types2.Slice:
   395  		orig := orig.(*types2.Slice)
   396  		d.add(t.Elem(), orig.Elem())
   397  
   398  	case *types2.Map:
   399  		orig := orig.(*types2.Map)
   400  		d.add(t.Key(), orig.Key())
   401  		d.add(t.Elem(), orig.Elem())
   402  
   403  	case *types2.Array:
   404  		orig := orig.(*types2.Array)
   405  		assert(t.Len() == orig.Len())
   406  		d.add(t.Elem(), orig.Elem())
   407  
   408  	case *types2.Chan:
   409  		orig := orig.(*types2.Chan)
   410  		assert(t.Dir() == orig.Dir())
   411  		d.add(t.Elem(), orig.Elem())
   412  
   413  	case *types2.Struct:
   414  		orig := orig.(*types2.Struct)
   415  		assert(t.NumFields() == orig.NumFields())
   416  		for i := 0; i < t.NumFields(); i++ {
   417  			d.add(t.Field(i).Type(), orig.Field(i).Type())
   418  		}
   419  
   420  	case *types2.Interface:
   421  		orig := orig.(*types2.Interface)
   422  		assert(t.NumExplicitMethods() == orig.NumExplicitMethods())
   423  		assert(t.NumEmbeddeds() == orig.NumEmbeddeds())
   424  		for i := 0; i < t.NumExplicitMethods(); i++ {
   425  			d.add(t.ExplicitMethod(i).Type(), orig.ExplicitMethod(i).Type())
   426  		}
   427  		for i := 0; i < t.NumEmbeddeds(); i++ {
   428  			d.add(t.EmbeddedType(i), orig.EmbeddedType(i))
   429  		}
   430  
   431  	case *types2.Signature:
   432  		orig := orig.(*types2.Signature)
   433  		assert((t.Recv() == nil) == (orig.Recv() == nil))
   434  		if t.Recv() != nil {
   435  			d.add(t.Recv().Type(), orig.Recv().Type())
   436  		}
   437  		d.add(t.Params(), orig.Params())
   438  		d.add(t.Results(), orig.Results())
   439  
   440  	case *types2.Tuple:
   441  		orig := orig.(*types2.Tuple)
   442  		assert(t.Len() == orig.Len())
   443  		for i := 0; i < t.Len(); i++ {
   444  			d.add(t.At(i).Type(), orig.At(i).Type())
   445  		}
   446  
   447  	default:
   448  		assert(types2.Identical(t, orig))
   449  	}
   450  }
   451  

View as plain text