// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import ( "bytes" "flag" "fmt" "go/ast" "go/format" "go/parser" "go/scanner" "go/token" "io" "io/fs" "os" "path/filepath" "sort" "strconv" "strings" "cmd/internal/diff" ) var ( fset = token.NewFileSet() exitCode = 0 ) var allowedRewrites = flag.String("r", "", "restrict the rewrites to this comma-separated list") var forceRewrites = flag.String("force", "", "force these fixes to run even if the code looks updated") var allowed, force map[string]bool var ( doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") goVersionStr = flag.String("go", "", "go language version for files") goVersion int // 115 for go1.15 ) // enable for debugging fix failures const debug = false // display incorrectly reformatted source and exit func usage() { fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") sort.Sort(byName(fixes)) for _, f := range fixes { if f.disabled { fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name) } else { fmt.Fprintf(os.Stderr, "\n%s\n", f.name) } desc := strings.TrimSpace(f.desc) desc = strings.ReplaceAll(desc, "\n", "\n\t") fmt.Fprintf(os.Stderr, "\t%s\n", desc) } os.Exit(2) } func main() { flag.Usage = usage flag.Parse() if *goVersionStr != "" { if !strings.HasPrefix(*goVersionStr, "go") { report(fmt.Errorf("invalid -go=%s", *goVersionStr)) os.Exit(exitCode) } majorStr := (*goVersionStr)[len("go"):] minorStr := "0" if i := strings.Index(majorStr, "."); i >= 0 { majorStr, minorStr = majorStr[:i], majorStr[i+len("."):] } major, err1 := strconv.Atoi(majorStr) minor, err2 := strconv.Atoi(minorStr) if err1 != nil || err2 != nil || major < 0 || major >= 100 || minor < 0 || minor >= 100 { report(fmt.Errorf("invalid -go=%s", *goVersionStr)) os.Exit(exitCode) } goVersion = major*100 + minor } sort.Sort(byDate(fixes)) if *allowedRewrites != "" { allowed = make(map[string]bool) for _, f := range strings.Split(*allowedRewrites, ",") { allowed[f] = true } } if *forceRewrites != "" { force = make(map[string]bool) for _, f := range strings.Split(*forceRewrites, ",") { force[f] = true } } if flag.NArg() == 0 { if err := processFile("standard input", true); err != nil { report(err) } os.Exit(exitCode) } for i := 0; i < flag.NArg(); i++ { path := flag.Arg(i) switch dir, err := os.Stat(path); { case err != nil: report(err) case dir.IsDir(): walkDir(path) default: if err := processFile(path, false); err != nil { report(err) } } } os.Exit(exitCode) } const parserMode = parser.ParseComments func gofmtFile(f *ast.File) ([]byte, error) { var buf bytes.Buffer if err := format.Node(&buf, fset, f); err != nil { return nil, err } return buf.Bytes(), nil } func processFile(filename string, useStdin bool) error { var f *os.File var err error var fixlog bytes.Buffer if useStdin { f = os.Stdin } else { f, err = os.Open(filename) if err != nil { return err } defer f.Close() } src, err := io.ReadAll(f) if err != nil { return err } file, err := parser.ParseFile(fset, filename, src, parserMode) if err != nil { return err } // Make sure file is in canonical format. // This "fmt" pseudo-fix cannot be disabled. newSrc, err := gofmtFile(file) if err != nil { return err } if !bytes.Equal(newSrc, src) { newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode) if err != nil { return err } file = newFile fmt.Fprintf(&fixlog, " fmt") } // Apply all fixes to file. newFile := file fixed := false for _, fix := range fixes { if allowed != nil && !allowed[fix.name] { continue } if fix.disabled && !force[fix.name] { continue } if fix.f(newFile) { fixed = true fmt.Fprintf(&fixlog, " %s", fix.name) // AST changed. // Print and parse, to update any missing scoping // or position information for subsequent fixers. newSrc, err := gofmtFile(newFile) if err != nil { return err } newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) if err != nil { if debug { fmt.Printf("%s", newSrc) report(err) os.Exit(exitCode) } return err } } } if !fixed { return nil } fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) // Print AST. We did that after each fix, so this appears // redundant, but it is necessary to generate gofmt-compatible // source code in a few cases. The official gofmt style is the // output of the printer run on a standard AST generated by the parser, // but the source we generated inside the loop above is the // output of the printer run on a mangled AST generated by a fixer. newSrc, err = gofmtFile(newFile) if err != nil { return err } if *doDiff { data, err := diff.Diff("go-fix", src, newSrc) if err != nil { return fmt.Errorf("computing diff: %s", err) } fmt.Printf("diff %s fixed/%s\n", filename, filename) os.Stdout.Write(data) return nil } if useStdin { os.Stdout.Write(newSrc) return nil } return os.WriteFile(f.Name(), newSrc, 0) } func gofmt(n any) string { var gofmtBuf bytes.Buffer if err := format.Node(&gofmtBuf, fset, n); err != nil { return "<" + err.Error() + ">" } return gofmtBuf.String() } func report(err error) { scanner.PrintError(os.Stderr, err) exitCode = 2 } func walkDir(path string) { filepath.WalkDir(path, visitFile) } func visitFile(path string, f fs.DirEntry, err error) error { if err == nil && isGoFile(f) { err = processFile(path, false) } if err != nil { report(err) } return nil } func isGoFile(f fs.DirEntry) bool { // ignore non-Go files name := f.Name() return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") }