database/sql: ensure all driver Stmt are closed once

Previously  driver.Stmt could could be closed multiple times in
edge cases that drivers may not test for initially. Make their
job easier by ensuring the driver is only closed a single time.

Fixes #16019

Change-Id: I1e4777ef70697a849602e6ef9da73054a8feb4cd
Reviewed-on: https://go-review.googlesource.com/33352
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Daniel Theophanes 2016-11-17 09:33:31 -08:00 committed by Brad Fitzpatrick
parent e0942b76c7
commit 90b8a0ca2d
2 changed files with 88 additions and 77 deletions

View file

@ -330,7 +330,7 @@ type driverConn struct {
ci driver.Conn ci driver.Conn
closed bool closed bool
finalClosed bool // ci.Close has been called finalClosed bool // ci.Close has been called
openStmt map[driver.Stmt]bool openStmt map[*driverStmt]bool
// guarded by db.mu // guarded by db.mu
inUse bool inUse bool
@ -342,10 +342,10 @@ func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err) dc.db.putConn(dc, err)
} }
func (dc *driverConn) removeOpenStmt(si driver.Stmt) { func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock() dc.Lock()
defer dc.Unlock() defer dc.Unlock()
delete(dc.openStmt, si) delete(dc.openStmt, ds)
} }
func (dc *driverConn) expired(timeout time.Duration) bool { func (dc *driverConn) expired(timeout time.Duration) bool {
@ -355,28 +355,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc()) return dc.createdAt.Add(timeout).Before(nowFunc())
} }
func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) { func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query) si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err == nil { if err != nil {
// Track each driverConn's open statements, so we can close them return nil, err
// before closing the conn.
//
// TODO(bradfitz): let drivers opt out of caring about
// stmt closes if the conn is about to close anyway? For now
// do the safe thing, in case stmts need to be closed.
//
// TODO(bradfitz): after Go 1.2, closing driver.Stmts
// should be moved to driverStmt, using unique
// *driverStmts everywhere (including from
// *Stmt.connStmt, instead of returning a
// driver.Stmt), using driverStmt as a pointer
// everywhere, and making it a finalCloser.
if dc.openStmt == nil {
dc.openStmt = make(map[driver.Stmt]bool)
}
dc.openStmt[si] = true
} }
return si, err
// Track each driverConn's open statements, so we can close them
// before closing the conn.
//
// Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
if dc.openStmt == nil {
dc.openStmt = make(map[*driverStmt]bool)
}
ds := &driverStmt{Locker: dc, si: si}
dc.openStmt[ds] = true
return ds, nil
} }
// the dc.db's Mutex is held. // the dc.db's Mutex is held.
@ -409,17 +404,24 @@ func (dc *driverConn) Close() error {
func (dc *driverConn) finalClose() error { func (dc *driverConn) finalClose() error {
var err error var err error
withLock(dc, func() {
defer func() { // In case si.Close panics.
dc.openStmt = nil
dc.finalClosed = true
err = dc.ci.Close()
dc.ci = nil
}()
for si := range dc.openStmt { // Each *driverStmt has a lock to the dc. Copy the list out of the dc
si.Close() // before calling close on each stmt.
var openStmt []*driverStmt
withLock(dc, func() {
openStmt = make([]*driverStmt, 0, len(dc.openStmt))
for ds := range dc.openStmt {
openStmt = append(openStmt, ds)
} }
dc.openStmt = nil
})
for _, ds := range openStmt {
ds.Close()
}
withLock(dc, func() {
dc.finalClosed = true
err = dc.ci.Close()
dc.ci = nil
}) })
dc.db.mu.Lock() dc.db.mu.Lock()
@ -437,12 +439,21 @@ func (dc *driverConn) finalClose() error {
type driverStmt struct { type driverStmt struct {
sync.Locker // the *driverConn sync.Locker // the *driverConn
si driver.Stmt si driver.Stmt
closed bool
closeErr error // return value of previous Close call
} }
// Close ensures dirver.Stmt is only closed once any always returns the same
// result.
func (ds *driverStmt) Close() error { func (ds *driverStmt) Close() error {
ds.Lock() ds.Lock()
defer ds.Unlock() defer ds.Unlock()
return ds.si.Close() if ds.closed {
return ds.closeErr
}
ds.closed = true
ds.closeErr = ds.si.Close()
return ds.closeErr
} }
// depSet is a finalCloser's outstanding dependencies // depSet is a finalCloser's outstanding dependencies
@ -933,21 +944,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
// putConnHook is a hook for testing. // putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn) var putConnHook func(*DB, *driverConn)
// noteUnusedDriverStatement notes that si is no longer used and should // noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is // be closed whenever possible (when c is next not in use), unless c is
// already closed. // already closed.
func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
if c.inUse { if c.inUse {
c.onPut = append(c.onPut, func() { c.onPut = append(c.onPut, func() {
si.Close() ds.Close()
}) })
} else { } else {
c.Lock() c.Lock()
defer c.Unlock() fc := c.finalClosed
if !c.finalClosed { c.Unlock()
si.Close() if !fc {
ds.Close()
} }
} }
} }
@ -1084,9 +1096,9 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
if err != nil { if err != nil {
return nil, err return nil, err
} }
var si driver.Stmt var ds *driverStmt
withLock(dc, func() { withLock(dc, func() {
si, err = dc.prepareLocked(ctx, query) ds, err = dc.prepareLocked(ctx, query)
}) })
if err != nil { if err != nil {
db.putConn(dc, err) db.putConn(dc, err)
@ -1095,7 +1107,7 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
stmt := &Stmt{ stmt := &Stmt{
db: db, db: db,
query: query, query: query,
css: []connStmt{{dc, si}}, css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed), lastNumClosed: atomic.LoadUint64(&db.numClosed),
} }
db.addDep(stmt, stmt) db.addDep(stmt, stmt)
@ -1160,8 +1172,9 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer withLock(dc, func() { si.Close() }) ds := &driverStmt{Locker: dc, si: si}
return resultFromStatement(ctx, driverStmt{dc, si}, args...) defer ds.Close()
return resultFromStatement(ctx, ds, args...)
} }
// QueryContext executes a query that returns rows, typically a SELECT. // QueryContext executes a query that returns rows, typically a SELECT.
@ -1236,12 +1249,10 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
return nil, err return nil, err
} }
ds := driverStmt{dc, si} ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, ds, args...) rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil { if err != nil {
withLock(dc, func() { ds.Close()
si.Close()
})
releaseConn(err) releaseConn(err)
return nil, err return nil, err
} }
@ -1252,7 +1263,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
dc: dc, dc: dc,
releaseConn: releaseConn, releaseConn: releaseConn,
rowsi: rowsi, rowsi: rowsi,
closeStmt: si, closeStmt: ds,
} }
rows.initContextClose(ctx) rows.initContextClose(ctx)
return rows, nil return rows, nil
@ -1476,7 +1487,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
stmt := &Stmt{ stmt := &Stmt{
db: tx.db, db: tx.db,
tx: tx, tx: tx,
txsi: &driverStmt{ txds: &driverStmt{
Locker: dc, Locker: dc,
si: si, si: si,
}, },
@ -1530,7 +1541,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
txs := &Stmt{ txs := &Stmt{
db: tx.db, db: tx.db,
tx: tx, tx: tx,
txsi: &driverStmt{ txds: &driverStmt{
Locker: dc, Locker: dc,
si: si, si: si,
}, },
@ -1591,9 +1602,10 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer withLock(dc, func() { si.Close() }) ds := &driverStmt{Locker: dc, si: si}
defer ds.Close()
return resultFromStatement(ctx, driverStmt{dc, si}, args...) return resultFromStatement(ctx, ds, args...)
} }
// Exec executes a query that doesn't return rows. // Exec executes a query that doesn't return rows.
@ -1635,7 +1647,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
// connStmt is a prepared statement on a particular connection. // connStmt is a prepared statement on a particular connection.
type connStmt struct { type connStmt struct {
dc *driverConn dc *driverConn
si driver.Stmt ds *driverStmt
} }
// Stmt is a prepared statement. // Stmt is a prepared statement.
@ -1650,7 +1662,7 @@ type Stmt struct {
// If in a transaction, else both nil: // If in a transaction, else both nil:
tx *Tx tx *Tx
txsi *driverStmt txds *driverStmt
mu sync.Mutex // protects the rest of the fields mu sync.Mutex // protects the rest of the fields
closed bool closed bool
@ -1674,7 +1686,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
var res Result var res Result
for i := 0; i < maxBadConnRetries; i++ { for i := 0; i < maxBadConnRetries; i++ {
dc, releaseConn, si, err := s.connStmt(ctx) _, releaseConn, ds, err := s.connStmt(ctx)
if err != nil { if err != nil {
if err == driver.ErrBadConn { if err == driver.ErrBadConn {
continue continue
@ -1682,7 +1694,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
return nil, err return nil, err
} }
res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...) res, err = resultFromStatement(ctx, ds, args...)
releaseConn(err) releaseConn(err)
if err != driver.ErrBadConn { if err != driver.ErrBadConn {
return res, err return res, err
@ -1697,13 +1709,13 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return s.ExecContext(context.Background(), args...) return s.ExecContext(context.Background(), args...)
} }
func driverNumInput(ds driverStmt) int { func driverNumInput(ds *driverStmt) int {
ds.Lock() ds.Lock()
defer ds.Unlock() // in case NumInput panics defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput() return ds.si.NumInput()
} }
func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) { func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds) want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
@ -1713,7 +1725,7 @@ func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
} }
dargs, err := driverArgs(&ds, args) dargs, err := driverArgs(ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1757,7 +1769,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the // connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a // statement, a function to call to release the connection, and a
// statement bound to that connection. // statement bound to that connection.
func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil { if err = s.stickyErr; err != nil {
return return
} }
@ -1777,7 +1789,7 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
return return
} }
releaseConn = func(error) {} releaseConn = func(error) {}
return ci, releaseConn, s.txsi.si, nil return ci, releaseConn, s.txds, nil
} }
s.removeClosedStmtLocked() s.removeClosedStmtLocked()
@ -1792,25 +1804,25 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
for _, v := range s.css { for _, v := range s.css {
if v.dc == dc { if v.dc == dc {
s.mu.Unlock() s.mu.Unlock()
return dc, dc.releaseConn, v.si, nil return dc, dc.releaseConn, v.ds, nil
} }
} }
s.mu.Unlock() s.mu.Unlock()
// No luck; we need to prepare the statement on this connection // No luck; we need to prepare the statement on this connection
withLock(dc, func() { withLock(dc, func() {
si, err = dc.prepareLocked(ctx, s.query) ds, err = dc.prepareLocked(ctx, s.query)
}) })
if err != nil { if err != nil {
s.db.putConn(dc, err) s.db.putConn(dc, err)
return nil, nil, nil, err return nil, nil, nil, err
} }
s.mu.Lock() s.mu.Lock()
cs := connStmt{dc, si} cs := connStmt{dc, ds}
s.css = append(s.css, cs) s.css = append(s.css, cs)
s.mu.Unlock() s.mu.Unlock()
return dc, dc.releaseConn, si, nil return dc, dc.releaseConn, ds, nil
} }
// QueryContext executes a prepared query statement with the given arguments // QueryContext executes a prepared query statement with the given arguments
@ -1821,7 +1833,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
var rowsi driver.Rows var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ { for i := 0; i < maxBadConnRetries; i++ {
dc, releaseConn, si, err := s.connStmt(ctx) dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil { if err != nil {
if err == driver.ErrBadConn { if err == driver.ErrBadConn {
continue continue
@ -1829,7 +1841,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
return nil, err return nil, err
} }
rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...) rowsi, err = rowsiFromStatement(ctx, ds, args...)
if err == nil { if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed // Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn. // with releaseConn.
@ -1861,7 +1873,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...) return s.QueryContext(context.Background(), args...)
} }
func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) { func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int var want int
withLock(ds, func() { withLock(ds, func() {
want = ds.si.NumInput() want = ds.si.NumInput()
@ -1874,7 +1886,7 @@ func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{})
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
} }
dargs, err := driverArgs(&ds, args) dargs, err := driverArgs(ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1937,12 +1949,11 @@ func (s *Stmt) Close() error {
return nil return nil
} }
s.closed = true s.closed = true
s.mu.Unlock()
if s.tx != nil { if s.tx != nil {
defer s.mu.Unlock() return s.txds.Close()
return s.txsi.Close()
} }
s.mu.Unlock()
return s.db.removeDep(s, s) return s.db.removeDep(s, s)
} }
@ -1952,8 +1963,8 @@ func (s *Stmt) finalClose() error {
defer s.mu.Unlock() defer s.mu.Unlock()
if s.css != nil { if s.css != nil {
for _, v := range s.css { for _, v := range s.css {
s.db.noteUnusedDriverStatement(v.dc, v.si) s.db.noteUnusedDriverStatement(v.dc, v.ds)
v.dc.removeOpenStmt(v.si) v.dc.removeOpenStmt(v.ds)
} }
s.css = nil s.css = nil
} }
@ -1985,7 +1996,7 @@ type Rows struct {
ctxClose chan struct{} // closed when Rows is closed, may be null. ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value lastcols []driver.Value
lasterr error // non-nil only if closed is true lasterr error // non-nil only if closed is true
closeStmt driver.Stmt // if non-nil, statement to Close on close closeStmt *driverStmt // if non-nil, statement to Close on close
} }
func (rs *Rows) initContextClose(ctx context.Context) { func (rs *Rows) initContextClose(ctx context.Context) {

View file

@ -672,7 +672,7 @@ func TestStatementClose(t *testing.T) {
msg string msg string
}{ }{
{&Stmt{stickyErr: want}, "stickyErr not propagated"}, {&Stmt{stickyErr: want}, "stickyErr not propagated"},
{&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
} }
for _, test := range tests { for _, test := range tests {
if err := test.stmt.Close(); err != want { if err := test.stmt.Close(); err != want {