Source file
src/cmd/gofmt/rewrite.go
1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 "os"
13 "reflect"
14 "strings"
15 "unicode"
16 "unicode/utf8"
17 )
18
19 func initRewrite() {
20 if *rewriteRule == "" {
21 rewrite = nil
22 return
23 }
24 f := strings.Split(*rewriteRule, "->")
25 if len(f) != 2 {
26 fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
27 os.Exit(2)
28 }
29 pattern := parseExpr(f[0], "pattern")
30 replace := parseExpr(f[1], "replacement")
31 rewrite = func(fset *token.FileSet, p *ast.File) *ast.File {
32 return rewriteFile(fset, pattern, replace, p)
33 }
34 }
35
36
37
38
39
40 func parseExpr(s, what string) ast.Expr {
41 x, err := parser.ParseExpr(s)
42 if err != nil {
43 fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
44 os.Exit(2)
45 }
46 return x
47 }
48
49
50
57
58
59 func rewriteFile(fileSet *token.FileSet, pattern, replace ast.Expr, p *ast.File) *ast.File {
60 cmap := ast.NewCommentMap(fileSet, p, p.Comments)
61 m := make(map[string]reflect.Value)
62 pat := reflect.ValueOf(pattern)
63 repl := reflect.ValueOf(replace)
64
65 var rewriteVal func(val reflect.Value) reflect.Value
66 rewriteVal = func(val reflect.Value) reflect.Value {
67
68 if !val.IsValid() {
69 return reflect.Value{}
70 }
71 val = apply(rewriteVal, val)
72 for k := range m {
73 delete(m, k)
74 }
75 if match(m, pat, val) {
76 val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
77 }
78 return val
79 }
80
81 r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
82 r.Comments = cmap.Filter(r).Comments()
83 return r
84 }
85
86
87 func set(x, y reflect.Value) {
88
89 if !x.CanSet() || !y.IsValid() {
90 return
91 }
92 defer func() {
93 if x := recover(); x != nil {
94 if s, ok := x.(string); ok &&
95 (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
96
97 return
98 }
99 panic(x)
100 }
101 }()
102 x.Set(y)
103 }
104
105
106 var (
107 objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
108 scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
109
110 identType = reflect.TypeOf((*ast.Ident)(nil))
111 objectPtrType = reflect.TypeOf((*ast.Object)(nil))
112 positionType = reflect.TypeOf(token.NoPos)
113 callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
114 scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
115 )
116
117
118
119 func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
120 if !val.IsValid() {
121 return reflect.Value{}
122 }
123
124
125
126 if val.Type() == objectPtrType {
127 return objectPtrNil
128 }
129
130
131
132 if val.Type() == scopePtrType {
133 return scopePtrNil
134 }
135
136 switch v := reflect.Indirect(val); v.Kind() {
137 case reflect.Slice:
138 for i := 0; i < v.Len(); i++ {
139 e := v.Index(i)
140 set(e, f(e))
141 }
142 case reflect.Struct:
143 for i := 0; i < v.NumField(); i++ {
144 e := v.Field(i)
145 set(e, f(e))
146 }
147 case reflect.Interface:
148 e := v.Elem()
149 set(v, f(e))
150 }
151 return val
152 }
153
154 func isWildcard(s string) bool {
155 rune, size := utf8.DecodeRuneInString(s)
156 return size == len(s) && unicode.IsLower(rune)
157 }
158
159
160
161
162 func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
163
164
165
166 if m != nil && pattern.IsValid() && pattern.Type() == identType {
167 name := pattern.Interface().(*ast.Ident).Name
168 if isWildcard(name) && val.IsValid() {
169
170 if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
171 if old, ok := m[name]; ok {
172 return match(nil, old, val)
173 }
174 m[name] = val
175 return true
176 }
177 }
178 }
179
180
181 if !pattern.IsValid() || !val.IsValid() {
182 return !pattern.IsValid() && !val.IsValid()
183 }
184 if pattern.Type() != val.Type() {
185 return false
186 }
187
188
189 switch pattern.Type() {
190 case identType:
191
192
193
194
195 p := pattern.Interface().(*ast.Ident)
196 v := val.Interface().(*ast.Ident)
197 return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
198 case objectPtrType, positionType:
199
200 return true
201 case callExprType:
202
203
204
205 p := pattern.Interface().(*ast.CallExpr)
206 v := val.Interface().(*ast.CallExpr)
207 if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
208 return false
209 }
210 }
211
212 p := reflect.Indirect(pattern)
213 v := reflect.Indirect(val)
214 if !p.IsValid() || !v.IsValid() {
215 return !p.IsValid() && !v.IsValid()
216 }
217
218 switch p.Kind() {
219 case reflect.Slice:
220 if p.Len() != v.Len() {
221 return false
222 }
223 for i := 0; i < p.Len(); i++ {
224 if !match(m, p.Index(i), v.Index(i)) {
225 return false
226 }
227 }
228 return true
229
230 case reflect.Struct:
231 for i := 0; i < p.NumField(); i++ {
232 if !match(m, p.Field(i), v.Field(i)) {
233 return false
234 }
235 }
236 return true
237
238 case reflect.Interface:
239 return match(m, p.Elem(), v.Elem())
240 }
241
242
243 return p.Interface() == v.Interface()
244 }
245
246
247
248
249
250 func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
251 if !pattern.IsValid() {
252 return reflect.Value{}
253 }
254
255
256 if m != nil && pattern.Type() == identType {
257 name := pattern.Interface().(*ast.Ident).Name
258 if isWildcard(name) {
259 if old, ok := m[name]; ok {
260 return subst(nil, old, reflect.Value{})
261 }
262 }
263 }
264
265 if pos.IsValid() && pattern.Type() == positionType {
266
267 if old := pattern.Interface().(token.Pos); !old.IsValid() {
268 return pattern
269 }
270 return pos
271 }
272
273
274 switch p := pattern; p.Kind() {
275 case reflect.Slice:
276 if p.IsNil() {
277
278
279
280 return reflect.Zero(p.Type())
281 }
282 v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
283 for i := 0; i < p.Len(); i++ {
284 v.Index(i).Set(subst(m, p.Index(i), pos))
285 }
286 return v
287
288 case reflect.Struct:
289 v := reflect.New(p.Type()).Elem()
290 for i := 0; i < p.NumField(); i++ {
291 v.Field(i).Set(subst(m, p.Field(i), pos))
292 }
293 return v
294
295 case reflect.Pointer:
296 v := reflect.New(p.Type()).Elem()
297 if elem := p.Elem(); elem.IsValid() {
298 v.Set(subst(m, elem, pos).Addr())
299 }
300 return v
301
302 case reflect.Interface:
303 v := reflect.New(p.Type()).Elem()
304 if elem := p.Elem(); elem.IsValid() {
305 v.Set(subst(m, elem, pos))
306 }
307 return v
308 }
309
310 return pattern
311 }
312
View as plain text