database/sql: correct level of write to same var for race detector

Rather then write to the same variable per fakeConn, write to either
fakeConn or rowsCursor.

Fixes #20646

Change-Id: Ifc79f989bd1606b8e3ebecb1e7844cce3ad06e17
Reviewed-on: https://go-review.googlesource.com/45393
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Daniel Theophanes 2017-06-12 11:28:18 -07:00
parent 15b1e4fb94
commit 5c37397a47
2 changed files with 51 additions and 34 deletions

View file

@ -84,6 +84,11 @@ type row struct {
cols []interface{} // must be same size as its table colname + coltype cols []interface{} // must be same size as its table colname + coltype
} }
type memToucher interface {
// touchMem reads & writes some memory, to help find data races.
touchMem()
}
type fakeConn struct { type fakeConn struct {
db *fakeDB // where to return ourselves to db *fakeDB // where to return ourselves to
@ -104,6 +109,10 @@ type fakeConn struct {
stickyBad bool stickyBad bool
} }
func (c *fakeConn) touchMem() {
c.line++
}
func (c *fakeConn) incrStat(v *int) { func (c *fakeConn) incrStat(v *int) {
c.mu.Lock() c.mu.Lock()
*v++ *v++
@ -121,6 +130,7 @@ type boundCol struct {
} }
type fakeStmt struct { type fakeStmt struct {
memToucher
c *fakeConn c *fakeConn
q string // just for debugging q string // just for debugging
@ -303,7 +313,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
if c.currTx != nil { if c.currTx != nil {
return nil, errors.New("already in a transaction") return nil, errors.New("already in a transaction")
} }
c.line++ c.touchMem()
c.currTx = &fakeTx{c: c} c.currTx = &fakeTx{c: c}
return c.currTx, nil return c.currTx, nil
} }
@ -345,7 +355,7 @@ func (c *fakeConn) Close() (err error) {
drv.mu.Unlock() drv.mu.Unlock()
} }
}() }()
c.line++ c.touchMem()
if c.currTx != nil { if c.currTx != nil {
return errors.New("can't close fakeConn; in a Transaction") return errors.New("can't close fakeConn; in a Transaction")
} }
@ -533,14 +543,14 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
c.line++ c.touchMem()
var firstStmt, prev *fakeStmt var firstStmt, prev *fakeStmt
for _, query := range strings.Split(query, ";") { for _, query := range strings.Split(query, ";") {
parts := strings.Split(query, "|") parts := strings.Split(query, "|")
if len(parts) < 1 { if len(parts) < 1 {
return nil, errf("empty query") return nil, errf("empty query")
} }
stmt := &fakeStmt{q: query, c: c} stmt := &fakeStmt{q: query, c: c, memToucher: c}
if firstStmt == nil { if firstStmt == nil {
firstStmt = stmt firstStmt = stmt
} }
@ -622,7 +632,7 @@ func (s *fakeStmt) Close() error {
if s.c.db == nil { if s.c.db == nil {
panic("in fakeStmt.Close, conn's db is nil (already closed)") panic("in fakeStmt.Close, conn's db is nil (already closed)")
} }
s.c.line++ s.touchMem()
if !s.closed { if !s.closed {
s.c.incrStat(&s.c.stmtsClosed) s.c.incrStat(&s.c.stmtsClosed)
s.closed = true s.closed = true
@ -657,7 +667,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.c.line++ s.touchMem()
if s.wait > 0 { if s.wait > 0 {
time.Sleep(s.wait) time.Sleep(s.wait)
@ -770,7 +780,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
return nil, err return nil, err
} }
s.c.line++ s.touchMem()
db := s.c.db db := s.c.db
if len(args) != s.placeholders { if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct") panic("error in pkg db; should only get here if size is correct")
@ -866,7 +876,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
} }
cursor := &rowsCursor{ cursor := &rowsCursor{
c: s.c, parentMem: s.c,
posRow: -1, posRow: -1,
rows: setMRows, rows: setMRows,
cols: setColumns, cols: setColumns,
@ -891,7 +901,7 @@ func (tx *fakeTx) Commit() error {
if hookCommitBadConn != nil && hookCommitBadConn() { if hookCommitBadConn != nil && hookCommitBadConn() {
return driver.ErrBadConn return driver.ErrBadConn
} }
tx.c.line++ tx.c.touchMem()
return nil return nil
} }
@ -903,12 +913,12 @@ func (tx *fakeTx) Rollback() error {
if hookRollbackBadConn != nil && hookRollbackBadConn() { if hookRollbackBadConn != nil && hookRollbackBadConn() {
return driver.ErrBadConn return driver.ErrBadConn
} }
tx.c.line++ tx.c.touchMem()
return nil return nil
} }
type rowsCursor struct { type rowsCursor struct {
c *fakeConn parentMem memToucher
cols [][]string cols [][]string
colType [][]string colType [][]string
posSet int posSet int
@ -924,6 +934,16 @@ type rowsCursor struct {
// the original slice's first byte address. we clone them // the original slice's first byte address. we clone them
// just so we're able to corrupt them on close. // just so we're able to corrupt them on close.
bytesClone map[*byte][]byte bytesClone map[*byte][]byte
// Every operation writes to line to enable the race detector
// check for data races.
// This is separate from the fakeConn.line to allow for drivers that
// can start multiple queries on the same transaction at the same time.
line int64
}
func (rc *rowsCursor) touchMem() {
rc.line++
} }
func (rc *rowsCursor) Close() error { func (rc *rowsCursor) Close() error {
@ -932,7 +952,8 @@ func (rc *rowsCursor) Close() error {
bs[0] = 255 // first byte corrupted bs[0] = 255 // first byte corrupted
} }
} }
rc.c.line++ rc.touchMem()
rc.parentMem.touchMem()
rc.closed = true rc.closed = true
return nil return nil
} }
@ -955,7 +976,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed { if rc.closed {
return errors.New("fakedb: cursor is closed") return errors.New("fakedb: cursor is closed")
} }
rc.c.line++ rc.touchMem()
rc.posRow++ rc.posRow++
if rc.posRow == rc.errPos { if rc.posRow == rc.errPos {
return rc.err return rc.err
@ -989,12 +1010,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
} }
func (rc *rowsCursor) HasNextResultSet() bool { func (rc *rowsCursor) HasNextResultSet() bool {
rc.c.line++ rc.touchMem()
return rc.posSet < len(rc.rows)-1 return rc.posSet < len(rc.rows)-1
} }
func (rc *rowsCursor) NextResultSet() error { func (rc *rowsCursor) NextResultSet() error {
rc.c.line++ rc.touchMem()
if rc.HasNextResultSet() { if rc.HasNextResultSet() {
rc.posSet++ rc.posSet++
rc.posRow = -1 rc.posRow = -1

View file

@ -2995,9 +2995,8 @@ func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
new(concurrentStmtExecTest), new(concurrentStmtExecTest),
new(concurrentTxQueryTest), new(concurrentTxQueryTest),
new(concurrentTxExecTest), new(concurrentTxExecTest),
// golang.org/issue/20646 new(concurrentTxStmtQueryTest),
// new(concurrentTxStmtQueryTest), new(concurrentTxStmtExecTest),
// new(concurrentTxStmtExecTest),
} }
for _, ct := range c.tests { for _, ct := range c.tests {
ct.init(t, db) ct.init(t, db)
@ -3243,9 +3242,8 @@ func TestConcurrency(t *testing.T) {
{"StmtExec", new(concurrentStmtExecTest)}, {"StmtExec", new(concurrentStmtExecTest)},
{"TxQuery", new(concurrentTxQueryTest)}, {"TxQuery", new(concurrentTxQueryTest)},
{"TxExec", new(concurrentTxExecTest)}, {"TxExec", new(concurrentTxExecTest)},
// golang.org/issue/20646 {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
// {"TxStmtQuery", new(concurrentTxStmtQueryTest)}, {"TxStmtExec", new(concurrentTxStmtExecTest)},
// {"TxStmtExec", new(concurrentTxStmtExecTest)},
{"Random", new(concurrentRandomTest)}, {"Random", new(concurrentRandomTest)},
} }
for _, item := range list { for _, item := range list {
@ -3582,7 +3580,6 @@ func BenchmarkConcurrentTxExec(b *testing.B) {
} }
func BenchmarkConcurrentTxStmtQuery(b *testing.B) { func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
b.Skip("golang.org/issue/20646")
b.ReportAllocs() b.ReportAllocs()
ct := new(concurrentTxStmtQueryTest) ct := new(concurrentTxStmtQueryTest)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -3591,7 +3588,6 @@ func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
} }
func BenchmarkConcurrentTxStmtExec(b *testing.B) { func BenchmarkConcurrentTxStmtExec(b *testing.B) {
b.Skip("golang.org/issue/20646")
b.ReportAllocs() b.ReportAllocs()
ct := new(concurrentTxStmtExecTest) ct := new(concurrentTxStmtExecTest)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {