database/sql: add RowsColumnScanner, expose ConvertAssign

Add a new optional interface for driver.Rows which permits
the driver to scan directly into the user-provided []any.

Export the convertAssign function, to permit drivers
to fall back to the old assignment path.

Using the prior Rows interface, a driver provides values
with a function like:

  func (rows *Rows) Next(dest []Value) error {
    dest[0] = "some value"
    return nil
  }

Using the new RowsColumnScanner interface, the
equivalent is:

  func (rows *Rows) ScanColumn(scanCtx driver.ScanContext, index int, dest any) error {
    return sql.ConvertAssign(scanCtx, dest, "some value")
  }

Fixes #67546

Change-Id: I421f5639a12c78c76d377534b5a82f846a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/777961
Reviewed-by: Austin Clements <austin@google.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Damien Neil 2026-03-24 09:41:49 -07:00
parent 4dde0f6c36
commit 4a38094e42
10 changed files with 328 additions and 15 deletions

8
api/next/67546.txt Normal file
View file

@ -0,0 +1,8 @@
pkg database/sql, func ConvertAssign(driver.ScanContext, interface{}, driver.Value) error #67546
pkg database/sql/driver, type RowsColumnScanner interface { Close, Columns, Next, NextRow, ScanColumn } #67546
pkg database/sql/driver, type RowsColumnScanner interface, Close() error #67546
pkg database/sql/driver, type RowsColumnScanner interface, Columns() []string #67546
pkg database/sql/driver, type RowsColumnScanner interface, Next([]Value) error #67546
pkg database/sql/driver, type RowsColumnScanner interface, NextRow() error #67546
pkg database/sql/driver, type RowsColumnScanner interface, ScanColumn(ScanContext, int, interface{}) error #67546
pkg database/sql/driver, type ScanContext struct #67546

View file

@ -0,0 +1,2 @@
The new [ConvertAssign] function gives database drivers access
to the type conversions performed by [Rows.Scan].

View file

@ -0,0 +1,2 @@
Drivers may implement the new [RowsColumnScanner] interface
to scan directly into user-provided destinations.

View file

@ -9,6 +9,7 @@ package sql
import (
"bytes"
"database/sql/driver"
"database/sql/internal"
"errors"
"fmt"
"reflect"
@ -219,10 +220,25 @@ func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.
// See go.dev/issue/67401.
//
//go:linkname convertAssign
func convertAssign(dest, src any) error {
func convertAssign(dest any, src any) error {
return convertAssignRows(dest, src, nil)
}
// ConvertAssign copies the value in src to the value pointed at by dest.
// See the documentation on [Rows.Scan] for details on conversions.
// dest must be a pointer or must implement [Scanner].
//
// Implementations of [driver.RowsColumnScanner] should pass through
// their [driver.ScanContext] parameter.
// In other cases, pass driver.ScanContext{} as the context.
//
// ConvertAssign is intended for use by driver implementations.
// Most users should not need to use it directly.
func ConvertAssign(scanCtx driver.ScanContext, dest any, src driver.Value) error {
rows, _ := internal.ScanContextValue(internal.ScanContext(scanCtx)).(*Rows)
return convertAssignRows(dest, src, rows)
}
// convertAssignRows copies to dest the value in src, converting it if possible.
// An error is returned if the copy would result in loss of information.
// dest should be a pointer type. If rows is passed in, the rows will
@ -353,8 +369,8 @@ func convertAssignRows(dest, src any, rows *Rows) error {
}
// The driver is returning a cursor the client may iterate over.
case driver.Rows:
switch d := dest.(type) {
case *Rows:
d, ok := dest.(*Rows)
if ok {
if d == nil {
return errNilPtr
}
@ -387,7 +403,7 @@ func convertAssignRows(dest, src any, rows *Rows) error {
parentCancel := rows.cancel
rows.cancel = func() {
// When Rows.cancel is called, the closemu will be locked as well.
// So we can access rs.lasterr.
// So we can access rows.lasterr.
d.close(rows.lasterr)
if parentCancel != nil {
parentCancel()

View file

@ -624,3 +624,14 @@ func TestDecimal(t *testing.T) {
})
}
}
func TestConvertAssignNoContext(t *testing.T) {
const want = 42
var got int64
if err := ConvertAssign(driver.ScanContext{}, &got, want); err != nil {
t.Fatalf("ConvertAssign: %v", err)
}
if got != int64(want) {
t.Errorf("after ConvertAssign: got %v, want %v", got, want)
}
}

View file

@ -40,6 +40,7 @@ package driver
import (
"context"
"database/sql/internal"
"errors"
"reflect"
)
@ -443,6 +444,28 @@ type Rows interface {
Next(dest []Value) error
}
// ScanContext carries state related to the current query
// through a [RowsColumnScanner.ScanColumn] function to [database/sql.ConvertAssign].
type ScanContext internal.ScanContext
// RowsColumnScanner extends the [Rows] interface by providing a way for the driver
// to scan directly into the user-provided destination.
//
// RowsColumnScanner supersedes the [Rows.Next] method.
type RowsColumnScanner interface {
Rows
// NextRow advances to the next row of data.
// It should return io.EOF when there are no more rows.
NextRow() error
// ScanColumn copies the column at the given index in the current row
// into the value pointed to by dest.
//
// The driver may assign a driver.Value to dest using [database/sql.ConvertAssign].
ScanColumn(scanCtx ScanContext, index int, dest any) error
}
// RowsNextResultSet extends the [Rows] interface by providing a way to signal
// the driver to advance to the next result set.
type RowsNextResultSet interface {

View file

@ -79,6 +79,16 @@ func (c *fakeConn) getFakeConn() *fakeConn {
return c
}
func getRowsCursor(rs *Rows) *rowsCursor {
return rs.rowsi.(interface {
getRowsCursor() *rowsCursor
}).getRowsCursor()
}
func (rc *rowsCursor) getRowsCursor() *rowsCursor {
return rc
}
func (c *fakeConnector) Driver() driver.Driver {
return fdriver
}
@ -1242,6 +1252,8 @@ func converterForType(typ string) driver.ValueConverter {
return driver.NotNull{Converter: driver.DefaultParameterConverter}
case "nulluuid":
return driver.Null{Converter: driver.DefaultParameterConverter}
case "nulltable":
return driver.Null{Converter: driver.DefaultParameterConverter}
case "any":
return anyTypeConverter{}
}

View file

@ -0,0 +1,22 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package internal contains internal symbols shared between
// database/sql and database/sql/driver.
package internal
// ScanContext is database/sql/driver.ScanContext.
// We define it here so driver.ScanContext can be opaque to users but
// visible to database/sql.
type ScanContext struct {
v any
}
func NewScanContext(v any) ScanContext {
return ScanContext{v}
}
func ScanContextValue(c ScanContext) any {
return c.v
}

View file

@ -18,6 +18,7 @@ package sql
import (
"context"
"database/sql/driver"
"database/sql/internal"
"errors"
"fmt"
"io"
@ -2963,10 +2964,16 @@ type Rows struct {
// expected not to be called concurrently.
hitEOF bool
// nextCalled is set by the first call to Next.
nextCalled bool
// lastcols is only used in Scan, Next, and NextResultSet which are expected
// not to be called concurrently.
lastcols []driver.Value
// numCols is the number of columns, and is initialized by the first Next call.
numCols int
// raw is a buffer for RawBytes that persists between Scan calls.
// This is used when the driver returns a mismatched type that requires
// a cloning allocation. For example, if the driver returns a *string and
@ -3065,11 +3072,20 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
rs.dc.Lock()
defer rs.dc.Unlock()
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
if !rs.nextCalled {
rs.numCols = len(rs.rowsi.Columns())
rs.nextCalled = true
}
if rscan, ok := rs.rowsi.(driver.RowsColumnScanner); ok {
rs.lasterr = rscan.NextRow()
} else {
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, rs.numCols)
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
}
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
if rs.lasterr != io.EOF {
@ -3117,6 +3133,7 @@ func (rs *Rows) NextResultSet() bool {
return false
}
rs.nextCalled = false
rs.lastcols = nil
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
@ -3386,6 +3403,11 @@ func (rs *Rows) Scan(dest ...any) error {
return err
}
// rowsScanContext is used to pass a *Rows through ScanColumn into ConvertAssign.
type rowsScanContext struct {
rs *Rows
}
func (rs *Rows) scanLocked(dest ...any) error {
if rs.lasterr != nil && rs.lasterr != io.EOF {
return rs.lasterr
@ -3394,11 +3416,26 @@ func (rs *Rows) scanLocked(dest ...any) error {
return rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.lastcols == nil {
if !rs.nextCalled {
return errors.New("sql: Scan called without calling Next")
}
if len(dest) != len(rs.lastcols) {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
if len(dest) != rs.numCols {
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", rs.numCols, len(dest))
}
if rscan, ok := rs.rowsi.(driver.RowsColumnScanner); ok {
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
for i, d := range dest {
scanCtx := driver.ScanContext(internal.NewScanContext(rs))
if err := rscan.ScanColumn(scanCtx, i, d); err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
return nil
}
for i, sv := range rs.lastcols {

View file

@ -70,6 +70,16 @@ Test:
"Validator",
},
},
{
name: "scancols",
connector: &rowsColumnScannerConnector{name: fakeDBName},
features: []string{
"ConnBeginTx",
"NamedValue",
"Validator",
"ScanColumn",
},
},
} {
for _, req := range require {
if !slices.Contains(test.features, req) {
@ -1671,6 +1681,68 @@ func testCursorFake(t *testing.T, db *DB) {
}
}
func TestCursorDoubleRowsPointer(t *testing.T) {
testDatabase(t, testCursorDoubleRowsPointer)
}
func testCursorDoubleRowsPointer(t *testing.T, db *DB) {
exec(t, db, "CREATE|table1|col=string")
exec(t, db, "INSERT|table1|col=value")
exec(t, db, "CREATE|cursor|list=table")
exec(t, db, "INSERT|cursor|list=table1!col")
rows, err := db.QueryContext(t.Context(), `SELECT|cursor|list|`)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
if !rows.Next() {
t.Fatal("no rows")
}
var cursor *Rows
if err := rows.Scan(&cursor); err != nil {
t.Fatal(err)
}
defer cursor.Close()
if !cursor.Next() {
t.Fatal("no child rows")
}
var col string
if err := cursor.Scan(&col); err != nil {
t.Fatal(err)
}
if got, want := col, "value"; got != want {
t.Errorf("read col=%q, want %q", got, want)
}
}
func TestCursorNull(t *testing.T) {
testDatabase(t, testCursorNull)
}
func testCursorNull(t *testing.T, db *DB) {
exec(t, db, "CREATE|cursor|list=nulltable")
exec(t, db, "INSERT|cursor|list=?", nil)
rows, err := db.QueryContext(t.Context(), `SELECT|cursor|list|`)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
if !rows.Next() {
t.Fatal("no rows")
}
var cursor *Rows
if err := rows.Scan(&cursor); err != nil {
t.Fatal(err)
}
if cursor != nil {
t.Errorf("Scan returned cursor, expected nil")
}
}
// TestCursorCancel exercises calling Rows.Close at various places,
// including canceling a cursor (child Rows).
func TestCursorCancel(t *testing.T) {
@ -2949,7 +3021,7 @@ func testRowsImplicitClose(t *testing.T, db *DB) {
}
want, fail := 2, errors.New("fail")
r := rows.rowsi.(*rowsCursor)
r := getRowsCursor(rows)
r.errPos, r.err = want, fail
got := 0
@ -2982,10 +3054,7 @@ func testRowsCloseError(t *testing.T, db *DB) {
}
got := []row{}
rc, ok := rows.rowsi.(*rowsCursor)
if !ok {
t.Fatal("not using *rowsCursor")
}
rc := getRowsCursor(rows)
rc.closeErr = errors.New("rowsCursor: failed to close")
for rows.Next() {
@ -5594,3 +5663,114 @@ func TestNullTypeScanNil(t *testing.T) {
})
}
}
type testStringType struct {
s string
}
func TestQueryRowsScanner(t *testing.T) {
testDatabase(t, testQueryRowsScanner, requireFeature("ScanColumn"))
}
func testQueryRowsScanner(t *testing.T, db *DB) {
populate(t, db, "people")
rows, err := db.Query("SELECT|people|age,name|")
if err != nil {
t.Fatalf("Query: %v", err)
}
defer rows.Close()
type row struct {
age int
name testStringType
}
got := []row{}
for rows.Next() {
var r row
err = rows.Scan(&r.age, &r.name)
if err != nil {
t.Fatalf("Scan: %v", err)
}
got = append(got, r)
}
err = rows.Err()
if err != nil {
t.Fatalf("Err: %v", err)
}
want := []row{
{age: 1, name: testStringType{"Alice"}},
{age: 2, name: testStringType{"Bob"}},
{age: 3, name: testStringType{"Chris"}},
}
if !slices.Equal(got, want) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
}
}
type rowsColumnScannerConnector struct {
fakeConnector
}
func (c *rowsColumnScannerConnector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.fakeConnector.Connect(ctx)
fc := getFakeConn(conn)
return &rowsColumnScannerConn{fc}, err
}
// rowsColumnScannerConn is a Conn with rows that implement RowsColumnScanner.
type rowsColumnScannerConn struct {
*fakeConn
}
func (s *rowsColumnScannerConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
stmt, err := s.fakeConn.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &rowsColumnScannerStmt{stmt.(*fakeStmt)}, nil
}
type rowsColumnScannerStmt struct {
*fakeStmt
}
func (s *rowsColumnScannerStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
rows, err := s.fakeStmt.QueryContext(ctx, args)
if err != nil {
return nil, err
}
return &rowsColumnScannerRows{rowsCursor: rows.(*rowsCursor)}, nil
}
type rowsColumnScannerRows struct {
*rowsCursor
row []driver.Value
}
func (c *rowsColumnScannerRows) NextRow() error {
if c.row == nil {
c.row = make([]driver.Value, len(c.rowsCursor.Columns()))
}
return c.rowsCursor.Next(c.row)
}
func (c *rowsColumnScannerRows) NextResultSet() error {
c.row = nil
return c.rowsCursor.NextResultSet()
}
func (c *rowsColumnScannerRows) ScanColumn(ctx driver.ScanContext, index int, dest any) error {
if index < 0 || index >= len(c.row) {
return fmt.Errorf("index %v out of range", index)
}
switch d := dest.(type) {
case *testStringType:
switch s := c.row[index].(type) {
case string:
d.s = s
return nil
case []byte:
d.s = string(s)
return nil
}
}
return ConvertAssign(ctx, dest, c.row[index])
}