1
2
3
4
5 package fuzz
6
7 import (
8 "bytes"
9 "fmt"
10 "go/ast"
11 "go/parser"
12 "go/token"
13 "math"
14 "strconv"
15 "unicode/utf8"
16 )
17
18
19 var encVersion1 = "go test fuzz v1"
20
21
22
23 func marshalCorpusFile(vals ...any) []byte {
24 if len(vals) == 0 {
25 panic("must have at least one value to marshal")
26 }
27 b := bytes.NewBuffer([]byte(encVersion1 + "\n"))
28
29
30 for _, val := range vals {
31 switch t := val.(type) {
32 case int, int8, int16, int64, uint, uint16, uint32, uint64, bool:
33 fmt.Fprintf(b, "%T(%v)\n", t, t)
34 case float32:
35 if math.IsNaN(float64(t)) && math.Float32bits(t) != math.Float32bits(float32(math.NaN())) {
36
37
38
39
40
41
42
43
44
45
46
47
48 fmt.Fprintf(b, "math.Float32frombits(0x%x)\n", math.Float32bits(t))
49 } else {
50
51
52
53
54
55
56
57
58
59 fmt.Fprintf(b, "%T(%v)\n", t, t)
60 }
61 case float64:
62 if math.IsNaN(t) && math.Float64bits(t) != math.Float64bits(math.NaN()) {
63 fmt.Fprintf(b, "math.Float64frombits(0x%x)\n", math.Float64bits(t))
64 } else {
65 fmt.Fprintf(b, "%T(%v)\n", t, t)
66 }
67 case string:
68 fmt.Fprintf(b, "string(%q)\n", t)
69 case rune:
70
71
72
73
74
75
76
77
78
79
80
81
82 if utf8.ValidRune(t) {
83 fmt.Fprintf(b, "rune(%q)\n", t)
84 } else {
85 fmt.Fprintf(b, "int32(%v)\n", t)
86 }
87 case byte:
88
89
90 fmt.Fprintf(b, "byte(%q)\n", t)
91 case []byte:
92 fmt.Fprintf(b, "[]byte(%q)\n", t)
93 default:
94 panic(fmt.Sprintf("unsupported type: %T", t))
95 }
96 }
97 return b.Bytes()
98 }
99
100
101 func unmarshalCorpusFile(b []byte) ([]any, error) {
102 if len(b) == 0 {
103 return nil, fmt.Errorf("cannot unmarshal empty string")
104 }
105 lines := bytes.Split(b, []byte("\n"))
106 if len(lines) < 2 {
107 return nil, fmt.Errorf("must include version and at least one value")
108 }
109 if string(lines[0]) != encVersion1 {
110 return nil, fmt.Errorf("unknown encoding version: %s", lines[0])
111 }
112 var vals []any
113 for _, line := range lines[1:] {
114 line = bytes.TrimSpace(line)
115 if len(line) == 0 {
116 continue
117 }
118 v, err := parseCorpusValue(line)
119 if err != nil {
120 return nil, fmt.Errorf("malformed line %q: %v", line, err)
121 }
122 vals = append(vals, v)
123 }
124 return vals, nil
125 }
126
127 func parseCorpusValue(line []byte) (any, error) {
128 fs := token.NewFileSet()
129 expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
130 if err != nil {
131 return nil, err
132 }
133 call, ok := expr.(*ast.CallExpr)
134 if !ok {
135 return nil, fmt.Errorf("expected call expression")
136 }
137 if len(call.Args) != 1 {
138 return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
139 }
140 arg := call.Args[0]
141
142 if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
143 if arrayType.Len != nil {
144 return nil, fmt.Errorf("expected []byte or primitive type")
145 }
146 elt, ok := arrayType.Elt.(*ast.Ident)
147 if !ok || elt.Name != "byte" {
148 return nil, fmt.Errorf("expected []byte")
149 }
150 lit, ok := arg.(*ast.BasicLit)
151 if !ok || lit.Kind != token.STRING {
152 return nil, fmt.Errorf("string literal required for type []byte")
153 }
154 s, err := strconv.Unquote(lit.Value)
155 if err != nil {
156 return nil, err
157 }
158 return []byte(s), nil
159 }
160
161 var idType *ast.Ident
162 if selector, ok := call.Fun.(*ast.SelectorExpr); ok {
163 xIdent, ok := selector.X.(*ast.Ident)
164 if !ok || xIdent.Name != "math" {
165 return nil, fmt.Errorf("invalid selector type")
166 }
167 switch selector.Sel.Name {
168 case "Float64frombits":
169 idType = &ast.Ident{Name: "float64-bits"}
170 case "Float32frombits":
171 idType = &ast.Ident{Name: "float32-bits"}
172 default:
173 return nil, fmt.Errorf("invalid selector type")
174 }
175 } else {
176 idType, ok = call.Fun.(*ast.Ident)
177 if !ok {
178 return nil, fmt.Errorf("expected []byte or primitive type")
179 }
180 if idType.Name == "bool" {
181 id, ok := arg.(*ast.Ident)
182 if !ok {
183 return nil, fmt.Errorf("malformed bool")
184 }
185 if id.Name == "true" {
186 return true, nil
187 } else if id.Name == "false" {
188 return false, nil
189 } else {
190 return nil, fmt.Errorf("true or false required for type bool")
191 }
192 }
193 }
194
195 var (
196 val string
197 kind token.Token
198 )
199 if op, ok := arg.(*ast.UnaryExpr); ok {
200 switch lit := op.X.(type) {
201 case *ast.BasicLit:
202 if op.Op != token.SUB {
203 return nil, fmt.Errorf("unsupported operation on int/float: %v", op.Op)
204 }
205
206 val = op.Op.String() + lit.Value
207 kind = lit.Kind
208 case *ast.Ident:
209 if lit.Name != "Inf" {
210 return nil, fmt.Errorf("expected operation on int or float type")
211 }
212 if op.Op == token.SUB {
213 val = "-Inf"
214 } else {
215 val = "+Inf"
216 }
217 kind = token.FLOAT
218 default:
219 return nil, fmt.Errorf("expected operation on int or float type")
220 }
221 } else {
222 switch lit := arg.(type) {
223 case *ast.BasicLit:
224 val, kind = lit.Value, lit.Kind
225 case *ast.Ident:
226 if lit.Name != "NaN" {
227 return nil, fmt.Errorf("literal value required for primitive type")
228 }
229 val, kind = "NaN", token.FLOAT
230 default:
231 return nil, fmt.Errorf("literal value required for primitive type")
232 }
233 }
234
235 switch typ := idType.Name; typ {
236 case "string":
237 if kind != token.STRING {
238 return nil, fmt.Errorf("string literal value required for type string")
239 }
240 return strconv.Unquote(val)
241 case "byte", "rune":
242 if kind == token.INT {
243 switch typ {
244 case "rune":
245 return parseInt(val, typ)
246 case "byte":
247 return parseUint(val, typ)
248 }
249 }
250 if kind != token.CHAR {
251 return nil, fmt.Errorf("character literal required for byte/rune types")
252 }
253 n := len(val)
254 if n < 2 {
255 return nil, fmt.Errorf("malformed character literal, missing single quotes")
256 }
257 code, _, _, err := strconv.UnquoteChar(val[1:n-1], '\'')
258 if err != nil {
259 return nil, err
260 }
261 if typ == "rune" {
262 return code, nil
263 }
264 if code >= 256 {
265 return nil, fmt.Errorf("can only encode single byte to a byte type")
266 }
267 return byte(code), nil
268 case "int", "int8", "int16", "int32", "int64":
269 if kind != token.INT {
270 return nil, fmt.Errorf("integer literal required for int types")
271 }
272 return parseInt(val, typ)
273 case "uint", "uint8", "uint16", "uint32", "uint64":
274 if kind != token.INT {
275 return nil, fmt.Errorf("integer literal required for uint types")
276 }
277 return parseUint(val, typ)
278 case "float32":
279 if kind != token.FLOAT && kind != token.INT {
280 return nil, fmt.Errorf("float or integer literal required for float32 type")
281 }
282 v, err := strconv.ParseFloat(val, 32)
283 return float32(v), err
284 case "float64":
285 if kind != token.FLOAT && kind != token.INT {
286 return nil, fmt.Errorf("float or integer literal required for float64 type")
287 }
288 return strconv.ParseFloat(val, 64)
289 case "float32-bits":
290 if kind != token.INT {
291 return nil, fmt.Errorf("integer literal required for math.Float32frombits type")
292 }
293 bits, err := parseUint(val, "uint32")
294 if err != nil {
295 return nil, err
296 }
297 return math.Float32frombits(bits.(uint32)), nil
298 case "float64-bits":
299 if kind != token.FLOAT && kind != token.INT {
300 return nil, fmt.Errorf("integer literal required for math.Float64frombits type")
301 }
302 bits, err := parseUint(val, "uint64")
303 if err != nil {
304 return nil, err
305 }
306 return math.Float64frombits(bits.(uint64)), nil
307 default:
308 return nil, fmt.Errorf("expected []byte or primitive type")
309 }
310 }
311
312
313 func parseInt(val, typ string) (any, error) {
314 switch typ {
315 case "int":
316
317
318
319
320
321 i, err := strconv.ParseInt(val, 0, 64)
322 return int(i), err
323 case "int8":
324 i, err := strconv.ParseInt(val, 0, 8)
325 return int8(i), err
326 case "int16":
327 i, err := strconv.ParseInt(val, 0, 16)
328 return int16(i), err
329 case "int32", "rune":
330 i, err := strconv.ParseInt(val, 0, 32)
331 return int32(i), err
332 case "int64":
333 return strconv.ParseInt(val, 0, 64)
334 default:
335 panic("unreachable")
336 }
337 }
338
339
340 func parseUint(val, typ string) (any, error) {
341 switch typ {
342 case "uint":
343 i, err := strconv.ParseUint(val, 0, 64)
344 return uint(i), err
345 case "uint8", "byte":
346 i, err := strconv.ParseUint(val, 0, 8)
347 return uint8(i), err
348 case "uint16":
349 i, err := strconv.ParseUint(val, 0, 16)
350 return uint16(i), err
351 case "uint32":
352 i, err := strconv.ParseUint(val, 0, 32)
353 return uint32(i), err
354 case "uint64":
355 return strconv.ParseUint(val, 0, 64)
356 default:
357 panic("unreachable")
358 }
359 }
360
View as plain text