Source file src/cmd/fix/main.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"io"
    17  	"io/fs"
    18  	"os"
    19  	"path/filepath"
    20  	"sort"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"cmd/internal/diff"
    25  )
    26  
    27  var (
    28  	fset     = token.NewFileSet()
    29  	exitCode = 0
    30  )
    31  
    32  var allowedRewrites = flag.String("r", "",
    33  	"restrict the rewrites to this comma-separated list")
    34  
    35  var forceRewrites = flag.String("force", "",
    36  	"force these fixes to run even if the code looks updated")
    37  
    38  var allowed, force map[string]bool
    39  
    40  var (
    41  	doDiff       = flag.Bool("diff", false, "display diffs instead of rewriting files")
    42  	goVersionStr = flag.String("go", "", "go language version for files")
    43  
    44  	goVersion int // 115 for go1.15
    45  )
    46  
    47  // enable for debugging fix failures
    48  const debug = false // display incorrectly reformatted source and exit
    49  
    50  func usage() {
    51  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    52  	flag.PrintDefaults()
    53  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    54  	sort.Sort(byName(fixes))
    55  	for _, f := range fixes {
    56  		if f.disabled {
    57  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    58  		} else {
    59  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    60  		}
    61  		desc := strings.TrimSpace(f.desc)
    62  		desc = strings.ReplaceAll(desc, "\n", "\n\t")
    63  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    64  	}
    65  	os.Exit(2)
    66  }
    67  
    68  func main() {
    69  	flag.Usage = usage
    70  	flag.Parse()
    71  
    72  	if *goVersionStr != "" {
    73  		if !strings.HasPrefix(*goVersionStr, "go") {
    74  			report(fmt.Errorf("invalid -go=%s", *goVersionStr))
    75  			os.Exit(exitCode)
    76  		}
    77  		majorStr := (*goVersionStr)[len("go"):]
    78  		minorStr := "0"
    79  		if i := strings.Index(majorStr, "."); i >= 0 {
    80  			majorStr, minorStr = majorStr[:i], majorStr[i+len("."):]
    81  		}
    82  		major, err1 := strconv.Atoi(majorStr)
    83  		minor, err2 := strconv.Atoi(minorStr)
    84  		if err1 != nil || err2 != nil || major < 0 || major >= 100 || minor < 0 || minor >= 100 {
    85  			report(fmt.Errorf("invalid -go=%s", *goVersionStr))
    86  			os.Exit(exitCode)
    87  		}
    88  
    89  		goVersion = major*100 + minor
    90  	}
    91  
    92  	sort.Sort(byDate(fixes))
    93  
    94  	if *allowedRewrites != "" {
    95  		allowed = make(map[string]bool)
    96  		for _, f := range strings.Split(*allowedRewrites, ",") {
    97  			allowed[f] = true
    98  		}
    99  	}
   100  
   101  	if *forceRewrites != "" {
   102  		force = make(map[string]bool)
   103  		for _, f := range strings.Split(*forceRewrites, ",") {
   104  			force[f] = true
   105  		}
   106  	}
   107  
   108  	if flag.NArg() == 0 {
   109  		if err := processFile("standard input", true); err != nil {
   110  			report(err)
   111  		}
   112  		os.Exit(exitCode)
   113  	}
   114  
   115  	for i := 0; i < flag.NArg(); i++ {
   116  		path := flag.Arg(i)
   117  		switch dir, err := os.Stat(path); {
   118  		case err != nil:
   119  			report(err)
   120  		case dir.IsDir():
   121  			walkDir(path)
   122  		default:
   123  			if err := processFile(path, false); err != nil {
   124  				report(err)
   125  			}
   126  		}
   127  	}
   128  
   129  	os.Exit(exitCode)
   130  }
   131  
   132  const parserMode = parser.ParseComments
   133  
   134  func gofmtFile(f *ast.File) ([]byte, error) {
   135  	var buf bytes.Buffer
   136  	if err := format.Node(&buf, fset, f); err != nil {
   137  		return nil, err
   138  	}
   139  	return buf.Bytes(), nil
   140  }
   141  
   142  func processFile(filename string, useStdin bool) error {
   143  	var f *os.File
   144  	var err error
   145  	var fixlog bytes.Buffer
   146  
   147  	if useStdin {
   148  		f = os.Stdin
   149  	} else {
   150  		f, err = os.Open(filename)
   151  		if err != nil {
   152  			return err
   153  		}
   154  		defer f.Close()
   155  	}
   156  
   157  	src, err := io.ReadAll(f)
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	// Make sure file is in canonical format.
   168  	// This "fmt" pseudo-fix cannot be disabled.
   169  	newSrc, err := gofmtFile(file)
   170  	if err != nil {
   171  		return err
   172  	}
   173  	if !bytes.Equal(newSrc, src) {
   174  		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
   175  		if err != nil {
   176  			return err
   177  		}
   178  		file = newFile
   179  		fmt.Fprintf(&fixlog, " fmt")
   180  	}
   181  
   182  	// Apply all fixes to file.
   183  	newFile := file
   184  	fixed := false
   185  	for _, fix := range fixes {
   186  		if allowed != nil && !allowed[fix.name] {
   187  			continue
   188  		}
   189  		if fix.disabled && !force[fix.name] {
   190  			continue
   191  		}
   192  		if fix.f(newFile) {
   193  			fixed = true
   194  			fmt.Fprintf(&fixlog, " %s", fix.name)
   195  
   196  			// AST changed.
   197  			// Print and parse, to update any missing scoping
   198  			// or position information for subsequent fixers.
   199  			newSrc, err := gofmtFile(newFile)
   200  			if err != nil {
   201  				return err
   202  			}
   203  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   204  			if err != nil {
   205  				if debug {
   206  					fmt.Printf("%s", newSrc)
   207  					report(err)
   208  					os.Exit(exitCode)
   209  				}
   210  				return err
   211  			}
   212  		}
   213  	}
   214  	if !fixed {
   215  		return nil
   216  	}
   217  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   218  
   219  	// Print AST.  We did that after each fix, so this appears
   220  	// redundant, but it is necessary to generate gofmt-compatible
   221  	// source code in a few cases. The official gofmt style is the
   222  	// output of the printer run on a standard AST generated by the parser,
   223  	// but the source we generated inside the loop above is the
   224  	// output of the printer run on a mangled AST generated by a fixer.
   225  	newSrc, err = gofmtFile(newFile)
   226  	if err != nil {
   227  		return err
   228  	}
   229  
   230  	if *doDiff {
   231  		data, err := diff.Diff("go-fix", src, newSrc)
   232  		if err != nil {
   233  			return fmt.Errorf("computing diff: %s", err)
   234  		}
   235  		fmt.Printf("diff %s fixed/%s\n", filename, filename)
   236  		os.Stdout.Write(data)
   237  		return nil
   238  	}
   239  
   240  	if useStdin {
   241  		os.Stdout.Write(newSrc)
   242  		return nil
   243  	}
   244  
   245  	return os.WriteFile(f.Name(), newSrc, 0)
   246  }
   247  
   248  func gofmt(n any) string {
   249  	var gofmtBuf bytes.Buffer
   250  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   251  		return "<" + err.Error() + ">"
   252  	}
   253  	return gofmtBuf.String()
   254  }
   255  
   256  func report(err error) {
   257  	scanner.PrintError(os.Stderr, err)
   258  	exitCode = 2
   259  }
   260  
   261  func walkDir(path string) {
   262  	filepath.WalkDir(path, visitFile)
   263  }
   264  
   265  func visitFile(path string, f fs.DirEntry, err error) error {
   266  	if err == nil && isGoFile(f) {
   267  		err = processFile(path, false)
   268  	}
   269  	if err != nil {
   270  		report(err)
   271  	}
   272  	return nil
   273  }
   274  
   275  func isGoFile(f fs.DirEntry) bool {
   276  	// ignore non-Go files
   277  	name := f.Name()
   278  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   279  }
   280  

View as plain text