1
2
3
4
5
6
7 package loopclosure
8
9 import (
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/inspect"
15 "golang.org/x/tools/go/ast/inspector"
16 "golang.org/x/tools/go/types/typeutil"
17 )
18
19 const Doc = `check references to loop variables from within nested functions
20
21 This analyzer checks for references to loop variables from within a
22 function literal inside the loop body. It checks only instances where
23 the function literal is called in a defer or go statement that is the
24 last statement in the loop body, as otherwise we would need whole
25 program analysis.
26
27 For example:
28
29 for i, v := range s {
30 go func() {
31 println(i, v) // not what you might expect
32 }()
33 }
34
35 See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
36
37 var Analyzer = &analysis.Analyzer{
38 Name: "loopclosure",
39 Doc: Doc,
40 Requires: []*analysis.Analyzer{inspect.Analyzer},
41 Run: run,
42 }
43
44 func run(pass *analysis.Pass) (interface{}, error) {
45 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
46
47 nodeFilter := []ast.Node{
48 (*ast.RangeStmt)(nil),
49 (*ast.ForStmt)(nil),
50 }
51 inspect.Preorder(nodeFilter, func(n ast.Node) {
52
53 var vars []*ast.Ident
54 addVar := func(expr ast.Expr) {
55 if id, ok := expr.(*ast.Ident); ok {
56 vars = append(vars, id)
57 }
58 }
59 var body *ast.BlockStmt
60 switch n := n.(type) {
61 case *ast.RangeStmt:
62 body = n.Body
63 addVar(n.Key)
64 addVar(n.Value)
65 case *ast.ForStmt:
66 body = n.Body
67 switch post := n.Post.(type) {
68 case *ast.AssignStmt:
69
70 for _, lhs := range post.Lhs {
71 addVar(lhs)
72 }
73 case *ast.IncDecStmt:
74
75 addVar(post.X)
76 }
77 }
78 if vars == nil {
79 return
80 }
81
82
83
84
85
86
87 if len(body.List) == 0 {
88 return
89 }
90
91 var fun ast.Expr
92 switch s := body.List[len(body.List)-1].(type) {
93 case *ast.GoStmt:
94 fun = s.Call.Fun
95 case *ast.DeferStmt:
96 fun = s.Call.Fun
97 case *ast.ExprStmt:
98 if call, ok := s.X.(*ast.CallExpr); ok {
99 fun = goInvokes(pass.TypesInfo, call)
100 }
101 }
102 lit, ok := fun.(*ast.FuncLit)
103 if !ok {
104 return
105 }
106 ast.Inspect(lit.Body, func(n ast.Node) bool {
107 id, ok := n.(*ast.Ident)
108 if !ok || id.Obj == nil {
109 return true
110 }
111 if pass.TypesInfo.Types[id].Type == nil {
112
113 return true
114 }
115 for _, v := range vars {
116 if v.Obj == id.Obj {
117 pass.ReportRangef(id, "loop variable %s captured by func literal",
118 id.Name)
119 }
120 }
121 return true
122 })
123 })
124 return nil, nil
125 }
126
127
128
129
130
131
132
133
134
135
136 func goInvokes(info *types.Info, call *ast.CallExpr) ast.Expr {
137 f := typeutil.StaticCallee(info, call)
138
139 if f == nil || f.Name() != "Go" {
140 return nil
141 }
142 recv := f.Type().(*types.Signature).Recv()
143 if recv == nil {
144 return nil
145 }
146 rtype, ok := recv.Type().(*types.Pointer)
147 if !ok {
148 return nil
149 }
150 named, ok := rtype.Elem().(*types.Named)
151 if !ok {
152 return nil
153 }
154 if named.Obj().Name() != "Group" {
155 return nil
156 }
157 pkg := f.Pkg()
158 if pkg == nil {
159 return nil
160 }
161 if pkg.Path() != "golang.org/x/sync/errgroup" {
162 return nil
163 }
164 return call.Args[0]
165 }
166
View as plain text