go/*,cmd/gofmt: guard AST changes with the typeparams build tag

This CL changes our approach to guarding type parameter functionality
and API. Previously, we guarded type parameter functionality with the
parser.parseTypeParams parser mode, and were in the process of hiding
the type parameter API behind the go1.18 build constraint.

These mechanisms had several limitations:
 + Requiring the parser.parseTypeParams mode to be set meant that
   existing tooling would have to opt-in to type parameters in all
   places where it parses Go files.
 + The parseTypeParams mode value had to be copied in several places.
 + go1.18 is not specific to typeparams, making it difficult to set up
   the builders to run typeparams tests.

This CL addresses the above limitations, and completes the task of
hiding the AST API, by switching to a new 'typeparams' build constraint
and adding a new go/internal/typeparams helper package.

The typeparams build constraint is used to conditionally compile the new
AST changes. The typeparams package provides utilities for accessing and
writing the new AST data, so that we don't have to fragment our parser
or type checker logic across build constraints. The typeparams.Enabled
const is used to guard tests that require type parameter support.

The parseTypeParams parser mode is gone, replaced by a new
typeparams.DisableParsing mode with the opposite sense. Now, type
parameters are only parsed if go/parser is compiled with the typeparams
build constraint set AND typeparams.DisableParsing not set. This new
parser mode allows opting out of type parameter parsing for tests.

How exactly to run tests on builders is left to a follow-up CL.

Updates #44933

Change-Id: I3091e42a2e5e2f23e8b2ae584f415a784b9fbd65
Reviewed-on: https://go-review.googlesource.com/c/go/+/300649
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
This commit is contained in:
Rob Findley 2021-03-05 10:57:48 -05:00 committed by Robert Findley
parent 693859542e
commit efaf75a216
35 changed files with 417 additions and 213 deletions

View file

@ -33,10 +33,6 @@ var (
doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") doDiff = flag.Bool("d", false, "display diffs instead of rewriting files")
allErrors = flag.Bool("e", false, "report all errors (not just the first 10 on different lines)") allErrors = flag.Bool("e", false, "report all errors (not just the first 10 on different lines)")
// allowTypeParams controls whether type parameters are allowed in the code
// being formatted. It is enabled for go1.18 in gofmt_go1.18.go.
allowTypeParams = false
// debugging // debugging
cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file") cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file")
) )
@ -53,10 +49,6 @@ const (
printerNormalizeNumbers = 1 << 30 printerNormalizeNumbers = 1 << 30
) )
// parseTypeParams tells go/parser to parse type parameters. Must be kept in
// sync with go/parser/interface.go.
const parseTypeParams parser.Mode = 1 << 30
var ( var (
fileSet = token.NewFileSet() // per process FileSet fileSet = token.NewFileSet() // per process FileSet
exitCode = 0 exitCode = 0
@ -79,9 +71,6 @@ func initParserMode() {
if *allErrors { if *allErrors {
parserMode |= parser.AllErrors parserMode |= parser.AllErrors
} }
if allowTypeParams {
parserMode |= parseTypeParams
}
} }
func isGoFile(f fs.DirEntry) bool { func isGoFile(f fs.DirEntry) bool {

View file

@ -49,12 +49,13 @@ func gofmtFlags(filename string, maxLines int) string {
case scanner.EOF: case scanner.EOF:
return "" return ""
} }
} }
return "" return ""
} }
var typeParamsEnabled = false
func runTest(t *testing.T, in, out string) { func runTest(t *testing.T, in, out string) {
// process flags // process flags
*simplifyAST = false *simplifyAST = false
@ -78,8 +79,10 @@ func runTest(t *testing.T, in, out string) {
// fake flag - pretend input is from stdin // fake flag - pretend input is from stdin
stdin = true stdin = true
case "-G": case "-G":
// fake flag - allow parsing type parameters // fake flag - test is for generic code
allowTypeParams = true if !typeParamsEnabled {
return
}
default: default:
t.Errorf("unrecognized flag name: %s", name) t.Errorf("unrecognized flag name: %s", name)
} }

View file

@ -2,11 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.18 //go:build typeparams
// +build go1.18 // +build typeparams
package main package main
func init() { func init() {
allowTypeParams = true typeParamsEnabled = true
} }

View file

@ -374,13 +374,6 @@ type (
Rparen token.Pos // position of ")" Rparen token.Pos // position of ")"
} }
// A ListExpr node represents a list of expressions separated by commas.
// ListExpr nodes are used as index in IndexExpr nodes representing type
// or function instantiations with more than one type argument.
ListExpr struct {
ElemList []Expr
}
// A StarExpr node represents an expression of the form "*" Expression. // A StarExpr node represents an expression of the form "*" Expression.
// Semantically it could be a unary "*" expression, or a pointer type. // Semantically it could be a unary "*" expression, or a pointer type.
// //
@ -447,14 +440,6 @@ type (
// Pointer types are represented via StarExpr nodes. // Pointer types are represented via StarExpr nodes.
// A FuncType node represents a function type.
FuncType struct {
Func token.Pos // position of "func" keyword (token.NoPos if there is no "func")
TParams *FieldList // type parameters; or nil
Params *FieldList // (incoming) parameters; non-nil
Results *FieldList // (outgoing) results; or nil
}
// An InterfaceType node represents an interface type. // An InterfaceType node represents an interface type.
InterfaceType struct { InterfaceType struct {
Interface token.Pos // position of "interface" keyword Interface token.Pos // position of "interface" keyword
@ -497,18 +482,12 @@ func (x *IndexExpr) Pos() token.Pos { return x.X.Pos() }
func (x *SliceExpr) Pos() token.Pos { return x.X.Pos() } func (x *SliceExpr) Pos() token.Pos { return x.X.Pos() }
func (x *TypeAssertExpr) Pos() token.Pos { return x.X.Pos() } func (x *TypeAssertExpr) Pos() token.Pos { return x.X.Pos() }
func (x *CallExpr) Pos() token.Pos { return x.Fun.Pos() } func (x *CallExpr) Pos() token.Pos { return x.Fun.Pos() }
func (x *ListExpr) Pos() token.Pos { func (x *StarExpr) Pos() token.Pos { return x.Star }
if len(x.ElemList) > 0 { func (x *UnaryExpr) Pos() token.Pos { return x.OpPos }
return x.ElemList[0].Pos() func (x *BinaryExpr) Pos() token.Pos { return x.X.Pos() }
} func (x *KeyValueExpr) Pos() token.Pos { return x.Key.Pos() }
return token.NoPos func (x *ArrayType) Pos() token.Pos { return x.Lbrack }
} func (x *StructType) Pos() token.Pos { return x.Struct }
func (x *StarExpr) Pos() token.Pos { return x.Star }
func (x *UnaryExpr) Pos() token.Pos { return x.OpPos }
func (x *BinaryExpr) Pos() token.Pos { return x.X.Pos() }
func (x *KeyValueExpr) Pos() token.Pos { return x.Key.Pos() }
func (x *ArrayType) Pos() token.Pos { return x.Lbrack }
func (x *StructType) Pos() token.Pos { return x.Struct }
func (x *FuncType) Pos() token.Pos { func (x *FuncType) Pos() token.Pos {
if x.Func.IsValid() || x.Params == nil { // see issue 3870 if x.Func.IsValid() || x.Params == nil { // see issue 3870
return x.Func return x.Func
@ -536,18 +515,12 @@ func (x *IndexExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *SliceExpr) End() token.Pos { return x.Rbrack + 1 } func (x *SliceExpr) End() token.Pos { return x.Rbrack + 1 }
func (x *TypeAssertExpr) End() token.Pos { return x.Rparen + 1 } func (x *TypeAssertExpr) End() token.Pos { return x.Rparen + 1 }
func (x *CallExpr) End() token.Pos { return x.Rparen + 1 } func (x *CallExpr) End() token.Pos { return x.Rparen + 1 }
func (x *ListExpr) End() token.Pos { func (x *StarExpr) End() token.Pos { return x.X.End() }
if len(x.ElemList) > 0 { func (x *UnaryExpr) End() token.Pos { return x.X.End() }
return x.ElemList[len(x.ElemList)-1].End() func (x *BinaryExpr) End() token.Pos { return x.Y.End() }
} func (x *KeyValueExpr) End() token.Pos { return x.Value.End() }
return token.NoPos func (x *ArrayType) End() token.Pos { return x.Elt.End() }
} func (x *StructType) End() token.Pos { return x.Fields.End() }
func (x *StarExpr) End() token.Pos { return x.X.End() }
func (x *UnaryExpr) End() token.Pos { return x.X.End() }
func (x *BinaryExpr) End() token.Pos { return x.Y.End() }
func (x *KeyValueExpr) End() token.Pos { return x.Value.End() }
func (x *ArrayType) End() token.Pos { return x.Elt.End() }
func (x *StructType) End() token.Pos { return x.Fields.End() }
func (x *FuncType) End() token.Pos { func (x *FuncType) End() token.Pos {
if x.Results != nil { if x.Results != nil {
return x.Results.End() return x.Results.End()
@ -573,7 +546,6 @@ func (*IndexExpr) exprNode() {}
func (*SliceExpr) exprNode() {} func (*SliceExpr) exprNode() {}
func (*TypeAssertExpr) exprNode() {} func (*TypeAssertExpr) exprNode() {}
func (*CallExpr) exprNode() {} func (*CallExpr) exprNode() {}
func (*ListExpr) exprNode() {}
func (*StarExpr) exprNode() {} func (*StarExpr) exprNode() {}
func (*UnaryExpr) exprNode() {} func (*UnaryExpr) exprNode() {}
func (*BinaryExpr) exprNode() {} func (*BinaryExpr) exprNode() {}
@ -920,16 +892,6 @@ type (
Values []Expr // initial values; or nil Values []Expr // initial values; or nil
Comment *CommentGroup // line comments; or nil Comment *CommentGroup // line comments; or nil
} }
// A TypeSpec node represents a type declaration (TypeSpec production).
TypeSpec struct {
Doc *CommentGroup // associated documentation; or nil
Name *Ident // type name
TParams *FieldList // type parameters; or nil
Assign token.Pos // position of '=', if any
Type Expr // *Ident, *ParenExpr, *SelectorExpr, *StarExpr, or any of the *XxxTypes
Comment *CommentGroup // line comments; or nil
}
) )
// Pos and End implementations for spec nodes. // Pos and End implementations for spec nodes.

View file

@ -0,0 +1,28 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !typeparams
// +build !typeparams
package ast
import "go/token"
type (
// A FuncType node represents a function type.
FuncType struct {
Func token.Pos // position of "func" keyword (token.NoPos if there is no "func")
Params *FieldList // (incoming) parameters; non-nil
Results *FieldList // (outgoing) results; or nil
}
// A TypeSpec node represents a type declaration (TypeSpec production).
TypeSpec struct {
Doc *CommentGroup // associated documentation; or nil
Name *Ident // type name
Assign token.Pos // position of '=', if any
Type Expr // *Ident, *ParenExpr, *SelectorExpr, *StarExpr, or any of the *XxxTypes
Comment *CommentGroup // line comments; or nil
}
)

View file

@ -0,0 +1,51 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build typeparams
// +build typeparams
package ast
import "go/token"
type (
// A FuncType node represents a function type.
FuncType struct {
Func token.Pos // position of "func" keyword (token.NoPos if there is no "func")
TParams *FieldList // type parameters; or nil
Params *FieldList // (incoming) parameters; non-nil
Results *FieldList // (outgoing) results; or nil
}
// A TypeSpec node represents a type declaration (TypeSpec production).
TypeSpec struct {
Doc *CommentGroup // associated documentation; or nil
Name *Ident // type name
TParams *FieldList // type parameters; or nil
Assign token.Pos // position of '=', if any
Type Expr // *Ident, *ParenExpr, *SelectorExpr, *StarExpr, or any of the *XxxTypes
Comment *CommentGroup // line comments; or nil
}
// A ListExpr node represents a list of expressions separated by commas.
// ListExpr nodes are used as index in IndexExpr nodes representing type
// or function instantiations with more than one type argument.
ListExpr struct {
ElemList []Expr
}
)
func (*ListExpr) exprNode() {}
func (x *ListExpr) Pos() token.Pos {
if len(x.ElemList) > 0 {
return x.ElemList[0].Pos()
}
return token.NoPos
}
func (x *ListExpr) End() token.Pos {
if len(x.ElemList) > 0 {
return x.ElemList[len(x.ElemList)-1].End()
}
return token.NoPos
}

View file

@ -4,8 +4,6 @@
package ast package ast
import "fmt"
// A Visitor's Visit method is invoked for each node encountered by Walk. // A Visitor's Visit method is invoked for each node encountered by Walk.
// If the result visitor w is not nil, Walk visits each of the children // If the result visitor w is not nil, Walk visits each of the children
// of node with the visitor w, followed by a call of w.Visit(nil). // of node with the visitor w, followed by a call of w.Visit(nil).
@ -116,9 +114,6 @@ func Walk(v Visitor, node Node) {
Walk(v, n.X) Walk(v, n.X)
Walk(v, n.Index) Walk(v, n.Index)
case *ListExpr:
walkExprList(v, n.ElemList)
case *SliceExpr: case *SliceExpr:
Walk(v, n.X) Walk(v, n.X)
if n.Low != nil { if n.Low != nil {
@ -166,9 +161,7 @@ func Walk(v Visitor, node Node) {
Walk(v, n.Fields) Walk(v, n.Fields)
case *FuncType: case *FuncType:
if n.TParams != nil { walkFuncTypeParams(v, n)
Walk(v, n.TParams)
}
if n.Params != nil { if n.Params != nil {
Walk(v, n.Params) Walk(v, n.Params)
} }
@ -323,9 +316,7 @@ func Walk(v Visitor, node Node) {
Walk(v, n.Doc) Walk(v, n.Doc)
} }
Walk(v, n.Name) Walk(v, n.Name)
if n.TParams != nil { walkTypeSpecParams(v, n)
Walk(v, n.TParams)
}
Walk(v, n.Type) Walk(v, n.Type)
if n.Comment != nil { if n.Comment != nil {
Walk(v, n.Comment) Walk(v, n.Comment)
@ -372,7 +363,7 @@ func Walk(v Visitor, node Node) {
} }
default: default:
panic(fmt.Sprintf("ast.Walk: unexpected node type %T", n)) walkOtherNodes(v, n)
} }
v.Visit(nil) v.Visit(nil)

View file

@ -0,0 +1,17 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !typeparams
// +build !typeparams
package ast
import "fmt"
func walkFuncTypeParams(v Visitor, n *FuncType) {}
func walkTypeSpecParams(v Visitor, n *TypeSpec) {}
func walkOtherNodes(v Visitor, n Node) {
panic(fmt.Sprintf("ast.Walk: unexpected node type %T", n))
}

View file

@ -0,0 +1,30 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build typeparams
// +build typeparams
package ast
func walkFuncTypeParams(v Visitor, n *FuncType) {
if n.TParams != nil {
Walk(v, n.TParams)
}
}
func walkTypeSpecParams(v Visitor, n *TypeSpec) {
if n.TParams != nil {
Walk(v, n.TParams)
}
}
func walkOtherNodes(v Visitor, n Node) {
if e, ok := n.(*ast.ListExpr); ok {
if e != nil {
Walk(v, e)
}
} else {
panic(fmt.Sprintf("ast.Walk: unexpected node type %T", n))
}
}

View file

@ -278,6 +278,7 @@ var depsRules = `
< go/token < go/token
< go/scanner < go/scanner
< go/ast < go/ast
< go/internal/typeparams
< go/parser; < go/parser;
FMT FMT

View file

@ -0,0 +1,13 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package typeparams provides functions to work with type parameter data
// stored in the AST, while these AST changes are guarded by a build
// constraint.
package typeparams
// DisallowParsing is the numeric value of a parsing mode that disallows type
// parameters. This only matters if the typeparams experiment is active, and
// may be used for running tests that disallow generics.
const DisallowParsing = 1 << 30

View file

@ -0,0 +1,38 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !typeparams
// +build !typeparams
package typeparams
import (
"go/ast"
)
const Enabled = false
func PackExpr(list []ast.Expr) ast.Expr {
switch len(list) {
case 0:
return nil
case 1:
return list[0]
default:
// The parser should not attempt to pack multiple expressions into an
// IndexExpr if type params are disabled.
panic("multiple index expressions are unsupported without type params")
}
}
func UnpackExpr(expr ast.Expr) []ast.Expr {
return []ast.Expr{expr}
}
func Get(ast.Node) *ast.FieldList {
return nil
}
func Set(node ast.Node, params *ast.FieldList) {
}

View file

@ -0,0 +1,61 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build typeparams
// +build typeparams
package typeparams
import (
"fmt"
"go/ast"
)
const Enabled = true
func PackExpr(list []ast.Expr) ast.Expr {
switch len(list) {
case 0:
return nil
case 1:
return list[0]
default:
return &ast.ListExpr{ElemList: list}
}
}
// TODO(gri) Should find a more efficient solution that doesn't
// require introduction of a new slice for simple
// expressions.
func UnpackExpr(x ast.Expr) []ast.Expr {
if x, _ := x.(*ast.ListExpr); x != nil {
return x.ElemList
}
if x != nil {
return []ast.Expr{x}
}
return nil
}
func Get(n ast.Node) *ast.FieldList {
switch n := n.(type) {
case *ast.TypeSpec:
return n.TParams
case *ast.FuncType:
return n.TParams
default:
panic(fmt.Sprintf("node type %T has no type parameters", n))
}
}
func Set(n ast.Node, params *ast.FieldList) {
switch n := n.(type) {
case *ast.TypeSpec:
n.TParams = params
case *ast.FuncType:
n.TParams = params
default:
panic(fmt.Sprintf("node type %T has no type parameters", n))
}
}

View file

@ -23,6 +23,7 @@
package parser package parser
import ( import (
"go/internal/typeparams"
"go/scanner" "go/scanner"
"go/token" "go/token"
"os" "os"
@ -188,7 +189,11 @@ func TestErrors(t *testing.T) {
if !d.IsDir() && !strings.HasPrefix(name, ".") && (strings.HasSuffix(name, ".src") || strings.HasSuffix(name, ".go2")) { if !d.IsDir() && !strings.HasPrefix(name, ".") && (strings.HasSuffix(name, ".src") || strings.HasSuffix(name, ".go2")) {
mode := DeclarationErrors | AllErrors mode := DeclarationErrors | AllErrors
if strings.HasSuffix(name, ".go2") { if strings.HasSuffix(name, ".go2") {
mode |= parseTypeParams if !typeparams.Enabled {
continue
}
} else {
mode |= typeparams.DisallowParsing
} }
checkErrors(t, filepath.Join(testdata, name), nil, mode, true) checkErrors(t, filepath.Join(testdata, name), nil, mode, true)
} }

View file

@ -56,13 +56,6 @@ const (
DeclarationErrors // report declaration errors DeclarationErrors // report declaration errors
SpuriousErrors // same as AllErrors, for backward-compatibility SpuriousErrors // same as AllErrors, for backward-compatibility
AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines) AllErrors = SpuriousErrors // report all errors (not just the first 10 on different lines)
// parseTypeParams controls the parsing of type parameters. Must be
// kept in sync with:
// go/printer/printer_test.go
// go/types/check_test.go
// cmd/gofmt/gofmt.go
parseTypeParams = 1 << 30
) )
// ParseFile parses the source code of a single Go source file and returns // ParseFile parses the source code of a single Go source file and returns

View file

@ -19,6 +19,7 @@ package parser
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/internal/typeparams"
"go/scanner" "go/scanner"
"go/token" "go/token"
"strconv" "strconv"
@ -75,6 +76,10 @@ func (p *parser) init(fset *token.FileSet, filename string, src []byte, mode Mod
p.next() p.next()
} }
func (p *parser) parseTypeParams() bool {
return typeparams.Enabled && p.mode&typeparams.DisallowParsing == 0
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Parsing support // Parsing support
@ -494,7 +499,7 @@ func (p *parser) parseQualifiedIdent(ident *ast.Ident) ast.Expr {
} }
typ := p.parseTypeName(ident) typ := p.parseTypeName(ident)
if p.tok == token.LBRACK && p.mode&parseTypeParams != 0 { if p.tok == token.LBRACK && p.parseTypeParams() {
typ = p.parseTypeInstance(typ) typ = p.parseTypeInstance(typ)
} }
@ -553,7 +558,7 @@ func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Ex
// TODO(rfindley): consider changing parseRhsOrType so that this function variable // TODO(rfindley): consider changing parseRhsOrType so that this function variable
// is not needed. // is not needed.
argparser := p.parseRhsOrType argparser := p.parseRhsOrType
if p.mode&parseTypeParams == 0 { if !p.parseTypeParams() {
argparser = p.parseRhs argparser = p.parseRhs
} }
if p.tok != token.RBRACK { if p.tok != token.RBRACK {
@ -583,19 +588,19 @@ func (p *parser) parseArrayFieldOrTypeInstance(x *ast.Ident) (*ast.Ident, ast.Ex
// x [P]E // x [P]E
return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt} return x, &ast.ArrayType{Lbrack: lbrack, Len: args[0], Elt: elt}
} }
if p.mode&parseTypeParams == 0 { if !p.parseTypeParams() {
p.error(rbrack, "missing element type in array type expression") p.error(rbrack, "missing element type in array type expression")
return nil, &ast.BadExpr{From: args[0].Pos(), To: args[0].End()} return nil, &ast.BadExpr{From: args[0].Pos(), To: args[0].End()}
} }
} }
if p.mode&parseTypeParams == 0 { if !p.parseTypeParams() {
p.error(firstComma, "expected ']', found ','") p.error(firstComma, "expected ']', found ','")
return x, &ast.BadExpr{From: args[0].Pos(), To: args[len(args)-1].End()} return x, &ast.BadExpr{From: args[0].Pos(), To: args[len(args)-1].End()}
} }
// x[P], x[P1, P2], ... // x[P], x[P1, P2], ...
return nil, &ast.IndexExpr{X: x, Lbrack: lbrack, Index: &ast.ListExpr{ElemList: args}, Rbrack: rbrack} return nil, &ast.IndexExpr{X: x, Lbrack: lbrack, Index: typeparams.PackExpr(args), Rbrack: rbrack}
} }
func (p *parser) parseFieldDecl() *ast.Field { func (p *parser) parseFieldDecl() *ast.Field {
@ -878,7 +883,7 @@ func (p *parser) parseParameters(acceptTParams bool) (tparams, params *ast.Field
defer un(trace(p, "Parameters")) defer un(trace(p, "Parameters"))
} }
if p.mode&parseTypeParams != 0 && acceptTParams && p.tok == token.LBRACK { if p.parseTypeParams() && acceptTParams && p.tok == token.LBRACK {
opening := p.pos opening := p.pos
p.next() p.next()
// [T any](params) syntax // [T any](params) syntax
@ -951,7 +956,7 @@ func (p *parser) parseMethodSpec() *ast.Field {
x := p.parseTypeName(nil) x := p.parseTypeName(nil)
if ident, _ := x.(*ast.Ident); ident != nil { if ident, _ := x.(*ast.Ident); ident != nil {
switch { switch {
case p.tok == token.LBRACK && p.mode&parseTypeParams != 0: case p.tok == token.LBRACK && p.parseTypeParams():
// generic method or embedded instantiated type // generic method or embedded instantiated type
lbrack := p.pos lbrack := p.pos
p.next() p.next()
@ -967,7 +972,8 @@ func (p *parser) parseMethodSpec() *ast.Field {
_, params := p.parseParameters(false) _, params := p.parseParameters(false)
results := p.parseResult() results := p.parseResult()
idents = []*ast.Ident{ident} idents = []*ast.Ident{ident}
typ = &ast.FuncType{Func: token.NoPos, TParams: tparams, Params: params, Results: results} typ = &ast.FuncType{Func: token.NoPos, Params: params, Results: results}
typeparams.Set(typ, tparams)
} else { } else {
// embedded instantiated type // embedded instantiated type
// TODO(rfindley) should resolve all identifiers in x. // TODO(rfindley) should resolve all identifiers in x.
@ -984,7 +990,7 @@ func (p *parser) parseMethodSpec() *ast.Field {
p.exprLev-- p.exprLev--
} }
rbrack := p.expectClosing(token.RBRACK, "type argument list") rbrack := p.expectClosing(token.RBRACK, "type argument list")
typ = &ast.IndexExpr{X: ident, Lbrack: lbrack, Index: &ast.ListExpr{ElemList: list}, Rbrack: rbrack} typ = &ast.IndexExpr{X: ident, Lbrack: lbrack, Index: typeparams.PackExpr(list), Rbrack: rbrack}
} }
case p.tok == token.LPAREN: case p.tok == token.LPAREN:
// ordinary method // ordinary method
@ -1000,7 +1006,7 @@ func (p *parser) parseMethodSpec() *ast.Field {
} else { } else {
// embedded, possibly instantiated type // embedded, possibly instantiated type
typ = x typ = x
if p.tok == token.LBRACK && p.mode&parseTypeParams != 0 { if p.tok == token.LBRACK && p.parseTypeParams() {
// embedded instantiated interface // embedded instantiated interface
typ = p.parseTypeInstance(typ) typ = p.parseTypeInstance(typ)
} }
@ -1020,7 +1026,7 @@ func (p *parser) parseInterfaceType() *ast.InterfaceType {
pos := p.expect(token.INTERFACE) pos := p.expect(token.INTERFACE)
lbrace := p.expect(token.LBRACE) lbrace := p.expect(token.LBRACE)
var list []*ast.Field var list []*ast.Field
for p.tok == token.IDENT || p.mode&parseTypeParams != 0 && p.tok == token.TYPE { for p.tok == token.IDENT || p.parseTypeParams() && p.tok == token.TYPE {
if p.tok == token.IDENT { if p.tok == token.IDENT {
list = append(list, p.parseMethodSpec()) list = append(list, p.parseMethodSpec())
} else { } else {
@ -1108,14 +1114,14 @@ func (p *parser) parseTypeInstance(typ ast.Expr) ast.Expr {
closing := p.expectClosing(token.RBRACK, "type argument list") closing := p.expectClosing(token.RBRACK, "type argument list")
return &ast.IndexExpr{X: typ, Lbrack: opening, Index: &ast.ListExpr{ElemList: list}, Rbrack: closing} return &ast.IndexExpr{X: typ, Lbrack: opening, Index: typeparams.PackExpr(list), Rbrack: closing}
} }
func (p *parser) tryIdentOrType() ast.Expr { func (p *parser) tryIdentOrType() ast.Expr {
switch p.tok { switch p.tok {
case token.IDENT: case token.IDENT:
typ := p.parseTypeName(nil) typ := p.parseTypeName(nil)
if p.tok == token.LBRACK && p.mode&parseTypeParams != 0 { if p.tok == token.LBRACK && p.parseTypeParams() {
typ = p.parseTypeInstance(typ) typ = p.parseTypeInstance(typ)
} }
return typ return typ
@ -1360,13 +1366,13 @@ func (p *parser) parseIndexOrSliceOrInstance(x ast.Expr) ast.Expr {
return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack} return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: index[0], Rbrack: rbrack}
} }
if p.mode&parseTypeParams == 0 { if !p.parseTypeParams() {
p.error(firstComma, "expected ']' or ':', found ','") p.error(firstComma, "expected ']' or ':', found ','")
return &ast.BadExpr{From: args[0].Pos(), To: args[len(args)-1].End()} return &ast.BadExpr{From: args[0].Pos(), To: args[len(args)-1].End()}
} }
// instance expression // instance expression
return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: &ast.ListExpr{ElemList: args}, Rbrack: rbrack} return &ast.IndexExpr{X: x, Lbrack: lbrack, Index: typeparams.PackExpr(args), Rbrack: rbrack}
} }
func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr { func (p *parser) parseCallOrConversion(fun ast.Expr) *ast.CallExpr {
@ -2406,7 +2412,7 @@ func (p *parser) parseValueSpec(doc *ast.CommentGroup, _ token.Pos, keyword toke
func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, closeTok token.Token) { func (p *parser) parseGenericType(spec *ast.TypeSpec, openPos token.Pos, name0 *ast.Ident, closeTok token.Token) {
list := p.parseParameterList(name0, closeTok, p.parseParamDecl, true) list := p.parseParameterList(name0, closeTok, p.parseParamDecl, true)
closePos := p.expect(closeTok) closePos := p.expect(closeTok)
spec.TParams = &ast.FieldList{Opening: openPos, List: list, Closing: closePos} typeparams.Set(spec, &ast.FieldList{Opening: openPos, List: list, Closing: closePos})
// Type alias cannot have type parameters. Accept them for robustness but complain. // Type alias cannot have type parameters. Accept them for robustness but complain.
if p.tok == token.ASSIGN { if p.tok == token.ASSIGN {
p.error(p.pos, "generic type cannot be alias") p.error(p.pos, "generic type cannot be alias")
@ -2432,7 +2438,7 @@ func (p *parser) parseTypeSpec(doc *ast.CommentGroup, _ token.Pos, _ token.Token
p.exprLev++ p.exprLev++
x := p.parseExpr() x := p.parseExpr()
p.exprLev-- p.exprLev--
if name0, _ := x.(*ast.Ident); p.mode&parseTypeParams != 0 && name0 != nil && p.tok != token.RBRACK { if name0, _ := x.(*ast.Ident); p.parseTypeParams() && name0 != nil && p.tok != token.RBRACK {
// generic type [T any]; // generic type [T any];
p.parseGenericType(spec, lbrack, name0, token.RBRACK) p.parseGenericType(spec, lbrack, name0, token.RBRACK)
} else { } else {
@ -2537,12 +2543,12 @@ func (p *parser) parseFuncDecl() *ast.FuncDecl {
Name: ident, Name: ident,
Type: &ast.FuncType{ Type: &ast.FuncType{
Func: pos, Func: pos,
TParams: tparams,
Params: params, Params: params,
Results: results, Results: results,
}, },
Body: body, Body: body,
} }
typeparams.Set(decl.Type, tparams)
return decl return decl
} }

View file

@ -7,6 +7,7 @@ package parser
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/internal/typeparams"
"go/token" "go/token"
) )
@ -450,10 +451,10 @@ func (r *resolver) Visit(node ast.Node) ast.Visitor {
// at the identifier in the TypeSpec and ends at the end of the innermost // at the identifier in the TypeSpec and ends at the end of the innermost
// containing block. // containing block.
r.declare(spec, nil, r.topScope, ast.Typ, spec.Name) r.declare(spec, nil, r.topScope, ast.Typ, spec.Name)
if spec.TParams != nil { if tparams := typeparams.Get(spec); tparams != nil {
r.openScope(spec.Pos()) r.openScope(spec.Pos())
defer r.closeScope() defer r.closeScope()
r.walkFieldList(r.topScope, spec.TParams, ast.Typ) r.walkFieldList(r.topScope, tparams, ast.Typ)
} }
ast.Walk(r, spec.Type) ast.Walk(r, spec.Type)
} }
@ -476,7 +477,6 @@ func (r *resolver) Visit(node ast.Node) ast.Visitor {
} }
func (r *resolver) walkFuncType(scope *ast.Scope, typ *ast.FuncType) { func (r *resolver) walkFuncType(scope *ast.Scope, typ *ast.FuncType) {
r.walkFieldList(scope, typ.TParams, ast.Typ)
r.walkFieldList(scope, typ.Params, ast.Var) r.walkFieldList(scope, typ.Params, ast.Var)
r.walkFieldList(scope, typ.Results, ast.Var) r.walkFieldList(scope, typ.Results, ast.Var)
} }

View file

@ -7,6 +7,7 @@ package parser
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/internal/typeparams"
"go/scanner" "go/scanner"
"go/token" "go/token"
"os" "os"
@ -41,7 +42,11 @@ func TestResolution(t *testing.T) {
src := readFile(path) // panics on failure src := readFile(path) // panics on failure
var mode Mode var mode Mode
if strings.HasSuffix(path, ".go2") { if strings.HasSuffix(path, ".go2") {
mode = parseTypeParams if !typeparams.Enabled {
t.Skip("type params are not enabled")
}
} else {
mode |= typeparams.DisallowParsing
} }
file, err := ParseFile(fset, path, src, mode) file, err := ParseFile(fset, path, src, mode)
if err != nil { if err != nil {

View file

@ -6,7 +6,10 @@
package parser package parser
import "testing" import (
"go/internal/typeparams"
"testing"
)
var valids = []string{ var valids = []string{
"package p\n", "package p\n",
@ -130,19 +133,22 @@ func TestValid(t *testing.T) {
} }
}) })
t.Run("tparams", func(t *testing.T) { t.Run("tparams", func(t *testing.T) {
if !typeparams.Enabled {
t.Skip("type params are not enabled")
}
for _, src := range valids { for _, src := range valids {
checkErrors(t, src, src, DeclarationErrors|AllErrors|parseTypeParams, false) checkErrors(t, src, src, DeclarationErrors|AllErrors, false)
} }
for _, src := range validWithTParamsOnly { for _, src := range validWithTParamsOnly {
checkErrors(t, src, src, DeclarationErrors|AllErrors|parseTypeParams, false) checkErrors(t, src, src, DeclarationErrors|AllErrors, false)
} }
}) })
} }
// TestSingle is useful to track down a problem with a single short test program. // TestSingle is useful to track down a problem with a single short test program.
func TestSingle(t *testing.T) { func TestSingle(t *testing.T) {
const src = `package p; var _ = T[P]{}` const src = `package p; var _ = T{}`
checkErrors(t, src, src, DeclarationErrors|AllErrors|parseTypeParams, true) checkErrors(t, src, src, DeclarationErrors|AllErrors, true)
} }
var invalids = []string{ var invalids = []string{
@ -250,21 +256,24 @@ var invalidTParamErrs = []string{
func TestInvalid(t *testing.T) { func TestInvalid(t *testing.T) {
t.Run("no tparams", func(t *testing.T) { t.Run("no tparams", func(t *testing.T) {
for _, src := range invalids { for _, src := range invalids {
checkErrors(t, src, src, DeclarationErrors|AllErrors, true) checkErrors(t, src, src, DeclarationErrors|AllErrors|typeparams.DisallowParsing, true)
} }
for _, src := range validWithTParamsOnly { for _, src := range validWithTParamsOnly {
checkErrors(t, src, src, DeclarationErrors|AllErrors, true) checkErrors(t, src, src, DeclarationErrors|AllErrors|typeparams.DisallowParsing, true)
} }
for _, src := range invalidNoTParamErrs { for _, src := range invalidNoTParamErrs {
checkErrors(t, src, src, DeclarationErrors|AllErrors, true) checkErrors(t, src, src, DeclarationErrors|AllErrors|typeparams.DisallowParsing, true)
} }
}) })
t.Run("tparams", func(t *testing.T) { t.Run("tparams", func(t *testing.T) {
if !typeparams.Enabled {
t.Skip("type params are not enabled")
}
for _, src := range invalids { for _, src := range invalids {
checkErrors(t, src, src, DeclarationErrors|AllErrors|parseTypeParams, true) checkErrors(t, src, src, DeclarationErrors|AllErrors, true)
} }
for _, src := range invalidTParamErrs { for _, src := range invalidTParamErrs {
checkErrors(t, src, src, DeclarationErrors|AllErrors|parseTypeParams, true) checkErrors(t, src, src, DeclarationErrors|AllErrors, true)
} }
}) })
} }

View file

@ -11,6 +11,7 @@ package printer
import ( import (
"bytes" "bytes"
"go/ast" "go/ast"
"go/internal/typeparams"
"go/token" "go/token"
"math" "math"
"strconv" "strconv"
@ -382,8 +383,8 @@ func (p *printer) parameters(fields *ast.FieldList, isTypeParam bool) {
} }
func (p *printer) signature(sig *ast.FuncType) { func (p *printer) signature(sig *ast.FuncType) {
if sig.TParams != nil { if tparams := typeparams.Get(sig); tparams != nil {
p.parameters(sig.TParams, true) p.parameters(tparams, true)
} }
if sig.Params != nil { if sig.Params != nil {
p.parameters(sig.Params, false) p.parameters(sig.Params, false)
@ -870,8 +871,14 @@ func (p *printer) expr1(expr ast.Expr, prec1, depth int) {
// TODO(gri): should treat[] like parentheses and undo one level of depth // TODO(gri): should treat[] like parentheses and undo one level of depth
p.expr1(x.X, token.HighestPrec, 1) p.expr1(x.X, token.HighestPrec, 1)
p.print(x.Lbrack, token.LBRACK) p.print(x.Lbrack, token.LBRACK)
if e, _ := x.Index.(*ast.ListExpr); e != nil { // Note: we're a bit defensive here to handle the case of a ListExpr of
p.exprList(x.Lbrack, e.ElemList, depth+1, commaTerm, x.Rbrack, false) // length 1.
if list := typeparams.UnpackExpr(x.Index); len(list) > 0 {
if len(list) > 1 {
p.exprList(x.Lbrack, list, depth+1, commaTerm, x.Rbrack, false)
} else {
p.expr0(list[0], depth+1)
}
} else { } else {
p.expr0(x.Index, depth+1) p.expr0(x.Index, depth+1)
} }
@ -1628,8 +1635,8 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool) {
case *ast.TypeSpec: case *ast.TypeSpec:
p.setComment(s.Doc) p.setComment(s.Doc)
p.expr(s.Name) p.expr(s.Name)
if s.TParams != nil { if tparams := typeparams.Get(s); tparams != nil {
p.parameters(s.TParams, true) p.parameters(tparams, true)
} }
if n == 1 { if n == 1 {
p.print(blank) p.print(blank)

View file

@ -10,6 +10,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"go/ast" "go/ast"
"go/internal/typeparams"
"go/parser" "go/parser"
"go/token" "go/token"
"io" "io"
@ -19,10 +20,6 @@ import (
"time" "time"
) )
// parseTypeParams tells go/parser to parse type parameters. Must be kept in
// sync with go/parser/interface.go.
const parseTypeParams parser.Mode = 1 << 30
const ( const (
dataDir = "testdata" dataDir = "testdata"
tabwidth = 8 tabwidth = 8
@ -47,11 +44,7 @@ const (
// if any. // if any.
func format(src []byte, mode checkMode) ([]byte, error) { func format(src []byte, mode checkMode) ([]byte, error) {
// parse src // parse src
parseMode := parser.ParseComments f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
if mode&allowTypeParams != 0 {
parseMode |= parseTypeParams
}
f, err := parser.ParseFile(fset, "", src, parseMode)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse: %s\n%s", err, src) return nil, fmt.Errorf("parse: %s\n%s", err, src)
} }
@ -79,7 +72,7 @@ func format(src []byte, mode checkMode) ([]byte, error) {
// make sure formatted output is syntactically correct // make sure formatted output is syntactically correct
res := buf.Bytes() res := buf.Bytes()
if _, err := parser.ParseFile(fset, "", res, parseTypeParams); err != nil { if _, err := parser.ParseFile(fset, "", res, parser.ParseComments); err != nil {
return nil, fmt.Errorf("re-parse: %s\n%s", err, buf.Bytes()) return nil, fmt.Errorf("re-parse: %s\n%s", err, buf.Bytes())
} }
@ -210,7 +203,7 @@ var data = []entry{
{"linebreaks.input", "linebreaks.golden", idempotent}, {"linebreaks.input", "linebreaks.golden", idempotent},
{"expressions.input", "expressions.golden", idempotent}, {"expressions.input", "expressions.golden", idempotent},
{"expressions.input", "expressions.raw", rawFormat | idempotent}, {"expressions.input", "expressions.raw", rawFormat | idempotent},
{"declarations.input", "declarations.golden", allowTypeParams}, {"declarations.input", "declarations.golden", 0},
{"statements.input", "statements.golden", 0}, {"statements.input", "statements.golden", 0},
{"slow.input", "slow.golden", idempotent}, {"slow.input", "slow.golden", idempotent},
{"complit.input", "complit.x", export}, {"complit.input", "complit.x", export},
@ -229,6 +222,9 @@ var data = []entry{
func TestFiles(t *testing.T) { func TestFiles(t *testing.T) {
t.Parallel() t.Parallel()
for _, e := range data { for _, e := range data {
if !typeparams.Enabled && e.mode&allowTypeParams != 0 {
continue
}
source := filepath.Join(dataDir, e.source) source := filepath.Join(dataDir, e.source)
golden := filepath.Join(dataDir, e.golden) golden := filepath.Join(dataDir, e.golden)
mode := e.mode mode := e.mode

View file

@ -942,13 +942,6 @@ type _ interface {
x ...int) x ...int)
} }
// properly format one-line type lists
type _ interface{ type a }
type _ interface {
type a, b, c
}
// omit superfluous parentheses in parameter lists // omit superfluous parentheses in parameter lists
func _(int) func _(int)
func _(int) func _(int)
@ -999,10 +992,6 @@ func _(struct {
y int y int
}) // no extra comma between } and ) }) // no extra comma between } and )
// type parameters
func _[A, B any](a A, b B) int {}
func _[T any](x, y T) T
// alias declarations // alias declarations
type c0 struct{} type c0 struct{}

View file

@ -955,11 +955,6 @@ r string,
x ...int) x ...int)
} }
// properly format one-line type lists
type _ interface { type a }
type _ interface { type a,b,c }
// omit superfluous parentheses in parameter lists // omit superfluous parentheses in parameter lists
func _((int)) func _((int))
func _((((((int)))))) func _((((((int))))))
@ -1010,10 +1005,6 @@ func _(struct {
y int y int
}) // no extra comma between } and ) }) // no extra comma between } and )
// type parameters
func _[A, B any](a A, b B) int {}
func _[T any](x, y T) T
// alias declarations // alias declarations
type c0 struct{} type c0 struct{}

View file

@ -4,6 +4,9 @@
package generics package generics
func _[A, B any](a A, b B) int {}
func _[T any](x, y T) T
type T[P any] struct{} type T[P any] struct{}
type T[P1, P2, P3 any] struct{} type T[P1, P2, P3 any] struct{}
@ -31,3 +34,10 @@ func _() {
var _ []T[P] var _ []T[P]
_ = []T[P]{} _ = []T[P]{}
} }
// properly format one-line type lists
type _ interface{ type a }
type _ interface {
type a, b, c
}

View file

@ -4,6 +4,9 @@
package generics package generics
func _[A, B any](a A, b B) int {}
func _[T any](x, y T) T
type T[P any] struct{} type T[P any] struct{}
type T[P1, P2, P3 any] struct{} type T[P1, P2, P3 any] struct{}
@ -28,3 +31,8 @@ func _() {
var _ []T[P] var _ []T[P]
_ = []T[P]{} _ = []T[P]{}
} }
// properly format one-line type lists
type _ interface { type a }
type _ interface { type a,b,c }

View file

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/importer" "go/importer"
"go/internal/typeparams"
"go/parser" "go/parser"
"go/token" "go/token"
"internal/testenv" "internal/testenv"
@ -48,8 +49,8 @@ func mustTypecheck(t *testing.T, path, source string, info *Info) string {
const genericPkg = "package generic_" const genericPkg = "package generic_"
func modeForSource(src string) parser.Mode { func modeForSource(src string) parser.Mode {
if strings.HasPrefix(src, genericPkg) { if !strings.HasPrefix(src, genericPkg) {
return parseTypeParams return typeparams.DisallowParsing
} }
return 0 return 0
} }
@ -347,6 +348,9 @@ func TestTypesInfo(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
if strings.HasPrefix(test.src, genericPkg) && !typeparams.Enabled {
continue
}
info := Info{Types: make(map[ast.Expr]TypeAndValue)} info := Info{Types: make(map[ast.Expr]TypeAndValue)}
var name string var name string
if strings.HasPrefix(test.src, broken) { if strings.HasPrefix(test.src, broken) {
@ -401,6 +405,9 @@ func TestDefsInfo(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
if strings.HasPrefix(test.src, genericPkg) && !typeparams.Enabled {
continue
}
info := Info{ info := Info{
Defs: make(map[*ast.Ident]Object), Defs: make(map[*ast.Ident]Object),
} }
@ -446,6 +453,9 @@ func TestUsesInfo(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
if strings.HasPrefix(test.src, genericPkg) && !typeparams.Enabled {
continue
}
info := Info{ info := Info{
Uses: make(map[*ast.Ident]Object), Uses: make(map[*ast.Ident]Object),
} }

View file

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.18 //go:build typeparams
// +build go1.18 // +build typeparams
package types package types

View file

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.18 //go:build typeparams
// +build go1.18 // +build typeparams
package types_test package types_test

View file

@ -303,20 +303,6 @@ func (check *Checker) assignVars(lhs, origRHS []ast.Expr) {
} }
} }
// unpack unpacks an *ast.ListExpr into a list of ast.Expr.
// TODO(gri) Should find a more efficient solution that doesn't
// require introduction of a new slice for simple
// expressions.
func unpackExpr(x ast.Expr) []ast.Expr {
if x, _ := x.(*ast.ListExpr); x != nil {
return x.ElemList
}
if x != nil {
return []ast.Expr{x}
}
return nil
}
func (check *Checker) shortVarDecl(pos positioner, lhs, rhs []ast.Expr) { func (check *Checker) shortVarDecl(pos positioner, lhs, rhs []ast.Expr) {
top := len(check.delayed) top := len(check.delayed)
scope := check.scope scope := check.scope

View file

@ -8,6 +8,7 @@ package types
import ( import (
"go/ast" "go/ast"
"go/internal/typeparams"
"go/token" "go/token"
"strings" "strings"
"unicode" "unicode"
@ -16,7 +17,8 @@ import (
// funcInst type-checks a function instantiaton inst and returns the result in x. // funcInst type-checks a function instantiaton inst and returns the result in x.
// The operand x must be the evaluation of inst.X and its type must be a signature. // The operand x must be the evaluation of inst.X and its type must be a signature.
func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) { func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) {
args, ok := check.exprOrTypeList(unpackExpr(inst.Index)) exprs := typeparams.UnpackExpr(inst.Index)
args, ok := check.exprOrTypeList(exprs)
if !ok { if !ok {
x.mode = invalid x.mode = invalid
x.expr = inst x.expr = inst

View file

@ -30,6 +30,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/importer" "go/importer"
"go/internal/typeparams"
"go/parser" "go/parser"
"go/scanner" "go/scanner"
"go/token" "go/token"
@ -43,10 +44,6 @@ import (
. "go/types" . "go/types"
) )
// parseTypeParams tells go/parser to parse type parameters. Must be kept in
// sync with go/parser/interface.go.
const parseTypeParams parser.Mode = 1 << 30
var ( var (
haltOnError = flag.Bool("halt", false, "halt on error") haltOnError = flag.Bool("halt", false, "halt on error")
listErrors = flag.Bool("errlist", false, "list errors") listErrors = flag.Bool("errlist", false, "list errors")
@ -213,7 +210,11 @@ func checkFiles(t *testing.T, goVersion string, filenames []string, srcs [][]byt
mode := parser.AllErrors mode := parser.AllErrors
if strings.HasSuffix(filenames[0], ".go2") { if strings.HasSuffix(filenames[0], ".go2") {
mode |= parseTypeParams if !typeparams.Enabled {
t.Skip("type params are not enabled")
}
} else {
mode |= typeparams.DisallowParsing
} }
// parse files and collect parser errors // parse files and collect parser errors

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/constant" "go/constant"
"go/internal/typeparams"
"go/token" "go/token"
) )
@ -645,7 +646,7 @@ func (check *Checker) typeDecl(obj *TypeName, tdecl *ast.TypeSpec, def *Named) {
}) })
alias := tdecl.Assign.IsValid() alias := tdecl.Assign.IsValid()
if alias && tdecl.TParams != nil { if alias && typeparams.Get(tdecl) != nil {
// The parser will ensure this but we may still get an invalid AST. // The parser will ensure this but we may still get an invalid AST.
// Complain and continue as regular type definition. // Complain and continue as regular type definition.
check.error(atPos(tdecl.Assign), 0, "generic type cannot be alias") check.error(atPos(tdecl.Assign), 0, "generic type cannot be alias")
@ -668,10 +669,10 @@ func (check *Checker) typeDecl(obj *TypeName, tdecl *ast.TypeSpec, def *Named) {
def.setUnderlying(named) def.setUnderlying(named)
obj.typ = named // make sure recursive type declarations terminate obj.typ = named // make sure recursive type declarations terminate
if tdecl.TParams != nil { if tparams := typeparams.Get(tdecl); tparams != nil {
check.openScope(tdecl, "type parameters") check.openScope(tdecl, "type parameters")
defer check.closeScope() defer check.closeScope()
named.tparams = check.collectTypeParams(tdecl.TParams) named.tparams = check.collectTypeParams(tparams)
} }
// determine underlying type of named // determine underlying type of named

View file

@ -10,6 +10,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"go/ast" "go/ast"
"go/internal/typeparams"
) )
// ExprString returns the (possibly shortened) string representation for x. // ExprString returns the (possibly shortened) string representation for x.
@ -69,16 +70,14 @@ func WriteExpr(buf *bytes.Buffer, x ast.Expr) {
case *ast.IndexExpr: case *ast.IndexExpr:
WriteExpr(buf, x.X) WriteExpr(buf, x.X)
buf.WriteByte('[') buf.WriteByte('[')
WriteExpr(buf, x.Index) exprs := typeparams.UnpackExpr(x.Index)
buf.WriteByte(']') for i, e := range exprs {
case *ast.ListExpr:
for i, e := range x.ElemList {
if i > 0 { if i > 0 {
buf.WriteString(", ") buf.WriteString(", ")
} }
WriteExpr(buf, e) WriteExpr(buf, e)
} }
buf.WriteByte(']')
case *ast.SliceExpr: case *ast.SliceExpr:
WriteExpr(buf, x.X) WriteExpr(buf, x.X)

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/constant" "go/constant"
"go/internal/typeparams"
"go/token" "go/token"
"sort" "sort"
"strconv" "strconv"
@ -389,8 +390,8 @@ func (check *Checker) collectObjects() {
if name == "main" { if name == "main" {
code = _InvalidMainDecl code = _InvalidMainDecl
} }
if d.decl.Type.TParams != nil { if tparams := typeparams.Get(d.decl.Type); tparams != nil {
check.softErrorf(d.decl.Type.TParams, code, "func %s must have no type parameters", name) check.softErrorf(tparams, code, "func %s must have no type parameters", name)
} }
if t := d.decl.Type; t.Params.NumFields() != 0 || t.Results != nil { if t := d.decl.Type; t.Params.NumFields() != 0 || t.Results != nil {
// TODO(rFindley) Should this be a hard error? // TODO(rFindley) Should this be a hard error?
@ -497,7 +498,7 @@ L: // unpack receiver type
if ptyp, _ := rtyp.(*ast.IndexExpr); ptyp != nil { if ptyp, _ := rtyp.(*ast.IndexExpr); ptyp != nil {
rtyp = ptyp.X rtyp = ptyp.X
if unpackParams { if unpackParams {
for _, arg := range unpackExpr(ptyp.Index) { for _, arg := range typeparams.UnpackExpr(ptyp.Index) {
var par *ast.Ident var par *ast.Ident
switch arg := arg.(type) { switch arg := arg.(type) {
case *ast.Ident: case *ast.Ident:

View file

@ -10,6 +10,7 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/constant" "go/constant"
"go/internal/typeparams"
"go/token" "go/token"
"sort" "sort"
"strconv" "strconv"
@ -209,27 +210,22 @@ func isubst(x ast.Expr, smap map[*ast.Ident]*ast.Ident) ast.Expr {
return &new return &new
} }
case *ast.IndexExpr: case *ast.IndexExpr:
index := isubst(n.Index, smap) elems := typeparams.UnpackExpr(n.Index)
if index != n.Index { var newElems []ast.Expr
new := *n for i, elem := range elems {
new.Index = index
return &new
}
case *ast.ListExpr:
var elems []ast.Expr
for i, elem := range n.ElemList {
new := isubst(elem, smap) new := isubst(elem, smap)
if new != elem { if new != elem {
if elems == nil { if newElems == nil {
elems = make([]ast.Expr, len(n.ElemList)) newElems = make([]ast.Expr, len(elems))
copy(elems, n.ElemList) copy(newElems, elems)
} }
elems[i] = new newElems[i] = new
} }
} }
if elems != nil { if newElems != nil {
index := typeparams.PackExpr(newElems)
new := *n new := *n
new.ElemList = elems new.Index = index
return &new return &new
} }
case *ast.ParenExpr: case *ast.ParenExpr:
@ -316,13 +312,13 @@ func (check *Checker) funcType(sig *Signature, recvPar *ast.FieldList, ftyp *ast
} }
} }
if ftyp.TParams != nil { if tparams := typeparams.Get(ftyp); tparams != nil {
sig.tparams = check.collectTypeParams(ftyp.TParams) sig.tparams = check.collectTypeParams(tparams)
// Always type-check method type parameters but complain that they are not allowed. // Always type-check method type parameters but complain that they are not allowed.
// (A separate check is needed when type-checking interface method signatures because // (A separate check is needed when type-checking interface method signatures because
// they don't have a receiver specification.) // they don't have a receiver specification.)
if recvPar != nil { if recvPar != nil {
check.errorf(ftyp.TParams, _Todo, "methods cannot have type parameters") check.errorf(tparams, _Todo, "methods cannot have type parameters")
} }
} }
@ -467,7 +463,8 @@ func (check *Checker) typInternal(e0 ast.Expr, def *Named) (T Type) {
} }
case *ast.IndexExpr: case *ast.IndexExpr:
return check.instantiatedType(e.X, unpackExpr(e.Index), def) exprs := typeparams.UnpackExpr(e.Index)
return check.instantiatedType(e.X, exprs, def)
case *ast.ParenExpr: case *ast.ParenExpr:
// Generic types must be instantiated before they can be used in any form. // Generic types must be instantiated before they can be used in any form.
@ -801,7 +798,11 @@ func (check *Checker) interfaceType(ityp *Interface, iface *ast.InterfaceType, d
// (This extra check is needed here because interface method signatures don't have // (This extra check is needed here because interface method signatures don't have
// a receiver specification.) // a receiver specification.)
if sig.tparams != nil { if sig.tparams != nil {
check.errorf(f.Type.(*ast.FuncType).TParams, _Todo, "methods cannot have type parameters") var at positioner = f.Type
if tparams := typeparams.Get(f.Type); tparams != nil {
at = tparams
}
check.errorf(at, _Todo, "methods cannot have type parameters")
} }
// use named receiver type if available (for better error messages) // use named receiver type if available (for better error messages)