Source file
src/cmd/fix/main.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "io"
17 "io/fs"
18 "os"
19 "path/filepath"
20 "sort"
21 "strconv"
22 "strings"
23
24 "cmd/internal/diff"
25 )
26
27 var (
28 fset = token.NewFileSet()
29 exitCode = 0
30 )
31
32 var allowedRewrites = flag.String("r", "",
33 "restrict the rewrites to this comma-separated list")
34
35 var forceRewrites = flag.String("force", "",
36 "force these fixes to run even if the code looks updated")
37
38 var allowed, force map[string]bool
39
40 var (
41 doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
42 goVersionStr = flag.String("go", "", "go language version for files")
43
44 goVersion int
45 )
46
47
48 const debug = false
49
50 func usage() {
51 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
52 flag.PrintDefaults()
53 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
54 sort.Sort(byName(fixes))
55 for _, f := range fixes {
56 if f.disabled {
57 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
58 } else {
59 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
60 }
61 desc := strings.TrimSpace(f.desc)
62 desc = strings.ReplaceAll(desc, "\n", "\n\t")
63 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
64 }
65 os.Exit(2)
66 }
67
68 func main() {
69 flag.Usage = usage
70 flag.Parse()
71
72 if *goVersionStr != "" {
73 if !strings.HasPrefix(*goVersionStr, "go") {
74 report(fmt.Errorf("invalid -go=%s", *goVersionStr))
75 os.Exit(exitCode)
76 }
77 majorStr := (*goVersionStr)[len("go"):]
78 minorStr := "0"
79 if i := strings.Index(majorStr, "."); i >= 0 {
80 majorStr, minorStr = majorStr[:i], majorStr[i+len("."):]
81 }
82 major, err1 := strconv.Atoi(majorStr)
83 minor, err2 := strconv.Atoi(minorStr)
84 if err1 != nil || err2 != nil || major < 0 || major >= 100 || minor < 0 || minor >= 100 {
85 report(fmt.Errorf("invalid -go=%s", *goVersionStr))
86 os.Exit(exitCode)
87 }
88
89 goVersion = major*100 + minor
90 }
91
92 sort.Sort(byDate(fixes))
93
94 if *allowedRewrites != "" {
95 allowed = make(map[string]bool)
96 for _, f := range strings.Split(*allowedRewrites, ",") {
97 allowed[f] = true
98 }
99 }
100
101 if *forceRewrites != "" {
102 force = make(map[string]bool)
103 for _, f := range strings.Split(*forceRewrites, ",") {
104 force[f] = true
105 }
106 }
107
108 if flag.NArg() == 0 {
109 if err := processFile("standard input", true); err != nil {
110 report(err)
111 }
112 os.Exit(exitCode)
113 }
114
115 for i := 0; i < flag.NArg(); i++ {
116 path := flag.Arg(i)
117 switch dir, err := os.Stat(path); {
118 case err != nil:
119 report(err)
120 case dir.IsDir():
121 walkDir(path)
122 default:
123 if err := processFile(path, false); err != nil {
124 report(err)
125 }
126 }
127 }
128
129 os.Exit(exitCode)
130 }
131
132 const parserMode = parser.ParseComments
133
134 func gofmtFile(f *ast.File) ([]byte, error) {
135 var buf bytes.Buffer
136 if err := format.Node(&buf, fset, f); err != nil {
137 return nil, err
138 }
139 return buf.Bytes(), nil
140 }
141
142 func processFile(filename string, useStdin bool) error {
143 var f *os.File
144 var err error
145 var fixlog bytes.Buffer
146
147 if useStdin {
148 f = os.Stdin
149 } else {
150 f, err = os.Open(filename)
151 if err != nil {
152 return err
153 }
154 defer f.Close()
155 }
156
157 src, err := io.ReadAll(f)
158 if err != nil {
159 return err
160 }
161
162 file, err := parser.ParseFile(fset, filename, src, parserMode)
163 if err != nil {
164 return err
165 }
166
167
168
169 newSrc, err := gofmtFile(file)
170 if err != nil {
171 return err
172 }
173 if !bytes.Equal(newSrc, src) {
174 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
175 if err != nil {
176 return err
177 }
178 file = newFile
179 fmt.Fprintf(&fixlog, " fmt")
180 }
181
182
183 newFile := file
184 fixed := false
185 for _, fix := range fixes {
186 if allowed != nil && !allowed[fix.name] {
187 continue
188 }
189 if fix.disabled && !force[fix.name] {
190 continue
191 }
192 if fix.f(newFile) {
193 fixed = true
194 fmt.Fprintf(&fixlog, " %s", fix.name)
195
196
197
198
199 newSrc, err := gofmtFile(newFile)
200 if err != nil {
201 return err
202 }
203 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
204 if err != nil {
205 if debug {
206 fmt.Printf("%s", newSrc)
207 report(err)
208 os.Exit(exitCode)
209 }
210 return err
211 }
212 }
213 }
214 if !fixed {
215 return nil
216 }
217 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
218
219
220
221
222
223
224
225 newSrc, err = gofmtFile(newFile)
226 if err != nil {
227 return err
228 }
229
230 if *doDiff {
231 data, err := diff.Diff("go-fix", src, newSrc)
232 if err != nil {
233 return fmt.Errorf("computing diff: %s", err)
234 }
235 fmt.Printf("diff %s fixed/%s\n", filename, filename)
236 os.Stdout.Write(data)
237 return nil
238 }
239
240 if useStdin {
241 os.Stdout.Write(newSrc)
242 return nil
243 }
244
245 return os.WriteFile(f.Name(), newSrc, 0)
246 }
247
248 func gofmt(n any) string {
249 var gofmtBuf bytes.Buffer
250 if err := format.Node(&gofmtBuf, fset, n); err != nil {
251 return "<" + err.Error() + ">"
252 }
253 return gofmtBuf.String()
254 }
255
256 func report(err error) {
257 scanner.PrintError(os.Stderr, err)
258 exitCode = 2
259 }
260
261 func walkDir(path string) {
262 filepath.WalkDir(path, visitFile)
263 }
264
265 func visitFile(path string, f fs.DirEntry, err error) error {
266 if err == nil && isGoFile(f) {
267 err = processFile(path, false)
268 }
269 if err != nil {
270 report(err)
271 }
272 return nil
273 }
274
275 func isGoFile(f fs.DirEntry) bool {
276
277 name := f.Name()
278 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
279 }
280
View as plain text