mirror of
https://github.com/golang/go.git
synced 2025-10-19 19:13:18 +00:00
[release-branch.go1.25] database/sql: avoid closing Rows while scan is in progress
A database/sql/driver.Rows can return database-owned data from Rows.Next. The driver.Rows documentation doesn't explicitly document the lifetime guarantees for this data, but a reasonable expectation is that the caller of Next should only access it until the next call to Rows.Close or Rows.Next. Avoid violating that constraint when a query is cancelled while a call to database/sql.Rows.Scan (note the difference between the two different Rows types!) is in progress. We previously took care to avoid closing a driver.Rows while the user has access to driver-owned memory via a RawData, but we could still close a driver.Rows while a Scan call was in the process of reading previously-returned driver-owned data. Update the fake DB used in database/sql tests to invalidate returned data to help catch other places we might be incorrectly retaining it. Updates #74831 Fixes #74834 Change-Id: Ice45b5fad51b679c38e3e1d21ef39156b56d6037 Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/2540 Reviewed-by: Roland Shoemaker <bracewell@google.com> Reviewed-by: Neal Patel <nealpatel@google.com> Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/2600 Reviewed-on: https://go-review.googlesource.com/c/go/+/693559 Auto-Submit: Dmitri Shuralyov <dmitshur@google.com> TryBot-Bypass: Dmitri Shuralyov <dmitshur@golang.org> Reviewed-by: Mark Freeman <markfreeman@google.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
parent
ebee011a54
commit
6961c3775f
4 changed files with 92 additions and 51 deletions
|
@ -335,7 +335,6 @@ func convertAssignRows(dest, src any, rows *Rows) error {
|
||||||
if rows == nil {
|
if rows == nil {
|
||||||
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
|
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
|
||||||
}
|
}
|
||||||
rows.closemu.Lock()
|
|
||||||
*d = Rows{
|
*d = Rows{
|
||||||
dc: rows.dc,
|
dc: rows.dc,
|
||||||
releaseConn: func(error) {},
|
releaseConn: func(error) {},
|
||||||
|
@ -351,7 +350,6 @@ func convertAssignRows(dest, src any, rows *Rows) error {
|
||||||
parentCancel()
|
parentCancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rows.closemu.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package sql
|
package sql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -15,7 +16,6 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -91,8 +91,6 @@ func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
|
||||||
type fakeDB struct {
|
type fakeDB struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
useRawBytes atomic.Bool
|
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
tables map[string]*table
|
tables map[string]*table
|
||||||
badConn bool
|
badConn bool
|
||||||
|
@ -684,8 +682,6 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
|
||||||
switch cmd {
|
switch cmd {
|
||||||
case "WIPE":
|
case "WIPE":
|
||||||
// Nothing
|
// Nothing
|
||||||
case "USE_RAWBYTES":
|
|
||||||
c.db.useRawBytes.Store(true)
|
|
||||||
case "SELECT":
|
case "SELECT":
|
||||||
stmt, err = c.prepareSelect(stmt, parts)
|
stmt, err = c.prepareSelect(stmt, parts)
|
||||||
case "CREATE":
|
case "CREATE":
|
||||||
|
@ -789,9 +785,6 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
|
||||||
case "WIPE":
|
case "WIPE":
|
||||||
db.wipe()
|
db.wipe()
|
||||||
return driver.ResultNoRows, nil
|
return driver.ResultNoRows, nil
|
||||||
case "USE_RAWBYTES":
|
|
||||||
s.c.db.useRawBytes.Store(true)
|
|
||||||
return driver.ResultNoRows, nil
|
|
||||||
case "CREATE":
|
case "CREATE":
|
||||||
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
|
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1076,10 +1069,9 @@ type rowsCursor struct {
|
||||||
errPos int
|
errPos int
|
||||||
err error
|
err error
|
||||||
|
|
||||||
// a clone of slices to give out to clients, indexed by the
|
// Data returned to clients.
|
||||||
// original slice's first byte address. we clone them
|
// We clone and stash it here so it can be invalidated by Close and Next.
|
||||||
// just so we're able to corrupt them on close.
|
driverOwnedMemory [][]byte
|
||||||
bytesClone map[*byte][]byte
|
|
||||||
|
|
||||||
// Every operation writes to line to enable the race detector
|
// Every operation writes to line to enable the race detector
|
||||||
// check for data races.
|
// check for data races.
|
||||||
|
@ -1096,9 +1088,19 @@ func (rc *rowsCursor) touchMem() {
|
||||||
rc.line++
|
rc.line++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rc *rowsCursor) invalidateDriverOwnedMemory() {
|
||||||
|
for _, buf := range rc.driverOwnedMemory {
|
||||||
|
for i := range buf {
|
||||||
|
buf[i] = 'x'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rc.driverOwnedMemory = nil
|
||||||
|
}
|
||||||
|
|
||||||
func (rc *rowsCursor) Close() error {
|
func (rc *rowsCursor) Close() error {
|
||||||
rc.touchMem()
|
rc.touchMem()
|
||||||
rc.parentMem.touchMem()
|
rc.parentMem.touchMem()
|
||||||
|
rc.invalidateDriverOwnedMemory()
|
||||||
rc.closed = true
|
rc.closed = true
|
||||||
return rc.closeErr
|
return rc.closeErr
|
||||||
}
|
}
|
||||||
|
@ -1129,6 +1131,8 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
|
||||||
if rc.posRow >= len(rc.rows[rc.posSet]) {
|
if rc.posRow >= len(rc.rows[rc.posSet]) {
|
||||||
return io.EOF // per interface spec
|
return io.EOF // per interface spec
|
||||||
}
|
}
|
||||||
|
// Corrupt any previously returned bytes.
|
||||||
|
rc.invalidateDriverOwnedMemory()
|
||||||
for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
|
for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
|
||||||
// TODO(bradfitz): convert to subset types? naah, I
|
// TODO(bradfitz): convert to subset types? naah, I
|
||||||
// think the subset types should only be input to
|
// think the subset types should only be input to
|
||||||
|
@ -1136,20 +1140,13 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
|
||||||
// a wider range of types coming out of drivers. all
|
// a wider range of types coming out of drivers. all
|
||||||
// for ease of drivers, and to prevent drivers from
|
// for ease of drivers, and to prevent drivers from
|
||||||
// messing up conversions or doing them differently.
|
// messing up conversions or doing them differently.
|
||||||
|
if bs, ok := v.([]byte); ok {
|
||||||
|
// Clone []bytes and stash for later invalidation.
|
||||||
|
bs = bytes.Clone(bs)
|
||||||
|
rc.driverOwnedMemory = append(rc.driverOwnedMemory, bs)
|
||||||
|
v = bs
|
||||||
|
}
|
||||||
dest[i] = v
|
dest[i] = v
|
||||||
|
|
||||||
if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
|
|
||||||
if rc.bytesClone == nil {
|
|
||||||
rc.bytesClone = make(map[*byte][]byte)
|
|
||||||
}
|
|
||||||
clone, ok := rc.bytesClone[&bs[0]]
|
|
||||||
if !ok {
|
|
||||||
clone = make([]byte, len(bs))
|
|
||||||
copy(clone, bs)
|
|
||||||
rc.bytesClone[&bs[0]] = clone
|
|
||||||
}
|
|
||||||
dest[i] = clone
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -3368,38 +3368,36 @@ func (rs *Rows) Scan(dest ...any) error {
|
||||||
// without calling Next.
|
// without calling Next.
|
||||||
return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
|
return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
|
||||||
}
|
}
|
||||||
|
|
||||||
rs.closemu.RLock()
|
rs.closemu.RLock()
|
||||||
|
|
||||||
if rs.lasterr != nil && rs.lasterr != io.EOF {
|
|
||||||
rs.closemu.RUnlock()
|
|
||||||
return rs.lasterr
|
|
||||||
}
|
|
||||||
if rs.closed {
|
|
||||||
err := rs.lasterrOrErrLocked(errRowsClosed)
|
|
||||||
rs.closemu.RUnlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if scanArgsContainRawBytes(dest) {
|
|
||||||
rs.closemuScanHold = true
|
|
||||||
rs.raw = rs.raw[:0]
|
rs.raw = rs.raw[:0]
|
||||||
|
err := rs.scanLocked(dest...)
|
||||||
|
if err == nil && scanArgsContainRawBytes(dest) {
|
||||||
|
rs.closemuScanHold = true
|
||||||
} else {
|
} else {
|
||||||
rs.closemu.RUnlock()
|
rs.closemu.RUnlock()
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *Rows) scanLocked(dest ...any) error {
|
||||||
|
if rs.lasterr != nil && rs.lasterr != io.EOF {
|
||||||
|
return rs.lasterr
|
||||||
|
}
|
||||||
|
if rs.closed {
|
||||||
|
return rs.lasterrOrErrLocked(errRowsClosed)
|
||||||
|
}
|
||||||
|
|
||||||
if rs.lastcols == nil {
|
if rs.lastcols == nil {
|
||||||
rs.closemuRUnlockIfHeldByScan()
|
|
||||||
return errors.New("sql: Scan called without calling Next")
|
return errors.New("sql: Scan called without calling Next")
|
||||||
}
|
}
|
||||||
if len(dest) != len(rs.lastcols) {
|
if len(dest) != len(rs.lastcols) {
|
||||||
rs.closemuRUnlockIfHeldByScan()
|
|
||||||
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
|
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, sv := range rs.lastcols {
|
for i, sv := range rs.lastcols {
|
||||||
err := convertAssignRows(dest[i], sv, rs)
|
err := convertAssignRows(dest[i], sv, rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rs.closemuRUnlockIfHeldByScan()
|
|
||||||
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
|
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package sql
|
package sql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -4434,10 +4435,6 @@ func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
|
||||||
db := newTestDB(t, "people")
|
db := newTestDB(t, "people")
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
||||||
if _, err := db.Exec("USE_RAWBYTES"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cancel used to call close asynchronously.
|
// cancel used to call close asynchronously.
|
||||||
// This test checks that it waits so as not to interfere with RawBytes.
|
// This test checks that it waits so as not to interfere with RawBytes.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -4529,6 +4526,61 @@ func TestContextCancelBetweenNextAndErr(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type testScanner struct {
|
||||||
|
scanf func(src any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
|
||||||
|
|
||||||
|
func TestContextCancelDuringScan(t *testing.T) {
|
||||||
|
db := newTestDB(t, "people")
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
scanStart := make(chan any)
|
||||||
|
scanEnd := make(chan error)
|
||||||
|
scanner := &testScanner{
|
||||||
|
scanf: func(src any) error {
|
||||||
|
scanStart <- src
|
||||||
|
return <-scanEnd
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a query, and pause it mid-scan.
|
||||||
|
want := []byte("Alice")
|
||||||
|
r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !r.Next() {
|
||||||
|
t.Fatalf("r.Next() = false, want true")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
r.Scan(scanner)
|
||||||
|
}()
|
||||||
|
got := <-scanStart
|
||||||
|
defer close(scanEnd)
|
||||||
|
gotBytes, ok := got.([]byte)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("r.Scan returned %T, want []byte", got)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotBytes, want) {
|
||||||
|
t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel the query.
|
||||||
|
// Sleep to give it a chance to finish canceling.
|
||||||
|
cancel()
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Cancelling the query should not have changed the result.
|
||||||
|
if !bytes.Equal(gotBytes, want) {
|
||||||
|
t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNilErrorAfterClose(t *testing.T) {
|
func TestNilErrorAfterClose(t *testing.T) {
|
||||||
db := newTestDB(t, "people")
|
db := newTestDB(t, "people")
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
@ -4562,10 +4614,6 @@ func TestRawBytesReuse(t *testing.T) {
|
||||||
db := newTestDB(t, "people")
|
db := newTestDB(t, "people")
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
||||||
if _, err := db.Exec("USE_RAWBYTES"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var raw RawBytes
|
var raw RawBytes
|
||||||
|
|
||||||
// The RawBytes in this query aliases driver-owned memory.
|
// The RawBytes in this query aliases driver-owned memory.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue