database/sql: allow drivers to support custom arg types

Previously all arguments were passed through driver.IsValid.
This checked arguments against a few fundamental go types and
prevented others from being passed in as arguments.

The new interface driver.NamedValueChecker may be implemented
by both driver.Stmt and driver.Conn. This allows
this new interface to completely supersede the
driver.ColumnConverter interface as it can be used for
checking arguments known to a prepared statement and
arbitrary query arguments. The NamedValueChecker may be
skipped with driver.ErrSkip after all special cases are
exhausted to use the default argument converter.

In addition if driver.ErrRemoveArgument is returned
the argument will not be passed to the query at all,
useful for passing in driver specific per-query options.

Add a canonical Out argument wrapper to be passed
to OUTPUT parameters. This will unify checks that need to
be written in the NameValueChecker.

The statement number check is also moved to the argument
converter so the NamedValueChecker may remove arguments
passed to the query.

Fixes #13567
Fixes #18079
Updates #18417
Updates #17834
Updates #16235
Updates #13067
Updates #19797

Change-Id: I89088bd9cca4596a48bba37bfd20d987453ef237
Reviewed-on: https://go-review.googlesource.com/38533
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Daniel Theophanes 2017-03-23 13:17:59 -07:00
parent 9044cb04f2
commit a9bf3b2e19
6 changed files with 360 additions and 90 deletions

View file

@ -3191,6 +3191,131 @@ func TestConnectionLeak(t *testing.T) {
wg.Wait()
}
type nvcDriver struct {
fakeDriver
skipNamedValueCheck bool
}
func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
c, err := d.fakeDriver.Open(dsn)
fc := c.(*fakeConn)
fc.db.allowAny = true
return &nvcConn{fc, d.skipNamedValueCheck}, err
}
type nvcConn struct {
*fakeConn
skipNamedValueCheck bool
}
type decimal struct {
value int
}
type doNotInclude struct{}
var _ driver.NamedValueChecker = &nvcConn{}
func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
if c.skipNamedValueCheck {
return driver.ErrSkip
}
switch v := nv.Value.(type) {
default:
return driver.ErrSkip
case Out:
switch ov := v.Dest.(type) {
default:
return errors.New("unkown NameValueCheck OUTPUT type")
case *string:
*ov = "from-server"
nv.Value = "OUT:*string"
}
return nil
case decimal, []int64:
return nil
case doNotInclude:
return driver.ErrRemoveArgument
}
}
func TestNamedValueChecker(t *testing.T) {
Register("NamedValueCheck", &nvcDriver{})
db, err := Open("NamedValueCheck", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = db.ExecContext(ctx, "WIPE")
if err != nil {
t.Fatal("exec wipe", err)
}
_, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
if err != nil {
t.Fatal("exec create", err)
}
o1 := ""
_, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimal{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
if err != nil {
t.Fatal("exec insert", err)
}
var (
str1 string
dec1 decimal
arr1 []int64
)
err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
if err != nil {
t.Fatal("select", err)
}
list := []struct{ got, want interface{} }{
{o1, "from-server"},
{dec1, decimal{123}},
{str1, "hello"},
{arr1, []int64{42, 128, 707}},
}
for index, item := range list {
if !reflect.DeepEqual(item.got, item.want) {
t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
}
}
}
func TestNamedValueCheckerSkip(t *testing.T) {
Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
db, err := Open("NamedValueCheckSkip", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = db.ExecContext(ctx, "WIPE")
if err != nil {
t.Fatal("exec wipe", err)
}
_, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
if err != nil {
t.Fatal("exec create", err)
}
_, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimal{123}))
if err == nil {
t.Fatalf("expected error with bad argument, got %v", err)
}
}
// badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics.
type badConn struct{}