mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
cmd/compile: fix mishandling of unsafe-uintptr arguments in go/defer
Currently, the statement:
go g(uintptr(f()))
gets rewritten into:
tmp := f()
newproc(8, g, uintptr(tmp))
runtime.KeepAlive(tmp)
which doesn't guarantee that tmp is still alive by time the g call is
scheduled to run.
This CL fixes the issue, by wrapping g call in a closure:
go func(p unsafe.Pointer) {
g(uintptr(p))
}(f())
then this will be rewritten into:
tmp := f()
go func(p unsafe.Pointer) {
g(uintptr(p))
runtime.KeepAlive(p)
}(tmp)
runtime.KeepAlive(tmp) // superfluous, but harmless
So the unsafe.Pointer p will be kept alive at the time g call runs.
Updates #24491
Change-Id: Ic10821251cbb1b0073daec92b82a866c6ebaf567
Reviewed-on: https://go-review.googlesource.com/c/go/+/253457
Run-TryBot: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
1e6ad65b43
commit
bdb480fd62
4 changed files with 117 additions and 24 deletions
|
|
@ -502,6 +502,7 @@ func (o *Order) call(n *Node) {
|
||||||
x := o.copyExpr(arg.Left, arg.Left.Type, false)
|
x := o.copyExpr(arg.Left, arg.Left.Type, false)
|
||||||
x.Name.SetKeepalive(true)
|
x.Name.SetKeepalive(true)
|
||||||
arg.Left = x
|
arg.Left = x
|
||||||
|
n.SetNeedsWrapper(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -141,19 +141,20 @@ const (
|
||||||
nodeInitorder, _ // tracks state during init1; two bits
|
nodeInitorder, _ // tracks state during init1; two bits
|
||||||
_, _ // second nodeInitorder bit
|
_, _ // second nodeInitorder bit
|
||||||
_, nodeHasBreak
|
_, nodeHasBreak
|
||||||
_, nodeNoInline // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
|
_, nodeNoInline // used internally by inliner to indicate that a function call should not be inlined; set for OCALLFUNC and OCALLMETH only
|
||||||
_, nodeImplicit // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
|
_, nodeImplicit // implicit OADDR or ODEREF; ++/-- statement represented as OASOP; or ANDNOT lowered to OAND
|
||||||
_, nodeIsDDD // is the argument variadic
|
_, nodeIsDDD // is the argument variadic
|
||||||
_, nodeDiag // already printed error about this
|
_, nodeDiag // already printed error about this
|
||||||
_, nodeColas // OAS resulting from :=
|
_, nodeColas // OAS resulting from :=
|
||||||
_, nodeNonNil // guaranteed to be non-nil
|
_, nodeNonNil // guaranteed to be non-nil
|
||||||
_, nodeTransient // storage can be reused immediately after this statement
|
_, nodeTransient // storage can be reused immediately after this statement
|
||||||
_, nodeBounded // bounds check unnecessary
|
_, nodeBounded // bounds check unnecessary
|
||||||
_, nodeHasCall // expression contains a function call
|
_, nodeHasCall // expression contains a function call
|
||||||
_, nodeLikely // if statement condition likely
|
_, nodeLikely // if statement condition likely
|
||||||
_, nodeHasVal // node.E contains a Val
|
_, nodeHasVal // node.E contains a Val
|
||||||
_, nodeHasOpt // node.E contains an Opt
|
_, nodeHasOpt // node.E contains an Opt
|
||||||
_, nodeEmbedded // ODCLFIELD embedded type
|
_, nodeEmbedded // ODCLFIELD embedded type
|
||||||
|
_, nodeNeedsWrapper // OCALLxxx node that needs to be wrapped
|
||||||
)
|
)
|
||||||
|
|
||||||
func (n *Node) Class() Class { return Class(n.flags.get3(nodeClass)) }
|
func (n *Node) Class() Class { return Class(n.flags.get3(nodeClass)) }
|
||||||
|
|
@ -286,6 +287,20 @@ func (n *Node) SetIota(x int64) {
|
||||||
n.Xoffset = x
|
n.Xoffset = x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Node) NeedsWrapper() bool {
|
||||||
|
return n.flags&nodeNeedsWrapper != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNeedsWrapper indicates that OCALLxxx node needs to be wrapped by a closure.
|
||||||
|
func (n *Node) SetNeedsWrapper(b bool) {
|
||||||
|
switch n.Op {
|
||||||
|
case OCALLFUNC, OCALLMETH, OCALLINTER:
|
||||||
|
default:
|
||||||
|
Fatalf("Node.SetNeedsWrapper %v", n.Op)
|
||||||
|
}
|
||||||
|
n.flags.set(nodeNeedsWrapper, b)
|
||||||
|
}
|
||||||
|
|
||||||
// mayBeShared reports whether n may occur in multiple places in the AST.
|
// mayBeShared reports whether n may occur in multiple places in the AST.
|
||||||
// Extra care must be taken when mutating such a node.
|
// Extra care must be taken when mutating such a node.
|
||||||
func (n *Node) mayBeShared() bool {
|
func (n *Node) mayBeShared() bool {
|
||||||
|
|
|
||||||
|
|
@ -232,7 +232,11 @@ func walkstmt(n *Node) *Node {
|
||||||
n.Left = copyany(n.Left, &n.Ninit, true)
|
n.Left = copyany(n.Left, &n.Ninit, true)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
n.Left = walkexpr(n.Left, &n.Ninit)
|
if n.Left.NeedsWrapper() {
|
||||||
|
n.Left = wrapCall(n.Left, &n.Ninit)
|
||||||
|
} else {
|
||||||
|
n.Left = walkexpr(n.Left, &n.Ninit)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case OFOR, OFORUNTIL:
|
case OFOR, OFORUNTIL:
|
||||||
|
|
@ -3857,6 +3861,14 @@ func candiscard(n *Node) bool {
|
||||||
// builtin(a1, a2, a3)
|
// builtin(a1, a2, a3)
|
||||||
// }(x, y, z)
|
// }(x, y, z)
|
||||||
// for print, println, and delete.
|
// for print, println, and delete.
|
||||||
|
//
|
||||||
|
// Rewrite
|
||||||
|
// go f(x, y, uintptr(unsafe.Pointer(z)))
|
||||||
|
// into
|
||||||
|
// go func(a1, a2, a3) {
|
||||||
|
// builtin(a1, a2, uintptr(a3))
|
||||||
|
// }(x, y, unsafe.Pointer(z))
|
||||||
|
// for function contains unsafe-uintptr arguments.
|
||||||
|
|
||||||
var wrapCall_prgen int
|
var wrapCall_prgen int
|
||||||
|
|
||||||
|
|
@ -3868,9 +3880,17 @@ func wrapCall(n *Node, init *Nodes) *Node {
|
||||||
init.AppendNodes(&n.Ninit)
|
init.AppendNodes(&n.Ninit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isBuiltinCall := n.Op != OCALLFUNC && n.Op != OCALLMETH && n.Op != OCALLINTER
|
||||||
|
// origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion.
|
||||||
|
origArgs := make([]*Node, n.List.Len())
|
||||||
t := nod(OTFUNC, nil, nil)
|
t := nod(OTFUNC, nil, nil)
|
||||||
for i, arg := range n.List.Slice() {
|
for i, arg := range n.List.Slice() {
|
||||||
s := lookupN("a", i)
|
s := lookupN("a", i)
|
||||||
|
if !isBuiltinCall && arg.Op == OCONVNOP && arg.Type.Etype == TUINTPTR && arg.Left.Type.Etype == TUNSAFEPTR {
|
||||||
|
origArgs[i] = arg
|
||||||
|
arg = arg.Left
|
||||||
|
n.List.SetIndex(i, arg)
|
||||||
|
}
|
||||||
t.List.Append(symfield(s, arg.Type))
|
t.List.Append(symfield(s, arg.Type))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3878,10 +3898,22 @@ func wrapCall(n *Node, init *Nodes) *Node {
|
||||||
sym := lookupN("wrap·", wrapCall_prgen)
|
sym := lookupN("wrap·", wrapCall_prgen)
|
||||||
fn := dclfunc(sym, t)
|
fn := dclfunc(sym, t)
|
||||||
|
|
||||||
a := nod(n.Op, nil, nil)
|
args := paramNnames(t.Type)
|
||||||
a.List.Set(paramNnames(t.Type))
|
for i, origArg := range origArgs {
|
||||||
a = typecheck(a, ctxStmt)
|
if origArg == nil {
|
||||||
fn.Nbody.Set1(a)
|
continue
|
||||||
|
}
|
||||||
|
arg := nod(origArg.Op, args[i], nil)
|
||||||
|
arg.Type = origArg.Type
|
||||||
|
args[i] = arg
|
||||||
|
}
|
||||||
|
call := nod(n.Op, nil, nil)
|
||||||
|
if !isBuiltinCall {
|
||||||
|
call.Op = OCALL
|
||||||
|
call.Left = n.Left
|
||||||
|
}
|
||||||
|
call.List.Set(args)
|
||||||
|
fn.Nbody.Set1(call)
|
||||||
|
|
||||||
funcbody()
|
funcbody()
|
||||||
|
|
||||||
|
|
@ -3889,12 +3921,12 @@ func wrapCall(n *Node, init *Nodes) *Node {
|
||||||
typecheckslice(fn.Nbody.Slice(), ctxStmt)
|
typecheckslice(fn.Nbody.Slice(), ctxStmt)
|
||||||
xtop = append(xtop, fn)
|
xtop = append(xtop, fn)
|
||||||
|
|
||||||
a = nod(OCALL, nil, nil)
|
call = nod(OCALL, nil, nil)
|
||||||
a.Left = fn.Func.Nname
|
call.Left = fn.Func.Nname
|
||||||
a.List.Set(n.List.Slice())
|
call.List.Set(n.List.Slice())
|
||||||
a = typecheck(a, ctxStmt)
|
call = typecheck(call, ctxStmt)
|
||||||
a = walkexpr(a, init)
|
call = walkexpr(call, init)
|
||||||
return a
|
return call
|
||||||
}
|
}
|
||||||
|
|
||||||
// substArgTypes substitutes the given list of types for
|
// substArgTypes substitutes the given list of types for
|
||||||
|
|
|
||||||
45
test/fixedbugs/issue24491.go
Normal file
45
test/fixedbugs/issue24491.go
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
// run
|
||||||
|
|
||||||
|
// Copyright 2020 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// This test makes sure unsafe-uintptr arguments are handled correctly.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var done = make(chan bool, 1)
|
||||||
|
|
||||||
|
func setup() unsafe.Pointer {
|
||||||
|
s := "ok"
|
||||||
|
runtime.SetFinalizer(&s, func(p *string) { *p = "FAIL" })
|
||||||
|
return unsafe.Pointer(&s)
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:noinline
|
||||||
|
//go:uintptrescapes
|
||||||
|
func test(s string, p uintptr) {
|
||||||
|
runtime.GC()
|
||||||
|
if *(*string)(unsafe.Pointer(p)) != "ok" {
|
||||||
|
panic(s + " return unexpected result")
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
test("normal", uintptr(setup()))
|
||||||
|
<-done
|
||||||
|
|
||||||
|
go test("go", uintptr(setup()))
|
||||||
|
<-done
|
||||||
|
|
||||||
|
func() {
|
||||||
|
defer test("defer", uintptr(setup()))
|
||||||
|
}()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue