go/types: combine all type inference in a single function

This is a port of CL 306170 to go/types, adjusted for the different
positioning API.

Some of the error positions in tests had to be adjusted, but I think the
new locations are better.

Change-Id: Ib157fbb47d7483e3c6302bd57f5070bd74602a36
Reviewed-on: https://go-review.googlesource.com/c/go/+/312191
Trust: Robert Findley <rfindley@google.com>
Run-TryBot: Robert Findley <rfindley@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
This commit is contained in:
Rob Findley 2021-04-20 22:59:59 -04:00 committed by Robert Findley
parent 6639bb894d
commit 7bedd47798
6 changed files with 176 additions and 99 deletions

View file

@ -13,4 +13,4 @@ type S struct {}
func NewS[T any]() *S func NewS[T any]() *S
func (_ *S /* ERROR S is not a generic type */ [T]) M() func (_ *S /* ERROR S is not a generic type */ [T]) M()

View file

@ -26,54 +26,39 @@ func (check *Checker) funcInst(x *operand, inst *ast.IndexExpr) {
} }
assert(len(targs) == len(xlist)) assert(len(targs) == len(xlist))
// check number of type arguments // check number of type arguments (got) vs number of type parameters (want)
n := len(targs)
sig := x.typ.(*Signature) sig := x.typ.(*Signature)
if n > len(sig.tparams) { got, want := len(targs), len(sig.tparams)
check.errorf(xlist[n-1], _Todo, "got %d type arguments but want %d", n, len(sig.tparams)) if got > want {
check.errorf(xlist[got-1], _Todo, "got %d type arguments but want %d", got, want)
x.mode = invalid x.mode = invalid
x.expr = inst x.expr = inst
return return
} }
// determine argument positions (for error reporting) // if we don't have enough type arguments, try type inference
// TODO(rFindley) use a positioner here? instantiate would need to be inferred := false
// updated accordingly.
poslist := make([]token.Pos, n)
for i, x := range xlist {
poslist[i] = x.Pos()
}
// if we don't have enough type arguments, use constraint type inference if got < want {
var inferred bool targs = check.infer(inst, sig.tparams, targs, nil, nil, true)
if n < len(sig.tparams) {
var failed int
targs, failed = check.inferB(sig.tparams, targs)
if targs == nil { if targs == nil {
// error was already reported // error was already reported
x.mode = invalid x.mode = invalid
x.expr = inst x.expr = inst
return return
} }
if failed >= 0 { got = len(targs)
// at least one type argument couldn't be inferred
assert(targs[failed] == nil)
tpar := sig.tparams[failed]
check.errorf(inNode(inst, inst.Rbrack), 0, "cannot infer %s (%v) (%s)", tpar.name, tpar.pos, targs)
x.mode = invalid
x.expr = inst
return
}
// all type arguments were inferred successfully
if debug {
for _, targ := range targs {
assert(targ != nil)
}
}
n = len(targs)
inferred = true inferred = true
} }
assert(n == len(sig.tparams)) assert(got == want)
// determine argument positions (for error reporting)
// TODO(rFindley) use a positioner here? instantiate would need to be
// updated accordingly.
poslist := make([]token.Pos, len(xlist))
for i, x := range xlist {
poslist[i] = x.Pos()
}
// instantiate function signature // instantiate function signature
res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature) res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature)
@ -307,32 +292,10 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, args []*oper
if len(sig.tparams) > 0 { if len(sig.tparams) > 0 {
// TODO(gri) provide position information for targs so we can feed // TODO(gri) provide position information for targs so we can feed
// it to the instantiate call for better error reporting // it to the instantiate call for better error reporting
targs, failed := check.infer(sig.tparams, sigParams, args) targs := check.infer(call, sig.tparams, nil, sigParams, args, true)
if targs == nil { if targs == nil {
return // error already reported return // error already reported
} }
if failed >= 0 {
// Some type arguments couldn't be inferred. Use
// bounds type inference to try to make progress.
targs, failed = check.inferB(sig.tparams, targs)
if targs == nil {
return // error already reported
}
if failed >= 0 {
// at least one type argument couldn't be inferred
assert(targs[failed] == nil)
tpar := sig.tparams[failed]
ppos := check.fset.Position(tpar.pos).String()
check.errorf(inNode(call, call.Rparen), _Todo, "cannot infer %s (%s) (%s)", tpar.name, ppos, targs)
return
}
}
// all type arguments were inferred successfully
if debug {
for _, targ := range targs {
assert(targ != nil)
}
}
// compute result signature // compute result signature
rsig = check.instantiate(call.Pos(), sig, targs, nil).(*Signature) rsig = check.instantiate(call.Pos(), sig, targs, nil).(*Signature)
@ -544,13 +507,13 @@ func (check *Checker) selector(x *operand, e *ast.SelectorExpr) {
recv = recv.(*Pointer).base recv = recv.(*Pointer).base
} }
} }
// Disable reporting of errors during inference below. If we're unable to infer
// the receiver type arguments here, the receiver must be be otherwise invalid
// and an error has been reported elsewhere.
arg := operand{mode: variable, expr: x.expr, typ: recv} arg := operand{mode: variable, expr: x.expr, typ: recv}
targs, failed := check.infer(sig.rparams, NewTuple(sig.recv), []*operand{&arg}) targs := check.infer(m, sig.rparams, nil, NewTuple(sig.recv), []*operand{&arg}, false /* no error reporting */)
if failed >= 0 { if targs == nil {
// We may reach here if there were other errors (see issue #40056). // We may reach here if there were other errors (see issue #40056).
// check.infer will report a follow-up error.
// TODO(gri) avoid the follow-up error as it is confusing
// (there's no inference in the source code)
goto Error goto Error
} }
// Don't modify m. Instead - for now - make a copy of m and use that instead. // Don't modify m. Instead - for now - make a copy of m and use that instead.

View file

@ -89,4 +89,4 @@ func F26[Z any]() T26 { return F26[] /* ERROR operand */ }
// crash 27 // crash 27
func e27[T any]() interface{ x27 /* ERROR not a type */ } func e27[T any]() interface{ x27 /* ERROR not a type */ }
func x27() { e27() /* ERROR cannot infer T */ } func x27() { e27 /* ERROR cannot infer T */ () }

View file

@ -5,11 +5,11 @@
package p package p
func _() { func _() {
NewS() /* ERROR cannot infer T */ .M() NewS /* ERROR cannot infer T */ ().M()
} }
type S struct {} type S struct {}
func NewS[T any]() *S func NewS[T any]() *S
func (_ *S /* ERROR S is not a generic type */ [T]) M() func (_ *S /* ERROR S is not a generic type */ [T]) M()

View file

@ -12,22 +12,104 @@ import (
"strings" "strings"
) )
// infer returns the list of actual type arguments for the given list of type parameters tparams // infer attempts to infer the complete set of type arguments for generic function instantiation/call
// by inferring them from the actual arguments args for the parameters params. If type inference // based on the given type parameters tparams, type arguments targs, function parameters params, and
// is impossible because unification fails, an error is reported and the resulting types list is // function arguments args, if any. There must be at least one type parameter, no more type arguments
// nil, and index is 0. Otherwise, types is the list of inferred type arguments, and index is // than type parameters, and params and args must match in number (incl. zero).
// the index of the first type argument in that list that couldn't be inferred (and thus is nil). // If successful, infer returns the complete list of type arguments, one for each type parameter.
// If all type arguments were inferred successfully, index is < 0. // Otherwise the result is nil and appropriate errors will be reported unless report is set to false.
func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand) (types []Type, index int) { //
// Inference proceeds in 3 steps:
//
// 1) Start with given type arguments.
// 2) Infer type arguments from typed function arguments.
// 3) Infer type arguments from untyped function arguments.
//
// Constraint type inference is used after each step to expand the set of type arguments.
//
func (check *Checker) infer(posn positioner, tparams []*TypeName, targs []Type, params *Tuple, args []*operand, report bool) (result []Type) {
if debug {
defer func() {
assert(result == nil || len(result) == len(tparams))
for _, targ := range result {
assert(targ != nil)
}
//check.dump("### inferred targs = %s", result)
}()
}
// There must be at least one type parameter, and no more type arguments than type parameters.
n := len(tparams)
assert(n > 0 && len(targs) <= n)
// Function parameters and arguments must match in number.
assert(params.Len() == len(args)) assert(params.Len() == len(args))
// --- 0 ---
// If we already have all type arguments, we're done.
if len(targs) == n {
return targs
}
// len(targs) < n
// --- 1 ---
// Explicitly provided type arguments take precedence over any inferred types;
// and types inferred via constraint type inference take precedence over types
// inferred from function arguments.
// If we have type arguments, see how far we get with constraint type inference.
if len(targs) > 0 {
var index int
targs, index = check.inferB(tparams, targs, report)
if targs == nil || index < 0 {
return targs
}
}
// Continue with the type arguments we have now. Avoid matching generic
// parameters that already have type arguments against function arguments:
// It may fail because matching uses type identity while parameter passing
// uses assignment rules. Instantiate the parameter list with the type
// arguments we have, and continue with that parameter list.
// First, make sure we have a "full" list of type arguments, so of which
// may be nil (unknown).
if len(targs) < n {
targs2 := make([]Type, n)
copy(targs2, targs)
targs = targs2
}
// len(targs) == n
// Substitute type arguments for their respective type parameters in params,
// if any. Note that nil targs entries are ignored by check.subst.
// TODO(gri) Can we avoid this (we're setting known type argumemts below,
// but that doesn't impact the isParameterized check for now).
if params.Len() > 0 {
smap := makeSubstMap(tparams, targs)
params = check.subst(token.NoPos, params, smap).(*Tuple)
}
// --- 2 ---
// Unify parameter and argument types for generic parameters with typed arguments
// and collect the indices of generic parameters with untyped arguments.
// Terminology: generic parameter = function parameter with a type-parameterized type
u := newUnifier(check, false) u := newUnifier(check, false)
u.x.init(tparams) u.x.init(tparams)
// Set the type arguments which we know already.
for i, targ := range targs {
if targ != nil {
u.x.set(i, targ)
}
}
errorf := func(kind string, tpar, targ Type, arg *operand) { errorf := func(kind string, tpar, targ Type, arg *operand) {
if !report {
return
}
// provide a better error message if we can // provide a better error message if we can
targs, failed := u.x.types() targs, index := u.x.types()
if failed == 0 { if index == 0 {
// The first type parameter couldn't be inferred. // The first type parameter couldn't be inferred.
// If none of them could be inferred, don't try // If none of them could be inferred, don't try
// to provide the inferred type in the error msg. // to provide the inferred type in the error msg.
@ -53,16 +135,13 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
} }
} }
// Terminology: generic parameter = function parameter with a type-parameterized type // indices of the generic parameters with untyped arguments - save for later
// 1st pass: Unify parameter and argument types for generic parameters with typed arguments
// and collect the indices of generic parameters with untyped arguments.
var indices []int var indices []int
for i, arg := range args { for i, arg := range args {
par := params.At(i) par := params.At(i)
// If we permit bidirectional unification, this conditional code needs to be // If we permit bidirectional unification, this conditional code needs to be
// executed even if par.typ is not parameterized since the argument may be a // executed even if par.typ is not parameterized since the argument may be a
// generic function (for which we want to infer // its type arguments). // generic function (for which we want to infer its type arguments).
if isParameterized(tparams, par.typ) { if isParameterized(tparams, par.typ) {
if arg.mode == invalid { if arg.mode == invalid {
// An error was reported earlier. Ignore this targ // An error was reported earlier. Ignore this targ
@ -76,7 +155,7 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
// the respective type parameters of targ. // the respective type parameters of targ.
if !u.unify(par.typ, targ) { if !u.unify(par.typ, targ) {
errorf("type", par.typ, targ, arg) errorf("type", par.typ, targ, arg)
return nil, 0 return nil
} }
} else { } else {
indices = append(indices, i) indices = append(indices, i)
@ -84,12 +163,25 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
} }
} }
// Some generic parameters with untyped arguments may have been given a type // If we've got all type arguments, we're done.
// indirectly through another generic parameter with a typed argument; we can var index int
// ignore those now. (This only means that we know the types for those generic targs, index = u.x.types()
// parameters; it doesn't mean untyped arguments can be passed safely. We still if index < 0 {
// need to verify that assignment of those arguments is valid when we check return targs
// function parameter passing external to infer.) }
// See how far we get with constraint type inference.
// Note that even if we don't have any type arguments, constraint type inference
// may produce results for constraints that explicitly specify a type.
targs, index = check.inferB(tparams, targs, report)
if targs == nil || index < 0 {
return targs
}
// --- 3 ---
// Use any untyped arguments to infer additional type arguments.
// Some generic parameters with untyped arguments may have been given
// a type by now, we can ignore them.
j := 0 j := 0
for _, i := range indices { for _, i := range indices {
par := params.At(i) par := params.At(i)
@ -98,14 +190,15 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
// only parameter type it can possibly match against is a *TypeParam. // only parameter type it can possibly match against is a *TypeParam.
// Thus, only keep the indices of generic parameters that are not of // Thus, only keep the indices of generic parameters that are not of
// composite types and which don't have a type inferred yet. // composite types and which don't have a type inferred yet.
if tpar, _ := par.typ.(*_TypeParam); tpar != nil && u.x.at(tpar.index) == nil { if tpar, _ := par.typ.(*_TypeParam); tpar != nil && targs[tpar.index] == nil {
indices[j] = i indices[j] = i
j++ j++
} }
} }
indices = indices[:j] indices = indices[:j]
// 2nd pass: Unify parameter and default argument types for remaining generic parameters. // Unify parameter and default argument types for remaining generic parameters.
// TODO(gri) Rather than iterating again, combine this code with the loop above.
for _, i := range indices { for _, i := range indices {
par := params.At(i) par := params.At(i)
arg := args[i] arg := args[i]
@ -115,11 +208,29 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
// nil by making sure all default argument types are typed. // nil by making sure all default argument types are typed.
if isTyped(targ) && !u.unify(par.typ, targ) { if isTyped(targ) && !u.unify(par.typ, targ) {
errorf("default type", par.typ, targ, arg) errorf("default type", par.typ, targ, arg)
return nil, 0 return nil
} }
} }
return u.x.types() // If we've got all type arguments, we're done.
targs, index = u.x.types()
if index < 0 {
return targs
}
// Again, follow up with constraint type inference.
targs, index = check.inferB(tparams, targs, report)
if targs == nil || index < 0 {
return targs
}
// At least one type argument couldn't be inferred.
assert(targs != nil && index >= 0 && targs[index] == nil)
tpar := tparams[index]
if report {
check.errorf(posn, _Todo, "cannot infer %s (%v) (%v)", tpar.name, tpar.pos, targs)
}
return nil
} }
// typeNamesString produces a string containing all the // typeNamesString produces a string containing all the
@ -268,12 +379,13 @@ func (w *tpWalker) isParameterizedList(list []Type) bool {
// inferB returns the list of actual type arguments inferred from the type parameters' // inferB returns the list of actual type arguments inferred from the type parameters'
// bounds and an initial set of type arguments. If type inference is impossible because // bounds and an initial set of type arguments. If type inference is impossible because
// unification fails, an error is reported, the resulting types list is nil, and index is 0. // unification fails, an error is reported if report is set to true, the resulting types
// list is nil, and index is 0.
// Otherwise, types is the list of inferred type arguments, and index is the index of the // Otherwise, types is the list of inferred type arguments, and index is the index of the
// first type argument in that list that couldn't be inferred (and thus is nil). If all // first type argument in that list that couldn't be inferred (and thus is nil). If all
// type arguments where inferred successfully, index is < 0. The number of type arguments // type arguments were inferred successfully, index is < 0. The number of type arguments
// provided may be less than the number of type parameters, but there must be at least one. // provided may be less than the number of type parameters, but there must be at least one.
func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, index int) { func (check *Checker) inferB(tparams []*TypeName, targs []Type, report bool) (types []Type, index int) {
assert(len(tparams) >= len(targs) && len(targs) > 0) assert(len(tparams) >= len(targs) && len(targs) > 0)
// Setup bidirectional unification between those structural bounds // Setup bidirectional unification between those structural bounds
@ -295,7 +407,9 @@ func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, i
sbound := check.structuralType(typ.bound) sbound := check.structuralType(typ.bound)
if sbound != nil { if sbound != nil {
if !u.unify(typ, sbound) { if !u.unify(typ, sbound) {
check.errorf(tpar, 0, "%s does not match %s", tpar, sbound) if report {
check.errorf(tpar, 0, "%s does not match %s", tpar, sbound)
}
return nil, 0 return nil, 0
} }
} }

View file

@ -178,17 +178,17 @@ func _[T interface{ type string, chan<-int }](x T) {
// type inference checks // type inference checks
var _ = new() /* ERROR cannot infer T */ var _ = new /* ERROR cannot infer T */ ()
func f4[A, B, C any](A, B) C func f4[A, B, C any](A, B) C
var _ = f4(1, 2) /* ERROR cannot infer C */ var _ = f4 /* ERROR cannot infer C */ (1, 2)
var _ = f4[int, float32, complex128](1, 2) var _ = f4[int, float32, complex128](1, 2)
func f5[A, B, C any](A, []*B, struct{f []C}) int func f5[A, B, C any](A, []*B, struct{f []C}) int
var _ = f5[int, float32, complex128](0, nil, struct{f []complex128}{}) var _ = f5[int, float32, complex128](0, nil, struct{f []complex128}{})
var _ = f5(0, nil, struct{f []complex128}{}) // ERROR cannot infer var _ = f5 /* ERROR cannot infer */ (0, nil, struct{f []complex128}{})
var _ = f5(0, []*float32{new[float32]()}, struct{f []complex128}{}) var _ = f5(0, []*float32{new[float32]()}, struct{f []complex128}{})
func f6[A any](A, []A) int func f6[A any](A, []A) int
@ -197,13 +197,13 @@ var _ = f6(0, nil)
func f6nil[A any](A) int func f6nil[A any](A) int
var _ = f6nil(nil) // ERROR cannot infer var _ = f6nil /* ERROR cannot infer */ (nil)
// type inference with variadic functions // type inference with variadic functions
func f7[T any](...T) T func f7[T any](...T) T
var _ int = f7() /* ERROR cannot infer T */ var _ int = f7 /* ERROR cannot infer T */ ()
var _ int = f7(1) var _ int = f7(1)
var _ int = f7(1, 2) var _ int = f7(1, 2)
var _ int = f7([]int{}...) var _ int = f7([]int{}...)