cmd/compile: deal with closures in generic functions and instantiated function values

- Deal with closures in generic functions by fixing the stenciling code

 - Deal with instantiated function values (instantiated generic
   functions that are not immediately called) during stenciling. This
   requires changing the OFUNCINST node to an ONAME node for the
   appropriately instantiated function. We do this in a second pass,
   since this is uncommon, but requires editing the tree at multiple
   levels.

 - Check global assignments (as well as functions) for generic function
   instantiations.

 - Fix a bug in (*subst).typ where a generic type in a generic function
   may definitely not use all the type args of the function, so we need
   to translate the rparams of the type based on the tparams/targs of
   the function.

 - Added new test combine.go that tests out closures in generic
   functions and instantiated function values.

 - Added one new variant to the settable test.

 - Enabling inlining functions with closures for -G=3. (For now, set
   Ntype on closures in -G=3 mode to keep compatibility with later parts
   of compiler, and allow inlining of functions with closures.)

Change-Id: Iea63d5704c322e42e2f750a83adc8b44f911d4ec
Reviewed-on: https://go-review.googlesource.com/c/go/+/296269
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Dan Scales <danscales@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Dan Scales <danscales@google.com>
This commit is contained in:
Dan Scales 2021-02-21 10:54:38 -08:00
parent 19f96e73bf
commit d8e33d558e
5 changed files with 236 additions and 51 deletions

View file

@ -354,7 +354,7 @@ func (v *hairyVisitor) doNode(n ir.Node) bool {
return true return true
case ir.OCLOSURE: case ir.OCLOSURE:
if base.Debug.InlFuncsWithClosures == 0 || base.Flag.G > 0 { if base.Debug.InlFuncsWithClosures == 0 {
v.reason = "not inlining functions with closures" v.reason = "not inlining functions with closures"
return true return true
} }

View file

@ -325,19 +325,22 @@ func (g *irgen) compLit(typ types2.Type, lit *syntax.CompositeLit) ir.Node {
return typecheck.Expr(ir.NewCompLitExpr(g.pos(lit), ir.OCOMPLIT, ir.TypeNode(g.typ(typ)), exprs)) return typecheck.Expr(ir.NewCompLitExpr(g.pos(lit), ir.OCOMPLIT, ir.TypeNode(g.typ(typ)), exprs))
} }
func (g *irgen) funcLit(typ types2.Type, expr *syntax.FuncLit) ir.Node { func (g *irgen) funcLit(typ2 types2.Type, expr *syntax.FuncLit) ir.Node {
fn := ir.NewFunc(g.pos(expr)) fn := ir.NewFunc(g.pos(expr))
fn.SetIsHiddenClosure(ir.CurFunc != nil) fn.SetIsHiddenClosure(ir.CurFunc != nil)
fn.Nname = ir.NewNameAt(g.pos(expr), typecheck.ClosureName(ir.CurFunc)) fn.Nname = ir.NewNameAt(g.pos(expr), typecheck.ClosureName(ir.CurFunc))
ir.MarkFunc(fn.Nname) ir.MarkFunc(fn.Nname)
fn.Nname.SetType(g.typ(typ)) typ := g.typ(typ2)
fn.Nname.Func = fn fn.Nname.Func = fn
fn.Nname.Defn = fn fn.Nname.Defn = fn
// Set Ntype for now to be compatible with later parts of compile, remove later.
fn.Nname.Ntype = ir.TypeNode(typ)
typed(typ, fn.Nname)
fn.SetTypecheck(1)
fn.OClosure = ir.NewClosureExpr(g.pos(expr), fn) fn.OClosure = ir.NewClosureExpr(g.pos(expr), fn)
fn.OClosure.SetType(fn.Nname.Type()) typed(typ, fn.OClosure)
fn.OClosure.SetTypecheck(1)
g.funcBody(fn, nil, expr.Type, expr.Body) g.funcBody(fn, nil, expr.Type, expr.Body)

View file

@ -27,19 +27,37 @@ func (g *irgen) stencil() {
// functions calling other generic functions. // functions calling other generic functions.
for i := 0; i < len(g.target.Decls); i++ { for i := 0; i < len(g.target.Decls); i++ {
decl := g.target.Decls[i] decl := g.target.Decls[i]
if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 {
// Skip any non-function declarations and skip generic functions // Look for function instantiations in bodies of non-generic
// functions or in global assignments (ignore global type and
// constant declarations).
switch decl.Op() {
case ir.ODCLFUNC:
if decl.Type().HasTParam() {
// Skip any generic functions
continue
}
case ir.OAS:
case ir.OAS2:
default:
continue continue
} }
// For each non-generic function, search for any function calls using // For all non-generic code, search for any function calls using
// generic function instantiations. (We don't yet handle generic // generic function instantiations. Then create the needed
// function instantiations that are not immediately called.) // instantiated function if it hasn't been created yet, and change
// Then create the needed instantiated function if it hasn't been // to calling that function directly.
// created yet, and change to calling that function directly.
f := decl.(*ir.Func)
modified := false modified := false
ir.VisitList(f.Body, func(n ir.Node) { foundFuncInst := false
ir.Visit(decl, func(n ir.Node) {
if n.Op() == ir.OFUNCINST {
// We found a function instantiation that is not
// immediately called.
foundFuncInst = true
}
if n.Op() != ir.OCALLFUNC || n.(*ir.CallExpr).X.Op() != ir.OFUNCINST { if n.Op() != ir.OCALLFUNC || n.(*ir.CallExpr).X.Op() != ir.OFUNCINST {
return return
} }
@ -47,19 +65,7 @@ func (g *irgen) stencil() {
// instantiation. // instantiation.
call := n.(*ir.CallExpr) call := n.(*ir.CallExpr)
inst := call.X.(*ir.InstExpr) inst := call.X.(*ir.InstExpr)
sym := makeInstName(inst) st := g.getInstantiation(inst)
//fmt.Printf("Found generic func call in %v to %v\n", f, s)
st := g.target.Stencils[sym]
if st == nil {
// If instantiation doesn't exist yet, create it and add
// to the list of decls.
st = genericSubst(sym, inst)
g.target.Stencils[sym] = st
g.target.Decls = append(g.target.Decls, st)
if base.Flag.W > 1 {
ir.Dump(fmt.Sprintf("\nstenciled %v", st), st)
}
}
// Replace the OFUNCINST with a direct reference to the // Replace the OFUNCINST with a direct reference to the
// new stenciled function // new stenciled function
call.X = st.Nname call.X = st.Nname
@ -76,6 +82,26 @@ func (g *irgen) stencil() {
} }
modified = true modified = true
}) })
// If we found an OFUNCINST without a corresponding call in the
// above decl, then traverse the nodes of decl again (with
// EditChildren rather than Visit), where we actually change the
// OFUNCINST node to an ONAME for the instantiated function.
// EditChildren is more expensive than Visit, so we only do this
// in the infrequent case of an OFUNCINSt without a corresponding
// call.
if foundFuncInst {
var edit func(ir.Node) ir.Node
edit = func(x ir.Node) ir.Node {
if x.Op() == ir.OFUNCINST {
st := g.getInstantiation(x.(*ir.InstExpr))
return st.Nname
}
ir.EditChildren(x, edit)
return x
}
edit(decl)
}
if base.Flag.W > 1 && modified { if base.Flag.W > 1 && modified {
ir.Dump(fmt.Sprintf("\nmodified %v", decl), decl) ir.Dump(fmt.Sprintf("\nmodified %v", decl), decl)
} }
@ -83,18 +109,39 @@ func (g *irgen) stencil() {
} }
// makeInstName makes the unique name for a stenciled generic function, based on // getInstantiation gets the instantiated function corresponding to inst. If the
// the name of the function and the types of the type params. // instantiated function is not already cached, then it calls genericStub to
func makeInstName(inst *ir.InstExpr) *types.Sym { // create the new instantiation.
b := bytes.NewBufferString("#") func (g *irgen) getInstantiation(inst *ir.InstExpr) *ir.Func {
var sym *types.Sym
if meth, ok := inst.X.(*ir.SelectorExpr); ok { if meth, ok := inst.X.(*ir.SelectorExpr); ok {
// Write the name of the generic method, including receiver type // Write the name of the generic method, including receiver type
b.WriteString(meth.Selection.Nname.Sym().Name) sym = makeInstName(meth.Selection.Nname.Sym(), inst.Targs)
} else { } else {
b.WriteString(inst.X.(*ir.Name).Name().Sym().Name) sym = makeInstName(inst.X.(*ir.Name).Name().Sym(), inst.Targs)
} }
//fmt.Printf("Found generic func call in %v to %v\n", f, s)
st := g.target.Stencils[sym]
if st == nil {
// If instantiation doesn't exist yet, create it and add
// to the list of decls.
st = g.genericSubst(sym, inst)
g.target.Stencils[sym] = st
g.target.Decls = append(g.target.Decls, st)
if base.Flag.W > 1 {
ir.Dump(fmt.Sprintf("\nstenciled %v", st), st)
}
}
return st
}
// makeInstName makes the unique name for a stenciled generic function, based on
// the name of the function and the targs.
func makeInstName(fnsym *types.Sym, targs []ir.Node) *types.Sym {
b := bytes.NewBufferString("#")
b.WriteString(fnsym.Name)
b.WriteString("[") b.WriteString("[")
for i, targ := range inst.Targs { for i, targ := range targs {
if i > 0 { if i > 0 {
b.WriteString(",") b.WriteString(",")
} }
@ -107,6 +154,7 @@ func makeInstName(inst *ir.InstExpr) *types.Sym {
// Struct containing info needed for doing the substitution as we create the // Struct containing info needed for doing the substitution as we create the
// instantiation of a generic function with specified type arguments. // instantiation of a generic function with specified type arguments.
type subster struct { type subster struct {
g *irgen
newf *ir.Func // Func node for the new stenciled function newf *ir.Func // Func node for the new stenciled function
tparams []*types.Field tparams []*types.Field
targs []ir.Node targs []ir.Node
@ -121,7 +169,7 @@ type subster struct {
// inst. For a method with a generic receiver, it returns an instantiated function // inst. For a method with a generic receiver, it returns an instantiated function
// type where the receiver becomes the first parameter. Otherwise the instantiated // type where the receiver becomes the first parameter. Otherwise the instantiated
// method would still need to be transformed by later compiler phases. // method would still need to be transformed by later compiler phases.
func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func { func (g *irgen) genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func {
var nameNode *ir.Name var nameNode *ir.Name
var tparams []*types.Field var tparams []*types.Field
if selExpr, ok := inst.X.(*ir.SelectorExpr); ok { if selExpr, ok := inst.X.(*ir.SelectorExpr); ok {
@ -148,6 +196,7 @@ func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func {
name.Def = newf.Nname name.Def = newf.Nname
subst := &subster{ subst := &subster{
g: g,
newf: newf, newf: newf,
tparams: tparams, tparams: tparams,
targs: inst.Targs, targs: inst.Targs,
@ -198,6 +247,9 @@ func (subst *subster) node(n ir.Node) ir.Node {
return v return v
} }
m := ir.NewNameAt(name.Pos(), name.Sym()) m := ir.NewNameAt(name.Pos(), name.Sym())
if name.IsClosureVar() {
m.SetIsClosureVar(true)
}
t := x.Type() t := x.Type()
newt := subst.typ(t) newt := subst.typ(t)
m.SetType(newt) m.SetType(newt)
@ -219,10 +271,12 @@ func (subst *subster) node(n ir.Node) ir.Node {
// t can be nil only if this is a call that has no // t can be nil only if this is a call that has no
// return values, so allow that and otherwise give // return values, so allow that and otherwise give
// an error. // an error.
if _, isCallExpr := m.(*ir.CallExpr); !isCallExpr { _, isCallExpr := m.(*ir.CallExpr)
_, isStructKeyExpr := m.(*ir.StructKeyExpr)
if !isCallExpr && !isStructKeyExpr {
base.Fatalf(fmt.Sprintf("Nil type for %v", x)) base.Fatalf(fmt.Sprintf("Nil type for %v", x))
} }
} else { } else if x.Op() != ir.OCLOSURE {
m.SetType(subst.typ(x.Type())) m.SetType(subst.typ(x.Type()))
} }
} }
@ -270,14 +324,27 @@ func (subst *subster) node(n ir.Node) ir.Node {
if oldfn.ClosureCalled() { if oldfn.ClosureCalled() {
newfn.SetClosureCalled(true) newfn.SetClosureCalled(true)
} }
newfn.SetIsHiddenClosure(true)
m.(*ir.ClosureExpr).Func = newfn m.(*ir.ClosureExpr).Func = newfn
newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), oldfn.Nname.Sym()) newsym := makeInstName(oldfn.Nname.Sym(), subst.targs)
newfn.Nname.SetType(oldfn.Nname.Type()) newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), newsym)
newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype) newfn.Nname.Func = newfn
newfn.Nname.Defn = newfn
ir.MarkFunc(newfn.Nname)
newfn.OClosure = m.(*ir.ClosureExpr)
saveNewf := subst.newf
subst.newf = newfn
newfn.Dcl = subst.namelist(oldfn.Dcl)
newfn.ClosureVars = subst.namelist(oldfn.ClosureVars)
newfn.Body = subst.list(oldfn.Body) newfn.Body = subst.list(oldfn.Body)
// Make shallow copy of the Dcl and ClosureVar slices subst.newf = saveNewf
newfn.Dcl = append([]*ir.Name(nil), oldfn.Dcl...)
newfn.ClosureVars = append([]*ir.Name(nil), oldfn.ClosureVars...) // Set Ntype for now to be compatible with later parts of compiler
newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype)
typed(subst.typ(oldfn.Nname.Type()), newfn.Nname)
newfn.SetTypecheck(1)
subst.g.target.Decls = append(subst.g.target.Decls, newfn)
} }
return m return m
} }
@ -285,6 +352,20 @@ func (subst *subster) node(n ir.Node) ir.Node {
return edit(n) return edit(n)
} }
func (subst *subster) namelist(l []*ir.Name) []*ir.Name {
s := make([]*ir.Name, len(l))
for i, n := range l {
s[i] = subst.node(n).(*ir.Name)
if n.Defn != nil {
s[i].Defn = subst.node(n.Defn)
}
if n.Outer != nil {
s[i].Outer = subst.node(n.Outer).(*ir.Name)
}
}
return s
}
func (subst *subster) list(l []ir.Node) []ir.Node { func (subst *subster) list(l []ir.Node) []ir.Node {
s := make([]ir.Node, len(l)) s := make([]ir.Node, len(l))
for i, n := range l { for i, n := range l {
@ -293,7 +374,9 @@ func (subst *subster) list(l []ir.Node) []ir.Node {
return s return s
} }
// tstruct substitutes type params in a structure type // tstruct substitutes type params in types of the fields of a structure type. For
// each field, if Nname is set, tstruct also translates the Nname using subst.vars, if
// Nname is in subst.vars.
func (subst *subster) tstruct(t *types.Type) *types.Type { func (subst *subster) tstruct(t *types.Type) *types.Type {
if t.NumFields() == 0 { if t.NumFields() == 0 {
return t return t
@ -301,7 +384,7 @@ func (subst *subster) tstruct(t *types.Type) *types.Type {
var newfields []*types.Field var newfields []*types.Field
for i, f := range t.Fields().Slice() { for i, f := range t.Fields().Slice() {
t2 := subst.typ(f.Type) t2 := subst.typ(f.Type)
if t2 != f.Type && newfields == nil { if (t2 != f.Type || f.Nname != nil) && newfields == nil {
newfields = make([]*types.Field, t.NumFields()) newfields = make([]*types.Field, t.NumFields())
for j := 0; j < i; j++ { for j := 0; j < i; j++ {
newfields[j] = t.Field(j) newfields[j] = t.Field(j)
@ -309,6 +392,12 @@ func (subst *subster) tstruct(t *types.Type) *types.Type {
} }
if newfields != nil { if newfields != nil {
newfields[i] = types.NewField(f.Pos, f.Sym, t2) newfields[i] = types.NewField(f.Pos, f.Sym, t2)
if f.Nname != nil {
// f.Nname may not be in subst.vars[] if this is
// a function name or a function instantiation type
// that we are translating
newfields[i].Nname = subst.vars[f.Nname.(*ir.Name)]
}
} }
} }
if newfields != nil { if newfields != nil {
@ -319,14 +408,14 @@ func (subst *subster) tstruct(t *types.Type) *types.Type {
} }
// instTypeName creates a name for an instantiated type, based on the type args // instTypeName creates a name for an instantiated type, based on the type args
func instTypeName(name string, targs []ir.Node) string { func instTypeName(name string, targs []*types.Type) string {
b := bytes.NewBufferString(name) b := bytes.NewBufferString(name)
b.WriteByte('[') b.WriteByte('[')
for i, targ := range targs { for i, targ := range targs {
if i > 0 { if i > 0 {
b.WriteByte(',') b.WriteByte(',')
} }
b.WriteString(targ.Type().String()) b.WriteString(targ.String())
} }
b.WriteByte(']') b.WriteByte(']')
return b.String() return b.String()
@ -415,10 +504,17 @@ func (subst *subster) typ(t *types.Type) *types.Type {
// Since we've substituted types, we also need to change // Since we've substituted types, we also need to change
// the defined name of the type, by removing the old types // the defined name of the type, by removing the old types
// (in brackets) from the name, and adding the new types. // (in brackets) from the name, and adding the new types.
// Translate the type params for this type according to
// the tparam/targs mapping of the function.
neededTargs := make([]*types.Type, len(t.RParams))
for i, rparam := range t.RParams {
neededTargs[i] = subst.typ(rparam)
}
oldname := t.Sym().Name oldname := t.Sym().Name
i := strings.Index(oldname, "[") i := strings.Index(oldname, "[")
oldname = oldname[:i] oldname = oldname[:i]
sym := t.Sym().Pkg.Lookup(instTypeName(oldname, subst.targs)) sym := t.Sym().Pkg.Lookup(instTypeName(oldname, neededTargs))
if sym.Def != nil { if sym.Def != nil {
// We've already created this instantiated defined type. // We've already created this instantiated defined type.
return sym.Def.Type() return sym.Def.Type()

65
test/typeparam/combine.go Normal file
View file

@ -0,0 +1,65 @@
// run -gcflags=-G=3
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
)
type _Gen[A any] func() (A, bool)
func combine[T1, T2, T any](g1 _Gen[T1], g2 _Gen[T2], join func(T1, T2) T) _Gen[T] {
return func() (T, bool) {
var t T
t1, ok := g1()
if !ok {
return t, false
}
t2, ok := g2()
if !ok {
return t, false
}
return join(t1, t2), true
}
}
type _Pair[A, B any] struct {
A A
B B
}
func _NewPair[A, B any](a A, b B) _Pair[A, B] {
return _Pair[A, B]{a, b}
}
func _Combine2[A, B any](ga _Gen[A], gb _Gen[B]) _Gen[_Pair[A, B]] {
return combine(ga, gb, _NewPair[A, B])
}
func main() {
var g1 _Gen[int] = func() (int, bool) { return 3, true }
var g2 _Gen[string] = func() (string, bool) { return "x", false }
var g3 _Gen[string] = func() (string, bool) { return "y", true }
gc := combine(g1, g2, _NewPair[int, string])
if got, ok := gc(); ok {
panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok))
}
gc2 := _Combine2(g1, g2)
if got, ok := gc2(); ok {
panic(fmt.Sprintf("got %v, %v, wanted -/false", got, ok))
}
gc3 := combine(g1, g3, _NewPair[int, string])
if got, ok := gc3(); !ok || got.A != 3 || got.B != "y" {
panic(fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok))
}
gc4 := _Combine2(g1, g3)
if got, ok := gc4(); !ok || got.A != 3 || got.B != "y" {
panic (fmt.Sprintf("got %v, %v, wanted {3, y}, true", got, ok))
}
}

View file

@ -11,7 +11,24 @@ import (
"strconv" "strconv"
) )
func fromStrings3[T any](s []string, set func(*T, string)) []T { type Setter[B any] interface {
Set(string)
type *B
}
func fromStrings1[T any, PT Setter[T]](s []string) []T {
result := make([]T, len(s))
for i, v := range s {
// The type of &result[i] is *T which is in the type list
// of Setter, so we can convert it to PT.
p := PT(&result[i])
// PT has a Set method.
p.Set(v)
}
return result
}
func fromStrings2[T any](s []string, set func(*T, string)) []T {
results := make([]T, len(s)) results := make([]T, len(s))
for i, v := range s { for i, v := range s {
set(&results[i], v) set(&results[i], v)
@ -30,8 +47,12 @@ func (p *Settable) Set(s string) {
} }
func main() { func main() {
s := fromStrings3([]string{"1"}, s := fromStrings1[Settable, *Settable]([]string{"1"})
func(p *Settable, s string) { p.Set(s) }) if len(s) != 1 || s[0] != 1 {
panic(fmt.Sprintf("got %v, want %v", s, []int{1}))
}
s = fromStrings2([]string{"1"}, func(p *Settable, s string) { p.Set(s) })
if len(s) != 1 || s[0] != 1 { if len(s) != 1 || s[0] != 1 {
panic(fmt.Sprintf("got %v, want %v", s, []int{1})) panic(fmt.Sprintf("got %v, want %v", s, []int{1}))
} }