Source file src/cmd/compile/internal/types2/unify.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file implements type unification.
     6  
     7  package types2
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"strings"
    13  )
    14  
    15  // The unifier maintains two separate sets of type parameters x and y
    16  // which are used to resolve type parameters in the x and y arguments
    17  // provided to the unify call. For unidirectional unification, only
    18  // one of these sets (say x) is provided, and then type parameters are
    19  // only resolved for the x argument passed to unify, not the y argument
    20  // (even if that also contains possibly the same type parameters). This
    21  // is crucial to infer the type parameters of self-recursive calls:
    22  //
    23  //	func f[P any](a P) { f(a) }
    24  //
    25  // For the call f(a) we want to infer that the type argument for P is P.
    26  // During unification, the parameter type P must be resolved to the type
    27  // parameter P ("x" side), but the argument type P must be left alone so
    28  // that unification resolves the type parameter P to P.
    29  //
    30  // For bidirectional unification, both sets are provided. This enables
    31  // unification to go from argument to parameter type and vice versa.
    32  // For constraint type inference, we use bidirectional unification
    33  // where both the x and y type parameters are identical. This is done
    34  // by setting up one of them (using init) and then assigning its value
    35  // to the other.
    36  
    37  const (
    38  	// Upper limit for recursion depth. Used to catch infinite recursions
    39  	// due to implementation issues (e.g., see issues #48619, #48656).
    40  	unificationDepthLimit = 50
    41  
    42  	// Whether to panic when unificationDepthLimit is reached. Turn on when
    43  	// investigating infinite recursion.
    44  	panicAtUnificationDepthLimit = false
    45  
    46  	// If enableCoreTypeUnification is set, unification will consider
    47  	// the core types, if any, of non-local (unbound) type parameters.
    48  	enableCoreTypeUnification = true
    49  
    50  	// If traceInference is set, unification will print a trace of its operation.
    51  	// Interpretation of trace:
    52  	//   x ≡ y    attempt to unify types x and y
    53  	//   p ➞ y    type parameter p is set to type y (p is inferred to be y)
    54  	//   p ⇄ q    type parameters p and q match (p is inferred to be q and vice versa)
    55  	//   x ≢ y    types x and y cannot be unified
    56  	//   [p, q, ...] ➞ [x, y, ...]    mapping from type parameters to types
    57  	traceInference = false
    58  )
    59  
    60  // A unifier maintains the current type parameters for x and y
    61  // and the respective types inferred for each type parameter.
    62  // A unifier is created by calling newUnifier.
    63  type unifier struct {
    64  	exact bool
    65  	x, y  tparamsList // x and y must initialized via tparamsList.init
    66  	types []Type      // inferred types, shared by x and y
    67  	depth int         // recursion depth during unification
    68  }
    69  
    70  // newUnifier returns a new unifier.
    71  // If exact is set, unification requires unified types to match
    72  // exactly. If exact is not set, a named type's underlying type
    73  // is considered if unification would fail otherwise, and the
    74  // direction of channels is ignored.
    75  // TODO(gri) exact is not set anymore by a caller. Consider removing it.
    76  func newUnifier(exact bool) *unifier {
    77  	u := &unifier{exact: exact}
    78  	u.x.unifier = u
    79  	u.y.unifier = u
    80  	return u
    81  }
    82  
    83  // unify attempts to unify x and y and reports whether it succeeded.
    84  func (u *unifier) unify(x, y Type) bool {
    85  	return u.nify(x, y, nil)
    86  }
    87  
    88  func (u *unifier) tracef(format string, args ...interface{}) {
    89  	fmt.Println(strings.Repeat(".  ", u.depth) + sprintf(nil, true, format, args...))
    90  }
    91  
    92  // A tparamsList describes a list of type parameters and the types inferred for them.
    93  type tparamsList struct {
    94  	unifier *unifier
    95  	tparams []*TypeParam
    96  	// For each tparams element, there is a corresponding type slot index in indices.
    97  	// index  < 0: unifier.types[-index-1] == nil
    98  	// index == 0: no type slot allocated yet
    99  	// index  > 0: unifier.types[index-1] == typ
   100  	// Joined tparams elements share the same type slot and thus have the same index.
   101  	// By using a negative index for nil types we don't need to check unifier.types
   102  	// to see if we have a type or not.
   103  	indices []int // len(d.indices) == len(d.tparams)
   104  }
   105  
   106  // String returns a string representation for a tparamsList. For debugging.
   107  func (d *tparamsList) String() string {
   108  	var buf bytes.Buffer
   109  	w := newTypeWriter(&buf, nil)
   110  	w.byte('[')
   111  	for i, tpar := range d.tparams {
   112  		if i > 0 {
   113  			w.string(", ")
   114  		}
   115  		w.typ(tpar)
   116  		w.string(": ")
   117  		w.typ(d.at(i))
   118  	}
   119  	w.byte(']')
   120  	return buf.String()
   121  }
   122  
   123  // init initializes d with the given type parameters.
   124  // The type parameters must be in the order in which they appear in their declaration
   125  // (this ensures that the tparams indices match the respective type parameter index).
   126  func (d *tparamsList) init(tparams []*TypeParam) {
   127  	if len(tparams) == 0 {
   128  		return
   129  	}
   130  	if debug {
   131  		for i, tpar := range tparams {
   132  			assert(i == tpar.index)
   133  		}
   134  	}
   135  	d.tparams = tparams
   136  	d.indices = make([]int, len(tparams))
   137  }
   138  
   139  // join unifies the i'th type parameter of x with the j'th type parameter of y.
   140  // If both type parameters already have a type associated with them and they are
   141  // not joined, join fails and returns false.
   142  func (u *unifier) join(i, j int) bool {
   143  	if traceInference {
   144  		u.tracef("%s ⇄ %s", u.x.tparams[i], u.y.tparams[j])
   145  	}
   146  	ti := u.x.indices[i]
   147  	tj := u.y.indices[j]
   148  	switch {
   149  	case ti == 0 && tj == 0:
   150  		// Neither type parameter has a type slot associated with them.
   151  		// Allocate a new joined nil type slot (negative index).
   152  		u.types = append(u.types, nil)
   153  		u.x.indices[i] = -len(u.types)
   154  		u.y.indices[j] = -len(u.types)
   155  	case ti == 0:
   156  		// The type parameter for x has no type slot yet. Use slot of y.
   157  		u.x.indices[i] = tj
   158  	case tj == 0:
   159  		// The type parameter for y has no type slot yet. Use slot of x.
   160  		u.y.indices[j] = ti
   161  
   162  	// Both type parameters have a slot: ti != 0 && tj != 0.
   163  	case ti == tj:
   164  		// Both type parameters already share the same slot. Nothing to do.
   165  		break
   166  	case ti > 0 && tj > 0:
   167  		// Both type parameters have (possibly different) inferred types. Cannot join.
   168  		// TODO(gri) Should we check if types are identical? Investigate.
   169  		return false
   170  	case ti > 0:
   171  		// Only the type parameter for x has an inferred type. Use x slot for y.
   172  		u.y.setIndex(j, ti)
   173  	// This case is handled like the default case.
   174  	// case tj > 0:
   175  	// 	// Only the type parameter for y has an inferred type. Use y slot for x.
   176  	// 	u.x.setIndex(i, tj)
   177  	default:
   178  		// Neither type parameter has an inferred type. Use y slot for x
   179  		// (or x slot for y, it doesn't matter).
   180  		u.x.setIndex(i, tj)
   181  	}
   182  	return true
   183  }
   184  
   185  // If typ is a type parameter of d, index returns the type parameter index.
   186  // Otherwise, the result is < 0.
   187  func (d *tparamsList) index(typ Type) int {
   188  	if tpar, ok := typ.(*TypeParam); ok {
   189  		return tparamIndex(d.tparams, tpar)
   190  	}
   191  	return -1
   192  }
   193  
   194  // If tpar is a type parameter in list, tparamIndex returns the type parameter index.
   195  // Otherwise, the result is < 0. tpar must not be nil.
   196  func tparamIndex(list []*TypeParam, tpar *TypeParam) int {
   197  	// Once a type parameter is bound its index is >= 0. However, there are some
   198  	// code paths (namely tracing and type hashing) by which it is possible to
   199  	// arrive here with a type parameter that has not been bound, hence the check
   200  	// for 0 <= i below.
   201  	// TODO(rfindley): investigate a better approach for guarding against using
   202  	// unbound type parameters.
   203  	if i := tpar.index; 0 <= i && i < len(list) && list[i] == tpar {
   204  		return i
   205  	}
   206  	return -1
   207  }
   208  
   209  // setIndex sets the type slot index for the i'th type parameter
   210  // (and all its joined parameters) to tj. The type parameter
   211  // must have a (possibly nil) type slot associated with it.
   212  func (d *tparamsList) setIndex(i, tj int) {
   213  	ti := d.indices[i]
   214  	assert(ti != 0 && tj != 0)
   215  	for k, tk := range d.indices {
   216  		if tk == ti {
   217  			d.indices[k] = tj
   218  		}
   219  	}
   220  }
   221  
   222  // at returns the type set for the i'th type parameter; or nil.
   223  func (d *tparamsList) at(i int) Type {
   224  	if ti := d.indices[i]; ti > 0 {
   225  		return d.unifier.types[ti-1]
   226  	}
   227  	return nil
   228  }
   229  
   230  // set sets the type typ for the i'th type parameter;
   231  // typ must not be nil and it must not have been set before.
   232  func (d *tparamsList) set(i int, typ Type) {
   233  	assert(typ != nil)
   234  	u := d.unifier
   235  	if traceInference {
   236  		u.tracef("%s ➞ %s", d.tparams[i], typ)
   237  	}
   238  	switch ti := d.indices[i]; {
   239  	case ti < 0:
   240  		u.types[-ti-1] = typ
   241  		d.setIndex(i, -ti)
   242  	case ti == 0:
   243  		u.types = append(u.types, typ)
   244  		d.indices[i] = len(u.types)
   245  	default:
   246  		panic("type already set")
   247  	}
   248  }
   249  
   250  // unknowns returns the number of type parameters for which no type has been set yet.
   251  func (d *tparamsList) unknowns() int {
   252  	n := 0
   253  	for _, ti := range d.indices {
   254  		if ti <= 0 {
   255  			n++
   256  		}
   257  	}
   258  	return n
   259  }
   260  
   261  // types returns the list of inferred types (via unification) for the type parameters
   262  // described by d, and an index. If all types were inferred, the returned index is < 0.
   263  // Otherwise, it is the index of the first type parameter which couldn't be inferred;
   264  // i.e., for which list[index] is nil.
   265  func (d *tparamsList) types() (list []Type, index int) {
   266  	list = make([]Type, len(d.tparams))
   267  	index = -1
   268  	for i := range d.tparams {
   269  		t := d.at(i)
   270  		list[i] = t
   271  		if index < 0 && t == nil {
   272  			index = i
   273  		}
   274  	}
   275  	return
   276  }
   277  
   278  func (u *unifier) nifyEq(x, y Type, p *ifacePair) bool {
   279  	return x == y || u.nify(x, y, p)
   280  }
   281  
   282  // nify implements the core unification algorithm which is an
   283  // adapted version of Checker.identical. For changes to that
   284  // code the corresponding changes should be made here.
   285  // Must not be called directly from outside the unifier.
   286  func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) {
   287  	if traceInference {
   288  		u.tracef("%s ≡ %s", x, y)
   289  	}
   290  
   291  	// Stop gap for cases where unification fails.
   292  	if u.depth >= unificationDepthLimit {
   293  		if traceInference {
   294  			u.tracef("depth %d >= %d", u.depth, unificationDepthLimit)
   295  		}
   296  		if panicAtUnificationDepthLimit {
   297  			panic("unification reached recursion depth limit")
   298  		}
   299  		return false
   300  	}
   301  	u.depth++
   302  	defer func() {
   303  		u.depth--
   304  		if traceInference && !result {
   305  			u.tracef("%s ≢ %s", x, y)
   306  		}
   307  	}()
   308  
   309  	if !u.exact {
   310  		// If exact unification is known to fail because we attempt to
   311  		// match a type name against an unnamed type literal, consider
   312  		// the underlying type of the named type.
   313  		// (We use !hasName to exclude any type with a name, including
   314  		// basic types and type parameters; the rest are unamed types.)
   315  		if nx, _ := x.(*Named); nx != nil && !hasName(y) {
   316  			if traceInference {
   317  				u.tracef("under %s ≡ %s", nx, y)
   318  			}
   319  			return u.nify(nx.under(), y, p)
   320  		} else if ny, _ := y.(*Named); ny != nil && !hasName(x) {
   321  			if traceInference {
   322  				u.tracef("%s ≡ under %s", x, ny)
   323  			}
   324  			return u.nify(x, ny.under(), p)
   325  		}
   326  	}
   327  
   328  	// Cases where at least one of x or y is a type parameter.
   329  	switch i, j := u.x.index(x), u.y.index(y); {
   330  	case i >= 0 && j >= 0:
   331  		// both x and y are type parameters
   332  		if u.join(i, j) {
   333  			return true
   334  		}
   335  		// both x and y have an inferred type - they must match
   336  		return u.nifyEq(u.x.at(i), u.y.at(j), p)
   337  
   338  	case i >= 0:
   339  		// x is a type parameter, y is not
   340  		if tx := u.x.at(i); tx != nil {
   341  			return u.nifyEq(tx, y, p)
   342  		}
   343  		// otherwise, infer type from y
   344  		u.x.set(i, y)
   345  		return true
   346  
   347  	case j >= 0:
   348  		// y is a type parameter, x is not
   349  		if ty := u.y.at(j); ty != nil {
   350  			return u.nifyEq(x, ty, p)
   351  		}
   352  		// otherwise, infer type from x
   353  		u.y.set(j, x)
   354  		return true
   355  	}
   356  
   357  	// If we get here and x or y is a type parameter, they are type parameters
   358  	// from outside our declaration list. Try to unify their core types, if any
   359  	// (see issue #50755 for a test case).
   360  	if enableCoreTypeUnification && !u.exact {
   361  		if isTypeParam(x) && !hasName(y) {
   362  			// When considering the type parameter for unification
   363  			// we look at the adjusted core term (adjusted core type
   364  			// with tilde information).
   365  			// If the adjusted core type is a named type N; the
   366  			// corresponding core type is under(N). Since !u.exact
   367  			// and y doesn't have a name, unification will end up
   368  			// comparing under(N) to y, so we can just use the core
   369  			// type instead. And we can ignore the tilde because we
   370  			// already look at the underlying types on both sides
   371  			// and we have known types on both sides.
   372  			// Optimization.
   373  			if cx := coreType(x); cx != nil {
   374  				if traceInference {
   375  					u.tracef("core %s ≡ %s", x, y)
   376  				}
   377  				return u.nify(cx, y, p)
   378  			}
   379  		} else if isTypeParam(y) && !hasName(x) {
   380  			// see comment above
   381  			if cy := coreType(y); cy != nil {
   382  				if traceInference {
   383  					u.tracef("%s ≡ core %s", x, y)
   384  				}
   385  				return u.nify(x, cy, p)
   386  			}
   387  		}
   388  	}
   389  
   390  	// For type unification, do not shortcut (x == y) for identical
   391  	// types. Instead keep comparing them element-wise to unify the
   392  	// matching (and equal type parameter types). A simple test case
   393  	// where this matters is: func f[P any](a P) { f(a) } .
   394  
   395  	switch x := x.(type) {
   396  	case *Basic:
   397  		// Basic types are singletons except for the rune and byte
   398  		// aliases, thus we cannot solely rely on the x == y check
   399  		// above. See also comment in TypeName.IsAlias.
   400  		if y, ok := y.(*Basic); ok {
   401  			return x.kind == y.kind
   402  		}
   403  
   404  	case *Array:
   405  		// Two array types are identical if they have identical element types
   406  		// and the same array length.
   407  		if y, ok := y.(*Array); ok {
   408  			// If one or both array lengths are unknown (< 0) due to some error,
   409  			// assume they are the same to avoid spurious follow-on errors.
   410  			return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, p)
   411  		}
   412  
   413  	case *Slice:
   414  		// Two slice types are identical if they have identical element types.
   415  		if y, ok := y.(*Slice); ok {
   416  			return u.nify(x.elem, y.elem, p)
   417  		}
   418  
   419  	case *Struct:
   420  		// Two struct types are identical if they have the same sequence of fields,
   421  		// and if corresponding fields have the same names, and identical types,
   422  		// and identical tags. Two embedded fields are considered to have the same
   423  		// name. Lower-case field names from different packages are always different.
   424  		if y, ok := y.(*Struct); ok {
   425  			if x.NumFields() == y.NumFields() {
   426  				for i, f := range x.fields {
   427  					g := y.fields[i]
   428  					if f.embedded != g.embedded ||
   429  						x.Tag(i) != y.Tag(i) ||
   430  						!f.sameId(g.pkg, g.name) ||
   431  						!u.nify(f.typ, g.typ, p) {
   432  						return false
   433  					}
   434  				}
   435  				return true
   436  			}
   437  		}
   438  
   439  	case *Pointer:
   440  		// Two pointer types are identical if they have identical base types.
   441  		if y, ok := y.(*Pointer); ok {
   442  			return u.nify(x.base, y.base, p)
   443  		}
   444  
   445  	case *Tuple:
   446  		// Two tuples types are identical if they have the same number of elements
   447  		// and corresponding elements have identical types.
   448  		if y, ok := y.(*Tuple); ok {
   449  			if x.Len() == y.Len() {
   450  				if x != nil {
   451  					for i, v := range x.vars {
   452  						w := y.vars[i]
   453  						if !u.nify(v.typ, w.typ, p) {
   454  							return false
   455  						}
   456  					}
   457  				}
   458  				return true
   459  			}
   460  		}
   461  
   462  	case *Signature:
   463  		// Two function types are identical if they have the same number of parameters
   464  		// and result values, corresponding parameter and result types are identical,
   465  		// and either both functions are variadic or neither is. Parameter and result
   466  		// names are not required to match.
   467  		// TODO(gri) handle type parameters or document why we can ignore them.
   468  		if y, ok := y.(*Signature); ok {
   469  			return x.variadic == y.variadic &&
   470  				u.nify(x.params, y.params, p) &&
   471  				u.nify(x.results, y.results, p)
   472  		}
   473  
   474  	case *Interface:
   475  		// Two interface types are identical if they have the same set of methods with
   476  		// the same names and identical function types. Lower-case method names from
   477  		// different packages are always different. The order of the methods is irrelevant.
   478  		if y, ok := y.(*Interface); ok {
   479  			xset := x.typeSet()
   480  			yset := y.typeSet()
   481  			if xset.comparable != yset.comparable {
   482  				return false
   483  			}
   484  			if !xset.terms.equal(yset.terms) {
   485  				return false
   486  			}
   487  			a := xset.methods
   488  			b := yset.methods
   489  			if len(a) == len(b) {
   490  				// Interface types are the only types where cycles can occur
   491  				// that are not "terminated" via named types; and such cycles
   492  				// can only be created via method parameter types that are
   493  				// anonymous interfaces (directly or indirectly) embedding
   494  				// the current interface. Example:
   495  				//
   496  				//    type T interface {
   497  				//        m() interface{T}
   498  				//    }
   499  				//
   500  				// If two such (differently named) interfaces are compared,
   501  				// endless recursion occurs if the cycle is not detected.
   502  				//
   503  				// If x and y were compared before, they must be equal
   504  				// (if they were not, the recursion would have stopped);
   505  				// search the ifacePair stack for the same pair.
   506  				//
   507  				// This is a quadratic algorithm, but in practice these stacks
   508  				// are extremely short (bounded by the nesting depth of interface
   509  				// type declarations that recur via parameter types, an extremely
   510  				// rare occurrence). An alternative implementation might use a
   511  				// "visited" map, but that is probably less efficient overall.
   512  				q := &ifacePair{x, y, p}
   513  				for p != nil {
   514  					if p.identical(q) {
   515  						return true // same pair was compared before
   516  					}
   517  					p = p.prev
   518  				}
   519  				if debug {
   520  					assertSortedMethods(a)
   521  					assertSortedMethods(b)
   522  				}
   523  				for i, f := range a {
   524  					g := b[i]
   525  					if f.Id() != g.Id() || !u.nify(f.typ, g.typ, q) {
   526  						return false
   527  					}
   528  				}
   529  				return true
   530  			}
   531  		}
   532  
   533  	case *Map:
   534  		// Two map types are identical if they have identical key and value types.
   535  		if y, ok := y.(*Map); ok {
   536  			return u.nify(x.key, y.key, p) && u.nify(x.elem, y.elem, p)
   537  		}
   538  
   539  	case *Chan:
   540  		// Two channel types are identical if they have identical value types.
   541  		if y, ok := y.(*Chan); ok {
   542  			return (!u.exact || x.dir == y.dir) && u.nify(x.elem, y.elem, p)
   543  		}
   544  
   545  	case *Named:
   546  		// TODO(gri) This code differs now from the parallel code in Checker.identical. Investigate.
   547  		if y, ok := y.(*Named); ok {
   548  			xargs := x.targs.list()
   549  			yargs := y.targs.list()
   550  
   551  			if len(xargs) != len(yargs) {
   552  				return false
   553  			}
   554  
   555  			// TODO(gri) This is not always correct: two types may have the same names
   556  			//           in the same package if one of them is nested in a function.
   557  			//           Extremely unlikely but we need an always correct solution.
   558  			if x.obj.pkg == y.obj.pkg && x.obj.name == y.obj.name {
   559  				for i, x := range xargs {
   560  					if !u.nify(x, yargs[i], p) {
   561  						return false
   562  					}
   563  				}
   564  				return true
   565  			}
   566  		}
   567  
   568  	case *TypeParam:
   569  		// Two type parameters (which are not part of the type parameters of the
   570  		// enclosing type as those are handled in the beginning of this function)
   571  		// are identical if they originate in the same declaration.
   572  		return x == y
   573  
   574  	case nil:
   575  		// avoid a crash in case of nil type
   576  
   577  	default:
   578  		panic(sprintf(nil, true, "u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams))
   579  	}
   580  
   581  	return false
   582  }
   583  

View as plain text