Source file
src/cmd/fix/typecheck.go
1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/parser"
11 "go/token"
12 exec "internal/execabs"
13 "os"
14 "path/filepath"
15 "reflect"
16 "runtime"
17 "strings"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57 func mkType(t string) string {
58 return "type " + t
59 }
60
61 func getType(t string) string {
62 if !isType(t) {
63 return ""
64 }
65 return t[len("type "):]
66 }
67
68 func isType(t string) bool {
69 return strings.HasPrefix(t, "type ")
70 }
71
72
73
74
75
76
77 type TypeConfig struct {
78 Type map[string]*Type
79 Var map[string]string
80 Func map[string]string
81
82
83
84
85 External map[string]string
86 }
87
88
89
90 func (cfg *TypeConfig) typeof(name string) string {
91 if cfg.Var != nil {
92 if t := cfg.Var[name]; t != "" {
93 return t
94 }
95 }
96 if cfg.Func != nil {
97 if t := cfg.Func[name]; t != "" {
98 return "func()" + t
99 }
100 }
101 return ""
102 }
103
104
105
106
107 type Type struct {
108 Field map[string]string
109 Method map[string]string
110 Embed []string
111 Def string
112 }
113
114
115
116 func (typ *Type) dot(cfg *TypeConfig, name string) string {
117 if typ.Field != nil {
118 if t := typ.Field[name]; t != "" {
119 return t
120 }
121 }
122 if typ.Method != nil {
123 if t := typ.Method[name]; t != "" {
124 return t
125 }
126 }
127
128 for _, e := range typ.Embed {
129 etyp := cfg.Type[e]
130 if etyp != nil {
131 if t := etyp.dot(cfg, name); t != "" {
132 return t
133 }
134 }
135 }
136
137 return ""
138 }
139
140
141
142
143
144
145 func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[any]string, assign map[string][]any) {
146 typeof = make(map[any]string)
147 assign = make(map[string][]any)
148 cfg1 := &TypeConfig{}
149 *cfg1 = *cfg
150 copied := false
151
152
153 cfg.External = map[string]string{}
154 cfg1.External = cfg.External
155 if imports(f, "C") {
156
157
158
159 err := func() error {
160 txt, err := gofmtFile(f)
161 if err != nil {
162 return err
163 }
164 dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck")
165 if err != nil {
166 return err
167 }
168 defer os.RemoveAll(dir)
169 err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
170 if err != nil {
171 return err
172 }
173 cmd := exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
174 err = cmd.Run()
175 if err != nil {
176 return err
177 }
178 out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
179 if err != nil {
180 return err
181 }
182 cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
183 if err != nil {
184 return err
185 }
186 for _, decl := range cgo.Decls {
187 fn, ok := decl.(*ast.FuncDecl)
188 if !ok {
189 continue
190 }
191 if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
192 var params, results []string
193 for _, p := range fn.Type.Params.List {
194 t := gofmt(p.Type)
195 t = strings.ReplaceAll(t, "_Ctype_", "C.")
196 params = append(params, t)
197 }
198 for _, r := range fn.Type.Results.List {
199 t := gofmt(r.Type)
200 t = strings.ReplaceAll(t, "_Ctype_", "C.")
201 results = append(results, t)
202 }
203 cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
204 }
205 }
206 return nil
207 }()
208 if err != nil {
209 fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err)
210 }
211 }
212
213
214 for _, decl := range f.Decls {
215 fn, ok := decl.(*ast.FuncDecl)
216 if !ok {
217 continue
218 }
219 typecheck1(cfg, fn.Type, typeof, assign)
220 t := typeof[fn.Type]
221 if fn.Recv != nil {
222
223 rcvr := typeof[fn.Recv]
224 if !isType(rcvr) {
225 if len(fn.Recv.List) != 1 {
226 continue
227 }
228 rcvr = mkType(gofmt(fn.Recv.List[0].Type))
229 typeof[fn.Recv.List[0].Type] = rcvr
230 }
231 rcvr = getType(rcvr)
232 if rcvr != "" && rcvr[0] == '*' {
233 rcvr = rcvr[1:]
234 }
235 typeof[rcvr+"."+fn.Name.Name] = t
236 } else {
237 if isType(t) {
238 t = getType(t)
239 } else {
240 t = gofmt(fn.Type)
241 }
242 typeof[fn.Name] = t
243
244
245 typeof[fn.Name.Obj] = t
246 }
247 }
248
249
250 for _, decl := range f.Decls {
251 d, ok := decl.(*ast.GenDecl)
252 if ok {
253 for _, s := range d.Specs {
254 switch s := s.(type) {
255 case *ast.TypeSpec:
256 if cfg1.Type[s.Name.Name] != nil {
257 break
258 }
259 if !copied {
260 copied = true
261
262 cfg1.Type = make(map[string]*Type)
263 for k, v := range cfg.Type {
264 cfg1.Type[k] = v
265 }
266 }
267 t := &Type{Field: map[string]string{}}
268 cfg1.Type[s.Name.Name] = t
269 switch st := s.Type.(type) {
270 case *ast.StructType:
271 for _, f := range st.Fields.List {
272 for _, n := range f.Names {
273 t.Field[n.Name] = gofmt(f.Type)
274 }
275 }
276 case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
277 t.Def = gofmt(st)
278 }
279 }
280 }
281 }
282 }
283
284 typecheck1(cfg1, f, typeof, assign)
285 return typeof, assign
286 }
287
288 func makeExprList(a []*ast.Ident) []ast.Expr {
289 var b []ast.Expr
290 for _, x := range a {
291 b = append(b, x)
292 }
293 return b
294 }
295
296
297
298
299 func typecheck1(cfg *TypeConfig, f any, typeof map[any]string, assign map[string][]any) {
300
301
302 set := func(n ast.Expr, typ string, isDecl bool) {
303 if typeof[n] != "" || typ == "" {
304 if typeof[n] != typ {
305 assign[typ] = append(assign[typ], n)
306 }
307 return
308 }
309 typeof[n] = typ
310
311
312
313
314
315
316
317 if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
318 typeof[id.Obj] = typ
319 }
320 }
321
322
323
324
325 typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
326 if len(lhs) > 1 && len(rhs) == 1 {
327 if _, ok := rhs[0].(*ast.CallExpr); ok {
328 t := split(typeof[rhs[0]])
329
330 for i := 0; i < len(lhs) && i < len(t); i++ {
331 set(lhs[i], t[i], isDecl)
332 }
333 return
334 }
335 }
336 if len(lhs) == 1 && len(rhs) == 2 {
337
338 rhs = rhs[:1]
339 } else if len(lhs) == 2 && len(rhs) == 1 {
340
341 lhs = lhs[:1]
342 }
343
344
345 for i := 0; i < len(lhs) && i < len(rhs); i++ {
346 x, y := lhs[i], rhs[i]
347 if typeof[y] != "" {
348 set(x, typeof[y], isDecl)
349 } else {
350 set(y, typeof[x], false)
351 }
352 }
353 }
354
355 expand := func(s string) string {
356 typ := cfg.Type[s]
357 if typ != nil && typ.Def != "" {
358 return typ.Def
359 }
360 return s
361 }
362
363
364
365
366
367
368
369 var curfn []*ast.FuncType
370
371 before := func(n any) {
372
373 switch n := n.(type) {
374 case *ast.FuncDecl:
375 curfn = append(curfn, n.Type)
376 case *ast.FuncLit:
377 curfn = append(curfn, n.Type)
378 }
379 }
380
381
382 after := func(n any) {
383 if n == nil {
384 return
385 }
386 if false && reflect.TypeOf(n).Kind() == reflect.Pointer {
387 defer func() {
388 if t := typeof[n]; t != "" {
389 pos := fset.Position(n.(ast.Node).Pos())
390 fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
391 }
392 }()
393 }
394
395 switch n := n.(type) {
396 case *ast.FuncDecl, *ast.FuncLit:
397
398 curfn = curfn[:len(curfn)-1]
399
400 case *ast.FuncType:
401 typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
402
403 case *ast.FieldList:
404
405 t := ""
406 for _, field := range n.List {
407 if t != "" {
408 t += ", "
409 }
410 t += typeof[field]
411 }
412 typeof[n] = t
413
414 case *ast.Field:
415
416 all := ""
417 t := typeof[n.Type]
418 if !isType(t) {
419
420
421 t = mkType(gofmt(n.Type))
422 typeof[n.Type] = t
423 }
424 t = getType(t)
425 if len(n.Names) == 0 {
426 all = t
427 } else {
428 for _, id := range n.Names {
429 if all != "" {
430 all += ", "
431 }
432 all += t
433 typeof[id.Obj] = t
434 typeof[id] = t
435 }
436 }
437 typeof[n] = all
438
439 case *ast.ValueSpec:
440
441 if n.Type != nil {
442 t := typeof[n.Type]
443 if !isType(t) {
444 t = mkType(gofmt(n.Type))
445 typeof[n.Type] = t
446 }
447 t = getType(t)
448 for _, id := range n.Names {
449 set(id, t, true)
450 }
451 }
452
453 typecheckAssign(makeExprList(n.Names), n.Values, true)
454
455 case *ast.AssignStmt:
456 typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
457
458 case *ast.Ident:
459
460 if t := typeof[n.Obj]; t != "" {
461 typeof[n] = t
462 }
463
464 case *ast.SelectorExpr:
465
466 name := n.Sel.Name
467 if t := typeof[n.X]; t != "" {
468 t = strings.TrimPrefix(t, "*")
469 if typ := cfg.Type[t]; typ != nil {
470 if t := typ.dot(cfg, name); t != "" {
471 typeof[n] = t
472 return
473 }
474 }
475 tt := typeof[t+"."+name]
476 if isType(tt) {
477 typeof[n] = getType(tt)
478 return
479 }
480 }
481
482 if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
483 str := x.Name + "." + name
484 if cfg.Type[str] != nil {
485 typeof[n] = mkType(str)
486 return
487 }
488 if t := cfg.typeof(x.Name + "." + name); t != "" {
489 typeof[n] = t
490 return
491 }
492 }
493
494 case *ast.CallExpr:
495
496 if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
497 typeof[n] = gofmt(n.Args[0])
498 return
499 }
500
501 if isTopName(n.Fun, "new") && len(n.Args) == 1 {
502 typeof[n] = "*" + gofmt(n.Args[0])
503 return
504 }
505
506 t := typeof[n.Fun]
507 if t == "" {
508 t = cfg.External[gofmt(n.Fun)]
509 }
510 in, out := splitFunc(t)
511 if in == nil && out == nil {
512 return
513 }
514 typeof[n] = join(out)
515 for i, arg := range n.Args {
516 if i >= len(in) {
517 break
518 }
519 if typeof[arg] == "" {
520 typeof[arg] = in[i]
521 }
522 }
523
524 case *ast.TypeAssertExpr:
525
526 if n.Type == nil {
527 typeof[n] = typeof[n.X]
528 return
529 }
530
531 if t := typeof[n.Type]; isType(t) {
532 typeof[n] = getType(t)
533 } else {
534 typeof[n] = gofmt(n.Type)
535 }
536
537 case *ast.SliceExpr:
538
539 typeof[n] = typeof[n.X]
540
541 case *ast.IndexExpr:
542
543 t := expand(typeof[n.X])
544 if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
545
546
547 if _, elem, ok := strings.Cut(t, "]"); ok {
548 typeof[n] = elem
549 }
550 }
551
552 case *ast.StarExpr:
553
554
555
556 t := expand(typeof[n.X])
557 if isType(t) {
558 typeof[n] = "type *" + getType(t)
559 } else if strings.HasPrefix(t, "*") {
560 typeof[n] = t[len("*"):]
561 }
562
563 case *ast.UnaryExpr:
564
565 t := typeof[n.X]
566 if t != "" && n.Op == token.AND {
567 typeof[n] = "*" + t
568 }
569
570 case *ast.CompositeLit:
571
572 typeof[n] = gofmt(n.Type)
573
574
575 t := expand(typeof[n])
576 if strings.HasPrefix(t, "[") {
577
578 if _, et, ok := strings.Cut(t, "]"); ok {
579 for _, e := range n.Elts {
580 if kv, ok := e.(*ast.KeyValueExpr); ok {
581 e = kv.Value
582 }
583 if typeof[e] == "" {
584 typeof[e] = et
585 }
586 }
587 }
588 }
589 if strings.HasPrefix(t, "map[") {
590
591 if kt, vt, ok := strings.Cut(t[len("map["):], "]"); ok {
592 for _, e := range n.Elts {
593 if kv, ok := e.(*ast.KeyValueExpr); ok {
594 if typeof[kv.Key] == "" {
595 typeof[kv.Key] = kt
596 }
597 if typeof[kv.Value] == "" {
598 typeof[kv.Value] = vt
599 }
600 }
601 }
602 }
603 }
604 if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 {
605 for _, e := range n.Elts {
606 if kv, ok := e.(*ast.KeyValueExpr); ok {
607 if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
608 if typeof[kv.Value] == "" {
609 typeof[kv.Value] = ft
610 }
611 }
612 }
613 }
614 }
615
616 case *ast.ParenExpr:
617
618 typeof[n] = typeof[n.X]
619
620 case *ast.RangeStmt:
621 t := expand(typeof[n.X])
622 if t == "" {
623 return
624 }
625 var key, value string
626 if t == "string" {
627 key, value = "int", "rune"
628 } else if strings.HasPrefix(t, "[") {
629 key = "int"
630 _, value, _ = strings.Cut(t, "]")
631 } else if strings.HasPrefix(t, "map[") {
632 if k, v, ok := strings.Cut(t[len("map["):], "]"); ok {
633 key, value = k, v
634 }
635 }
636 changed := false
637 if n.Key != nil && key != "" {
638 changed = true
639 set(n.Key, key, n.Tok == token.DEFINE)
640 }
641 if n.Value != nil && value != "" {
642 changed = true
643 set(n.Value, value, n.Tok == token.DEFINE)
644 }
645
646
647 if changed {
648 typecheck1(cfg, n.Body, typeof, assign)
649 }
650
651 case *ast.TypeSwitchStmt:
652
653
654
655
656 as, ok := n.Assign.(*ast.AssignStmt)
657 if !ok {
658 return
659 }
660 varx, ok := as.Lhs[0].(*ast.Ident)
661 if !ok {
662 return
663 }
664 t := typeof[varx]
665 for _, cas := range n.Body.List {
666 cas := cas.(*ast.CaseClause)
667 if len(cas.List) == 1 {
668
669
670 if tt := typeof[cas.List[0]]; isType(tt) {
671 tt = getType(tt)
672 typeof[varx] = tt
673 typeof[varx.Obj] = tt
674 typecheck1(cfg, cas.Body, typeof, assign)
675 }
676 }
677 }
678
679 typeof[varx] = t
680 typeof[varx.Obj] = t
681
682 case *ast.ReturnStmt:
683 if len(curfn) == 0 {
684
685 return
686 }
687 f := curfn[len(curfn)-1]
688 res := n.Results
689 if f.Results != nil {
690 t := split(typeof[f.Results])
691 for i := 0; i < len(res) && i < len(t); i++ {
692 set(res[i], t[i], false)
693 }
694 }
695
696 case *ast.BinaryExpr:
697
698 switch n.Op {
699 case token.EQL, token.NEQ:
700 if typeof[n.X] != "" && typeof[n.Y] == "" {
701 typeof[n.Y] = typeof[n.X]
702 }
703 if typeof[n.X] == "" && typeof[n.Y] != "" {
704 typeof[n.X] = typeof[n.Y]
705 }
706 }
707 }
708 }
709 walkBeforeAfter(f, before, after)
710 }
711
712
713
714
715
716
717
718 func splitFunc(s string) (in, out []string) {
719 if !strings.HasPrefix(s, "func(") {
720 return nil, nil
721 }
722
723 i := len("func(")
724 nparen := 0
725 for j := i; j < len(s); j++ {
726 switch s[j] {
727 case '(':
728 nparen++
729 case ')':
730 nparen--
731 if nparen < 0 {
732
733 out := strings.TrimSpace(s[j+1:])
734 if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
735 out = out[1 : len(out)-1]
736 }
737 return split(s[i:j]), split(out)
738 }
739 }
740 }
741 return nil, nil
742 }
743
744
745 func joinFunc(in, out []string) string {
746 outs := ""
747 if len(out) == 1 {
748 outs = " " + out[0]
749 } else if len(out) > 1 {
750 outs = " (" + join(out) + ")"
751 }
752 return "func(" + join(in) + ")" + outs
753 }
754
755
756 func split(s string) []string {
757 out := []string{}
758 i := 0
759 nparen := 0
760 for j := 0; j < len(s); j++ {
761 switch s[j] {
762 case ' ':
763 if i == j {
764 i++
765 }
766 case '(':
767 nparen++
768 case ')':
769 nparen--
770 if nparen < 0 {
771
772 return nil
773 }
774 case ',':
775 if nparen == 0 {
776 if i < j {
777 out = append(out, s[i:j])
778 }
779 i = j + 1
780 }
781 }
782 }
783 if nparen != 0 {
784
785 return nil
786 }
787 if i < len(s) {
788 out = append(out, s[i:])
789 }
790 return out
791 }
792
793
794 func join(x []string) string {
795 return strings.Join(x, ", ")
796 }
797
View as plain text