Source file src/cmd/compile/internal/walk/switch.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package walk
     6  
     7  import (
     8  	"go/constant"
     9  	"go/token"
    10  	"sort"
    11  
    12  	"cmd/compile/internal/base"
    13  	"cmd/compile/internal/ir"
    14  	"cmd/compile/internal/typecheck"
    15  	"cmd/compile/internal/types"
    16  	"cmd/internal/src"
    17  )
    18  
    19  // walkSwitch walks a switch statement.
    20  func walkSwitch(sw *ir.SwitchStmt) {
    21  	// Guard against double walk, see #25776.
    22  	if sw.Walked() {
    23  		return // Was fatal, but eliminating every possible source of double-walking is hard
    24  	}
    25  	sw.SetWalked(true)
    26  
    27  	if sw.Tag != nil && sw.Tag.Op() == ir.OTYPESW {
    28  		walkSwitchType(sw)
    29  	} else {
    30  		walkSwitchExpr(sw)
    31  	}
    32  }
    33  
    34  // walkSwitchExpr generates an AST implementing sw.  sw is an
    35  // expression switch.
    36  func walkSwitchExpr(sw *ir.SwitchStmt) {
    37  	lno := ir.SetPos(sw)
    38  
    39  	cond := sw.Tag
    40  	sw.Tag = nil
    41  
    42  	// convert switch {...} to switch true {...}
    43  	if cond == nil {
    44  		cond = ir.NewBool(true)
    45  		cond = typecheck.Expr(cond)
    46  		cond = typecheck.DefaultLit(cond, nil)
    47  	}
    48  
    49  	// Given "switch string(byteslice)",
    50  	// with all cases being side-effect free,
    51  	// use a zero-cost alias of the byte slice.
    52  	// Do this before calling walkExpr on cond,
    53  	// because walkExpr will lower the string
    54  	// conversion into a runtime call.
    55  	// See issue 24937 for more discussion.
    56  	if cond.Op() == ir.OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
    57  		cond := cond.(*ir.ConvExpr)
    58  		cond.SetOp(ir.OBYTES2STRTMP)
    59  	}
    60  
    61  	cond = walkExpr(cond, sw.PtrInit())
    62  	if cond.Op() != ir.OLITERAL && cond.Op() != ir.ONIL {
    63  		cond = copyExpr(cond, cond.Type(), &sw.Compiled)
    64  	}
    65  
    66  	base.Pos = lno
    67  
    68  	s := exprSwitch{
    69  		exprname: cond,
    70  	}
    71  
    72  	var defaultGoto ir.Node
    73  	var body ir.Nodes
    74  	for _, ncase := range sw.Cases {
    75  		label := typecheck.AutoLabel(".s")
    76  		jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
    77  
    78  		// Process case dispatch.
    79  		if len(ncase.List) == 0 {
    80  			if defaultGoto != nil {
    81  				base.Fatalf("duplicate default case not detected during typechecking")
    82  			}
    83  			defaultGoto = jmp
    84  		}
    85  
    86  		for _, n1 := range ncase.List {
    87  			s.Add(ncase.Pos(), n1, jmp)
    88  		}
    89  
    90  		// Process body.
    91  		body.Append(ir.NewLabelStmt(ncase.Pos(), label))
    92  		body.Append(ncase.Body...)
    93  		if fall, pos := endsInFallthrough(ncase.Body); !fall {
    94  			br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
    95  			br.SetPos(pos)
    96  			body.Append(br)
    97  		}
    98  	}
    99  	sw.Cases = nil
   100  
   101  	if defaultGoto == nil {
   102  		br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
   103  		br.SetPos(br.Pos().WithNotStmt())
   104  		defaultGoto = br
   105  	}
   106  
   107  	s.Emit(&sw.Compiled)
   108  	sw.Compiled.Append(defaultGoto)
   109  	sw.Compiled.Append(body.Take()...)
   110  	walkStmtList(sw.Compiled)
   111  }
   112  
   113  // An exprSwitch walks an expression switch.
   114  type exprSwitch struct {
   115  	exprname ir.Node // value being switched on
   116  
   117  	done    ir.Nodes
   118  	clauses []exprClause
   119  }
   120  
   121  type exprClause struct {
   122  	pos    src.XPos
   123  	lo, hi ir.Node
   124  	jmp    ir.Node
   125  }
   126  
   127  func (s *exprSwitch) Add(pos src.XPos, expr, jmp ir.Node) {
   128  	c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
   129  	if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL {
   130  		s.clauses = append(s.clauses, c)
   131  		return
   132  	}
   133  
   134  	s.flush()
   135  	s.clauses = append(s.clauses, c)
   136  	s.flush()
   137  }
   138  
   139  func (s *exprSwitch) Emit(out *ir.Nodes) {
   140  	s.flush()
   141  	out.Append(s.done.Take()...)
   142  }
   143  
   144  func (s *exprSwitch) flush() {
   145  	cc := s.clauses
   146  	s.clauses = nil
   147  	if len(cc) == 0 {
   148  		return
   149  	}
   150  
   151  	// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
   152  	// The code below is structured to implicitly handle this case
   153  	// (e.g., sort.Slice doesn't need to invoke the less function
   154  	// when there's only a single slice element).
   155  
   156  	if s.exprname.Type().IsString() && len(cc) >= 2 {
   157  		// Sort strings by length and then by value. It is
   158  		// much cheaper to compare lengths than values, and
   159  		// all we need here is consistency. We respect this
   160  		// sorting below.
   161  		sort.Slice(cc, func(i, j int) bool {
   162  			si := ir.StringVal(cc[i].lo)
   163  			sj := ir.StringVal(cc[j].lo)
   164  			if len(si) != len(sj) {
   165  				return len(si) < len(sj)
   166  			}
   167  			return si < sj
   168  		})
   169  
   170  		// runLen returns the string length associated with a
   171  		// particular run of exprClauses.
   172  		runLen := func(run []exprClause) int64 { return int64(len(ir.StringVal(run[0].lo))) }
   173  
   174  		// Collapse runs of consecutive strings with the same length.
   175  		var runs [][]exprClause
   176  		start := 0
   177  		for i := 1; i < len(cc); i++ {
   178  			if runLen(cc[start:]) != runLen(cc[i:]) {
   179  				runs = append(runs, cc[start:i])
   180  				start = i
   181  			}
   182  		}
   183  		runs = append(runs, cc[start:])
   184  
   185  		// Perform two-level binary search.
   186  		binarySearch(len(runs), &s.done,
   187  			func(i int) ir.Node {
   188  				return ir.NewBinaryExpr(base.Pos, ir.OLE, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(runs[i-1])))
   189  			},
   190  			func(i int, nif *ir.IfStmt) {
   191  				run := runs[i]
   192  				nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(run)))
   193  				s.search(run, &nif.Body)
   194  			},
   195  		)
   196  		return
   197  	}
   198  
   199  	sort.Slice(cc, func(i, j int) bool {
   200  		return constant.Compare(cc[i].lo.Val(), token.LSS, cc[j].lo.Val())
   201  	})
   202  
   203  	// Merge consecutive integer cases.
   204  	if s.exprname.Type().IsInteger() {
   205  		consecutive := func(last, next constant.Value) bool {
   206  			delta := constant.BinaryOp(next, token.SUB, last)
   207  			return constant.Compare(delta, token.EQL, constant.MakeInt64(1))
   208  		}
   209  
   210  		merged := cc[:1]
   211  		for _, c := range cc[1:] {
   212  			last := &merged[len(merged)-1]
   213  			if last.jmp == c.jmp && consecutive(last.hi.Val(), c.lo.Val()) {
   214  				last.hi = c.lo
   215  			} else {
   216  				merged = append(merged, c)
   217  			}
   218  		}
   219  		cc = merged
   220  	}
   221  
   222  	s.search(cc, &s.done)
   223  }
   224  
   225  func (s *exprSwitch) search(cc []exprClause, out *ir.Nodes) {
   226  	binarySearch(len(cc), out,
   227  		func(i int) ir.Node {
   228  			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.exprname, cc[i-1].hi)
   229  		},
   230  		func(i int, nif *ir.IfStmt) {
   231  			c := &cc[i]
   232  			nif.Cond = c.test(s.exprname)
   233  			nif.Body = []ir.Node{c.jmp}
   234  		},
   235  	)
   236  }
   237  
   238  func (c *exprClause) test(exprname ir.Node) ir.Node {
   239  	// Integer range.
   240  	if c.hi != c.lo {
   241  		low := ir.NewBinaryExpr(c.pos, ir.OGE, exprname, c.lo)
   242  		high := ir.NewBinaryExpr(c.pos, ir.OLE, exprname, c.hi)
   243  		return ir.NewLogicalExpr(c.pos, ir.OANDAND, low, high)
   244  	}
   245  
   246  	// Optimize "switch true { ...}" and "switch false { ... }".
   247  	if ir.IsConst(exprname, constant.Bool) && !c.lo.Type().IsInterface() {
   248  		if ir.BoolVal(exprname) {
   249  			return c.lo
   250  		} else {
   251  			return ir.NewUnaryExpr(c.pos, ir.ONOT, c.lo)
   252  		}
   253  	}
   254  
   255  	return ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo)
   256  }
   257  
   258  func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
   259  	// In theory, we could be more aggressive, allowing any
   260  	// side-effect-free expressions in cases, but it's a bit
   261  	// tricky because some of that information is unavailable due
   262  	// to the introduction of temporaries during order.
   263  	// Restricting to constants is simple and probably powerful
   264  	// enough.
   265  
   266  	for _, ncase := range sw.Cases {
   267  		for _, v := range ncase.List {
   268  			if v.Op() != ir.OLITERAL {
   269  				return false
   270  			}
   271  		}
   272  	}
   273  	return true
   274  }
   275  
   276  // endsInFallthrough reports whether stmts ends with a "fallthrough" statement.
   277  func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
   278  	// Search backwards for the index of the fallthrough
   279  	// statement. Do not assume it'll be in the last
   280  	// position, since in some cases (e.g. when the statement
   281  	// list contains autotmp_ variables), one or more OVARKILL
   282  	// nodes will be at the end of the list.
   283  
   284  	i := len(stmts) - 1
   285  	for i >= 0 && stmts[i].Op() == ir.OVARKILL {
   286  		i--
   287  	}
   288  	if i < 0 {
   289  		return false, src.NoXPos
   290  	}
   291  	return stmts[i].Op() == ir.OFALL, stmts[i].Pos()
   292  }
   293  
   294  // walkSwitchType generates an AST that implements sw, where sw is a
   295  // type switch.
   296  func walkSwitchType(sw *ir.SwitchStmt) {
   297  	var s typeSwitch
   298  	s.facename = sw.Tag.(*ir.TypeSwitchGuard).X
   299  	sw.Tag = nil
   300  
   301  	s.facename = walkExpr(s.facename, sw.PtrInit())
   302  	s.facename = copyExpr(s.facename, s.facename.Type(), &sw.Compiled)
   303  	s.okname = typecheck.Temp(types.Types[types.TBOOL])
   304  
   305  	// Get interface descriptor word.
   306  	// For empty interfaces this will be the type.
   307  	// For non-empty interfaces this will be the itab.
   308  	itab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.facename)
   309  
   310  	// For empty interfaces, do:
   311  	//     if e._type == nil {
   312  	//         do nil case if it exists, otherwise default
   313  	//     }
   314  	//     h := e._type.hash
   315  	// Use a similar strategy for non-empty interfaces.
   316  	ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
   317  	ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, itab, typecheck.NodNil())
   318  	base.Pos = base.Pos.WithNotStmt() // disable statement marks after the first check.
   319  	ifNil.Cond = typecheck.Expr(ifNil.Cond)
   320  	ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
   321  	// ifNil.Nbody assigned at end.
   322  	sw.Compiled.Append(ifNil)
   323  
   324  	// Load hash from type or itab.
   325  	dotHash := typeHashFieldOf(base.Pos, itab)
   326  	s.hashname = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
   327  
   328  	br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
   329  	var defaultGoto, nilGoto ir.Node
   330  	var body ir.Nodes
   331  	for _, ncase := range sw.Cases {
   332  		caseVar := ncase.Var
   333  
   334  		// For single-type cases with an interface type,
   335  		// we initialize the case variable as part of the type assertion.
   336  		// In other cases, we initialize it in the body.
   337  		var singleType *types.Type
   338  		if len(ncase.List) == 1 && ncase.List[0].Op() == ir.OTYPE {
   339  			singleType = ncase.List[0].Type()
   340  		}
   341  		caseVarInitialized := false
   342  
   343  		label := typecheck.AutoLabel(".s")
   344  		jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
   345  
   346  		if len(ncase.List) == 0 { // default:
   347  			if defaultGoto != nil {
   348  				base.Fatalf("duplicate default case not detected during typechecking")
   349  			}
   350  			defaultGoto = jmp
   351  		}
   352  
   353  		for _, n1 := range ncase.List {
   354  			if ir.IsNil(n1) { // case nil:
   355  				if nilGoto != nil {
   356  					base.Fatalf("duplicate nil case not detected during typechecking")
   357  				}
   358  				nilGoto = jmp
   359  				continue
   360  			}
   361  
   362  			if singleType != nil && singleType.IsInterface() {
   363  				s.Add(ncase.Pos(), n1, caseVar, jmp)
   364  				caseVarInitialized = true
   365  			} else {
   366  				s.Add(ncase.Pos(), n1, nil, jmp)
   367  			}
   368  		}
   369  
   370  		body.Append(ir.NewLabelStmt(ncase.Pos(), label))
   371  		if caseVar != nil && !caseVarInitialized {
   372  			val := s.facename
   373  			if singleType != nil {
   374  				// We have a single concrete type. Extract the data.
   375  				if singleType.IsInterface() {
   376  					base.Fatalf("singleType interface should have been handled in Add")
   377  				}
   378  				val = ifaceData(ncase.Pos(), s.facename, singleType)
   379  			}
   380  			if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE {
   381  				dt := ncase.List[0].(*ir.DynamicType)
   382  				x := ir.NewDynamicTypeAssertExpr(ncase.Pos(), ir.ODYNAMICDOTTYPE, val, dt.X)
   383  				if dt.ITab != nil {
   384  					// TODO: make ITab a separate field in DynamicTypeAssertExpr?
   385  					x.T = dt.ITab
   386  				}
   387  				x.SetType(caseVar.Type())
   388  				x.SetTypecheck(1)
   389  				val = x
   390  			}
   391  			l := []ir.Node{
   392  				ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
   393  				ir.NewAssignStmt(ncase.Pos(), caseVar, val),
   394  			}
   395  			typecheck.Stmts(l)
   396  			body.Append(l...)
   397  		}
   398  		body.Append(ncase.Body...)
   399  		body.Append(br)
   400  	}
   401  	sw.Cases = nil
   402  
   403  	if defaultGoto == nil {
   404  		defaultGoto = br
   405  	}
   406  	if nilGoto == nil {
   407  		nilGoto = defaultGoto
   408  	}
   409  	ifNil.Body = []ir.Node{nilGoto}
   410  
   411  	s.Emit(&sw.Compiled)
   412  	sw.Compiled.Append(defaultGoto)
   413  	sw.Compiled.Append(body.Take()...)
   414  
   415  	walkStmtList(sw.Compiled)
   416  }
   417  
   418  // typeHashFieldOf returns an expression to select the type hash field
   419  // from an interface's descriptor word (whether a *runtime._type or
   420  // *runtime.itab pointer).
   421  func typeHashFieldOf(pos src.XPos, itab *ir.UnaryExpr) *ir.SelectorExpr {
   422  	if itab.Op() != ir.OITAB {
   423  		base.Fatalf("expected OITAB, got %v", itab.Op())
   424  	}
   425  	var hashField *types.Field
   426  	if itab.X.Type().IsEmptyInterface() {
   427  		// runtime._type's hash field
   428  		if rtypeHashField == nil {
   429  			rtypeHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
   430  		}
   431  		hashField = rtypeHashField
   432  	} else {
   433  		// runtime.itab's hash field
   434  		if itabHashField == nil {
   435  			itabHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
   436  		}
   437  		hashField = itabHashField
   438  	}
   439  	return boundedDotPtr(pos, itab, hashField)
   440  }
   441  
   442  var rtypeHashField, itabHashField *types.Field
   443  
   444  // A typeSwitch walks a type switch.
   445  type typeSwitch struct {
   446  	// Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
   447  	facename ir.Node // value being type-switched on
   448  	hashname ir.Node // type hash of the value being type-switched on
   449  	okname   ir.Node // boolean used for comma-ok type assertions
   450  
   451  	done    ir.Nodes
   452  	clauses []typeClause
   453  }
   454  
   455  type typeClause struct {
   456  	hash uint32
   457  	body ir.Nodes
   458  }
   459  
   460  func (s *typeSwitch) Add(pos src.XPos, n1 ir.Node, caseVar *ir.Name, jmp ir.Node) {
   461  	typ := n1.Type()
   462  	var body ir.Nodes
   463  	if caseVar != nil {
   464  		l := []ir.Node{
   465  			ir.NewDecl(pos, ir.ODCL, caseVar),
   466  			ir.NewAssignStmt(pos, caseVar, nil),
   467  		}
   468  		typecheck.Stmts(l)
   469  		body.Append(l...)
   470  	} else {
   471  		caseVar = ir.BlankNode.(*ir.Name)
   472  	}
   473  
   474  	// cv, ok = iface.(type)
   475  	as := ir.NewAssignListStmt(pos, ir.OAS2, nil, nil)
   476  	as.Lhs = []ir.Node{caseVar, s.okname} // cv, ok =
   477  	switch n1.Op() {
   478  	case ir.OTYPE:
   479  		// Static type assertion (non-generic)
   480  		dot := ir.NewTypeAssertExpr(pos, s.facename, nil)
   481  		dot.SetType(typ) // iface.(type)
   482  		as.Rhs = []ir.Node{dot}
   483  	case ir.ODYNAMICTYPE:
   484  		// Dynamic type assertion (generic)
   485  		dt := n1.(*ir.DynamicType)
   486  		dot := ir.NewDynamicTypeAssertExpr(pos, ir.ODYNAMICDOTTYPE, s.facename, dt.X)
   487  		if dt.ITab != nil {
   488  			dot.T = dt.ITab
   489  		}
   490  		dot.SetType(typ)
   491  		dot.SetTypecheck(1)
   492  		as.Rhs = []ir.Node{dot}
   493  	default:
   494  		base.Fatalf("unhandled type case %s", n1.Op())
   495  	}
   496  	appendWalkStmt(&body, as)
   497  
   498  	// if ok { goto label }
   499  	nif := ir.NewIfStmt(pos, nil, nil, nil)
   500  	nif.Cond = s.okname
   501  	nif.Body = []ir.Node{jmp}
   502  	body.Append(nif)
   503  
   504  	if n1.Op() == ir.OTYPE && !typ.IsInterface() {
   505  		// Defer static, noninterface cases so they can be binary searched by hash.
   506  		s.clauses = append(s.clauses, typeClause{
   507  			hash: types.TypeHash(n1.Type()),
   508  			body: body,
   509  		})
   510  		return
   511  	}
   512  
   513  	s.flush()
   514  	s.done.Append(body.Take()...)
   515  }
   516  
   517  func (s *typeSwitch) Emit(out *ir.Nodes) {
   518  	s.flush()
   519  	out.Append(s.done.Take()...)
   520  }
   521  
   522  func (s *typeSwitch) flush() {
   523  	cc := s.clauses
   524  	s.clauses = nil
   525  	if len(cc) == 0 {
   526  		return
   527  	}
   528  
   529  	sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
   530  
   531  	// Combine adjacent cases with the same hash.
   532  	merged := cc[:1]
   533  	for _, c := range cc[1:] {
   534  		last := &merged[len(merged)-1]
   535  		if last.hash == c.hash {
   536  			last.body.Append(c.body.Take()...)
   537  		} else {
   538  			merged = append(merged, c)
   539  		}
   540  	}
   541  	cc = merged
   542  
   543  	binarySearch(len(cc), &s.done,
   544  		func(i int) ir.Node {
   545  			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(int64(cc[i-1].hash)))
   546  		},
   547  		func(i int, nif *ir.IfStmt) {
   548  			// TODO(mdempsky): Omit hash equality check if
   549  			// there's only one type.
   550  			c := cc[i]
   551  			nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashname, ir.NewInt(int64(c.hash)))
   552  			nif.Body.Append(c.body.Take()...)
   553  		},
   554  	)
   555  }
   556  
   557  // binarySearch constructs a binary search tree for handling n cases,
   558  // and appends it to out. It's used for efficiently implementing
   559  // switch statements.
   560  //
   561  // less(i) should return a boolean expression. If it evaluates true,
   562  // then cases before i will be tested; otherwise, cases i and later.
   563  //
   564  // leaf(i, nif) should setup nif (an OIF node) to test case i. In
   565  // particular, it should set nif.Left and nif.Nbody.
   566  func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i int, nif *ir.IfStmt)) {
   567  	const binarySearchMin = 4 // minimum number of cases for binary search
   568  
   569  	var do func(lo, hi int, out *ir.Nodes)
   570  	do = func(lo, hi int, out *ir.Nodes) {
   571  		n := hi - lo
   572  		if n < binarySearchMin {
   573  			for i := lo; i < hi; i++ {
   574  				nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
   575  				leaf(i, nif)
   576  				base.Pos = base.Pos.WithNotStmt()
   577  				nif.Cond = typecheck.Expr(nif.Cond)
   578  				nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
   579  				out.Append(nif)
   580  				out = &nif.Else
   581  			}
   582  			return
   583  		}
   584  
   585  		half := lo + n/2
   586  		nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
   587  		nif.Cond = less(half)
   588  		base.Pos = base.Pos.WithNotStmt()
   589  		nif.Cond = typecheck.Expr(nif.Cond)
   590  		nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
   591  		do(lo, half, &nif.Body)
   592  		do(half, hi, &nif.Else)
   593  		out.Append(nif)
   594  	}
   595  
   596  	do(0, n, out)
   597  }
   598  

View as plain text