mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
database/sql: allow drivers to support custom arg types
Previously all arguments were passed through driver.IsValid. This checked arguments against a few fundamental go types and prevented others from being passed in as arguments. The new interface driver.NamedValueChecker may be implemented by both driver.Stmt and driver.Conn. This allows this new interface to completely supersede the driver.ColumnConverter interface as it can be used for checking arguments known to a prepared statement and arbitrary query arguments. The NamedValueChecker may be skipped with driver.ErrSkip after all special cases are exhausted to use the default argument converter. In addition if driver.ErrRemoveArgument is returned the argument will not be passed to the query at all, useful for passing in driver specific per-query options. Add a canonical Out argument wrapper to be passed to OUTPUT parameters. This will unify checks that need to be written in the NameValueChecker. The statement number check is also moved to the argument converter so the NamedValueChecker may remove arguments passed to the query. Fixes #13567 Fixes #18079 Updates #18417 Updates #17834 Updates #16235 Updates #13067 Updates #19797 Change-Id: I89088bd9cca4596a48bba37bfd20d987453ef237 Reviewed-on: https://go-review.googlesource.com/38533 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
9044cb04f2
commit
a9bf3b2e19
6 changed files with 360 additions and 90 deletions
|
|
@ -278,6 +278,27 @@ type Scanner interface {
|
|||
Scan(src interface{}) error
|
||||
}
|
||||
|
||||
// Out may be used to retrieve OUTPUT value parameters from stored procedures.
|
||||
//
|
||||
// Not all drivers and databases support OUTPUT value parameters.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// var outArg string
|
||||
// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg}))
|
||||
type Out struct {
|
||||
_Named_Fields_Required struct{}
|
||||
|
||||
// Dest is a pointer to the value that will be set to the result of the
|
||||
// stored procedure's OUTPUT parameter.
|
||||
Dest interface{}
|
||||
|
||||
// In is whether the parameter is an INOUT parameter. If so, the input value to the stored
|
||||
// procedure is the dereferenced value of Dest's pointer, which is then replaced with
|
||||
// the output value.
|
||||
In bool
|
||||
}
|
||||
|
||||
// ErrNoRows is returned by Scan when QueryRow doesn't return a
|
||||
// row. In such a case, QueryRow returns a placeholder *Row value that
|
||||
// defers this error until a Scan.
|
||||
|
|
@ -1206,7 +1227,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
|
|||
}()
|
||||
if execer, ok := dc.ci.(driver.Execer); ok {
|
||||
var dargs []driver.NamedValue
|
||||
dargs, err = driverArgs(nil, args)
|
||||
dargs, err = driverArgs(dc.ci, nil, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -1231,7 +1252,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
|
|||
}
|
||||
ds := &driverStmt{Locker: dc, si: si}
|
||||
defer ds.Close()
|
||||
return resultFromStatement(ctx, ds, args...)
|
||||
return resultFromStatement(ctx, dc.ci, ds, args...)
|
||||
}
|
||||
|
||||
// QueryContext executes a query that returns rows, typically a SELECT.
|
||||
|
|
@ -1270,7 +1291,7 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
|
|||
// The connection gets released by the releaseConn function.
|
||||
func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
|
||||
if queryer, ok := dc.ci.(driver.Queryer); ok {
|
||||
dargs, err := driverArgs(nil, args)
|
||||
dargs, err := driverArgs(dc.ci, nil, args)
|
||||
if err != nil {
|
||||
releaseConn(err)
|
||||
return nil, err
|
||||
|
|
@ -1307,7 +1328,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro
|
|||
}
|
||||
|
||||
ds := &driverStmt{Locker: dc, si: si}
|
||||
rowsi, err := rowsiFromStatement(ctx, ds, args...)
|
||||
rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
|
||||
if err != nil {
|
||||
ds.Close()
|
||||
releaseConn(err)
|
||||
|
|
@ -2009,7 +2030,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
|
|||
|
||||
var res Result
|
||||
for i := 0; i < maxBadConnRetries; i++ {
|
||||
_, releaseConn, ds, err := s.connStmt(ctx)
|
||||
dc, releaseConn, ds, err := s.connStmt(ctx)
|
||||
if err != nil {
|
||||
if err == driver.ErrBadConn {
|
||||
continue
|
||||
|
|
@ -2017,7 +2038,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
|
|||
return nil, err
|
||||
}
|
||||
|
||||
res, err = resultFromStatement(ctx, ds, args...)
|
||||
res, err = resultFromStatement(ctx, dc.ci, ds, args...)
|
||||
releaseConn(err)
|
||||
if err != driver.ErrBadConn {
|
||||
return res, err
|
||||
|
|
@ -2032,23 +2053,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
|
|||
return s.ExecContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
func driverNumInput(ds *driverStmt) int {
|
||||
ds.Lock()
|
||||
defer ds.Unlock() // in case NumInput panics
|
||||
return ds.si.NumInput()
|
||||
}
|
||||
|
||||
func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
|
||||
want := driverNumInput(ds)
|
||||
|
||||
// -1 means the driver doesn't know how to count the number of
|
||||
// placeholders, so we won't sanity check input here and instead let the
|
||||
// driver deal with errors.
|
||||
if want != -1 && len(args) != want {
|
||||
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
|
||||
}
|
||||
|
||||
dargs, err := driverArgs(ds, args)
|
||||
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
|
||||
dargs, err := driverArgs(ci, ds, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -2174,7 +2180,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
|
|||
return nil, err
|
||||
}
|
||||
|
||||
rowsi, err = rowsiFromStatement(ctx, ds, args...)
|
||||
rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
|
||||
if err == nil {
|
||||
// Note: ownership of ci passes to the *Rows, to be freed
|
||||
// with releaseConn.
|
||||
|
|
@ -2211,7 +2217,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
|||
return s.QueryContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
|
||||
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
|
||||
var want int
|
||||
withLock(ds, func() {
|
||||
want = ds.si.NumInput()
|
||||
|
|
@ -2224,7 +2230,7 @@ func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}
|
|||
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
|
||||
}
|
||||
|
||||
dargs, err := driverArgs(ds, args)
|
||||
dargs, err := driverArgs(ci, ds, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue