Source file src/cmd/compile/internal/walk/compare.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package walk
     6  
     7  import (
     8  	"go/constant"
     9  
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/ir"
    12  	"cmd/compile/internal/reflectdata"
    13  	"cmd/compile/internal/ssagen"
    14  	"cmd/compile/internal/typecheck"
    15  	"cmd/compile/internal/types"
    16  )
    17  
    18  // The result of walkCompare MUST be assigned back to n, e.g.
    19  // 	n.Left = walkCompare(n.Left, init)
    20  func walkCompare(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
    21  	if n.X.Type().IsInterface() && n.Y.Type().IsInterface() && n.X.Op() != ir.ONIL && n.Y.Op() != ir.ONIL {
    22  		return walkCompareInterface(n, init)
    23  	}
    24  
    25  	if n.X.Type().IsString() && n.Y.Type().IsString() {
    26  		return walkCompareString(n, init)
    27  	}
    28  
    29  	n.X = walkExpr(n.X, init)
    30  	n.Y = walkExpr(n.Y, init)
    31  
    32  	// Given mixed interface/concrete comparison,
    33  	// rewrite into types-equal && data-equal.
    34  	// This is efficient, avoids allocations, and avoids runtime calls.
    35  	if n.X.Type().IsInterface() != n.Y.Type().IsInterface() {
    36  		// Preserve side-effects in case of short-circuiting; see #32187.
    37  		l := cheapExpr(n.X, init)
    38  		r := cheapExpr(n.Y, init)
    39  		// Swap so that l is the interface value and r is the concrete value.
    40  		if n.Y.Type().IsInterface() {
    41  			l, r = r, l
    42  		}
    43  
    44  		// Handle both == and !=.
    45  		eq := n.Op()
    46  		andor := ir.OOROR
    47  		if eq == ir.OEQ {
    48  			andor = ir.OANDAND
    49  		}
    50  		// Check for types equal.
    51  		// For empty interface, this is:
    52  		//   l.tab == type(r)
    53  		// For non-empty interface, this is:
    54  		//   l.tab != nil && l.tab._type == type(r)
    55  		var eqtype ir.Node
    56  		tab := ir.NewUnaryExpr(base.Pos, ir.OITAB, l)
    57  		rtyp := reflectdata.TypePtr(r.Type())
    58  		if l.Type().IsEmptyInterface() {
    59  			tab.SetType(types.NewPtr(types.Types[types.TUINT8]))
    60  			tab.SetTypecheck(1)
    61  			eqtype = ir.NewBinaryExpr(base.Pos, eq, tab, rtyp)
    62  		} else {
    63  			nonnil := ir.NewBinaryExpr(base.Pos, brcom(eq), typecheck.NodNil(), tab)
    64  			match := ir.NewBinaryExpr(base.Pos, eq, itabType(tab), rtyp)
    65  			eqtype = ir.NewLogicalExpr(base.Pos, andor, nonnil, match)
    66  		}
    67  		// Check for data equal.
    68  		eqdata := ir.NewBinaryExpr(base.Pos, eq, ifaceData(n.Pos(), l, r.Type()), r)
    69  		// Put it all together.
    70  		expr := ir.NewLogicalExpr(base.Pos, andor, eqtype, eqdata)
    71  		return finishCompare(n, expr, init)
    72  	}
    73  
    74  	// Must be comparison of array or struct.
    75  	// Otherwise back end handles it.
    76  	// While we're here, decide whether to
    77  	// inline or call an eq alg.
    78  	t := n.X.Type()
    79  	var inline bool
    80  
    81  	maxcmpsize := int64(4)
    82  	unalignedLoad := ssagen.Arch.LinkArch.CanMergeLoads
    83  	if unalignedLoad {
    84  		// Keep this low enough to generate less code than a function call.
    85  		maxcmpsize = 2 * int64(ssagen.Arch.LinkArch.RegSize)
    86  	}
    87  
    88  	switch t.Kind() {
    89  	default:
    90  		if base.Debug.Libfuzzer != 0 && t.IsInteger() {
    91  			n.X = cheapExpr(n.X, init)
    92  			n.Y = cheapExpr(n.Y, init)
    93  
    94  			// If exactly one comparison operand is
    95  			// constant, invoke the constcmp functions
    96  			// instead, and arrange for the constant
    97  			// operand to be the first argument.
    98  			l, r := n.X, n.Y
    99  			if r.Op() == ir.OLITERAL {
   100  				l, r = r, l
   101  			}
   102  			constcmp := l.Op() == ir.OLITERAL && r.Op() != ir.OLITERAL
   103  
   104  			var fn string
   105  			var paramType *types.Type
   106  			switch t.Size() {
   107  			case 1:
   108  				fn = "libfuzzerTraceCmp1"
   109  				if constcmp {
   110  					fn = "libfuzzerTraceConstCmp1"
   111  				}
   112  				paramType = types.Types[types.TUINT8]
   113  			case 2:
   114  				fn = "libfuzzerTraceCmp2"
   115  				if constcmp {
   116  					fn = "libfuzzerTraceConstCmp2"
   117  				}
   118  				paramType = types.Types[types.TUINT16]
   119  			case 4:
   120  				fn = "libfuzzerTraceCmp4"
   121  				if constcmp {
   122  					fn = "libfuzzerTraceConstCmp4"
   123  				}
   124  				paramType = types.Types[types.TUINT32]
   125  			case 8:
   126  				fn = "libfuzzerTraceCmp8"
   127  				if constcmp {
   128  					fn = "libfuzzerTraceConstCmp8"
   129  				}
   130  				paramType = types.Types[types.TUINT64]
   131  			default:
   132  				base.Fatalf("unexpected integer size %d for %v", t.Size(), t)
   133  			}
   134  			init.Append(mkcall(fn, nil, init, tracecmpArg(l, paramType, init), tracecmpArg(r, paramType, init)))
   135  		}
   136  		return n
   137  	case types.TARRAY:
   138  		// We can compare several elements at once with 2/4/8 byte integer compares
   139  		inline = t.NumElem() <= 1 || (types.IsSimple[t.Elem().Kind()] && (t.NumElem() <= 4 || t.Elem().Size()*t.NumElem() <= maxcmpsize))
   140  	case types.TSTRUCT:
   141  		inline = t.NumComponents(types.IgnoreBlankFields) <= 4
   142  	}
   143  
   144  	cmpl := n.X
   145  	for cmpl != nil && cmpl.Op() == ir.OCONVNOP {
   146  		cmpl = cmpl.(*ir.ConvExpr).X
   147  	}
   148  	cmpr := n.Y
   149  	for cmpr != nil && cmpr.Op() == ir.OCONVNOP {
   150  		cmpr = cmpr.(*ir.ConvExpr).X
   151  	}
   152  
   153  	// Chose not to inline. Call equality function directly.
   154  	if !inline {
   155  		// eq algs take pointers; cmpl and cmpr must be addressable
   156  		if !ir.IsAddressable(cmpl) || !ir.IsAddressable(cmpr) {
   157  			base.Fatalf("arguments of comparison must be lvalues - %v %v", cmpl, cmpr)
   158  		}
   159  
   160  		fn, needsize := eqFor(t)
   161  		call := ir.NewCallExpr(base.Pos, ir.OCALL, fn, nil)
   162  		call.Args.Append(typecheck.NodAddr(cmpl))
   163  		call.Args.Append(typecheck.NodAddr(cmpr))
   164  		if needsize {
   165  			call.Args.Append(ir.NewInt(t.Size()))
   166  		}
   167  		res := ir.Node(call)
   168  		if n.Op() != ir.OEQ {
   169  			res = ir.NewUnaryExpr(base.Pos, ir.ONOT, res)
   170  		}
   171  		return finishCompare(n, res, init)
   172  	}
   173  
   174  	// inline: build boolean expression comparing element by element
   175  	andor := ir.OANDAND
   176  	if n.Op() == ir.ONE {
   177  		andor = ir.OOROR
   178  	}
   179  	var expr ir.Node
   180  	compare := func(el, er ir.Node) {
   181  		a := ir.NewBinaryExpr(base.Pos, n.Op(), el, er)
   182  		if expr == nil {
   183  			expr = a
   184  		} else {
   185  			expr = ir.NewLogicalExpr(base.Pos, andor, expr, a)
   186  		}
   187  	}
   188  	cmpl = safeExpr(cmpl, init)
   189  	cmpr = safeExpr(cmpr, init)
   190  	if t.IsStruct() {
   191  		for _, f := range t.Fields().Slice() {
   192  			sym := f.Sym
   193  			if sym.IsBlank() {
   194  				continue
   195  			}
   196  			compare(
   197  				ir.NewSelectorExpr(base.Pos, ir.OXDOT, cmpl, sym),
   198  				ir.NewSelectorExpr(base.Pos, ir.OXDOT, cmpr, sym),
   199  			)
   200  		}
   201  	} else {
   202  		step := int64(1)
   203  		remains := t.NumElem() * t.Elem().Size()
   204  		combine64bit := unalignedLoad && types.RegSize == 8 && t.Elem().Size() <= 4 && t.Elem().IsInteger()
   205  		combine32bit := unalignedLoad && t.Elem().Size() <= 2 && t.Elem().IsInteger()
   206  		combine16bit := unalignedLoad && t.Elem().Size() == 1 && t.Elem().IsInteger()
   207  		for i := int64(0); remains > 0; {
   208  			var convType *types.Type
   209  			switch {
   210  			case remains >= 8 && combine64bit:
   211  				convType = types.Types[types.TINT64]
   212  				step = 8 / t.Elem().Size()
   213  			case remains >= 4 && combine32bit:
   214  				convType = types.Types[types.TUINT32]
   215  				step = 4 / t.Elem().Size()
   216  			case remains >= 2 && combine16bit:
   217  				convType = types.Types[types.TUINT16]
   218  				step = 2 / t.Elem().Size()
   219  			default:
   220  				step = 1
   221  			}
   222  			if step == 1 {
   223  				compare(
   224  					ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(i)),
   225  					ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(i)),
   226  				)
   227  				i++
   228  				remains -= t.Elem().Size()
   229  			} else {
   230  				elemType := t.Elem().ToUnsigned()
   231  				cmplw := ir.Node(ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(i)))
   232  				cmplw = typecheck.Conv(cmplw, elemType) // convert to unsigned
   233  				cmplw = typecheck.Conv(cmplw, convType) // widen
   234  				cmprw := ir.Node(ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(i)))
   235  				cmprw = typecheck.Conv(cmprw, elemType)
   236  				cmprw = typecheck.Conv(cmprw, convType)
   237  				// For code like this:  uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 ...
   238  				// ssa will generate a single large load.
   239  				for offset := int64(1); offset < step; offset++ {
   240  					lb := ir.Node(ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(i+offset)))
   241  					lb = typecheck.Conv(lb, elemType)
   242  					lb = typecheck.Conv(lb, convType)
   243  					lb = ir.NewBinaryExpr(base.Pos, ir.OLSH, lb, ir.NewInt(8*t.Elem().Size()*offset))
   244  					cmplw = ir.NewBinaryExpr(base.Pos, ir.OOR, cmplw, lb)
   245  					rb := ir.Node(ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(i+offset)))
   246  					rb = typecheck.Conv(rb, elemType)
   247  					rb = typecheck.Conv(rb, convType)
   248  					rb = ir.NewBinaryExpr(base.Pos, ir.OLSH, rb, ir.NewInt(8*t.Elem().Size()*offset))
   249  					cmprw = ir.NewBinaryExpr(base.Pos, ir.OOR, cmprw, rb)
   250  				}
   251  				compare(cmplw, cmprw)
   252  				i += step
   253  				remains -= step * t.Elem().Size()
   254  			}
   255  		}
   256  	}
   257  	if expr == nil {
   258  		expr = ir.NewBool(n.Op() == ir.OEQ)
   259  		// We still need to use cmpl and cmpr, in case they contain
   260  		// an expression which might panic. See issue 23837.
   261  		t := typecheck.Temp(cmpl.Type())
   262  		a1 := typecheck.Stmt(ir.NewAssignStmt(base.Pos, t, cmpl))
   263  		a2 := typecheck.Stmt(ir.NewAssignStmt(base.Pos, t, cmpr))
   264  		init.Append(a1, a2)
   265  	}
   266  	return finishCompare(n, expr, init)
   267  }
   268  
   269  func walkCompareInterface(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
   270  	n.Y = cheapExpr(n.Y, init)
   271  	n.X = cheapExpr(n.X, init)
   272  	eqtab, eqdata := reflectdata.EqInterface(n.X, n.Y)
   273  	var cmp ir.Node
   274  	if n.Op() == ir.OEQ {
   275  		cmp = ir.NewLogicalExpr(base.Pos, ir.OANDAND, eqtab, eqdata)
   276  	} else {
   277  		eqtab.SetOp(ir.ONE)
   278  		cmp = ir.NewLogicalExpr(base.Pos, ir.OOROR, eqtab, ir.NewUnaryExpr(base.Pos, ir.ONOT, eqdata))
   279  	}
   280  	return finishCompare(n, cmp, init)
   281  }
   282  
   283  func walkCompareString(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
   284  	// Rewrite comparisons to short constant strings as length+byte-wise comparisons.
   285  	var cs, ncs ir.Node // const string, non-const string
   286  	switch {
   287  	case ir.IsConst(n.X, constant.String) && ir.IsConst(n.Y, constant.String):
   288  		// ignore; will be constant evaluated
   289  	case ir.IsConst(n.X, constant.String):
   290  		cs = n.X
   291  		ncs = n.Y
   292  	case ir.IsConst(n.Y, constant.String):
   293  		cs = n.Y
   294  		ncs = n.X
   295  	}
   296  	if cs != nil {
   297  		cmp := n.Op()
   298  		// Our comparison below assumes that the non-constant string
   299  		// is on the left hand side, so rewrite "" cmp x to x cmp "".
   300  		// See issue 24817.
   301  		if ir.IsConst(n.X, constant.String) {
   302  			cmp = brrev(cmp)
   303  		}
   304  
   305  		// maxRewriteLen was chosen empirically.
   306  		// It is the value that minimizes cmd/go file size
   307  		// across most architectures.
   308  		// See the commit description for CL 26758 for details.
   309  		maxRewriteLen := 6
   310  		// Some architectures can load unaligned byte sequence as 1 word.
   311  		// So we can cover longer strings with the same amount of code.
   312  		canCombineLoads := ssagen.Arch.LinkArch.CanMergeLoads
   313  		combine64bit := false
   314  		if canCombineLoads {
   315  			// Keep this low enough to generate less code than a function call.
   316  			maxRewriteLen = 2 * ssagen.Arch.LinkArch.RegSize
   317  			combine64bit = ssagen.Arch.LinkArch.RegSize >= 8
   318  		}
   319  
   320  		var and ir.Op
   321  		switch cmp {
   322  		case ir.OEQ:
   323  			and = ir.OANDAND
   324  		case ir.ONE:
   325  			and = ir.OOROR
   326  		default:
   327  			// Don't do byte-wise comparisons for <, <=, etc.
   328  			// They're fairly complicated.
   329  			// Length-only checks are ok, though.
   330  			maxRewriteLen = 0
   331  		}
   332  		if s := ir.StringVal(cs); len(s) <= maxRewriteLen {
   333  			if len(s) > 0 {
   334  				ncs = safeExpr(ncs, init)
   335  			}
   336  			r := ir.Node(ir.NewBinaryExpr(base.Pos, cmp, ir.NewUnaryExpr(base.Pos, ir.OLEN, ncs), ir.NewInt(int64(len(s)))))
   337  			remains := len(s)
   338  			for i := 0; remains > 0; {
   339  				if remains == 1 || !canCombineLoads {
   340  					cb := ir.NewInt(int64(s[i]))
   341  					ncb := ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(int64(i)))
   342  					r = ir.NewLogicalExpr(base.Pos, and, r, ir.NewBinaryExpr(base.Pos, cmp, ncb, cb))
   343  					remains--
   344  					i++
   345  					continue
   346  				}
   347  				var step int
   348  				var convType *types.Type
   349  				switch {
   350  				case remains >= 8 && combine64bit:
   351  					convType = types.Types[types.TINT64]
   352  					step = 8
   353  				case remains >= 4:
   354  					convType = types.Types[types.TUINT32]
   355  					step = 4
   356  				case remains >= 2:
   357  					convType = types.Types[types.TUINT16]
   358  					step = 2
   359  				}
   360  				ncsubstr := typecheck.Conv(ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(int64(i))), convType)
   361  				csubstr := int64(s[i])
   362  				// Calculate large constant from bytes as sequence of shifts and ors.
   363  				// Like this:  uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 ...
   364  				// ssa will combine this into a single large load.
   365  				for offset := 1; offset < step; offset++ {
   366  					b := typecheck.Conv(ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(int64(i+offset))), convType)
   367  					b = ir.NewBinaryExpr(base.Pos, ir.OLSH, b, ir.NewInt(int64(8*offset)))
   368  					ncsubstr = ir.NewBinaryExpr(base.Pos, ir.OOR, ncsubstr, b)
   369  					csubstr |= int64(s[i+offset]) << uint8(8*offset)
   370  				}
   371  				csubstrPart := ir.NewInt(csubstr)
   372  				// Compare "step" bytes as once
   373  				r = ir.NewLogicalExpr(base.Pos, and, r, ir.NewBinaryExpr(base.Pos, cmp, csubstrPart, ncsubstr))
   374  				remains -= step
   375  				i += step
   376  			}
   377  			return finishCompare(n, r, init)
   378  		}
   379  	}
   380  
   381  	var r ir.Node
   382  	if n.Op() == ir.OEQ || n.Op() == ir.ONE {
   383  		// prepare for rewrite below
   384  		n.X = cheapExpr(n.X, init)
   385  		n.Y = cheapExpr(n.Y, init)
   386  		eqlen, eqmem := reflectdata.EqString(n.X, n.Y)
   387  		// quick check of len before full compare for == or !=.
   388  		// memequal then tests equality up to length len.
   389  		if n.Op() == ir.OEQ {
   390  			// len(left) == len(right) && memequal(left, right, len)
   391  			r = ir.NewLogicalExpr(base.Pos, ir.OANDAND, eqlen, eqmem)
   392  		} else {
   393  			// len(left) != len(right) || !memequal(left, right, len)
   394  			eqlen.SetOp(ir.ONE)
   395  			r = ir.NewLogicalExpr(base.Pos, ir.OOROR, eqlen, ir.NewUnaryExpr(base.Pos, ir.ONOT, eqmem))
   396  		}
   397  	} else {
   398  		// sys_cmpstring(s1, s2) :: 0
   399  		r = mkcall("cmpstring", types.Types[types.TINT], init, typecheck.Conv(n.X, types.Types[types.TSTRING]), typecheck.Conv(n.Y, types.Types[types.TSTRING]))
   400  		r = ir.NewBinaryExpr(base.Pos, n.Op(), r, ir.NewInt(0))
   401  	}
   402  
   403  	return finishCompare(n, r, init)
   404  }
   405  
   406  // The result of finishCompare MUST be assigned back to n, e.g.
   407  // 	n.Left = finishCompare(n.Left, x, r, init)
   408  func finishCompare(n *ir.BinaryExpr, r ir.Node, init *ir.Nodes) ir.Node {
   409  	r = typecheck.Expr(r)
   410  	r = typecheck.Conv(r, n.Type())
   411  	r = walkExpr(r, init)
   412  	return r
   413  }
   414  
   415  func eqFor(t *types.Type) (n ir.Node, needsize bool) {
   416  	// Should only arrive here with large memory or
   417  	// a struct/array containing a non-memory field/element.
   418  	// Small memory is handled inline, and single non-memory
   419  	// is handled by walkCompare.
   420  	switch a, _ := types.AlgType(t); a {
   421  	case types.AMEM:
   422  		n := typecheck.LookupRuntime("memequal")
   423  		n = typecheck.SubstArgTypes(n, t, t)
   424  		return n, true
   425  	case types.ASPECIAL:
   426  		sym := reflectdata.TypeSymPrefix(".eq", t)
   427  		// TODO(austin): This creates an ir.Name with a nil Func.
   428  		n := typecheck.NewName(sym)
   429  		ir.MarkFunc(n)
   430  		n.SetType(types.NewSignature(types.NoPkg, nil, nil, []*types.Field{
   431  			types.NewField(base.Pos, nil, types.NewPtr(t)),
   432  			types.NewField(base.Pos, nil, types.NewPtr(t)),
   433  		}, []*types.Field{
   434  			types.NewField(base.Pos, nil, types.Types[types.TBOOL]),
   435  		}))
   436  		return n, false
   437  	}
   438  	base.Fatalf("eqFor %v", t)
   439  	return nil, false
   440  }
   441  
   442  // brcom returns !(op).
   443  // For example, brcom(==) is !=.
   444  func brcom(op ir.Op) ir.Op {
   445  	switch op {
   446  	case ir.OEQ:
   447  		return ir.ONE
   448  	case ir.ONE:
   449  		return ir.OEQ
   450  	case ir.OLT:
   451  		return ir.OGE
   452  	case ir.OGT:
   453  		return ir.OLE
   454  	case ir.OLE:
   455  		return ir.OGT
   456  	case ir.OGE:
   457  		return ir.OLT
   458  	}
   459  	base.Fatalf("brcom: no com for %v\n", op)
   460  	return op
   461  }
   462  
   463  // brrev returns reverse(op).
   464  // For example, Brrev(<) is >.
   465  func brrev(op ir.Op) ir.Op {
   466  	switch op {
   467  	case ir.OEQ:
   468  		return ir.OEQ
   469  	case ir.ONE:
   470  		return ir.ONE
   471  	case ir.OLT:
   472  		return ir.OGT
   473  	case ir.OGT:
   474  		return ir.OLT
   475  	case ir.OLE:
   476  		return ir.OGE
   477  	case ir.OGE:
   478  		return ir.OLE
   479  	}
   480  	base.Fatalf("brrev: no rev for %v\n", op)
   481  	return op
   482  }
   483  
   484  func tracecmpArg(n ir.Node, t *types.Type, init *ir.Nodes) ir.Node {
   485  	// Ugly hack to avoid "constant -1 overflows uintptr" errors, etc.
   486  	if n.Op() == ir.OLITERAL && n.Type().IsSigned() && ir.Int64Val(n) < 0 {
   487  		n = copyExpr(n, n.Type(), init)
   488  	}
   489  
   490  	return typecheck.Conv(n, t)
   491  }
   492  

View as plain text