1
2
3
4
5
6
7
8
9
10
11 package main
12
13 import (
14 "bytes"
15 "fmt"
16 "go/format"
17 "io/ioutil"
18 "log"
19 "math/big"
20 "sort"
21 )
22
23 const (
24 maxU64 = (1 << 64) - 1
25 maxU32 = (1 << 32) - 1
26 maxU16 = (1 << 16) - 1
27 maxU8 = (1 << 8) - 1
28
29 maxI64 = (1 << 63) - 1
30 maxI32 = (1 << 31) - 1
31 maxI16 = (1 << 15) - 1
32 maxI8 = (1 << 7) - 1
33
34 minI64 = -(1 << 63)
35 minI32 = -(1 << 31)
36 minI16 = -(1 << 15)
37 minI8 = -(1 << 7)
38 )
39
40 func cmp(left *big.Int, op string, right *big.Int) bool {
41 switch left.Cmp(right) {
42 case -1:
43 return op == "<" || op == "<=" || op == "!="
44 case 0:
45 return op == "==" || op == "<=" || op == ">="
46 case 1:
47 return op == ">" || op == ">=" || op == "!="
48 }
49 panic("unexpected comparison value")
50 }
51
52 func inRange(typ string, val *big.Int) bool {
53 min, max := &big.Int{}, &big.Int{}
54 switch typ {
55 case "uint64":
56 max = max.SetUint64(maxU64)
57 case "uint32":
58 max = max.SetUint64(maxU32)
59 case "uint16":
60 max = max.SetUint64(maxU16)
61 case "uint8":
62 max = max.SetUint64(maxU8)
63 case "int64":
64 min = min.SetInt64(minI64)
65 max = max.SetInt64(maxI64)
66 case "int32":
67 min = min.SetInt64(minI32)
68 max = max.SetInt64(maxI32)
69 case "int16":
70 min = min.SetInt64(minI16)
71 max = max.SetInt64(maxI16)
72 case "int8":
73 min = min.SetInt64(minI8)
74 max = max.SetInt64(maxI8)
75 default:
76 panic("unexpected type")
77 }
78 return cmp(min, "<=", val) && cmp(val, "<=", max)
79 }
80
81 func getValues(typ string) []*big.Int {
82 Uint := func(v uint64) *big.Int { return big.NewInt(0).SetUint64(v) }
83 Int := func(v int64) *big.Int { return big.NewInt(0).SetInt64(v) }
84 values := []*big.Int{
85
86 Uint(maxU64),
87 Uint(maxU64 - 1),
88 Uint(maxI64 + 1),
89 Uint(maxI64),
90 Uint(maxI64 - 1),
91 Uint(maxU32 + 1),
92 Uint(maxU32),
93 Uint(maxU32 - 1),
94 Uint(maxI32 + 1),
95 Uint(maxI32),
96 Uint(maxI32 - 1),
97 Uint(maxU16 + 1),
98 Uint(maxU16),
99 Uint(maxU16 - 1),
100 Uint(maxI16 + 1),
101 Uint(maxI16),
102 Uint(maxI16 - 1),
103 Uint(maxU8 + 1),
104 Uint(maxU8),
105 Uint(maxU8 - 1),
106 Uint(maxI8 + 1),
107 Uint(maxI8),
108 Uint(maxI8 - 1),
109 Uint(0),
110 Int(minI8 + 1),
111 Int(minI8),
112 Int(minI8 - 1),
113 Int(minI16 + 1),
114 Int(minI16),
115 Int(minI16 - 1),
116 Int(minI32 + 1),
117 Int(minI32),
118 Int(minI32 - 1),
119 Int(minI64 + 1),
120 Int(minI64),
121
122
123 Uint(1),
124 Int(-1),
125 Uint(0xff << 56),
126 Uint(0xff << 32),
127 Uint(0xff << 24),
128 }
129 sort.Slice(values, func(i, j int) bool { return values[i].Cmp(values[j]) == -1 })
130 var ret []*big.Int
131 for _, val := range values {
132 if !inRange(typ, val) {
133 continue
134 }
135 ret = append(ret, val)
136 }
137 return ret
138 }
139
140 func sigString(v *big.Int) string {
141 var t big.Int
142 t.Abs(v)
143 if v.Sign() == -1 {
144 return "neg" + t.String()
145 }
146 return t.String()
147 }
148
149 func main() {
150 types := []string{
151 "uint64", "uint32", "uint16", "uint8",
152 "int64", "int32", "int16", "int8",
153 }
154
155 w := new(bytes.Buffer)
156 fmt.Fprintf(w, "// Code generated by gen/cmpConstGen.go. DO NOT EDIT.\n\n")
157 fmt.Fprintf(w, "package main;\n")
158 fmt.Fprintf(w, "import (\"testing\"; \"reflect\"; \"runtime\";)\n")
159 fmt.Fprintf(w, "// results show the expected result for the elements left of, equal to and right of the index.\n")
160 fmt.Fprintf(w, "type result struct{l, e, r bool}\n")
161 fmt.Fprintf(w, "var (\n")
162 fmt.Fprintf(w, " eq = result{l: false, e: true, r: false}\n")
163 fmt.Fprintf(w, " ne = result{l: true, e: false, r: true}\n")
164 fmt.Fprintf(w, " lt = result{l: true, e: false, r: false}\n")
165 fmt.Fprintf(w, " le = result{l: true, e: true, r: false}\n")
166 fmt.Fprintf(w, " gt = result{l: false, e: false, r: true}\n")
167 fmt.Fprintf(w, " ge = result{l: false, e: true, r: true}\n")
168 fmt.Fprintf(w, ")\n")
169
170 operators := []struct{ op, name string }{
171 {"<", "lt"},
172 {"<=", "le"},
173 {">", "gt"},
174 {">=", "ge"},
175 {"==", "eq"},
176 {"!=", "ne"},
177 }
178
179 for _, typ := range types {
180
181 fmt.Fprintf(w, "\n// %v tests\n", typ)
182 values := getValues(typ)
183 fmt.Fprintf(w, "var %v_vals = []%v{\n", typ, typ)
184 for _, val := range values {
185 fmt.Fprintf(w, "%v,\n", val.String())
186 }
187 fmt.Fprintf(w, "}\n")
188
189
190 for _, r := range values {
191
192 sig := sigString(r)
193 for _, op := range operators {
194
195 fmt.Fprintf(w, "func %v_%v_%v(x %v) bool { return x %v %v; }\n", op.name, sig, typ, typ, op.op, r.String())
196 }
197 }
198
199
200 fmt.Fprintf(w, "var %v_tests = []struct{\n", typ)
201 fmt.Fprintf(w, " idx int // index of the constant used\n")
202 fmt.Fprintf(w, " exp result // expected results\n")
203 fmt.Fprintf(w, " fn func(%v) bool\n", typ)
204 fmt.Fprintf(w, "}{\n")
205 for i, r := range values {
206 sig := sigString(r)
207 for _, op := range operators {
208 fmt.Fprintf(w, "{idx: %v,", i)
209 fmt.Fprintf(w, "exp: %v,", op.name)
210 fmt.Fprintf(w, "fn: %v_%v_%v},\n", op.name, sig, typ)
211 }
212 }
213 fmt.Fprintf(w, "}\n")
214 }
215
216
217 fmt.Fprintf(w, "// TestComparisonsConst tests results for comparison operations against constants.\n")
218 fmt.Fprintf(w, "func TestComparisonsConst(t *testing.T) {\n")
219 for _, typ := range types {
220 fmt.Fprintf(w, "for i, test := range %v_tests {\n", typ)
221 fmt.Fprintf(w, " for j, x := range %v_vals {\n", typ)
222 fmt.Fprintf(w, " want := test.exp.l\n")
223 fmt.Fprintf(w, " if j == test.idx {\nwant = test.exp.e\n}")
224 fmt.Fprintf(w, " else if j > test.idx {\nwant = test.exp.r\n}\n")
225 fmt.Fprintf(w, " if test.fn(x) != want {\n")
226 fmt.Fprintf(w, " fn := runtime.FuncForPC(reflect.ValueOf(test.fn).Pointer()).Name()\n")
227 fmt.Fprintf(w, " t.Errorf(\"test failed: %%v(%%v) != %%v [type=%v i=%%v j=%%v idx=%%v]\", fn, x, want, i, j, test.idx)\n", typ)
228 fmt.Fprintf(w, " }\n")
229 fmt.Fprintf(w, " }\n")
230 fmt.Fprintf(w, "}\n")
231 }
232 fmt.Fprintf(w, "}\n")
233
234
235 b := w.Bytes()
236 src, err := format.Source(b)
237 if err != nil {
238 fmt.Printf("%s\n", b)
239 panic(err)
240 }
241
242
243 err = ioutil.WriteFile("../cmpConst_test.go", src, 0666)
244 if err != nil {
245 log.Fatalf("can't write output: %v\n", err)
246 }
247 }
248
View as plain text