database/sql: allow drivers to override Scan behavior

Implementing RowsColumnScanner allows the driver
to completely control how values are scanned.

Fixes #67546

Change-Id: Id8e7c3a973479c9665e4476fe2d29e1255aee687
GitHub-Last-Rev: ed0cacaec4
GitHub-Pull-Request: golang/go#67648
Reviewed-on: https://go-review.googlesource.com/c/go/+/588435
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Jack Christensen 2025-05-31 15:27:15 +00:00 committed by Sean Liao
parent 2b804abf07
commit 3dbef65bf3
5 changed files with 124 additions and 1 deletions

5
api/next/67546.txt Normal file
View file

@ -0,0 +1,5 @@
pkg database/sql/driver, type RowsColumnScanner interface { Close, Columns, Next, ScanColumn } #67546
pkg database/sql/driver, type RowsColumnScanner interface, Close() error #67546
pkg database/sql/driver, type RowsColumnScanner interface, Columns() []string #67546
pkg database/sql/driver, type RowsColumnScanner interface, Next([]Value) error #67546
pkg database/sql/driver, type RowsColumnScanner interface, ScanColumn(interface{}, int) error #67546

View file

@ -0,0 +1 @@
A database driver may implement [RowsColumnScanner] to entirely override `Scan` behavior.

View file

@ -515,6 +515,18 @@ type RowsColumnTypePrecisionScale interface {
ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
}
// RowsColumnScanner may be implemented by [Rows]. It allows the driver to completely
// take responsibility for how values are scanned and replace the normal [database/sql].
// scanning path. This allows drivers to directly support types that do not implement
// [database/sql.Scanner].
type RowsColumnScanner interface {
Rows
// ScanColumn copies the column in the current row into the value pointed at by
// dest. It returns [ErrSkip] to fall back to the normal [database/sql] scanning path.
ScanColumn(dest any, index int) error
}
// Tx is a transaction.
type Tx interface {
Commit() error

View file

@ -3396,7 +3396,16 @@ func (rs *Rows) scanLocked(dest ...any) error {
}
for i, sv := range rs.lastcols {
err := convertAssignRows(dest[i], sv, rs)
err := driver.ErrSkip
if rcs, ok := rs.rowsi.(driver.RowsColumnScanner); ok {
err = rcs.ScanColumn(dest[i], i)
}
if err == driver.ErrSkip {
err = convertAssignRows(dest[i], sv, rs)
}
if err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}

View file

@ -4201,6 +4201,102 @@ func TestNamedValueCheckerSkip(t *testing.T) {
}
}
type rcsDriver struct {
fakeDriver
}
func (d *rcsDriver) Open(dsn string) (driver.Conn, error) {
c, err := d.fakeDriver.Open(dsn)
fc := c.(*fakeConn)
fc.db.allowAny = true
return &rcsConn{fc}, err
}
type rcsConn struct {
*fakeConn
}
func (c *rcsConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
stmt, err := c.fakeConn.PrepareContext(ctx, q)
if err != nil {
return stmt, err
}
return &rcsStmt{stmt.(*fakeStmt)}, nil
}
type rcsStmt struct {
*fakeStmt
}
func (s *rcsStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
rows, err := s.fakeStmt.QueryContext(ctx, args)
if err != nil {
return rows, err
}
return &rcsRows{rows.(*rowsCursor)}, nil
}
type rcsRows struct {
*rowsCursor
}
func (r *rcsRows) ScanColumn(dest any, index int) error {
switch d := dest.(type) {
case *int64:
*d = 42
return nil
}
return driver.ErrSkip
}
func TestRowsColumnScanner(t *testing.T) {
Register("RowsColumnScanner", &rcsDriver{})
db, err := Open("RowsColumnScanner", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = db.ExecContext(ctx, "CREATE|t|str=string,n=int64")
if err != nil {
t.Fatal("exec create", err)
}
_, err = db.ExecContext(ctx, "INSERT|t|str=?,n=?", "foo", int64(1))
if err != nil {
t.Fatal("exec insert", err)
}
var (
str string
i64 int64
i int
f64 float64
ui uint
)
err = db.QueryRowContext(ctx, "SELECT|t|str,n,n,n,n|").Scan(&str, &i64, &i, &f64, &ui)
if err != nil {
t.Fatal("select", err)
}
list := []struct{ got, want any }{
{str, "foo"},
{i64, int64(42)},
{i, int(1)},
{f64, float64(1)},
{ui, uint(1)},
}
for index, item := range list {
if !reflect.DeepEqual(item.got, item.want) {
t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
}
}
}
func TestOpenConnector(t *testing.T) {
Register("testctx", &fakeDriverCtx{})
db, err := Open("testctx", "people")