cmd/compile: fix mishandling of unsafe-uintptr arguments in go/defer

Currently, the statement:

	go g(uintptr(f()))

gets rewritten into:

	tmp := f()
	newproc(8, g, uintptr(tmp))
	runtime.KeepAlive(tmp)

which doesn't guarantee that tmp is still alive by time the g call is
scheduled to run.

This CL fixes the issue, by wrapping g call in a closure:

	go func(p unsafe.Pointer) {
		g(uintptr(p))
	}(f())

then this will be rewritten into:

	tmp := f()
	go func(p unsafe.Pointer) {
		g(uintptr(p))
		runtime.KeepAlive(p)
	}(tmp)
	runtime.KeepAlive(tmp)  // superfluous, but harmless

So the unsafe.Pointer p will be kept alive at the time g call runs.

Updates #24491

Change-Id: Ic10821251cbb1b0073daec92b82a866c6ebaf567
Reviewed-on: https://go-review.googlesource.com/c/go/+/253457
Run-TryBot: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Cuong Manh Le 2020-09-08 15:28:43 +07:00
parent 1e6ad65b43
commit bdb480fd62
4 changed files with 117 additions and 24 deletions

View file

@ -502,6 +502,7 @@ func (o *Order) call(n *Node) {
x := o.copyExpr(arg.Left, arg.Left.Type, false)
x.Name.SetKeepalive(true)
arg.Left = x
n.SetNeedsWrapper(true)
}
}

View file

@ -154,6 +154,7 @@ const (
_, nodeHasVal // node.E contains a Val
_, nodeHasOpt // node.E contains an Opt
_, nodeEmbedded // ODCLFIELD embedded type
_, nodeNeedsWrapper // OCALLxxx node that needs to be wrapped
)
func (n *Node) Class() Class { return Class(n.flags.get3(nodeClass)) }
@ -286,6 +287,20 @@ func (n *Node) SetIota(x int64) {
n.Xoffset = x
}
func (n *Node) NeedsWrapper() bool {
return n.flags&nodeNeedsWrapper != 0
}
// SetNeedsWrapper indicates that OCALLxxx node needs to be wrapped by a closure.
func (n *Node) SetNeedsWrapper(b bool) {
switch n.Op {
case OCALLFUNC, OCALLMETH, OCALLINTER:
default:
Fatalf("Node.SetNeedsWrapper %v", n.Op)
}
n.flags.set(nodeNeedsWrapper, b)
}
// mayBeShared reports whether n may occur in multiple places in the AST.
// Extra care must be taken when mutating such a node.
func (n *Node) mayBeShared() bool {

View file

@ -232,8 +232,12 @@ func walkstmt(n *Node) *Node {
n.Left = copyany(n.Left, &n.Ninit, true)
default:
if n.Left.NeedsWrapper() {
n.Left = wrapCall(n.Left, &n.Ninit)
} else {
n.Left = walkexpr(n.Left, &n.Ninit)
}
}
case OFOR, OFORUNTIL:
if n.Left != nil {
@ -3857,6 +3861,14 @@ func candiscard(n *Node) bool {
// builtin(a1, a2, a3)
// }(x, y, z)
// for print, println, and delete.
//
// Rewrite
// go f(x, y, uintptr(unsafe.Pointer(z)))
// into
// go func(a1, a2, a3) {
// builtin(a1, a2, uintptr(a3))
// }(x, y, unsafe.Pointer(z))
// for function contains unsafe-uintptr arguments.
var wrapCall_prgen int
@ -3868,9 +3880,17 @@ func wrapCall(n *Node, init *Nodes) *Node {
init.AppendNodes(&n.Ninit)
}
isBuiltinCall := n.Op != OCALLFUNC && n.Op != OCALLMETH && n.Op != OCALLINTER
// origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion.
origArgs := make([]*Node, n.List.Len())
t := nod(OTFUNC, nil, nil)
for i, arg := range n.List.Slice() {
s := lookupN("a", i)
if !isBuiltinCall && arg.Op == OCONVNOP && arg.Type.Etype == TUINTPTR && arg.Left.Type.Etype == TUNSAFEPTR {
origArgs[i] = arg
arg = arg.Left
n.List.SetIndex(i, arg)
}
t.List.Append(symfield(s, arg.Type))
}
@ -3878,10 +3898,22 @@ func wrapCall(n *Node, init *Nodes) *Node {
sym := lookupN("wrap·", wrapCall_prgen)
fn := dclfunc(sym, t)
a := nod(n.Op, nil, nil)
a.List.Set(paramNnames(t.Type))
a = typecheck(a, ctxStmt)
fn.Nbody.Set1(a)
args := paramNnames(t.Type)
for i, origArg := range origArgs {
if origArg == nil {
continue
}
arg := nod(origArg.Op, args[i], nil)
arg.Type = origArg.Type
args[i] = arg
}
call := nod(n.Op, nil, nil)
if !isBuiltinCall {
call.Op = OCALL
call.Left = n.Left
}
call.List.Set(args)
fn.Nbody.Set1(call)
funcbody()
@ -3889,12 +3921,12 @@ func wrapCall(n *Node, init *Nodes) *Node {
typecheckslice(fn.Nbody.Slice(), ctxStmt)
xtop = append(xtop, fn)
a = nod(OCALL, nil, nil)
a.Left = fn.Func.Nname
a.List.Set(n.List.Slice())
a = typecheck(a, ctxStmt)
a = walkexpr(a, init)
return a
call = nod(OCALL, nil, nil)
call.Left = fn.Func.Nname
call.List.Set(n.List.Slice())
call = typecheck(call, ctxStmt)
call = walkexpr(call, init)
return call
}
// substArgTypes substitutes the given list of types for

View file

@ -0,0 +1,45 @@
// run
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This test makes sure unsafe-uintptr arguments are handled correctly.
package main
import (
"runtime"
"unsafe"
)
var done = make(chan bool, 1)
func setup() unsafe.Pointer {
s := "ok"
runtime.SetFinalizer(&s, func(p *string) { *p = "FAIL" })
return unsafe.Pointer(&s)
}
//go:noinline
//go:uintptrescapes
func test(s string, p uintptr) {
runtime.GC()
if *(*string)(unsafe.Pointer(p)) != "ok" {
panic(s + " return unexpected result")
}
done <- true
}
func main() {
test("normal", uintptr(setup()))
<-done
go test("go", uintptr(setup()))
<-done
func() {
defer test("defer", uintptr(setup()))
}()
<-done
}