Source file src/cmd/fix/main_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"go/ast"
     9  	"go/parser"
    10  	"strings"
    11  	"testing"
    12  
    13  	"cmd/internal/diff"
    14  )
    15  
    16  type testCase struct {
    17  	Name    string
    18  	Fn      func(*ast.File) bool
    19  	Version int
    20  	In      string
    21  	Out     string
    22  }
    23  
    24  var testCases []testCase
    25  
    26  func addTestCases(t []testCase, fn func(*ast.File) bool) {
    27  	// Fill in fn to avoid repetition in definitions.
    28  	if fn != nil {
    29  		for i := range t {
    30  			if t[i].Fn == nil {
    31  				t[i].Fn = fn
    32  			}
    33  		}
    34  	}
    35  	testCases = append(testCases, t...)
    36  }
    37  
    38  func fnop(*ast.File) bool { return false }
    39  
    40  func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
    41  	file, err := parser.ParseFile(fset, desc, in, parserMode)
    42  	if err != nil {
    43  		t.Errorf("parsing: %v", err)
    44  		return
    45  	}
    46  
    47  	outb, err := gofmtFile(file)
    48  	if err != nil {
    49  		t.Errorf("printing: %v", err)
    50  		return
    51  	}
    52  	if s := string(outb); in != s && mustBeGofmt {
    53  		t.Errorf("not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
    54  			desc, in, desc, s)
    55  		tdiff(t, in, s)
    56  		return
    57  	}
    58  
    59  	if fn == nil {
    60  		for _, fix := range fixes {
    61  			if fix.f(file) {
    62  				fixed = true
    63  			}
    64  		}
    65  	} else {
    66  		fixed = fn(file)
    67  	}
    68  
    69  	outb, err = gofmtFile(file)
    70  	if err != nil {
    71  		t.Errorf("printing: %v", err)
    72  		return
    73  	}
    74  
    75  	return string(outb), fixed, true
    76  }
    77  
    78  func TestRewrite(t *testing.T) {
    79  	for _, tt := range testCases {
    80  		tt := tt
    81  		t.Run(tt.Name, func(t *testing.T) {
    82  			if tt.Version == 0 {
    83  				t.Parallel()
    84  			} else {
    85  				old := goVersion
    86  				goVersion = tt.Version
    87  				defer func() {
    88  					goVersion = old
    89  				}()
    90  			}
    91  
    92  			// Apply fix: should get tt.Out.
    93  			out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
    94  			if !ok {
    95  				return
    96  			}
    97  
    98  			// reformat to get printing right
    99  			out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
   100  			if !ok {
   101  				return
   102  			}
   103  
   104  			if tt.Out == "" {
   105  				tt.Out = tt.In
   106  			}
   107  			if out != tt.Out {
   108  				t.Errorf("incorrect output.\n")
   109  				if !strings.HasPrefix(tt.Name, "testdata/") {
   110  					t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
   111  				}
   112  				tdiff(t, out, tt.Out)
   113  				return
   114  			}
   115  
   116  			if changed := out != tt.In; changed != fixed {
   117  				t.Errorf("changed=%v != fixed=%v", changed, fixed)
   118  				return
   119  			}
   120  
   121  			// Should not change if run again.
   122  			out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
   123  			if !ok {
   124  				return
   125  			}
   126  
   127  			if fixed2 {
   128  				t.Errorf("applied fixes during second round")
   129  				return
   130  			}
   131  
   132  			if out2 != out {
   133  				t.Errorf("changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
   134  					out, out2)
   135  				tdiff(t, out, out2)
   136  			}
   137  		})
   138  	}
   139  }
   140  
   141  func tdiff(t *testing.T, a, b string) {
   142  	data, err := diff.Diff("go-fix-test", []byte(a), []byte(b))
   143  	if err != nil {
   144  		t.Error(err)
   145  		return
   146  	}
   147  	t.Error(string(data))
   148  }
   149  

View as plain text