cmd/cgo: recognize unsafe.{StringData,SliceData}

A simple call to unsafe.StringData can't contain any pointers.

When looking for field references, a call to unsafe.StringData or
unsafe.SliceData can be treated as a type conversion.

In order to make unsafe.SliceData useful, recognize slice expressions
when calling C functions.

Fixes #59954

Change-Id: I08a3ace7882073284c1d46a5210582a2521b0b4e
Reviewed-on: https://go-review.googlesource.com/c/go/+/493556
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
Ian Lance Taylor 2023-05-08 12:45:42 -07:00 committed by Gopher Robot
parent 2f1e643229
commit c7aa48eced
2 changed files with 124 additions and 4 deletions

View file

@ -938,7 +938,7 @@ func (p *Package) rewriteCall(f *File, call *Call) (string, bool) {
// constants to the parameter type, to avoid a type mismatch.
ptype := p.rewriteUnsafe(param.Go)
if !p.needsPointerCheck(f, param.Go, args[i]) || param.BadPointer {
if !p.needsPointerCheck(f, param.Go, args[i]) || param.BadPointer || p.checkUnsafeStringData(args[i]) {
if ptype != param.Go {
needsUnsafe = true
}
@ -957,6 +957,11 @@ func (p *Package) rewriteCall(f *File, call *Call) (string, bool) {
continue
}
// Check for a[:].
if p.checkSlice(&sb, &sbCheck, arg, i) {
continue
}
fmt.Fprintf(&sb, "_cgo%d := %s; ", i, gofmtPos(arg, origArg.Pos()))
fmt.Fprintf(&sbCheck, "_cgoCheckPointer(_cgo%d, nil); ", i)
}
@ -1178,7 +1183,10 @@ func (p *Package) checkIndex(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) boo
x := arg
for {
c, ok := x.(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
if !ok || len(c.Args) != 1 {
break
}
if !p.isType(c.Fun) && !p.isUnsafeData(c.Fun, false) {
break
}
x = c.Args[0]
@ -1232,7 +1240,10 @@ func (p *Package) checkAddr(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool
px := &arg
for {
c, ok := (*px).(*ast.CallExpr)
if !ok || len(c.Args) != 1 || !p.isType(c.Fun) {
if !ok || len(c.Args) != 1 {
break
}
if !p.isType(c.Fun) && !p.isUnsafeData(c.Fun, false) {
break
}
px = &c.Args[0]
@ -1255,6 +1266,71 @@ func (p *Package) checkAddr(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool
return true
}
// checkSlice checks whether arg has the form x[i:j], possibly inside
// type conversions. If so, it writes
//
// _cgoSliceNN := x[i:j]
// _cgoNN := _cgoSliceNN // with type conversions, if any
//
// to sb, and writes
//
// _cgoCheckPointer(_cgoSliceNN, true)
//
// to sbCheck, and returns true. This tells _cgoCheckPointer to check
// just the contents of the slice being passed, not any other part
// of the memory allocation.
func (p *Package) checkSlice(sb, sbCheck *bytes.Buffer, arg ast.Expr, i int) bool {
// Strip type conversions.
px := &arg
for {
c, ok := (*px).(*ast.CallExpr)
if !ok || len(c.Args) != 1 {
break
}
if !p.isType(c.Fun) && !p.isUnsafeData(c.Fun, false) {
break
}
px = &c.Args[0]
}
if _, ok := (*px).(*ast.SliceExpr); !ok {
return false
}
fmt.Fprintf(sb, "_cgoSlice%d := %s; ", i, gofmtPos(*px, (*px).Pos()))
origX := *px
*px = ast.NewIdent(fmt.Sprintf("_cgoSlice%d", i))
fmt.Fprintf(sb, "_cgo%d := %s; ", i, gofmtPos(arg, arg.Pos()))
*px = origX
// Use 0 == 0 to do the right thing in the unlikely event
// that "true" is shadowed.
fmt.Fprintf(sbCheck, "_cgoCheckPointer(_cgoSlice%d, 0 == 0); ", i)
return true
}
// checkUnsafeStringData checks for a call to unsafe.StringData.
// The result of that call can't contain a pointer so there is
// no need to call _cgoCheckPointer.
func (p *Package) checkUnsafeStringData(arg ast.Expr) bool {
x := arg
for {
c, ok := x.(*ast.CallExpr)
if !ok || len(c.Args) != 1 {
break
}
if p.isUnsafeData(c.Fun, true) {
return true
}
if !p.isType(c.Fun) {
break
}
x = c.Args[0]
}
return false
}
// isType reports whether the expression is definitely a type.
// This is conservative--it returns false for an unknown identifier.
func (p *Package) isType(t ast.Expr) bool {
@ -1299,6 +1375,28 @@ func (p *Package) isType(t ast.Expr) bool {
return false
}
// isUnsafeData reports whether the expression is unsafe.StringData
// or unsafe.SliceData. We can ignore these when checking for pointers
// because they don't change whether or not their argument contains
// any Go pointers. If onlyStringData is true we only check for StringData.
func (p *Package) isUnsafeData(x ast.Expr, onlyStringData bool) bool {
st, ok := x.(*ast.SelectorExpr)
if !ok {
return false
}
id, ok := st.X.(*ast.Ident)
if !ok {
return false
}
if id.Name != "unsafe" {
return false
}
if !onlyStringData && st.Sel.Name == "SliceData" {
return true
}
return st.Sel.Name == "StringData"
}
// isVariable reports whether x is a variable, possibly with field references.
func (p *Package) isVariable(x ast.Expr) bool {
switch x := x.(type) {

View file

@ -444,6 +444,28 @@ var ptrTests = []ptrTest{
body: `s := &S40{p: new(int)}; C.f40((*C.struct_S40i)(&s.a))`,
fail: false,
},
{
// Test that we handle unsafe.StringData.
name: "stringdata",
c: `void f41(void* p) {}`,
imports: []string{"unsafe"},
body: `s := struct { a [4]byte; p *int }{p: new(int)}; str := unsafe.String(&s.a[0], 4); C.f41(unsafe.Pointer(unsafe.StringData(str)))`,
fail: false,
},
{
name: "slicedata",
c: `void f42(void* p) {}`,
imports: []string{"unsafe"},
body: `s := []*byte{nil, new(byte)}; C.f42(unsafe.Pointer(unsafe.SliceData(s)))`,
fail: true,
},
{
name: "slicedata2",
c: `void f43(void* p) {}`,
imports: []string{"unsafe"},
body: `s := struct { a [4]byte; p *int }{p: new(int)}; C.f43(unsafe.Pointer(unsafe.SliceData(s.a[:])))`,
fail: false,
},
}
func TestPointerChecks(t *testing.T) {
@ -497,7 +519,7 @@ func buildPtrTests(t *testing.T, gopath string, cgocheck2 bool) (exe string) {
if err := os.MkdirAll(src, 0777); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(src, "go.mod"), []byte("module ptrtest"), 0666); err != nil {
if err := os.WriteFile(filepath.Join(src, "go.mod"), []byte("module ptrtest\ngo 1.20"), 0666); err != nil {
t.Fatal(err)
}