cmd/fix: remove all functionality

The buildtag fixer has been incorporated into the vet analyzer
of the same name; all other fixers were already no-ops since
CL 695855.

Fixes #73605
Updates #71859

Change-Id: I90b6c730849a5ecbac3e6fb6fc0e062b5de74831
Reviewed-on: https://go-review.googlesource.com/c/go/+/706758
Reviewed-by: Michael Matloob <matloob@golang.org>
Reviewed-by: Michael Matloob <matloob@google.com>
Auto-Submit: Alan Donovan <adonovan@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Alan Donovan 2025-09-25 12:41:08 -04:00
parent 6dceff8bad
commit 393d91aea0
15 changed files with 12 additions and 2475 deletions

View file

@ -1,52 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"go/ast"
"go/version"
"strings"
)
func init() {
register(buildtagFix)
}
const buildtagGoVersionCutoff = "go1.18"
var buildtagFix = fix{
name: "buildtag",
date: "2021-08-25",
f: buildtag,
desc: `Remove +build comments from modules using Go 1.18 or later`,
}
func buildtag(f *ast.File) bool {
if version.Compare(*goVersion, buildtagGoVersionCutoff) < 0 {
return false
}
// File is already gofmt-ed, so we know that if there are +build lines,
// they are in a comment group that starts with a //go:build line followed
// by a blank line. While we cannot delete comments from an AST and
// expect consistent output in general, this specific case - deleting only
// some lines from a comment block - does format correctly.
fixed := false
for _, g := range f.Comments {
sawGoBuild := false
for i, c := range g.List {
if strings.HasPrefix(c.Text, "//go:build ") {
sawGoBuild = true
}
if sawGoBuild && strings.HasPrefix(c.Text, "// +build ") {
g.List = g.List[:i]
fixed = true
break
}
}
}
return fixed
}

View file

@ -1,34 +0,0 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
addTestCases(buildtagTests, buildtag)
}
var buildtagTests = []testCase{
{
Name: "buildtag.oldGo",
Version: "go1.10",
In: `//go:build yes
// +build yes
package main
`,
},
{
Name: "buildtag.new",
Version: "go1.99",
In: `//go:build yes
// +build yes
package main
`,
Out: `//go:build yes
package main
`,
},
}

View file

@ -1,25 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"go/ast"
)
func init() {
register(cftypeFix)
}
var cftypeFix = fix{
name: "cftype",
date: "2017-09-27",
f: noop,
desc: `Fixes initializers and casts of C.*Ref and JNI types (removed)`,
disabled: false,
}
func noop(f *ast.File) bool {
return false
}

View file

@ -1,17 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
register(contextFix)
}
var contextFix = fix{
name: "context",
date: "2016-09-09",
f: noop,
desc: `Change imports of golang.org/x/net/context to context (removed)`,
disabled: false,
}

View file

@ -9,29 +9,12 @@ the necessary changes to your programs.
Usage: Usage:
go tool fix [-r name,...] [path ...] go tool fix [ignored...]
Without an explicit path, fix reads standard input and writes the This tool is currently in transition. All its historical fixers were
result to standard output. long obsolete and have been removed, so it is currently a no-op. In
due course the tool will integrate with the Go analysis framework
If the named path is a file, fix rewrites the named files in place. (golang.org/x/tools/go/analysis) and run a modern suite of fix
If the named path is a directory, fix rewrites all .go files in that algorithms; see https://go.dev/issue/71859.
directory tree. When fix rewrites a file, it prints a line to standard
error giving the name of the file and the rewrite applied.
If the -diff flag is set, no files are rewritten. Instead fix prints
the differences a rewrite would introduce.
The -r flag restricts the set of rewrites considered to those in the
named list. By default fix considers all known rewrites. Fix's
rewrites are idempotent, so that it is safe to apply fix to updated
or partially updated code even without using the -r flag.
Fix prints the full list of fixes it can apply in its help output;
to see them, run go tool fix -help.
Fix does not make backup copies of the files that it edits.
Instead, use a version control system's diff functionality to inspect
the changes that fix makes before committing them.
*/ */
package main package main

View file

@ -1,26 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
register(eglFixDisplay)
register(eglFixConfig)
}
var eglFixDisplay = fix{
name: "egl",
date: "2018-12-15",
f: noop,
desc: `Fixes initializers of EGLDisplay (removed)`,
disabled: false,
}
var eglFixConfig = fix{
name: "eglconf",
date: "2020-05-30",
f: noop,
desc: `Fixes initializers of EGLConfig (removed)`,
disabled: false,
}

View file

@ -1,552 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/token"
"path"
"strconv"
)
type fix struct {
name string
date string // date that fix was introduced, in YYYY-MM-DD format
f func(*ast.File) bool
desc string
disabled bool // whether this fix should be disabled by default
}
var fixes []fix
func register(f fix) {
fixes = append(fixes, f)
}
// walk traverses the AST x, calling visit(y) for each node y in the tree but
// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
// in a bottom-up traversal.
func walk(x any, visit func(any)) {
walkBeforeAfter(x, nop, visit)
}
func nop(any) {}
// walkBeforeAfter is like walk but calls before(x) before traversing
// x's children and after(x) afterward.
func walkBeforeAfter(x any, before, after func(any)) {
before(x)
switch n := x.(type) {
default:
panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
case nil:
// pointers to interfaces
case *ast.Decl:
walkBeforeAfter(*n, before, after)
case *ast.Expr:
walkBeforeAfter(*n, before, after)
case *ast.Spec:
walkBeforeAfter(*n, before, after)
case *ast.Stmt:
walkBeforeAfter(*n, before, after)
// pointers to struct pointers
case **ast.BlockStmt:
walkBeforeAfter(*n, before, after)
case **ast.CallExpr:
walkBeforeAfter(*n, before, after)
case **ast.FieldList:
walkBeforeAfter(*n, before, after)
case **ast.FuncType:
walkBeforeAfter(*n, before, after)
case **ast.Ident:
walkBeforeAfter(*n, before, after)
case **ast.BasicLit:
walkBeforeAfter(*n, before, after)
// pointers to slices
case *[]ast.Decl:
walkBeforeAfter(*n, before, after)
case *[]ast.Expr:
walkBeforeAfter(*n, before, after)
case *[]*ast.File:
walkBeforeAfter(*n, before, after)
case *[]*ast.Ident:
walkBeforeAfter(*n, before, after)
case *[]ast.Spec:
walkBeforeAfter(*n, before, after)
case *[]ast.Stmt:
walkBeforeAfter(*n, before, after)
// These are ordered and grouped to match ../../go/ast/ast.go
case *ast.Field:
walkBeforeAfter(&n.Names, before, after)
walkBeforeAfter(&n.Type, before, after)
walkBeforeAfter(&n.Tag, before, after)
case *ast.FieldList:
for _, field := range n.List {
walkBeforeAfter(field, before, after)
}
case *ast.BadExpr:
case *ast.Ident:
case *ast.Ellipsis:
walkBeforeAfter(&n.Elt, before, after)
case *ast.BasicLit:
case *ast.FuncLit:
walkBeforeAfter(&n.Type, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.CompositeLit:
walkBeforeAfter(&n.Type, before, after)
walkBeforeAfter(&n.Elts, before, after)
case *ast.ParenExpr:
walkBeforeAfter(&n.X, before, after)
case *ast.SelectorExpr:
walkBeforeAfter(&n.X, before, after)
case *ast.IndexExpr:
walkBeforeAfter(&n.X, before, after)
walkBeforeAfter(&n.Index, before, after)
case *ast.IndexListExpr:
walkBeforeAfter(&n.X, before, after)
walkBeforeAfter(&n.Indices, before, after)
case *ast.SliceExpr:
walkBeforeAfter(&n.X, before, after)
if n.Low != nil {
walkBeforeAfter(&n.Low, before, after)
}
if n.High != nil {
walkBeforeAfter(&n.High, before, after)
}
case *ast.TypeAssertExpr:
walkBeforeAfter(&n.X, before, after)
walkBeforeAfter(&n.Type, before, after)
case *ast.CallExpr:
walkBeforeAfter(&n.Fun, before, after)
walkBeforeAfter(&n.Args, before, after)
case *ast.StarExpr:
walkBeforeAfter(&n.X, before, after)
case *ast.UnaryExpr:
walkBeforeAfter(&n.X, before, after)
case *ast.BinaryExpr:
walkBeforeAfter(&n.X, before, after)
walkBeforeAfter(&n.Y, before, after)
case *ast.KeyValueExpr:
walkBeforeAfter(&n.Key, before, after)
walkBeforeAfter(&n.Value, before, after)
case *ast.ArrayType:
walkBeforeAfter(&n.Len, before, after)
walkBeforeAfter(&n.Elt, before, after)
case *ast.StructType:
walkBeforeAfter(&n.Fields, before, after)
case *ast.FuncType:
if n.TypeParams != nil {
walkBeforeAfter(&n.TypeParams, before, after)
}
walkBeforeAfter(&n.Params, before, after)
if n.Results != nil {
walkBeforeAfter(&n.Results, before, after)
}
case *ast.InterfaceType:
walkBeforeAfter(&n.Methods, before, after)
case *ast.MapType:
walkBeforeAfter(&n.Key, before, after)
walkBeforeAfter(&n.Value, before, after)
case *ast.ChanType:
walkBeforeAfter(&n.Value, before, after)
case *ast.BadStmt:
case *ast.DeclStmt:
walkBeforeAfter(&n.Decl, before, after)
case *ast.EmptyStmt:
case *ast.LabeledStmt:
walkBeforeAfter(&n.Stmt, before, after)
case *ast.ExprStmt:
walkBeforeAfter(&n.X, before, after)
case *ast.SendStmt:
walkBeforeAfter(&n.Chan, before, after)
walkBeforeAfter(&n.Value, before, after)
case *ast.IncDecStmt:
walkBeforeAfter(&n.X, before, after)
case *ast.AssignStmt:
walkBeforeAfter(&n.Lhs, before, after)
walkBeforeAfter(&n.Rhs, before, after)
case *ast.GoStmt:
walkBeforeAfter(&n.Call, before, after)
case *ast.DeferStmt:
walkBeforeAfter(&n.Call, before, after)
case *ast.ReturnStmt:
walkBeforeAfter(&n.Results, before, after)
case *ast.BranchStmt:
case *ast.BlockStmt:
walkBeforeAfter(&n.List, before, after)
case *ast.IfStmt:
walkBeforeAfter(&n.Init, before, after)
walkBeforeAfter(&n.Cond, before, after)
walkBeforeAfter(&n.Body, before, after)
walkBeforeAfter(&n.Else, before, after)
case *ast.CaseClause:
walkBeforeAfter(&n.List, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.SwitchStmt:
walkBeforeAfter(&n.Init, before, after)
walkBeforeAfter(&n.Tag, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.TypeSwitchStmt:
walkBeforeAfter(&n.Init, before, after)
walkBeforeAfter(&n.Assign, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.CommClause:
walkBeforeAfter(&n.Comm, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.SelectStmt:
walkBeforeAfter(&n.Body, before, after)
case *ast.ForStmt:
walkBeforeAfter(&n.Init, before, after)
walkBeforeAfter(&n.Cond, before, after)
walkBeforeAfter(&n.Post, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.RangeStmt:
walkBeforeAfter(&n.Key, before, after)
walkBeforeAfter(&n.Value, before, after)
walkBeforeAfter(&n.X, before, after)
walkBeforeAfter(&n.Body, before, after)
case *ast.ImportSpec:
case *ast.ValueSpec:
walkBeforeAfter(&n.Type, before, after)
walkBeforeAfter(&n.Values, before, after)
walkBeforeAfter(&n.Names, before, after)
case *ast.TypeSpec:
if n.TypeParams != nil {
walkBeforeAfter(&n.TypeParams, before, after)
}
walkBeforeAfter(&n.Type, before, after)
case *ast.BadDecl:
case *ast.GenDecl:
walkBeforeAfter(&n.Specs, before, after)
case *ast.FuncDecl:
if n.Recv != nil {
walkBeforeAfter(&n.Recv, before, after)
}
walkBeforeAfter(&n.Type, before, after)
if n.Body != nil {
walkBeforeAfter(&n.Body, before, after)
}
case *ast.File:
walkBeforeAfter(&n.Decls, before, after)
case *ast.Package:
walkBeforeAfter(&n.Files, before, after)
case []*ast.File:
for i := range n {
walkBeforeAfter(&n[i], before, after)
}
case []ast.Decl:
for i := range n {
walkBeforeAfter(&n[i], before, after)
}
case []ast.Expr:
for i := range n {
walkBeforeAfter(&n[i], before, after)
}
case []*ast.Ident:
for i := range n {
walkBeforeAfter(&n[i], before, after)
}
case []ast.Stmt:
for i := range n {
walkBeforeAfter(&n[i], before, after)
}
case []ast.Spec:
for i := range n {
walkBeforeAfter(&n[i], before, after)
}
}
after(x)
}
// imports reports whether f imports path.
func imports(f *ast.File, path string) bool {
return importSpec(f, path) != nil
}
// importSpec returns the import spec if f imports path,
// or nil otherwise.
func importSpec(f *ast.File, path string) *ast.ImportSpec {
for _, s := range f.Imports {
if importPath(s) == path {
return s
}
}
return nil
}
// importPath returns the unquoted import path of s,
// or "" if the path is not properly quoted.
func importPath(s *ast.ImportSpec) string {
t, err := strconv.Unquote(s.Path.Value)
if err == nil {
return t
}
return ""
}
// declImports reports whether gen contains an import of path.
func declImports(gen *ast.GenDecl, path string) bool {
if gen.Tok != token.IMPORT {
return false
}
for _, spec := range gen.Specs {
impspec := spec.(*ast.ImportSpec)
if importPath(impspec) == path {
return true
}
}
return false
}
// isTopName reports whether n is a top-level unresolved identifier with the given name.
func isTopName(n ast.Expr, name string) bool {
id, ok := n.(*ast.Ident)
return ok && id.Name == name && id.Obj == nil
}
// renameTop renames all references to the top-level name old.
// It reports whether it makes any changes.
func renameTop(f *ast.File, old, new string) bool {
var fixed bool
// Rename any conflicting imports
// (assuming package name is last element of path).
for _, s := range f.Imports {
if s.Name != nil {
if s.Name.Name == old {
s.Name.Name = new
fixed = true
}
} else {
_, thisName := path.Split(importPath(s))
if thisName == old {
s.Name = ast.NewIdent(new)
fixed = true
}
}
}
// Rename any top-level declarations.
for _, d := range f.Decls {
switch d := d.(type) {
case *ast.FuncDecl:
if d.Recv == nil && d.Name.Name == old {
d.Name.Name = new
d.Name.Obj.Name = new
fixed = true
}
case *ast.GenDecl:
for _, s := range d.Specs {
switch s := s.(type) {
case *ast.TypeSpec:
if s.Name.Name == old {
s.Name.Name = new
s.Name.Obj.Name = new
fixed = true
}
case *ast.ValueSpec:
for _, n := range s.Names {
if n.Name == old {
n.Name = new
n.Obj.Name = new
fixed = true
}
}
}
}
}
}
// Rename top-level old to new, both unresolved names
// (probably defined in another file) and names that resolve
// to a declaration we renamed.
walk(f, func(n any) {
id, ok := n.(*ast.Ident)
if ok && isTopName(id, old) {
id.Name = new
fixed = true
}
if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
id.Name = id.Obj.Name
fixed = true
}
})
return fixed
}
// matchLen returns the length of the longest prefix shared by x and y.
func matchLen(x, y string) int {
i := 0
for i < len(x) && i < len(y) && x[i] == y[i] {
i++
}
return i
}
// addImport adds the import path to the file f, if absent.
func addImport(f *ast.File, ipath string) (added bool) {
if imports(f, ipath) {
return false
}
// Determine name of import.
// Assume added imports follow convention of using last element.
_, name := path.Split(ipath)
// Rename any conflicting top-level references from name to name_.
renameTop(f, name, name+"_")
newImport := &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: strconv.Quote(ipath),
},
}
// Find an import decl to add to.
var (
bestMatch = -1
lastImport = -1
impDecl *ast.GenDecl
impIndex = -1
)
for i, decl := range f.Decls {
gen, ok := decl.(*ast.GenDecl)
if ok && gen.Tok == token.IMPORT {
lastImport = i
// Do not add to import "C", to avoid disrupting the
// association with its doc comment, breaking cgo.
if declImports(gen, "C") {
continue
}
// Compute longest shared prefix with imports in this block.
for j, spec := range gen.Specs {
impspec := spec.(*ast.ImportSpec)
n := matchLen(importPath(impspec), ipath)
if n > bestMatch {
bestMatch = n
impDecl = gen
impIndex = j
}
}
}
}
// If no import decl found, add one after the last import.
if impDecl == nil {
impDecl = &ast.GenDecl{
Tok: token.IMPORT,
}
f.Decls = append(f.Decls, nil)
copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
f.Decls[lastImport+1] = impDecl
}
// Ensure the import decl has parentheses, if needed.
if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
impDecl.Lparen = impDecl.Pos()
}
insertAt := impIndex + 1
if insertAt == 0 {
insertAt = len(impDecl.Specs)
}
impDecl.Specs = append(impDecl.Specs, nil)
copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
impDecl.Specs[insertAt] = newImport
if insertAt > 0 {
// Assign same position as the previous import,
// so that the sorter sees it as being in the same block.
prev := impDecl.Specs[insertAt-1]
newImport.Path.ValuePos = prev.Pos()
newImport.EndPos = prev.Pos()
}
f.Imports = append(f.Imports, newImport)
return true
}
// deleteImport deletes the import path from the file f, if present.
func deleteImport(f *ast.File, path string) (deleted bool) {
oldImport := importSpec(f, path)
// Find the import node that imports path, if any.
for i, decl := range f.Decls {
gen, ok := decl.(*ast.GenDecl)
if !ok || gen.Tok != token.IMPORT {
continue
}
for j, spec := range gen.Specs {
impspec := spec.(*ast.ImportSpec)
if oldImport != impspec {
continue
}
// We found an import spec that imports path.
// Delete it.
deleted = true
copy(gen.Specs[j:], gen.Specs[j+1:])
gen.Specs = gen.Specs[:len(gen.Specs)-1]
// If this was the last import spec in this decl,
// delete the decl, too.
if len(gen.Specs) == 0 {
copy(f.Decls[i:], f.Decls[i+1:])
f.Decls = f.Decls[:len(f.Decls)-1]
} else if len(gen.Specs) == 1 {
gen.Lparen = token.NoPos // drop parens
}
if j > 0 {
// We deleted an entry but now there will be
// a blank line-sized hole where the import was.
// Close the hole by making the previous
// import appear to "end" where this one did.
gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
}
break
}
}
// Delete it from f.Imports.
for i, imp := range f.Imports {
if imp == oldImport {
copy(f.Imports[i:], f.Imports[i+1:])
f.Imports = f.Imports[:len(f.Imports)-1]
break
}
}
return
}
// rewriteImport rewrites any import of path oldPath to path newPath.
func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
for _, imp := range f.Imports {
if importPath(imp) == oldPath {
rewrote = true
// record old End, because the default is to compute
// it using the length of imp.Path.Value.
imp.EndPos = imp.End()
imp.Path.Value = strconv.Quote(newPath)
}
}
return
}

View file

@ -1,16 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
register(gotypesFix)
}
var gotypesFix = fix{
name: "gotypes",
date: "2015-07-16",
f: noop,
desc: `Change imports of golang.org/x/tools/go/{exact,types} to go/{constant,types} (removed)`,
}

View file

@ -1,458 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "go/ast"
func init() {
addTestCases(importTests, nil)
}
var importTests = []testCase{
{
Name: "import.0",
Fn: addImportFn("os"),
In: `package main
import (
"os"
)
`,
Out: `package main
import (
"os"
)
`,
},
{
Name: "import.1",
Fn: addImportFn("os"),
In: `package main
`,
Out: `package main
import "os"
`,
},
{
Name: "import.2",
Fn: addImportFn("os"),
In: `package main
// Comment
import "C"
`,
Out: `package main
// Comment
import "C"
import "os"
`,
},
{
Name: "import.3",
Fn: addImportFn("os"),
In: `package main
// Comment
import "C"
import (
"io"
"utf8"
)
`,
Out: `package main
// Comment
import "C"
import (
"io"
"os"
"utf8"
)
`,
},
{
Name: "import.4",
Fn: deleteImportFn("os"),
In: `package main
import (
"os"
)
`,
Out: `package main
`,
},
{
Name: "import.5",
Fn: deleteImportFn("os"),
In: `package main
// Comment
import "C"
import "os"
`,
Out: `package main
// Comment
import "C"
`,
},
{
Name: "import.6",
Fn: deleteImportFn("os"),
In: `package main
// Comment
import "C"
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
// Comment
import "C"
import (
"io"
"utf8"
)
`,
},
{
Name: "import.7",
Fn: deleteImportFn("io"),
In: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
Out: `package main
import (
// a
"os" // b
"utf8" // c
)
`,
},
{
Name: "import.8",
Fn: deleteImportFn("os"),
In: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
Out: `package main
import (
"io" // a
// b
"utf8" // c
)
`,
},
{
Name: "import.9",
Fn: deleteImportFn("utf8"),
In: `package main
import (
"io" // a
"os" // b
"utf8" // c
)
`,
Out: `package main
import (
"io" // a
"os" // b
// c
)
`,
},
{
Name: "import.10",
Fn: deleteImportFn("io"),
In: `package main
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
import (
"os"
"utf8"
)
`,
},
{
Name: "import.11",
Fn: deleteImportFn("os"),
In: `package main
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
import (
"io"
"utf8"
)
`,
},
{
Name: "import.12",
Fn: deleteImportFn("utf8"),
In: `package main
import (
"io"
"os"
"utf8"
)
`,
Out: `package main
import (
"io"
"os"
)
`,
},
{
Name: "import.13",
Fn: rewriteImportFn("utf8", "encoding/utf8"),
In: `package main
import (
"io"
"os"
"utf8" // thanks ken
)
`,
Out: `package main
import (
"encoding/utf8" // thanks ken
"io"
"os"
)
`,
},
{
Name: "import.14",
Fn: rewriteImportFn("asn1", "encoding/asn1"),
In: `package main
import (
"asn1"
"crypto"
"crypto/rsa"
_ "crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"time"
)
var x = 1
`,
Out: `package main
import (
"crypto"
"crypto/rsa"
_ "crypto/sha1"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"time"
)
var x = 1
`,
},
{
Name: "import.15",
Fn: rewriteImportFn("url", "net/url"),
In: `package main
import (
"bufio"
"net"
"path"
"url"
)
var x = 1 // comment on x, not on url
`,
Out: `package main
import (
"bufio"
"net"
"net/url"
"path"
)
var x = 1 // comment on x, not on url
`,
},
{
Name: "import.16",
Fn: rewriteImportFn("http", "net/http", "template", "text/template"),
In: `package main
import (
"flag"
"http"
"log"
"template"
)
var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
`,
Out: `package main
import (
"flag"
"log"
"net/http"
"text/template"
)
var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
`,
},
{
Name: "import.17",
Fn: addImportFn("x/y/z", "x/a/c"),
In: `package main
// Comment
import "C"
import (
"a"
"b"
"x/w"
"d/f"
)
`,
Out: `package main
// Comment
import "C"
import (
"a"
"b"
"x/a/c"
"x/w"
"x/y/z"
"d/f"
)
`,
},
{
Name: "import.18",
Fn: addDelImportFn("e", "o"),
In: `package main
import (
"f"
"o"
"z"
)
`,
Out: `package main
import (
"e"
"f"
"z"
)
`,
},
}
func addImportFn(path ...string) func(*ast.File) bool {
return func(f *ast.File) bool {
fixed := false
for _, p := range path {
if !imports(f, p) {
addImport(f, p)
fixed = true
}
}
return fixed
}
}
func deleteImportFn(path string) func(*ast.File) bool {
return func(f *ast.File) bool {
if imports(f, path) {
deleteImport(f, path)
return true
}
return false
}
}
func addDelImportFn(p1 string, p2 string) func(*ast.File) bool {
return func(f *ast.File) bool {
fixed := false
if !imports(f, p1) {
addImport(f, p1)
fixed = true
}
if imports(f, p2) {
deleteImport(f, p2)
fixed = true
}
return fixed
}
}
func rewriteImportFn(oldnew ...string) func(*ast.File) bool {
return func(f *ast.File) bool {
fixed := false
for i := 0; i < len(oldnew); i += 2 {
if imports(f, oldnew[i]) {
rewriteImport(f, oldnew[i], oldnew[i+1])
fixed = true
}
}
return fixed
}
}

View file

@ -1,17 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
register(jniFix)
}
var jniFix = fix{
name: "jni",
date: "2017-12-04",
f: noop,
desc: `Fixes initializers of JNI's jobject and subtypes (removed)`,
disabled: false,
}

View file

@ -5,261 +5,27 @@
package main package main
import ( import (
"bytes"
"flag" "flag"
"fmt" "fmt"
"go/ast"
"go/format"
"go/parser"
"go/scanner"
"go/token"
"go/version"
"internal/diff"
"io"
"io/fs"
"os" "os"
"path/filepath"
"slices"
"strings"
"cmd/internal/telemetry/counter"
) )
var ( var (
fset = token.NewFileSet() _ = flag.Bool("diff", false, "obsolete, no effect")
exitCode = 0 _ = flag.String("go", "", "obsolete, no effect")
_ = flag.String("r", "", "obsolete, no effect")
_ = flag.String("force", "", "obsolete, no effect")
) )
var allowedRewrites = flag.String("r", "",
"restrict the rewrites to this comma-separated list")
var forceRewrites = flag.String("force", "",
"force these fixes to run even if the code looks updated")
var allowed, force map[string]bool
var (
doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
goVersion = flag.String("go", "", "go language version for files")
)
// enable for debugging fix failures
const debug = false // display incorrectly reformatted source and exit
func usage() { func usage() {
fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r ignored] [-force ignored] ...\n")
flag.PrintDefaults() flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
slices.SortFunc(fixes, func(a, b fix) int {
return strings.Compare(a.name, b.name)
})
for _, f := range fixes {
if f.disabled {
fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
} else {
fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
}
desc := strings.TrimSpace(f.desc)
desc = strings.ReplaceAll(desc, "\n", "\n\t")
fmt.Fprintf(os.Stderr, "\t%s\n", desc)
}
os.Exit(2) os.Exit(2)
} }
func main() { func main() {
counter.Open()
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
counter.Inc("fix/invocations")
counter.CountFlags("fix/flag:", *flag.CommandLine)
if !version.IsValid(*goVersion) { os.Exit(0)
report(fmt.Errorf("invalid -go=%s", *goVersion))
os.Exit(exitCode)
}
slices.SortFunc(fixes, func(a, b fix) int {
return strings.Compare(a.date, b.date)
})
if *allowedRewrites != "" {
allowed = make(map[string]bool)
for f := range strings.SplitSeq(*allowedRewrites, ",") {
allowed[f] = true
}
}
if *forceRewrites != "" {
force = make(map[string]bool)
for f := range strings.SplitSeq(*forceRewrites, ",") {
force[f] = true
}
}
if flag.NArg() == 0 {
if err := processFile("standard input", true); err != nil {
report(err)
}
os.Exit(exitCode)
}
for i := 0; i < flag.NArg(); i++ {
path := flag.Arg(i)
switch dir, err := os.Stat(path); {
case err != nil:
report(err)
case dir.IsDir():
walkDir(path)
default:
if err := processFile(path, false); err != nil {
report(err)
}
}
}
os.Exit(exitCode)
}
const parserMode = parser.ParseComments
func gofmtFile(f *ast.File) ([]byte, error) {
var buf bytes.Buffer
if err := format.Node(&buf, fset, f); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func processFile(filename string, useStdin bool) error {
var f *os.File
var err error
var fixlog strings.Builder
if useStdin {
f = os.Stdin
} else {
f, err = os.Open(filename)
if err != nil {
return err
}
defer f.Close()
}
src, err := io.ReadAll(f)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, filename, src, parserMode)
if err != nil {
return err
}
// Make sure file is in canonical format.
// This "fmt" pseudo-fix cannot be disabled.
newSrc, err := gofmtFile(file)
if err != nil {
return err
}
if !bytes.Equal(newSrc, src) {
newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
if err != nil {
return err
}
file = newFile
fmt.Fprintf(&fixlog, " fmt")
}
// Apply all fixes to file.
newFile := file
fixed := false
for _, fix := range fixes {
if allowed != nil && !allowed[fix.name] {
continue
}
if fix.disabled && !force[fix.name] {
continue
}
if fix.f(newFile) {
fixed = true
fmt.Fprintf(&fixlog, " %s", fix.name)
// AST changed.
// Print and parse, to update any missing scoping
// or position information for subsequent fixers.
newSrc, err := gofmtFile(newFile)
if err != nil {
return err
}
newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
if err != nil {
if debug {
fmt.Printf("%s", newSrc)
report(err)
os.Exit(exitCode)
}
return err
}
}
}
if !fixed {
return nil
}
fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
// Print AST. We did that after each fix, so this appears
// redundant, but it is necessary to generate gofmt-compatible
// source code in a few cases. The official gofmt style is the
// output of the printer run on a standard AST generated by the parser,
// but the source we generated inside the loop above is the
// output of the printer run on a mangled AST generated by a fixer.
newSrc, err = gofmtFile(newFile)
if err != nil {
return err
}
if *doDiff {
os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
return nil
}
if useStdin {
os.Stdout.Write(newSrc)
return nil
}
return os.WriteFile(f.Name(), newSrc, 0)
}
func gofmt(n any) string {
var gofmtBuf strings.Builder
if err := format.Node(&gofmtBuf, fset, n); err != nil {
return "<" + err.Error() + ">"
}
return gofmtBuf.String()
}
func report(err error) {
scanner.PrintError(os.Stderr, err)
exitCode = 2
}
func walkDir(path string) {
filepath.WalkDir(path, visitFile)
}
func visitFile(path string, f fs.DirEntry, err error) error {
if err == nil && isGoFile(f) {
err = processFile(path, false)
}
if err != nil {
report(err)
}
return nil
}
func isGoFile(f fs.DirEntry) bool {
// ignore non-Go files
name := f.Name()
return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
} }

View file

@ -1,166 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/parser"
"internal/diff"
"internal/testenv"
"strings"
"testing"
)
type testCase struct {
Name string
Fn func(*ast.File) bool
Version string
In string
Out string
}
var testCases []testCase
func addTestCases(t []testCase, fn func(*ast.File) bool) {
// Fill in fn to avoid repetition in definitions.
if fn != nil {
for i := range t {
if t[i].Fn == nil {
t[i].Fn = fn
}
}
}
testCases = append(testCases, t...)
}
func fnop(*ast.File) bool { return false }
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
file, err := parser.ParseFile(fset, desc, in, parserMode)
if err != nil {
t.Errorf("parsing: %v", err)
return
}
outb, err := gofmtFile(file)
if err != nil {
t.Errorf("printing: %v", err)
return
}
if s := string(outb); in != s && mustBeGofmt {
t.Errorf("not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
desc, in, desc, s)
tdiff(t, "want", in, "have", s)
return
}
if fn == nil {
for _, fix := range fixes {
if fix.f(file) {
fixed = true
}
}
} else {
fixed = fn(file)
}
outb, err = gofmtFile(file)
if err != nil {
t.Errorf("printing: %v", err)
return
}
return string(outb), fixed, true
}
func TestRewrite(t *testing.T) {
// If cgo is enabled, enforce that cgo commands invoked by cmd/fix
// do not fail during testing.
if testenv.HasCGO() {
testenv.MustHaveGoBuild(t) // Really just 'go tool cgo', but close enough.
// The reportCgoError hook is global, so we can't set it per-test
// if we want to be able to run those tests in parallel.
// Instead, simply set it to panic on error: the goroutine dump
// from the panic should help us determine which test failed.
prevReportCgoError := reportCgoError
reportCgoError = func(err error) {
panic(fmt.Sprintf("unexpected cgo error: %v", err))
}
t.Cleanup(func() { reportCgoError = prevReportCgoError })
}
for _, tt := range testCases {
tt := tt
t.Run(tt.Name, func(t *testing.T) {
if tt.Version == "" {
if testing.Verbose() {
// Don't run in parallel: cmd/fix sometimes writes directly to stderr,
// and since -v prints which test is currently running we want that
// information to accurately correlate with the stderr output.
} else {
t.Parallel()
}
} else {
old := *goVersion
*goVersion = tt.Version
defer func() {
*goVersion = old
}()
}
// Apply fix: should get tt.Out.
out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
if !ok {
return
}
// reformat to get printing right
out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
if !ok {
return
}
if tt.Out == "" {
tt.Out = tt.In
}
if out != tt.Out {
t.Errorf("incorrect output.\n")
if !strings.HasPrefix(tt.Name, "testdata/") {
t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
}
tdiff(t, "have", out, "want", tt.Out)
return
}
if changed := out != tt.In; changed != fixed {
t.Errorf("changed=%v != fixed=%v", changed, fixed)
return
}
// Should not change if run again.
out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
if !ok {
return
}
if fixed2 {
t.Errorf("applied fixes during second round")
return
}
if out2 != out {
t.Errorf("changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
out, out2)
tdiff(t, "first", out, "second", out2)
}
})
}
}
func tdiff(t *testing.T, aname, a, bname, b string) {
t.Errorf("%s", diff.Diff(aname, []byte(a), bname, []byte(b)))
}

View file

@ -1,19 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
register(netipv6zoneFix)
}
var netipv6zoneFix = fix{
name: "netipv6zone",
date: "2012-11-26",
f: noop,
desc: `Adapt element key to IPAddr, UDPAddr or TCPAddr composite literals (removed).
https://codereview.appspot.com/6849045/
`,
}

View file

@ -1,16 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
func init() {
register(printerconfigFix)
}
var printerconfigFix = fix{
name: "printerconfig",
date: "2012-12-11",
f: noop,
desc: `Add element keys to Config composite literals (removed).`,
}

View file

@ -1,814 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"maps"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
"strings"
)
// Partial type checker.
//
// The fact that it is partial is very important: the input is
// an AST and a description of some type information to
// assume about one or more packages, but not all the
// packages that the program imports. The checker is
// expected to do as much as it can with what it has been
// given. There is not enough information supplied to do
// a full type check, but the type checker is expected to
// apply information that can be derived from variable
// declarations, function and method returns, and type switches
// as far as it can, so that the caller can still tell the types
// of expression relevant to a particular fix.
//
// TODO(rsc,gri): Replace with go/typechecker.
// Doing that could be an interesting test case for go/typechecker:
// the constraints about working with partial information will
// likely exercise it in interesting ways. The ideal interface would
// be to pass typecheck a map from importpath to package API text
// (Go source code), but for now we use data structures (TypeConfig, Type).
//
// The strings mostly use gofmt form.
//
// A Field or FieldList has as its type a comma-separated list
// of the types of the fields. For example, the field list
// x, y, z int
// has type "int, int, int".
// The prefix "type " is the type of a type.
// For example, given
// var x int
// type T int
// x's type is "int" but T's type is "type int".
// mkType inserts the "type " prefix.
// getType removes it.
// isType tests for it.
func mkType(t string) string {
return "type " + t
}
func getType(t string) string {
if !isType(t) {
return ""
}
return t[len("type "):]
}
func isType(t string) bool {
return strings.HasPrefix(t, "type ")
}
// TypeConfig describes the universe of relevant types.
// For ease of creation, the types are all referred to by string
// name (e.g., "reflect.Value"). TypeByName is the only place
// where the strings are resolved.
type TypeConfig struct {
Type map[string]*Type
Var map[string]string
Func map[string]string
// External maps from a name to its type.
// It provides additional typings not present in the Go source itself.
// For now, the only additional typings are those generated by cgo.
External map[string]string
}
// typeof returns the type of the given name, which may be of
// the form "x" or "p.X".
func (cfg *TypeConfig) typeof(name string) string {
if cfg.Var != nil {
if t := cfg.Var[name]; t != "" {
return t
}
}
if cfg.Func != nil {
if t := cfg.Func[name]; t != "" {
return "func()" + t
}
}
return ""
}
// Type describes the Fields and Methods of a type.
// If the field or method cannot be found there, it is next
// looked for in the Embed list.
type Type struct {
Field map[string]string // map field name to type
Method map[string]string // map method name to comma-separated return types (should start with "func ")
Embed []string // list of types this type embeds (for extra methods)
Def string // definition of named type
}
// dot returns the type of "typ.name", making its decision
// using the type information in cfg.
func (typ *Type) dot(cfg *TypeConfig, name string) string {
if typ.Field != nil {
if t := typ.Field[name]; t != "" {
return t
}
}
if typ.Method != nil {
if t := typ.Method[name]; t != "" {
return t
}
}
for _, e := range typ.Embed {
etyp := cfg.Type[e]
if etyp != nil {
if t := etyp.dot(cfg, name); t != "" {
return t
}
}
}
return ""
}
// typecheck type checks the AST f assuming the information in cfg.
// It returns two maps with type information:
// typeof maps AST nodes to type information in gofmt string form.
// assign maps type strings to lists of expressions that were assigned
// to values of another type that were assigned to that type.
func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[any]string, assign map[string][]any) {
typeof = make(map[any]string)
assign = make(map[string][]any)
cfg1 := &TypeConfig{}
*cfg1 = *cfg // make copy so we can add locally
copied := false
// If we import "C", add types of cgo objects.
cfg.External = map[string]string{}
cfg1.External = cfg.External
if imports(f, "C") {
// Run cgo on gofmtFile(f)
// Parse, extract decls from _cgo_gotypes.go
// Map _Ctype_* types to C.* types.
err := func() error {
txt, err := gofmtFile(f)
if err != nil {
return err
}
dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck")
if err != nil {
return err
}
defer os.RemoveAll(dir)
err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600)
if err != nil {
return err
}
goCmd := "go"
if goroot := runtime.GOROOT(); goroot != "" {
goCmd = filepath.Join(goroot, "bin", "go")
}
cmd := exec.Command(goCmd, "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go")
if reportCgoError != nil {
// Since cgo command errors will be reported, also forward the error
// output from the command for debugging.
cmd.Stderr = os.Stderr
}
err = cmd.Run()
if err != nil {
return err
}
out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go"))
if err != nil {
return err
}
cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0)
if err != nil {
return err
}
for _, decl := range cgo.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
if strings.HasPrefix(fn.Name.Name, "_Cfunc_") {
var params, results []string
for _, p := range fn.Type.Params.List {
t := gofmt(p.Type)
t = strings.ReplaceAll(t, "_Ctype_", "C.")
params = append(params, t)
}
for _, r := range fn.Type.Results.List {
t := gofmt(r.Type)
t = strings.ReplaceAll(t, "_Ctype_", "C.")
results = append(results, t)
}
cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results)
}
}
return nil
}()
if err != nil {
if reportCgoError == nil {
fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err)
} else {
reportCgoError(err)
}
}
}
// gather function declarations
for _, decl := range f.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
typecheck1(cfg, fn.Type, typeof, assign)
t := typeof[fn.Type]
if fn.Recv != nil {
// The receiver must be a type.
rcvr := typeof[fn.Recv]
if !isType(rcvr) {
if len(fn.Recv.List) != 1 {
continue
}
rcvr = mkType(gofmt(fn.Recv.List[0].Type))
typeof[fn.Recv.List[0].Type] = rcvr
}
rcvr = getType(rcvr)
if rcvr != "" && rcvr[0] == '*' {
rcvr = rcvr[1:]
}
typeof[rcvr+"."+fn.Name.Name] = t
} else {
if isType(t) {
t = getType(t)
} else {
t = gofmt(fn.Type)
}
typeof[fn.Name] = t
// Record typeof[fn.Name.Obj] for future references to fn.Name.
typeof[fn.Name.Obj] = t
}
}
// gather struct declarations
for _, decl := range f.Decls {
d, ok := decl.(*ast.GenDecl)
if ok {
for _, s := range d.Specs {
switch s := s.(type) {
case *ast.TypeSpec:
if cfg1.Type[s.Name.Name] != nil {
break
}
if !copied {
copied = true
// Copy map lazily: it's time.
cfg1.Type = maps.Clone(cfg.Type)
if cfg1.Type == nil {
cfg1.Type = make(map[string]*Type)
}
}
t := &Type{Field: map[string]string{}}
cfg1.Type[s.Name.Name] = t
switch st := s.Type.(type) {
case *ast.StructType:
for _, f := range st.Fields.List {
for _, n := range f.Names {
t.Field[n.Name] = gofmt(f.Type)
}
}
case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
t.Def = gofmt(st)
}
}
}
}
}
typecheck1(cfg1, f, typeof, assign)
return typeof, assign
}
// reportCgoError, if non-nil, reports a non-nil error from running the "cgo"
// tool. (Set to a non-nil hook during testing if cgo is expected to work.)
var reportCgoError func(err error)
func makeExprList(a []*ast.Ident) []ast.Expr {
var b []ast.Expr
for _, x := range a {
b = append(b, x)
}
return b
}
// typecheck1 is the recursive form of typecheck.
// It is like typecheck but adds to the information in typeof
// instead of allocating a new map.
func typecheck1(cfg *TypeConfig, f any, typeof map[any]string, assign map[string][]any) {
// set sets the type of n to typ.
// If isDecl is true, n is being declared.
set := func(n ast.Expr, typ string, isDecl bool) {
if typeof[n] != "" || typ == "" {
if typeof[n] != typ {
assign[typ] = append(assign[typ], n)
}
return
}
typeof[n] = typ
// If we obtained typ from the declaration of x
// propagate the type to all the uses.
// The !isDecl case is a cheat here, but it makes
// up in some cases for not paying attention to
// struct fields. The real type checker will be
// more accurate so we won't need the cheat.
if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
typeof[id.Obj] = typ
}
}
// Type-check an assignment lhs = rhs.
// If isDecl is true, this is := so we can update
// the types of the objects that lhs refers to.
typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
if len(lhs) > 1 && len(rhs) == 1 {
if _, ok := rhs[0].(*ast.CallExpr); ok {
t := split(typeof[rhs[0]])
// Lists should have same length but may not; pair what can be paired.
for i := 0; i < len(lhs) && i < len(t); i++ {
set(lhs[i], t[i], isDecl)
}
return
}
}
if len(lhs) == 1 && len(rhs) == 2 {
// x = y, ok
rhs = rhs[:1]
} else if len(lhs) == 2 && len(rhs) == 1 {
// x, ok = y
lhs = lhs[:1]
}
// Match as much as we can.
for i := 0; i < len(lhs) && i < len(rhs); i++ {
x, y := lhs[i], rhs[i]
if typeof[y] != "" {
set(x, typeof[y], isDecl)
} else {
set(y, typeof[x], false)
}
}
}
expand := func(s string) string {
typ := cfg.Type[s]
if typ != nil && typ.Def != "" {
return typ.Def
}
return s
}
// The main type check is a recursive algorithm implemented
// by walkBeforeAfter(n, before, after).
// Most of it is bottom-up, but in a few places we need
// to know the type of the function we are checking.
// The before function records that information on
// the curfn stack.
var curfn []*ast.FuncType
before := func(n any) {
// push function type on stack
switch n := n.(type) {
case *ast.FuncDecl:
curfn = append(curfn, n.Type)
case *ast.FuncLit:
curfn = append(curfn, n.Type)
}
}
// After is the real type checker.
after := func(n any) {
if n == nil {
return
}
if false && reflect.TypeOf(n).Kind() == reflect.Pointer { // debugging trace
defer func() {
if t := typeof[n]; t != "" {
pos := fset.Position(n.(ast.Node).Pos())
fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
}
}()
}
switch n := n.(type) {
case *ast.FuncDecl, *ast.FuncLit:
// pop function type off stack
curfn = curfn[:len(curfn)-1]
case *ast.FuncType:
typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
case *ast.FieldList:
// Field list is concatenation of sub-lists.
t := ""
for _, field := range n.List {
if t != "" {
t += ", "
}
t += typeof[field]
}
typeof[n] = t
case *ast.Field:
// Field is one instance of the type per name.
all := ""
t := typeof[n.Type]
if !isType(t) {
// Create a type, because it is typically *T or *p.T
// and we might care about that type.
t = mkType(gofmt(n.Type))
typeof[n.Type] = t
}
t = getType(t)
if len(n.Names) == 0 {
all = t
} else {
for _, id := range n.Names {
if all != "" {
all += ", "
}
all += t
typeof[id.Obj] = t
typeof[id] = t
}
}
typeof[n] = all
case *ast.ValueSpec:
// var declaration. Use type if present.
if n.Type != nil {
t := typeof[n.Type]
if !isType(t) {
t = mkType(gofmt(n.Type))
typeof[n.Type] = t
}
t = getType(t)
for _, id := range n.Names {
set(id, t, true)
}
}
// Now treat same as assignment.
typecheckAssign(makeExprList(n.Names), n.Values, true)
case *ast.AssignStmt:
typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
case *ast.Ident:
// Identifier can take its type from underlying object.
if t := typeof[n.Obj]; t != "" {
typeof[n] = t
}
case *ast.SelectorExpr:
// Field or method.
name := n.Sel.Name
if t := typeof[n.X]; t != "" {
t = strings.TrimPrefix(t, "*") // implicit *
if typ := cfg.Type[t]; typ != nil {
if t := typ.dot(cfg, name); t != "" {
typeof[n] = t
return
}
}
tt := typeof[t+"."+name]
if isType(tt) {
typeof[n] = getType(tt)
return
}
}
// Package selector.
if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
str := x.Name + "." + name
if cfg.Type[str] != nil {
typeof[n] = mkType(str)
return
}
if t := cfg.typeof(x.Name + "." + name); t != "" {
typeof[n] = t
return
}
}
case *ast.CallExpr:
// make(T) has type T.
if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
typeof[n] = gofmt(n.Args[0])
return
}
// new(T) has type *T
if isTopName(n.Fun, "new") && len(n.Args) == 1 {
typeof[n] = "*" + gofmt(n.Args[0])
return
}
// Otherwise, use type of function to determine arguments.
t := typeof[n.Fun]
if t == "" {
t = cfg.External[gofmt(n.Fun)]
}
in, out := splitFunc(t)
if in == nil && out == nil {
return
}
typeof[n] = join(out)
for i, arg := range n.Args {
if i >= len(in) {
break
}
if typeof[arg] == "" {
typeof[arg] = in[i]
}
}
case *ast.TypeAssertExpr:
// x.(type) has type of x.
if n.Type == nil {
typeof[n] = typeof[n.X]
return
}
// x.(T) has type T.
if t := typeof[n.Type]; isType(t) {
typeof[n] = getType(t)
} else {
typeof[n] = gofmt(n.Type)
}
case *ast.SliceExpr:
// x[i:j] has type of x.
typeof[n] = typeof[n.X]
case *ast.IndexExpr:
// x[i] has key type of x's type.
t := expand(typeof[n.X])
if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
// Lazy: assume there are no nested [] in the array
// length or map key type.
if _, elem, ok := strings.Cut(t, "]"); ok {
typeof[n] = elem
}
}
case *ast.StarExpr:
// *x for x of type *T has type T when x is an expr.
// We don't use the result when *x is a type, but
// compute it anyway.
t := expand(typeof[n.X])
if isType(t) {
typeof[n] = "type *" + getType(t)
} else if strings.HasPrefix(t, "*") {
typeof[n] = t[len("*"):]
}
case *ast.UnaryExpr:
// &x for x of type T has type *T.
t := typeof[n.X]
if t != "" && n.Op == token.AND {
typeof[n] = "*" + t
}
case *ast.CompositeLit:
// T{...} has type T.
typeof[n] = gofmt(n.Type)
// Propagate types down to values used in the composite literal.
t := expand(typeof[n])
if strings.HasPrefix(t, "[") { // array or slice
// Lazy: assume there are no nested [] in the array length.
if _, et, ok := strings.Cut(t, "]"); ok {
for _, e := range n.Elts {
if kv, ok := e.(*ast.KeyValueExpr); ok {
e = kv.Value
}
if typeof[e] == "" {
typeof[e] = et
}
}
}
}
if strings.HasPrefix(t, "map[") { // map
// Lazy: assume there are no nested [] in the map key type.
if kt, vt, ok := strings.Cut(t[len("map["):], "]"); ok {
for _, e := range n.Elts {
if kv, ok := e.(*ast.KeyValueExpr); ok {
if typeof[kv.Key] == "" {
typeof[kv.Key] = kt
}
if typeof[kv.Value] == "" {
typeof[kv.Value] = vt
}
}
}
}
}
if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 { // struct
for _, e := range n.Elts {
if kv, ok := e.(*ast.KeyValueExpr); ok {
if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" {
if typeof[kv.Value] == "" {
typeof[kv.Value] = ft
}
}
}
}
}
case *ast.ParenExpr:
// (x) has type of x.
typeof[n] = typeof[n.X]
case *ast.RangeStmt:
t := expand(typeof[n.X])
if t == "" {
return
}
var key, value string
if t == "string" {
key, value = "int", "rune"
} else if strings.HasPrefix(t, "[") {
key = "int"
_, value, _ = strings.Cut(t, "]")
} else if strings.HasPrefix(t, "map[") {
if k, v, ok := strings.Cut(t[len("map["):], "]"); ok {
key, value = k, v
}
}
changed := false
if n.Key != nil && key != "" {
changed = true
set(n.Key, key, n.Tok == token.DEFINE)
}
if n.Value != nil && value != "" {
changed = true
set(n.Value, value, n.Tok == token.DEFINE)
}
// Ugly failure of vision: already type-checked body.
// Do it again now that we have that type info.
if changed {
typecheck1(cfg, n.Body, typeof, assign)
}
case *ast.TypeSwitchStmt:
// Type of variable changes for each case in type switch,
// but go/parser generates just one variable.
// Repeat type check for each case with more precise
// type information.
as, ok := n.Assign.(*ast.AssignStmt)
if !ok {
return
}
varx, ok := as.Lhs[0].(*ast.Ident)
if !ok {
return
}
t := typeof[varx]
for _, cas := range n.Body.List {
cas := cas.(*ast.CaseClause)
if len(cas.List) == 1 {
// Variable has specific type only when there is
// exactly one type in the case list.
if tt := typeof[cas.List[0]]; isType(tt) {
tt = getType(tt)
typeof[varx] = tt
typeof[varx.Obj] = tt
typecheck1(cfg, cas.Body, typeof, assign)
}
}
}
// Restore t.
typeof[varx] = t
typeof[varx.Obj] = t
case *ast.ReturnStmt:
if len(curfn) == 0 {
// Probably can't happen.
return
}
f := curfn[len(curfn)-1]
res := n.Results
if f.Results != nil {
t := split(typeof[f.Results])
for i := 0; i < len(res) && i < len(t); i++ {
set(res[i], t[i], false)
}
}
case *ast.BinaryExpr:
// Propagate types across binary ops that require two args of the same type.
switch n.Op {
case token.EQL, token.NEQ: // TODO: more cases. This is enough for the cftype fix.
if typeof[n.X] != "" && typeof[n.Y] == "" {
typeof[n.Y] = typeof[n.X]
}
if typeof[n.X] == "" && typeof[n.Y] != "" {
typeof[n.X] = typeof[n.Y]
}
}
}
}
walkBeforeAfter(f, before, after)
}
// Convert between function type strings and lists of types.
// Using strings makes this a little harder, but it makes
// a lot of the rest of the code easier. This will all go away
// when we can use go/typechecker directly.
// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
func splitFunc(s string) (in, out []string) {
if !strings.HasPrefix(s, "func(") {
return nil, nil
}
i := len("func(") // index of beginning of 'in' arguments
nparen := 0
for j := i; j < len(s); j++ {
switch s[j] {
case '(':
nparen++
case ')':
nparen--
if nparen < 0 {
// found end of parameter list
out := strings.TrimSpace(s[j+1:])
if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
out = out[1 : len(out)-1]
}
return split(s[i:j]), split(out)
}
}
}
return nil, nil
}
// joinFunc is the inverse of splitFunc.
func joinFunc(in, out []string) string {
outs := ""
if len(out) == 1 {
outs = " " + out[0]
} else if len(out) > 1 {
outs = " (" + join(out) + ")"
}
return "func(" + join(in) + ")" + outs
}
// split splits "int, float" into ["int", "float"] and splits "" into [].
func split(s string) []string {
out := []string{}
i := 0 // current type being scanned is s[i:j].
nparen := 0
for j := 0; j < len(s); j++ {
switch s[j] {
case ' ':
if i == j {
i++
}
case '(':
nparen++
case ')':
nparen--
if nparen < 0 {
// probably can't happen
return nil
}
case ',':
if nparen == 0 {
if i < j {
out = append(out, s[i:j])
}
i = j + 1
}
}
}
if nparen != 0 {
// probably can't happen
return nil
}
if i < len(s) {
out = append(out, s[i:])
}
return out
}
// join is the inverse of split.
func join(x []string) string {
return strings.Join(x, ", ")
}