mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
database/sql: ensure all driver interfaces are called under single lock
Russ pointed out in a previous CL golang.org/cl/65731 that not only was the locking incomplete, previous changes did not correctly lock driver calls in other sections. After inspecting driverConn, driverStmt, driverResult, Tx, and Rows structs where driver interfaces are stored, I discovered a few more places that failed to lock driver calls. The largest of these was the parameter type converter "driverArgs". driverArgs was typically called right before another call to the driver in a locked region, so I made the entire driverArgs expect a locked driver mutex and combined the region. This should not be a problem because the connection is pulled out of the connection pool either way so there shouldn't be contention. Fixes #21117 Change-Id: I88d46f74dca25fb11a30f0bf8e79785a73133d23 Reviewed-on: https://go-review.googlesource.com/71433 Run-TryBot: Daniel Theophanes <kardianos@gmail.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Russ Cox <rsc@golang.org>
This commit is contained in:
parent
986582126a
commit
1126d1483f
5 changed files with 55 additions and 39 deletions
|
|
@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
|
|||
}
|
||||
if ok {
|
||||
var nvdargs []driver.NamedValue
|
||||
nvdargs, err = driverArgs(dc.ci, nil, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resi driver.Result
|
||||
withLock(dc, func() {
|
||||
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
|
||||
})
|
||||
if err != driver.ErrSkip {
|
||||
|
|
@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu
|
|||
queryer, ok = dc.ci.(driver.Queryer)
|
||||
}
|
||||
if ok {
|
||||
nvdargs, err := driverArgs(dc.ci, nil, args)
|
||||
if err != nil {
|
||||
releaseConn(err)
|
||||
return nil, err
|
||||
}
|
||||
var nvdargs []driver.NamedValue
|
||||
var rowsi driver.Rows
|
||||
var err error
|
||||
withLock(dc, func() {
|
||||
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
|
||||
})
|
||||
if err != driver.ErrSkip {
|
||||
|
|
@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
|
|||
stmt.mu.Unlock()
|
||||
|
||||
if si == nil {
|
||||
cs, err := stmt.prepareOnConnLocked(ctx, dc)
|
||||
withLock(dc, func() {
|
||||
var ds *driverStmt
|
||||
ds, err = stmt.prepareOnConnLocked(ctx, dc)
|
||||
si = ds.si
|
||||
})
|
||||
if err != nil {
|
||||
return &Stmt{stickyErr: err}
|
||||
}
|
||||
si = cs.si
|
||||
}
|
||||
parentStmt = stmt
|
||||
}
|
||||
|
|
@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
|||
}
|
||||
|
||||
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
|
||||
dargs, err := driverArgs(ci, ds, args)
|
||||
ds.Lock()
|
||||
defer ds.Unlock()
|
||||
|
||||
dargs, err := driverArgsConnLocked(ci, ds, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ds.Lock()
|
||||
defer ds.Unlock()
|
||||
|
||||
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
|||
}
|
||||
|
||||
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
|
||||
var want int
|
||||
withLock(ds, func() {
|
||||
want = ds.si.NumInput()
|
||||
})
|
||||
ds.Lock()
|
||||
defer ds.Unlock()
|
||||
|
||||
want := ds.si.NumInput()
|
||||
|
||||
// -1 means the driver doesn't know how to count the number of
|
||||
// placeholders, so we won't sanity check input here and instead let the
|
||||
|
|
@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg
|
|||
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
|
||||
}
|
||||
|
||||
dargs, err := driverArgs(ci, ds, args)
|
||||
dargs, err := driverArgsConnLocked(ci, ds, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ds.Lock()
|
||||
defer ds.Unlock()
|
||||
|
||||
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
|
|||
if rs.closed {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// Lock the driver connection before calling the driver interface
|
||||
// rowsi to prevent a Tx from rolling back the connection at the same time.
|
||||
rs.dc.Lock()
|
||||
defer rs.dc.Unlock()
|
||||
|
||||
if rs.lastcols == nil {
|
||||
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
|
||||
}
|
||||
|
||||
rs.lasterr = rs.rowsi.Next(rs.lastcols)
|
||||
if rs.lasterr != nil {
|
||||
// Close the connection if there is a driver error.
|
||||
|
|
@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool {
|
|||
doClose = true
|
||||
return false
|
||||
}
|
||||
|
||||
// Lock the driver connection before calling the driver interface
|
||||
// rowsi to prevent a Tx from rolling back the connection at the same time.
|
||||
rs.dc.Lock()
|
||||
defer rs.dc.Unlock()
|
||||
|
||||
rs.lasterr = nextResultSet.NextResultSet()
|
||||
if rs.lasterr != nil {
|
||||
doClose = true
|
||||
|
|
@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) {
|
|||
if rs.rowsi == nil {
|
||||
return nil, errors.New("sql: no Rows available")
|
||||
}
|
||||
rs.dc.Lock()
|
||||
defer rs.dc.Unlock()
|
||||
|
||||
return rs.rowsi.Columns(), nil
|
||||
}
|
||||
|
||||
|
|
@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
|
|||
if rs.rowsi == nil {
|
||||
return nil, errors.New("sql: no Rows available")
|
||||
}
|
||||
return rowsColumnInfoSetup(rs.rowsi), nil
|
||||
rs.dc.Lock()
|
||||
defer rs.dc.Unlock()
|
||||
|
||||
return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
|
||||
}
|
||||
|
||||
// ColumnType contains the name and type of a column.
|
||||
|
|
@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
|
|||
return ci.databaseType
|
||||
}
|
||||
|
||||
func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
|
||||
func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
|
||||
names := rowsi.Columns()
|
||||
|
||||
list := make([]*ColumnType, len(names))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue