mirror of
https://github.com/golang/go.git
synced 2026-06-28 03:40:37 +00:00
cmd/compile: eliminate impossible type assertions in generic functions
When a generic function converts a shape-typed value to an interface
and then type-asserts or type-switches on it, some cases can never
match because the asserted concrete type has a different shape than
the source. For example:
func foo[S string | []byte](x S) {
switch any(x).(type) {
case string: // possible only when S has shape string
case []byte: // possible only when S has shape []uint8
}
}
Since instantiated generic funcs work on shapes, all instantiations
contain the code for all cases even if they will never be hit.
Detect OCONVIFACE of a shape type followed by a concrete type
assertion, and compare the shapes. If they are incompatible, the
assertion can never succeed for that instantiation.
This applies to both type switch cases (which are skipped entirely)
and comma-ok type assertions (which are replaced with zero, false).
The analysis also tracks through intermediate variables using a
pre-walk pass with ReassignOracle, so patterns like
iface := any(x)
v, ok := iface.(string)
are handled as well.
Updates #57072
Change-Id: I837f6089b9e431f856a528463075fd10abe464dc
Reviewed-on: https://go-review.googlesource.com/c/go/+/767640
Reviewed-by: Michael Pratt <mpratt@google.com>
Reviewed-by: Keith Randall <khr@golang.org>
Reviewed-by: Keith Randall <khr@google.com>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Keith Randall <khr@golang.org>
This commit is contained in:
parent
58968c79e7
commit
343fbe2971
7 changed files with 369 additions and 17 deletions
|
|
@ -940,11 +940,11 @@ func (dict *readerDict) mangle(sym *types.Sym) *types.Sym {
|
|||
return sym.Pkg.Lookup(buf.String())
|
||||
}
|
||||
|
||||
// shapify returns the shape type for targ.
|
||||
// Shapify returns the shape type for targ.
|
||||
//
|
||||
// If basic is true, then the type argument is used to instantiate a
|
||||
// type parameter whose constraint is a basic interface.
|
||||
func shapify(targ *types.Type, basic bool) *types.Type {
|
||||
func Shapify(targ *types.Type, basic bool) *types.Type {
|
||||
if targ.Kind() == types.TFORW {
|
||||
if targ.IsFullyInstantiated() {
|
||||
// For recursive instantiated type argument, it may still be a TFORW
|
||||
|
|
@ -1076,7 +1076,7 @@ func (pr *pkgReader) objDictIdx(sym *types.Sym, idx index, implicits, explicits
|
|||
for i, targ := range dict.targs {
|
||||
basic := r.Bool()
|
||||
if dict.shaped {
|
||||
dict.targs[i] = shapify(targ, basic)
|
||||
dict.targs[i] = Shapify(targ, basic)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -120,6 +120,16 @@ func walkAssign(init *ir.Nodes, n ir.Node) ir.Node {
|
|||
// walkAssignDotType walks an OAS2DOTTYPE node.
|
||||
func walkAssignDotType(n *ir.AssignListStmt, init *ir.Nodes) ir.Node {
|
||||
walkExprListSafe(n.Lhs, init)
|
||||
|
||||
if r, ok := n.Rhs[0].(*ir.TypeAssertExpr); ok && r.Op() == ir.ODOTTYPE2 && !r.Type().IsInterface() {
|
||||
if shapeTypeAssertImpossible(r.X, r.Type()) {
|
||||
init.Append(typecheck.Stmt(ir.NewAssignStmt(base.Pos, ir.BlankNode, walkExpr(r.X, init))))
|
||||
init.Append(typecheck.Stmt(ir.NewAssignStmt(base.Pos, n.Lhs[0], ir.NewZero(base.Pos, r.Type()))))
|
||||
init.Append(typecheck.Stmt(ir.NewAssignStmt(base.Pos, n.Lhs[1], ir.NewBool(base.Pos, false))))
|
||||
return ir.NewBlockStmt(base.Pos, nil)
|
||||
}
|
||||
}
|
||||
|
||||
n.Rhs[0] = walkExpr(n.Rhs[0], init)
|
||||
return n
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
|
||||
"cmd/compile/internal/base"
|
||||
"cmd/compile/internal/ir"
|
||||
"cmd/compile/internal/noder"
|
||||
"cmd/compile/internal/objw"
|
||||
"cmd/compile/internal/reflectdata"
|
||||
"cmd/compile/internal/rttype"
|
||||
|
|
@ -762,6 +763,51 @@ func walkDotType(n *ir.TypeAssertExpr, init *ir.Nodes) ir.Node {
|
|||
return n
|
||||
}
|
||||
|
||||
// shapeTypeAssertImpossible reports whether a type assertion from src
|
||||
// to concrete type dst can never succeed because they have
|
||||
// incompatible shape types.
|
||||
func shapeTypeAssertImpossible(src ir.Node, dst *types.Type) bool {
|
||||
if dst.IsInterface() {
|
||||
return false
|
||||
}
|
||||
srcShape := convIfaceShapeType(src)
|
||||
if srcShape == nil {
|
||||
return false
|
||||
}
|
||||
return !types.Identical(srcShape, noder.Shapify(dst, false)) &&
|
||||
!types.Identical(srcShape, noder.Shapify(dst, true))
|
||||
}
|
||||
|
||||
// convIfaceShapeType returns the shape type from which src was
|
||||
// created via OCONVIFACE, or nil.
|
||||
func convIfaceShapeType(src ir.Node) *types.Type {
|
||||
for {
|
||||
switch s := src.(type) {
|
||||
case *ir.ParenExpr:
|
||||
src = s.X
|
||||
continue
|
||||
case *ir.ConvExpr:
|
||||
if s.Op() == ir.OCONVNOP {
|
||||
src = s.X
|
||||
continue
|
||||
}
|
||||
if s.Op() == ir.OCONVIFACE {
|
||||
srcType := s.X.Type()
|
||||
if srcType != nil && !srcType.IsInterface() && srcType.IsShape() {
|
||||
return srcType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if name, ok := src.(*ir.Name); ok && shapeConvSources != nil {
|
||||
return shapeConvSources[name.Canonical()]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeTypeAssertDescriptor(target *types.Type, canFail bool) *obj.LSym {
|
||||
// When converting from an interface to a non-empty interface. Needs a runtime call.
|
||||
// Allocate an internal/abi.TypeAssert descriptor for that call.
|
||||
|
|
|
|||
|
|
@ -703,7 +703,8 @@ func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
|
|||
// type switch.
|
||||
func walkSwitchType(sw *ir.SwitchStmt) {
|
||||
var s typeSwitch
|
||||
s.srcName = sw.Tag.(*ir.TypeSwitchGuard).X
|
||||
origSrc := sw.Tag.(*ir.TypeSwitchGuard).X
|
||||
s.srcName = origSrc
|
||||
s.srcName = walkExpr(s.srcName, sw.PtrInit())
|
||||
s.srcName = copyExpr(s.srcName, s.srcName.Type(), &sw.Compiled)
|
||||
s.okName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
|
||||
|
|
@ -933,6 +934,10 @@ caseLoop:
|
|||
// the dynamic type cases separately, as we do above.
|
||||
}
|
||||
|
||||
if shapeTypeAssertImpossible(origSrc, c.typ.Type()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.typ.Type().IsInterface() {
|
||||
interfaceCases = append(interfaceCases, c)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ const tmpstringbufsize = 32
|
|||
func Walk(fn *ir.Func) {
|
||||
ir.CurFunc = fn
|
||||
|
||||
// Set and then clear a package-level cache of static values for this fn.
|
||||
// Build pre-walk analysis caches with a single AST traversal.
|
||||
// (At some point, it might be worthwhile to have a walkState structure
|
||||
// that gets passed everywhere where things like this can go.)
|
||||
staticValues = findStaticValues(fn)
|
||||
defer func() { staticValues = nil }()
|
||||
analyzePreWalk(fn)
|
||||
defer func() { staticValues = nil; shapeConvSources = nil }()
|
||||
|
||||
errorsBefore := base.Errors()
|
||||
order(fn)
|
||||
|
|
@ -449,23 +449,44 @@ func staticValue(n ir.Node) ir.Node {
|
|||
// staticValues is a cache of static values for use by staticValue.
|
||||
var staticValues map[ir.Node]ir.Node
|
||||
|
||||
// findStaticValues returns a map of static values for fn.
|
||||
func findStaticValues(fn *ir.Func) map[ir.Node]ir.Node {
|
||||
// We can't use an ir.ReassignOracle or ir.StaticValue in the
|
||||
// middle of walk because they don't currently handle
|
||||
// transformed assignments (e.g., will complain about 'RHS == nil').
|
||||
// So we instead build this map to use in walk.
|
||||
// shapeConvSources maps an *ir.Name (a PAUTO interface variable) to
|
||||
// the shape type of the OCONVIFACE expression that is its single
|
||||
// static value, if any.
|
||||
var shapeConvSources map[*ir.Name]*types.Type
|
||||
|
||||
// analyzePreWalk populates staticValues and shapeConvSources using a
|
||||
// single AST traversal. We can't use an ir.ReassignOracle or
|
||||
// ir.StaticValue in the middle of walk because they don't currently
|
||||
// handle transformed assignments (e.g., will complain about
|
||||
// 'RHS == nil'). So we build these maps before walk begins.
|
||||
func analyzePreWalk(fn *ir.Func) {
|
||||
ro := &ir.ReassignOracle{}
|
||||
ro.Init(fn)
|
||||
m := make(map[ir.Node]ir.Node)
|
||||
sv := make(map[ir.Node]ir.Node)
|
||||
scs := make(map[*ir.Name]*types.Type)
|
||||
ir.Visit(fn, func(n ir.Node) {
|
||||
if n.Op() == ir.OCONVIFACE {
|
||||
switch n.Op() {
|
||||
case ir.OCONVIFACE:
|
||||
x := n.(*ir.ConvExpr).X
|
||||
v := ro.StaticValue(x)
|
||||
if v != nil && v != x {
|
||||
m[x] = v
|
||||
sv[x] = v
|
||||
}
|
||||
case ir.ONAME:
|
||||
name := n.(*ir.Name).Canonical()
|
||||
if name.Class != ir.PAUTO || name.Type() == nil || !name.Type().IsInterface() {
|
||||
return
|
||||
}
|
||||
val := ro.StaticValue(name)
|
||||
if val == nil || val.Op() != ir.OCONVIFACE {
|
||||
return
|
||||
}
|
||||
srcType := val.(*ir.ConvExpr).X.Type()
|
||||
if srcType != nil && !srcType.IsInterface() && srcType.IsShape() {
|
||||
scs[name] = srcType
|
||||
}
|
||||
}
|
||||
})
|
||||
return m
|
||||
staticValues = sv
|
||||
shapeConvSources = scs
|
||||
}
|
||||
|
|
|
|||
112
test/codegen/shape_assert.go
Normal file
112
test/codegen/shape_assert.go
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
// asmcheck
|
||||
|
||||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Test that type assertions and type switch cases that are impossible
|
||||
// based on shape type analysis are eliminated from generated code.
|
||||
|
||||
package codegen
|
||||
|
||||
// -- Type switch elimination --
|
||||
|
||||
func switchStringOrBytes[S string | []byte](x S) string {
|
||||
switch any(x).(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// In the string instantiation, the []byte case is impossible
|
||||
// and should be eliminated.
|
||||
func SwitchStringInst(x string) string {
|
||||
// amd64:-"type:.*uint8"
|
||||
return switchStringOrBytes(x)
|
||||
}
|
||||
|
||||
// In the []byte instantiation, the string case is impossible
|
||||
// and should be eliminated.
|
||||
func SwitchBytesInst(x []byte) string {
|
||||
// amd64:-"type:string"
|
||||
return switchStringOrBytes(x)
|
||||
}
|
||||
|
||||
// -- Comma-ok type assertion elimination --
|
||||
|
||||
func commaOkString[S string | []byte](x S) (string, bool) {
|
||||
v, ok := any(x).(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// In the []byte instantiation, .(string) always fails.
|
||||
// The type comparison against type:string should be eliminated.
|
||||
func CommaOkStringBytesInst(x []byte) (string, bool) {
|
||||
// amd64:-"type:string"
|
||||
return commaOkString(x)
|
||||
}
|
||||
|
||||
// In the string instantiation, the comparison against type:string
|
||||
// is also eliminated because the assertion always succeeds.
|
||||
func CommaOkStringStringInst(x string) (string, bool) {
|
||||
// amd64:-"LEAQ\ttype:string"
|
||||
return commaOkString(x)
|
||||
}
|
||||
|
||||
// -- Intermediate variable: comma-ok --
|
||||
|
||||
func commaOkViaVar[S string | []byte](x S) (string, bool) {
|
||||
iface := any(x)
|
||||
v, ok := iface.(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func CommaOkViaVarBytesInst(x []byte) (string, bool) {
|
||||
// amd64:-"type:string"
|
||||
return commaOkViaVar(x)
|
||||
}
|
||||
|
||||
// -- Intermediate variable: type switch --
|
||||
|
||||
func switchViaVar[S string | []byte](x S) string {
|
||||
iface := any(x)
|
||||
switch iface.(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SwitchViaVarStringInst(x string) string {
|
||||
// amd64:-"type:.*uint8"
|
||||
return switchViaVar(x)
|
||||
}
|
||||
|
||||
func SwitchViaVarBytesInst(x []byte) string {
|
||||
// amd64:-"type:string"
|
||||
return switchViaVar(x)
|
||||
}
|
||||
|
||||
// -- All cases eliminated for one instantiation --
|
||||
|
||||
func switchFallsToDefault[S string | []byte | int](x S) string {
|
||||
switch any(x).(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
}
|
||||
return "other"
|
||||
}
|
||||
|
||||
// int instantiation: both cases are impossible.
|
||||
func SwitchFallsToDefaultIntInst(x int) string {
|
||||
// amd64:-"type:string"
|
||||
// amd64:-"type:.*uint8"
|
||||
return switchFallsToDefault(x)
|
||||
}
|
||||
158
test/typeparam/shape_assert.go
Normal file
158
test/typeparam/shape_assert.go
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
// run
|
||||
|
||||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Test that type assertions and type switches in generic functions
|
||||
// produce correct results when the compiler eliminates impossible
|
||||
// cases based on shape type analysis.
|
||||
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
|
||||
func switchStringOrBytes[S string | []byte](x S) string {
|
||||
switch any(x).(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func switchThree[S string | []byte | int](x S) string {
|
||||
switch any(x).(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
case int:
|
||||
return "int"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type MyString string
|
||||
|
||||
func switchNamed[S string | MyString](x S) string {
|
||||
switch any(x).(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case MyString:
|
||||
return "MyString"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func commaOkString[S string | []byte](x S) (string, bool) {
|
||||
v, ok := any(x).(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func commaOkBytes[S string | []byte](x S) ([]byte, bool) {
|
||||
v, ok := any(x).([]byte)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func commaOkChain[S string | []byte | int](x S) string {
|
||||
if _, ok := any(x).(string); ok {
|
||||
return "string"
|
||||
}
|
||||
if _, ok := any(x).([]byte); ok {
|
||||
return "[]byte"
|
||||
}
|
||||
if _, ok := any(x).(int); ok {
|
||||
return "int"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Intermediate variable tests.
|
||||
func commaOkViaVar[S string | []byte](x S) (string, bool) {
|
||||
iface := any(x)
|
||||
v, ok := iface.(string)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func switchViaVar[S string | []byte](x S) string {
|
||||
iface := any(x)
|
||||
switch iface.(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// When no switch case matches the shape, the default is taken.
|
||||
func switchFallsToDefault[S string | []byte | int](x S) string {
|
||||
switch any(x).(type) {
|
||||
case string:
|
||||
return "string"
|
||||
case []byte:
|
||||
return "[]byte"
|
||||
}
|
||||
return "other"
|
||||
}
|
||||
|
||||
func main() {
|
||||
check("switchStringOrBytes string", switchStringOrBytes("hello"), "string")
|
||||
check("switchStringOrBytes []byte", switchStringOrBytes([]byte("hello")), "[]byte")
|
||||
|
||||
check("switchThree string", switchThree("x"), "string")
|
||||
check("switchThree []byte", switchThree([]byte("x")), "[]byte")
|
||||
check("switchThree int", switchThree(42), "int")
|
||||
|
||||
check("switchNamed string", switchNamed("hi"), "string")
|
||||
check("switchNamed MyString", switchNamed(MyString("hi")), "MyString")
|
||||
|
||||
v1, ok1 := commaOkString("hello")
|
||||
check("commaOkString[string] val", v1, "hello")
|
||||
checkBool("commaOkString[string] ok", ok1, true)
|
||||
v2, ok2 := commaOkString([]byte("hello"))
|
||||
check("commaOkString[[]byte] val", v2, "")
|
||||
checkBool("commaOkString[[]byte] ok", ok2, false)
|
||||
|
||||
v3, ok3 := commaOkBytes([]byte("world"))
|
||||
check("commaOkBytes[[]byte] val", string(v3), "world")
|
||||
checkBool("commaOkBytes[[]byte] ok", ok3, true)
|
||||
v4, ok4 := commaOkBytes("world")
|
||||
check("commaOkBytes[string] val", string(v4), "")
|
||||
checkBool("commaOkBytes[string] ok", ok4, false)
|
||||
|
||||
check("commaOkChain string", commaOkChain("x"), "string")
|
||||
check("commaOkChain []byte", commaOkChain([]byte("x")), "[]byte")
|
||||
check("commaOkChain int", commaOkChain(42), "int")
|
||||
|
||||
// Intermediate variable: comma-ok
|
||||
v5, ok5 := commaOkViaVar("hello")
|
||||
check("commaOkViaVar[string] val", v5, "hello")
|
||||
checkBool("commaOkViaVar[string] ok", ok5, true)
|
||||
v6, ok6 := commaOkViaVar([]byte("hello"))
|
||||
check("commaOkViaVar[[]byte] val", v6, "")
|
||||
checkBool("commaOkViaVar[[]byte] ok", ok6, false)
|
||||
|
||||
// Intermediate variable: type switch
|
||||
check("switchViaVar string", switchViaVar("x"), "string")
|
||||
check("switchViaVar []byte", switchViaVar([]byte("x")), "[]byte")
|
||||
|
||||
// All cases impossible: int instantiation hits default
|
||||
check("switchFallsToDefault string", switchFallsToDefault("x"), "string")
|
||||
check("switchFallsToDefault []byte", switchFallsToDefault([]byte("x")), "[]byte")
|
||||
check("switchFallsToDefault int", switchFallsToDefault(42), "other")
|
||||
}
|
||||
|
||||
func check(name, got, want string) {
|
||||
if got != want {
|
||||
panic(fmt.Sprintf("%s: got %q, want %q", name, got, want))
|
||||
}
|
||||
}
|
||||
|
||||
func checkBool(name string, got, want bool) {
|
||||
if got != want {
|
||||
panic(fmt.Sprintf("%s: got %v, want %v", name, got, want))
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue