database/sql: Close per-tx prepared statements when the associated tx ends

LGTM=bradfitz
R=golang-codereviews, bradfitz, mattn.jp
CC=golang-codereviews
https://golang.org/cl/131650043
This commit is contained in:
Marko Tiikkaja 2014-09-22 09:19:27 -04:00 committed by Brad Fitzpatrick
parent 93e5cc224e
commit 5f739d9dcd
2 changed files with 67 additions and 5 deletions

View file

@ -1043,6 +1043,13 @@ type Tx struct {
// or Rollback. once done, all operations fail with
// ErrTxDone.
done bool
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
stmts struct {
sync.Mutex
v []*Stmt
}
}
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
@ -1064,6 +1071,15 @@ func (tx *Tx) grabConn() (*driverConn, error) {
return tx.dc, nil
}
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
tx.stmts.Unlock()
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
if tx.done {
@ -1071,8 +1087,12 @@ func (tx *Tx) Commit() error {
}
defer tx.close()
tx.dc.Lock()
defer tx.dc.Unlock()
return tx.txi.Commit()
err := tx.txi.Commit()
tx.dc.Unlock()
if err != driver.ErrBadConn {
tx.closePrepared()
}
return err
}
// Rollback aborts the transaction.
@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error {
}
defer tx.close()
tx.dc.Lock()
defer tx.dc.Unlock()
return tx.txi.Rollback()
err := tx.txi.Rollback()
tx.dc.Unlock()
if err != driver.ErrBadConn {
tx.closePrepared()
}
return err
}
// Prepare creates a prepared statement for use within a transaction.
@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
},
query: query,
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
return stmt, nil
}
@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
dc.Lock()
si, err := dc.ci.Prepare(stmt.query)
dc.Unlock()
return &Stmt{
txs := &Stmt{
db: tx.db,
tx: tx,
txsi: &driverStmt{
@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
query: stmt.query,
stickyErr: err,
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx.stmts.Unlock()
return txs
}
// Exec executes a query that doesn't return rows.