diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index f7919f983ca..011df41fdcd 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1830,7 +1830,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { // // To use an existing prepared statement on this transaction, see Tx.Stmt. func (tx *Tx) Prepare(query string) (*Stmt, error) { - return tx.PrepareContext(context.Background(), query) + return tx.PrepareContext(tx.ctx, query) } // StmtContext returns a transaction-specific prepared statement from @@ -1928,7 +1928,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { // The returned statement operates within the transaction and will be closed // when the transaction has been committed or rolled back. func (tx *Tx) Stmt(stmt *Stmt) *Stmt { - return tx.StmtContext(context.Background(), stmt) + return tx.StmtContext(tx.ctx, stmt) } // ExecContext executes a query that doesn't return rows. @@ -1947,7 +1947,7 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{} // Exec executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - return tx.ExecContext(context.Background(), query, args...) + return tx.ExecContext(tx.ctx, query, args...) } // QueryContext executes a query that returns rows, typically a SELECT. @@ -1965,7 +1965,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{ // Query executes a query that returns rows, typically a SELECT. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - return tx.QueryContext(context.Background(), query, args...) + return tx.QueryContext(tx.ctx, query, args...) } // QueryRowContext executes a query that is expected to return at most one row. @@ -1980,7 +1980,7 @@ func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interfa // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - return tx.QueryRowContext(context.Background(), query, args...) + return tx.QueryRowContext(tx.ctx, query, args...) } // connStmt is a prepared statement on a particular connection. diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 8a477edf1ad..06877a60819 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -439,6 +439,35 @@ func TestTxContextWait(t *testing.T) { waitForFree(t, db, 5*time.Second, 0) } +// TestTxUsesContext tests the transaction behavior when the tx was created by context, +// but for query execution used methods without context +func TestTxUsesContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + // Guard against the context being canceled before BeginTx completes. + if err == context.DeadlineExceeded { + t.Skip("tx context canceled prior to first use") + } + t.Fatal(err) + } + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err = tx.Query("WAIT|1s|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + waitForFree(t, db, 5*time.Second, 0) +} + func TestMultiResultSetQuery(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)