1
2
3
4
5 package astutil
6
7 import (
8 "fmt"
9 "go/ast"
10 "reflect"
11 "sort"
12
13 "golang.org/x/tools/internal/typeparams"
14 )
15
16
17
18
19
20
21
22 type ApplyFunc func(*Cursor) bool
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
46 parent := &struct{ ast.Node }{root}
47 defer func() {
48 if r := recover(); r != nil && r != abort {
49 panic(r)
50 }
51 result = parent.Node
52 }()
53 a := &application{pre: pre, post: post}
54 a.apply(parent, "Node", nil, root)
55 return
56 }
57
58 var abort = new(int)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 type Cursor struct {
74 parent ast.Node
75 name string
76 iter *iterator
77 node ast.Node
78 }
79
80
81 func (c *Cursor) Node() ast.Node { return c.node }
82
83
84 func (c *Cursor) Parent() ast.Node { return c.parent }
85
86
87
88
89 func (c *Cursor) Name() string { return c.name }
90
91
92
93
94
95 func (c *Cursor) Index() int {
96 if c.iter != nil {
97 return c.iter.index
98 }
99 return -1
100 }
101
102
103 func (c *Cursor) field() reflect.Value {
104 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
105 }
106
107
108
109 func (c *Cursor) Replace(n ast.Node) {
110 if _, ok := c.node.(*ast.File); ok {
111 file, ok := n.(*ast.File)
112 if !ok {
113 panic("attempt to replace *ast.File with non-*ast.File")
114 }
115 c.parent.(*ast.Package).Files[c.name] = file
116 return
117 }
118
119 v := c.field()
120 if i := c.Index(); i >= 0 {
121 v = v.Index(i)
122 }
123 v.Set(reflect.ValueOf(n))
124 }
125
126
127
128
129
130 func (c *Cursor) Delete() {
131 if _, ok := c.node.(*ast.File); ok {
132 delete(c.parent.(*ast.Package).Files, c.name)
133 return
134 }
135
136 i := c.Index()
137 if i < 0 {
138 panic("Delete node not contained in slice")
139 }
140 v := c.field()
141 l := v.Len()
142 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
143 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
144 v.SetLen(l - 1)
145 c.iter.step--
146 }
147
148
149
150
151 func (c *Cursor) InsertAfter(n ast.Node) {
152 i := c.Index()
153 if i < 0 {
154 panic("InsertAfter node not contained in slice")
155 }
156 v := c.field()
157 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
158 l := v.Len()
159 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
160 v.Index(i + 1).Set(reflect.ValueOf(n))
161 c.iter.step++
162 }
163
164
165
166
167 func (c *Cursor) InsertBefore(n ast.Node) {
168 i := c.Index()
169 if i < 0 {
170 panic("InsertBefore node not contained in slice")
171 }
172 v := c.field()
173 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
174 l := v.Len()
175 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
176 v.Index(i).Set(reflect.ValueOf(n))
177 c.iter.index++
178 }
179
180
181 type application struct {
182 pre, post ApplyFunc
183 cursor Cursor
184 iter iterator
185 }
186
187 func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
188
189 if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
190 n = nil
191 }
192
193
194 saved := a.cursor
195 a.cursor.parent = parent
196 a.cursor.name = name
197 a.cursor.iter = iter
198 a.cursor.node = n
199
200 if a.pre != nil && !a.pre(&a.cursor) {
201 a.cursor = saved
202 return
203 }
204
205
206
207 switch n := n.(type) {
208 case nil:
209
210
211
212 case *ast.Comment:
213
214
215 case *ast.CommentGroup:
216 if n != nil {
217 a.applyList(n, "List")
218 }
219
220 case *ast.Field:
221 a.apply(n, "Doc", nil, n.Doc)
222 a.applyList(n, "Names")
223 a.apply(n, "Type", nil, n.Type)
224 a.apply(n, "Tag", nil, n.Tag)
225 a.apply(n, "Comment", nil, n.Comment)
226
227 case *ast.FieldList:
228 a.applyList(n, "List")
229
230
231 case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
232
233
234 case *ast.Ellipsis:
235 a.apply(n, "Elt", nil, n.Elt)
236
237 case *ast.FuncLit:
238 a.apply(n, "Type", nil, n.Type)
239 a.apply(n, "Body", nil, n.Body)
240
241 case *ast.CompositeLit:
242 a.apply(n, "Type", nil, n.Type)
243 a.applyList(n, "Elts")
244
245 case *ast.ParenExpr:
246 a.apply(n, "X", nil, n.X)
247
248 case *ast.SelectorExpr:
249 a.apply(n, "X", nil, n.X)
250 a.apply(n, "Sel", nil, n.Sel)
251
252 case *ast.IndexExpr:
253 a.apply(n, "X", nil, n.X)
254 a.apply(n, "Index", nil, n.Index)
255
256 case *typeparams.IndexListExpr:
257 a.apply(n, "X", nil, n.X)
258 a.applyList(n, "Indices")
259
260 case *ast.SliceExpr:
261 a.apply(n, "X", nil, n.X)
262 a.apply(n, "Low", nil, n.Low)
263 a.apply(n, "High", nil, n.High)
264 a.apply(n, "Max", nil, n.Max)
265
266 case *ast.TypeAssertExpr:
267 a.apply(n, "X", nil, n.X)
268 a.apply(n, "Type", nil, n.Type)
269
270 case *ast.CallExpr:
271 a.apply(n, "Fun", nil, n.Fun)
272 a.applyList(n, "Args")
273
274 case *ast.StarExpr:
275 a.apply(n, "X", nil, n.X)
276
277 case *ast.UnaryExpr:
278 a.apply(n, "X", nil, n.X)
279
280 case *ast.BinaryExpr:
281 a.apply(n, "X", nil, n.X)
282 a.apply(n, "Y", nil, n.Y)
283
284 case *ast.KeyValueExpr:
285 a.apply(n, "Key", nil, n.Key)
286 a.apply(n, "Value", nil, n.Value)
287
288
289 case *ast.ArrayType:
290 a.apply(n, "Len", nil, n.Len)
291 a.apply(n, "Elt", nil, n.Elt)
292
293 case *ast.StructType:
294 a.apply(n, "Fields", nil, n.Fields)
295
296 case *ast.FuncType:
297 a.apply(n, "Params", nil, n.Params)
298 a.apply(n, "Results", nil, n.Results)
299
300 case *ast.InterfaceType:
301 a.apply(n, "Methods", nil, n.Methods)
302
303 case *ast.MapType:
304 a.apply(n, "Key", nil, n.Key)
305 a.apply(n, "Value", nil, n.Value)
306
307 case *ast.ChanType:
308 a.apply(n, "Value", nil, n.Value)
309
310
311 case *ast.BadStmt:
312
313
314 case *ast.DeclStmt:
315 a.apply(n, "Decl", nil, n.Decl)
316
317 case *ast.EmptyStmt:
318
319
320 case *ast.LabeledStmt:
321 a.apply(n, "Label", nil, n.Label)
322 a.apply(n, "Stmt", nil, n.Stmt)
323
324 case *ast.ExprStmt:
325 a.apply(n, "X", nil, n.X)
326
327 case *ast.SendStmt:
328 a.apply(n, "Chan", nil, n.Chan)
329 a.apply(n, "Value", nil, n.Value)
330
331 case *ast.IncDecStmt:
332 a.apply(n, "X", nil, n.X)
333
334 case *ast.AssignStmt:
335 a.applyList(n, "Lhs")
336 a.applyList(n, "Rhs")
337
338 case *ast.GoStmt:
339 a.apply(n, "Call", nil, n.Call)
340
341 case *ast.DeferStmt:
342 a.apply(n, "Call", nil, n.Call)
343
344 case *ast.ReturnStmt:
345 a.applyList(n, "Results")
346
347 case *ast.BranchStmt:
348 a.apply(n, "Label", nil, n.Label)
349
350 case *ast.BlockStmt:
351 a.applyList(n, "List")
352
353 case *ast.IfStmt:
354 a.apply(n, "Init", nil, n.Init)
355 a.apply(n, "Cond", nil, n.Cond)
356 a.apply(n, "Body", nil, n.Body)
357 a.apply(n, "Else", nil, n.Else)
358
359 case *ast.CaseClause:
360 a.applyList(n, "List")
361 a.applyList(n, "Body")
362
363 case *ast.SwitchStmt:
364 a.apply(n, "Init", nil, n.Init)
365 a.apply(n, "Tag", nil, n.Tag)
366 a.apply(n, "Body", nil, n.Body)
367
368 case *ast.TypeSwitchStmt:
369 a.apply(n, "Init", nil, n.Init)
370 a.apply(n, "Assign", nil, n.Assign)
371 a.apply(n, "Body", nil, n.Body)
372
373 case *ast.CommClause:
374 a.apply(n, "Comm", nil, n.Comm)
375 a.applyList(n, "Body")
376
377 case *ast.SelectStmt:
378 a.apply(n, "Body", nil, n.Body)
379
380 case *ast.ForStmt:
381 a.apply(n, "Init", nil, n.Init)
382 a.apply(n, "Cond", nil, n.Cond)
383 a.apply(n, "Post", nil, n.Post)
384 a.apply(n, "Body", nil, n.Body)
385
386 case *ast.RangeStmt:
387 a.apply(n, "Key", nil, n.Key)
388 a.apply(n, "Value", nil, n.Value)
389 a.apply(n, "X", nil, n.X)
390 a.apply(n, "Body", nil, n.Body)
391
392
393 case *ast.ImportSpec:
394 a.apply(n, "Doc", nil, n.Doc)
395 a.apply(n, "Name", nil, n.Name)
396 a.apply(n, "Path", nil, n.Path)
397 a.apply(n, "Comment", nil, n.Comment)
398
399 case *ast.ValueSpec:
400 a.apply(n, "Doc", nil, n.Doc)
401 a.applyList(n, "Names")
402 a.apply(n, "Type", nil, n.Type)
403 a.applyList(n, "Values")
404 a.apply(n, "Comment", nil, n.Comment)
405
406 case *ast.TypeSpec:
407 a.apply(n, "Doc", nil, n.Doc)
408 a.apply(n, "Name", nil, n.Name)
409 a.apply(n, "Type", nil, n.Type)
410 a.apply(n, "Comment", nil, n.Comment)
411
412 case *ast.BadDecl:
413
414
415 case *ast.GenDecl:
416 a.apply(n, "Doc", nil, n.Doc)
417 a.applyList(n, "Specs")
418
419 case *ast.FuncDecl:
420 a.apply(n, "Doc", nil, n.Doc)
421 a.apply(n, "Recv", nil, n.Recv)
422 a.apply(n, "Name", nil, n.Name)
423 a.apply(n, "Type", nil, n.Type)
424 a.apply(n, "Body", nil, n.Body)
425
426
427 case *ast.File:
428 a.apply(n, "Doc", nil, n.Doc)
429 a.apply(n, "Name", nil, n.Name)
430 a.applyList(n, "Decls")
431
432
433
434 case *ast.Package:
435
436 var names []string
437 for name := range n.Files {
438 names = append(names, name)
439 }
440 sort.Strings(names)
441 for _, name := range names {
442 a.apply(n, name, nil, n.Files[name])
443 }
444
445 default:
446 panic(fmt.Sprintf("Apply: unexpected node type %T", n))
447 }
448
449 if a.post != nil && !a.post(&a.cursor) {
450 panic(abort)
451 }
452
453 a.cursor = saved
454 }
455
456
457 type iterator struct {
458 index, step int
459 }
460
461 func (a *application) applyList(parent ast.Node, name string) {
462
463 saved := a.iter
464 a.iter.index = 0
465 for {
466
467 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
468 if a.iter.index >= v.Len() {
469 break
470 }
471
472
473 var x ast.Node
474 if e := v.Index(a.iter.index); e.IsValid() {
475 x = e.Interface().(ast.Node)
476 }
477
478 a.iter.step = 1
479 a.apply(parent, name, &a.iter, x)
480 a.iter.index += a.iter.step
481 }
482 a.iter = saved
483 }
484
View as plain text