Source file src/cmd/compile/internal/test/testdata/gen/constFoldGen.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This program generates a test to verify that the standard arithmetic
     6  // operators properly handle constant folding. The test file should be
     7  // generated with a known working version of go.
     8  // launch with `go run constFoldGen.go` a file called constFold_test.go
     9  // will be written into the grandparent directory containing the tests.
    10  
    11  package main
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"go/format"
    17  	"io/ioutil"
    18  	"log"
    19  )
    20  
    21  type op struct {
    22  	name, symbol string
    23  }
    24  type szD struct {
    25  	name string
    26  	sn   string
    27  	u    []uint64
    28  	i    []int64
    29  }
    30  
    31  var szs []szD = []szD{
    32  	szD{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0xffffFFFFffffFFFF}},
    33  	szD{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF,
    34  		-4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}},
    35  
    36  	szD{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}},
    37  	szD{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0,
    38  		1, 0x7FFFFFFF}},
    39  
    40  	szD{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}},
    41  	szD{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}},
    42  
    43  	szD{name: "uint8", sn: "8", u: []uint64{0, 1, 255}},
    44  	szD{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}},
    45  }
    46  
    47  var ops = []op{
    48  	op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"},
    49  	op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"},
    50  }
    51  
    52  // compute the result of i op j, cast as type t.
    53  func ansU(i, j uint64, t, op string) string {
    54  	var ans uint64
    55  	switch op {
    56  	case "+":
    57  		ans = i + j
    58  	case "-":
    59  		ans = i - j
    60  	case "*":
    61  		ans = i * j
    62  	case "/":
    63  		if j != 0 {
    64  			ans = i / j
    65  		}
    66  	case "%":
    67  		if j != 0 {
    68  			ans = i % j
    69  		}
    70  	case "<<":
    71  		ans = i << j
    72  	case ">>":
    73  		ans = i >> j
    74  	}
    75  	switch t {
    76  	case "uint32":
    77  		ans = uint64(uint32(ans))
    78  	case "uint16":
    79  		ans = uint64(uint16(ans))
    80  	case "uint8":
    81  		ans = uint64(uint8(ans))
    82  	}
    83  	return fmt.Sprintf("%d", ans)
    84  }
    85  
    86  // compute the result of i op j, cast as type t.
    87  func ansS(i, j int64, t, op string) string {
    88  	var ans int64
    89  	switch op {
    90  	case "+":
    91  		ans = i + j
    92  	case "-":
    93  		ans = i - j
    94  	case "*":
    95  		ans = i * j
    96  	case "/":
    97  		if j != 0 {
    98  			ans = i / j
    99  		}
   100  	case "%":
   101  		if j != 0 {
   102  			ans = i % j
   103  		}
   104  	case "<<":
   105  		ans = i << uint64(j)
   106  	case ">>":
   107  		ans = i >> uint64(j)
   108  	}
   109  	switch t {
   110  	case "int32":
   111  		ans = int64(int32(ans))
   112  	case "int16":
   113  		ans = int64(int16(ans))
   114  	case "int8":
   115  		ans = int64(int8(ans))
   116  	}
   117  	return fmt.Sprintf("%d", ans)
   118  }
   119  
   120  func main() {
   121  	w := new(bytes.Buffer)
   122  	fmt.Fprintf(w, "// run\n")
   123  	fmt.Fprintf(w, "// Code generated by gen/constFoldGen.go. DO NOT EDIT.\n\n")
   124  	fmt.Fprintf(w, "package gc\n")
   125  	fmt.Fprintf(w, "import \"testing\"\n")
   126  
   127  	for _, s := range szs {
   128  		for _, o := range ops {
   129  			if o.symbol == "<<" || o.symbol == ">>" {
   130  				// shifts handled separately below, as they can have
   131  				// different types on the LHS and RHS.
   132  				continue
   133  			}
   134  			fmt.Fprintf(w, "func TestConstFold%s%s(t *testing.T) {\n", s.name, o.name)
   135  			fmt.Fprintf(w, "\tvar x, y, r %s\n", s.name)
   136  			// unsigned test cases
   137  			for _, c := range s.u {
   138  				fmt.Fprintf(w, "\tx = %d\n", c)
   139  				for _, d := range s.u {
   140  					if d == 0 && (o.symbol == "/" || o.symbol == "%") {
   141  						continue
   142  					}
   143  					fmt.Fprintf(w, "\ty = %d\n", d)
   144  					fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
   145  					want := ansU(c, d, s.name, o.symbol)
   146  					fmt.Fprintf(w, "\tif r != %s {\n", want)
   147  					fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
   148  					fmt.Fprintf(w, "\t}\n")
   149  				}
   150  			}
   151  			// signed test cases
   152  			for _, c := range s.i {
   153  				fmt.Fprintf(w, "\tx = %d\n", c)
   154  				for _, d := range s.i {
   155  					if d == 0 && (o.symbol == "/" || o.symbol == "%") {
   156  						continue
   157  					}
   158  					fmt.Fprintf(w, "\ty = %d\n", d)
   159  					fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
   160  					want := ansS(c, d, s.name, o.symbol)
   161  					fmt.Fprintf(w, "\tif r != %s {\n", want)
   162  					fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
   163  					fmt.Fprintf(w, "\t}\n")
   164  				}
   165  			}
   166  			fmt.Fprintf(w, "}\n")
   167  		}
   168  	}
   169  
   170  	// Special signed/unsigned cases for shifts
   171  	for _, ls := range szs {
   172  		for _, rs := range szs {
   173  			if rs.name[0] != 'u' {
   174  				continue
   175  			}
   176  			for _, o := range ops {
   177  				if o.symbol != "<<" && o.symbol != ">>" {
   178  					continue
   179  				}
   180  				fmt.Fprintf(w, "func TestConstFold%s%s%s(t *testing.T) {\n", ls.name, rs.name, o.name)
   181  				fmt.Fprintf(w, "\tvar x, r %s\n", ls.name)
   182  				fmt.Fprintf(w, "\tvar y %s\n", rs.name)
   183  				// unsigned LHS
   184  				for _, c := range ls.u {
   185  					fmt.Fprintf(w, "\tx = %d\n", c)
   186  					for _, d := range rs.u {
   187  						fmt.Fprintf(w, "\ty = %d\n", d)
   188  						fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
   189  						want := ansU(c, d, ls.name, o.symbol)
   190  						fmt.Fprintf(w, "\tif r != %s {\n", want)
   191  						fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
   192  						fmt.Fprintf(w, "\t}\n")
   193  					}
   194  				}
   195  				// signed LHS
   196  				for _, c := range ls.i {
   197  					fmt.Fprintf(w, "\tx = %d\n", c)
   198  					for _, d := range rs.u {
   199  						fmt.Fprintf(w, "\ty = %d\n", d)
   200  						fmt.Fprintf(w, "\tr = x %s y\n", o.symbol)
   201  						want := ansS(c, int64(d), ls.name, o.symbol)
   202  						fmt.Fprintf(w, "\tif r != %s {\n", want)
   203  						fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol)
   204  						fmt.Fprintf(w, "\t}\n")
   205  					}
   206  				}
   207  				fmt.Fprintf(w, "}\n")
   208  			}
   209  		}
   210  	}
   211  
   212  	// Constant folding for comparisons
   213  	for _, s := range szs {
   214  		fmt.Fprintf(w, "func TestConstFoldCompare%s(t *testing.T) {\n", s.name)
   215  		for _, x := range s.i {
   216  			for _, y := range s.i {
   217  				fmt.Fprintf(w, "\t{\n")
   218  				fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x)
   219  				fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y)
   220  				if x == y {
   221  					fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n")
   222  				} else {
   223  					fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n")
   224  				}
   225  				if x != y {
   226  					fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n")
   227  				} else {
   228  					fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n")
   229  				}
   230  				if x < y {
   231  					fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n")
   232  				} else {
   233  					fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n")
   234  				}
   235  				if x > y {
   236  					fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n")
   237  				} else {
   238  					fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n")
   239  				}
   240  				if x <= y {
   241  					fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n")
   242  				} else {
   243  					fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n")
   244  				}
   245  				if x >= y {
   246  					fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n")
   247  				} else {
   248  					fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n")
   249  				}
   250  				fmt.Fprintf(w, "\t}\n")
   251  			}
   252  		}
   253  		for _, x := range s.u {
   254  			for _, y := range s.u {
   255  				fmt.Fprintf(w, "\t{\n")
   256  				fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x)
   257  				fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y)
   258  				if x == y {
   259  					fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n")
   260  				} else {
   261  					fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n")
   262  				}
   263  				if x != y {
   264  					fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n")
   265  				} else {
   266  					fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n")
   267  				}
   268  				if x < y {
   269  					fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n")
   270  				} else {
   271  					fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n")
   272  				}
   273  				if x > y {
   274  					fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n")
   275  				} else {
   276  					fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n")
   277  				}
   278  				if x <= y {
   279  					fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n")
   280  				} else {
   281  					fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n")
   282  				}
   283  				if x >= y {
   284  					fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n")
   285  				} else {
   286  					fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n")
   287  				}
   288  				fmt.Fprintf(w, "\t}\n")
   289  			}
   290  		}
   291  		fmt.Fprintf(w, "}\n")
   292  	}
   293  
   294  	// gofmt result
   295  	b := w.Bytes()
   296  	src, err := format.Source(b)
   297  	if err != nil {
   298  		fmt.Printf("%s\n", b)
   299  		panic(err)
   300  	}
   301  
   302  	// write to file
   303  	err = ioutil.WriteFile("../../constFold_test.go", src, 0666)
   304  	if err != nil {
   305  		log.Fatalf("can't write output: %v\n", err)
   306  	}
   307  }
   308  

View as plain text