database/sql: add support for multiple result sets

Many database systems allow returning multiple result sets
in a single query. This can be useful when dealing with many
intermediate results on the server and there is a need
to return more then one arity of data to the client.

Fixes #12382

Change-Id: I480a9ac6dadfc8743e0ba8b6d868ccf8442a9ca1
Reviewed-on: https://go-review.googlesource.com/30592
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Daniel Theophanes 2016-10-06 11:06:21 -07:00 committed by Brad Fitzpatrick
parent be48aa3f3a
commit 86b2f29676
5 changed files with 335 additions and 88 deletions

View file

@ -36,6 +36,8 @@ var _ = log.Printf
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
// Multiple of these can be combined when separated with a semicolon.
//
// When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only.
type fakeDriver struct {
@ -109,6 +111,8 @@ type fakeStmt struct {
table string
panic string
next *fakeStmt // used for returning multiple results.
closed bool
colName []string // used by CREATE, INSERT, SELECT (selected columns)
@ -377,7 +381,7 @@ func errf(msg string, args ...interface{}) error {
// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
// (note that where columns must always contain ? marks,
// just a limitation for fakedb)
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 3 {
stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
@ -411,7 +415,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
// parts are table|col=type,col2=type2
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
@ -430,7 +434,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
}
// parts are table|col=?,col2=val
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@ -492,38 +496,52 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
return nil, driver.ErrBadConn
}
parts := strings.Split(query, "|")
if len(parts) < 1 {
return nil, errf("empty query")
}
stmt := &fakeStmt{q: query, c: c}
if len(parts) >= 3 && parts[0] == "PANIC" {
stmt.panic = parts[1]
parts = parts[2:]
}
cmd := parts[0]
stmt.cmd = cmd
parts = parts[1:]
var firstStmt, prev *fakeStmt
for _, query := range strings.Split(query, ";") {
parts := strings.Split(query, "|")
if len(parts) < 1 {
return nil, errf("empty query")
}
stmt := &fakeStmt{q: query, c: c}
if firstStmt == nil {
firstStmt = stmt
}
if len(parts) >= 3 && parts[0] == "PANIC" {
stmt.panic = parts[1]
parts = parts[2:]
}
cmd := parts[0]
stmt.cmd = cmd
parts = parts[1:]
c.incrStat(&c.stmtsMade)
switch cmd {
case "WIPE":
// Nothing
case "SELECT":
return c.prepareSelect(stmt, parts)
case "CREATE":
return c.prepareCreate(stmt, parts)
case "INSERT":
return c.prepareInsert(stmt, parts)
case "NOSERT":
// Do all the prep-work like for an INSERT but don't actually insert the row.
// Used for some of the concurrent tests.
return c.prepareInsert(stmt, parts)
default:
stmt.Close()
return nil, errf("unsupported command type %q", cmd)
c.incrStat(&c.stmtsMade)
var err error
switch cmd {
case "WIPE":
// Nothing
case "SELECT":
stmt, err = c.prepareSelect(stmt, parts)
case "CREATE":
stmt, err = c.prepareCreate(stmt, parts)
case "INSERT":
stmt, err = c.prepareInsert(stmt, parts)
case "NOSERT":
// Do all the prep-work like for an INSERT but don't actually insert the row.
// Used for some of the concurrent tests.
stmt, err = c.prepareInsert(stmt, parts)
default:
stmt.Close()
return nil, errf("unsupported command type %q", cmd)
}
if err != nil {
return nil, err
}
if prev != nil {
prev.next = stmt
}
prev = stmt
}
return stmt, nil
return firstStmt, nil
}
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
@ -550,6 +568,9 @@ func (s *fakeStmt) Close() error {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
}
if s.next != nil {
s.next.Close()
}
return nil
}
@ -667,64 +688,80 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("error in pkg db; should only get here if size is correct")
}
db.mu.Lock()
t, ok := db.table(s.table)
db.mu.Unlock()
if !ok {
return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
setMRows := make([][]*row, 0, 1)
setColumns := make([][]string, 0, 1)
if s.table == "magicquery" {
if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
if args[0] == "sleep" {
time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
for {
db.mu.Lock()
t, ok := db.table(s.table)
db.mu.Unlock()
if !ok {
return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
if s.table == "magicquery" {
if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
if args[0] == "sleep" {
time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
}
}
}
}
t.mu.Lock()
defer t.mu.Unlock()
t.mu.Lock()
colIdx := make(map[string]int) // select column name -> column index in table
for _, name := range s.colName {
idx := t.columnIndex(name)
if idx == -1 {
return nil, fmt.Errorf("fakedb: unknown column name %q", name)
}
colIdx[name] = idx
}
mrows := []*row{}
rows:
for _, trow := range t.rows {
// Process the where clause, skipping non-match rows. This is lazy
// and just uses fmt.Sprintf("%v") to test equality. Good enough
// for test code.
for widx, wcol := range s.whereCol {
idx := t.columnIndex(wcol)
colIdx := make(map[string]int) // select column name -> column index in table
for _, name := range s.colName {
idx := t.columnIndex(name)
if idx == -1 {
return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
}
tcol := trow.cols[idx]
if bs, ok := tcol.([]byte); ok {
// lazy hack to avoid sprintf %v on a []byte
tcol = string(bs)
}
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
continue rows
t.mu.Unlock()
return nil, fmt.Errorf("fakedb: unknown column name %q", name)
}
colIdx[name] = idx
}
mrow := &row{cols: make([]interface{}, len(s.colName))}
for seli, name := range s.colName {
mrow.cols[seli] = trow.cols[colIdx[name]]
mrows := []*row{}
rows:
for _, trow := range t.rows {
// Process the where clause, skipping non-match rows. This is lazy
// and just uses fmt.Sprintf("%v") to test equality. Good enough
// for test code.
for widx, wcol := range s.whereCol {
idx := t.columnIndex(wcol)
if idx == -1 {
t.mu.Unlock()
return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
}
tcol := trow.cols[idx]
if bs, ok := tcol.([]byte); ok {
// lazy hack to avoid sprintf %v on a []byte
tcol = string(bs)
}
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
continue rows
}
}
mrow := &row{cols: make([]interface{}, len(s.colName))}
for seli, name := range s.colName {
mrow.cols[seli] = trow.cols[colIdx[name]]
}
mrows = append(mrows, mrow)
}
mrows = append(mrows, mrow)
t.mu.Unlock()
setMRows = append(setMRows, mrows)
setColumns = append(setColumns, s.colName)
if s.next == nil {
break
}
s = s.next
}
cursor := &rowsCursor{
pos: -1,
rows: mrows,
cols: s.colName,
posRow: -1,
rows: setMRows,
cols: setColumns,
errPos: -1,
}
return cursor, nil
@ -760,9 +797,10 @@ func (tx *fakeTx) Rollback() error {
}
type rowsCursor struct {
cols []string
pos int
rows []*row
cols [][]string
posSet int
posRow int
rows [][]*row
closed bool
// errPos and err are for making Next return early with error.
@ -786,7 +824,7 @@ func (rc *rowsCursor) Close() error {
}
func (rc *rowsCursor) Columns() []string {
return rc.cols
return rc.cols[rc.posSet]
}
var rowsCursorNextHook func(dest []driver.Value) error
@ -799,14 +837,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed {
return errors.New("fakedb: cursor is closed")
}
rc.pos++
if rc.pos == rc.errPos {
rc.posRow++
if rc.posRow == rc.errPos {
return rc.err
}
if rc.pos >= len(rc.rows) {
if rc.posRow >= len(rc.rows[rc.posSet]) {
return io.EOF // per interface spec
}
for i, v := range rc.rows[rc.pos].cols {
for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
// TODO(bradfitz): convert to subset types? naah, I
// think the subset types should only be input to
// driver, but the sql package should be able to handle
@ -831,6 +869,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
return nil
}
func (rc *rowsCursor) HasNextResultSet() bool {
return rc.posSet < len(rc.rows)-1
}
func (rc *rowsCursor) NextResultSet() error {
if rc.HasNextResultSet() {
rc.posSet++
rc.posRow = -1
return nil
}
return io.EOF // Per interface spec.
}
// fakeDriverString is like driver.String, but indirects pointers like
// DefaultValueConverter.
//