summaryrefslogtreecommitdiff
path: root/libgo/go/database
diff options
context:
space:
mode:
authorIan Lance Taylor <iant@golang.org>2018-01-09 01:23:08 +0000
committerIan Lance Taylor <ian@gcc.gnu.org>2018-01-09 01:23:08 +0000
commit1a2f01efa63036a5104f203a4789e682c0e0915d (patch)
tree373e15778dc8295354584e1f86915ae493b604ff /libgo/go/database
parent8799df67f2dab88f9fda11739c501780a85575e2 (diff)
downloadgcc-1a2f01efa63036a5104f203a4789e682c0e0915d.tar.gz
libgo: update to Go1.10beta1
Update the Go library to the 1.10beta1 release. Requires a few changes to the compiler for modifications to the map runtime code, and to handle some nowritebarrier cases in the runtime. Reviewed-on: https://go-review.googlesource.com/86455 gotools/: * Makefile.am (go_cmd_vet_files): New variable. (go_cmd_buildid_files, go_cmd_test2json_files): New variables. (s-zdefaultcc): Change from constants to functions. (noinst_PROGRAMS): Add vet, buildid, and test2json. (cgo$(EXEEXT)): Link against $(LIBGOTOOL). (vet$(EXEEXT)): New target. (buildid$(EXEEXT)): New target. (test2json$(EXEEXT)): New target. (install-exec-local): Install all $(noinst_PROGRAMS). (uninstall-local): Uninstasll all $(noinst_PROGRAMS). (check-go-tool): Depend on $(noinst_PROGRAMS). Copy down objabi.go. (check-runtime): Depend on $(noinst_PROGRAMS). (check-cgo-test, check-carchive-test): Likewise. (check-vet): New target. (check): Depend on check-vet. Look at cmd_vet-testlog. (.PHONY): Add check-vet. * Makefile.in: Rebuild. From-SVN: r256365
Diffstat (limited to 'libgo/go/database')
-rw-r--r--libgo/go/database/sql/convert.go32
-rw-r--r--libgo/go/database/sql/convert_test.go15
-rw-r--r--libgo/go/database/sql/ctxutil.go19
-rw-r--r--libgo/go/database/sql/driver/driver.go87
-rw-r--r--libgo/go/database/sql/fakedb_test.go73
-rw-r--r--libgo/go/database/sql/sql.go305
-rw-r--r--libgo/go/database/sql/sql_test.go248
7 files changed, 647 insertions, 132 deletions
diff --git a/libgo/go/database/sql/convert.go b/libgo/go/database/sql/convert.go
index 4983181fe75..b79ec3f7b27 100644
--- a/libgo/go/database/sql/convert.go
+++ b/libgo/go/database/sql/convert.go
@@ -12,7 +12,6 @@ import (
"fmt"
"reflect"
"strconv"
- "sync"
"time"
"unicode"
"unicode/utf8"
@@ -38,17 +37,10 @@ func validateNamedValueName(name string) error {
return fmt.Errorf("name %q does not begin with a letter", name)
}
-func driverNumInput(ds *driverStmt) int {
- ds.Lock()
- defer ds.Unlock() // in case NumInput panics
- return ds.si.NumInput()
-}
-
// ccChecker wraps the driver.ColumnConverter and allows it to be used
// as if it were a NamedValueChecker. If the driver ColumnConverter
// is not present then the NamedValueChecker will return driver.ErrSkip.
type ccChecker struct {
- sync.Locker
cci driver.ColumnConverter
want int
}
@@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
// same error.
var err error
arg := nv.Value
- c.Lock()
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
- c.Unlock()
if err != nil {
return err
}
@@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
// Stmt.Query into driver Values.
//
// The statement ds may be nil, if no statement is available.
-func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
+func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args))
// -1 means the driver doesn't know how to count the number of
@@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na
var cc ccChecker
if ds != nil {
si = ds.si
- want = driverNumInput(ds)
- cc.Locker = ds.Locker
+ want = ds.si.NumInput()
cc.want = want
}
@@ -204,7 +193,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na
}
}
- // Check the length of arguments after convertion to allow for omitted
+ // Check the length of arguments after conversion to allow for omitted
// arguments.
if want != -1 && len(nvargs) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
@@ -234,6 +223,12 @@ func convertAssign(dest, src interface{}) error {
}
*d = []byte(s)
return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = append((*d)[:0], s...)
+ return nil
}
case []byte:
switch d := dest.(type) {
@@ -264,6 +259,9 @@ func convertAssign(dest, src interface{}) error {
}
case time.Time:
switch d := dest.(type) {
+ case *time.Time:
+ *d = s
+ return nil
case *string:
*d = s.Format(time.RFC3339Nano)
return nil
@@ -273,6 +271,12 @@ func convertAssign(dest, src interface{}) error {
}
*d = []byte(s.Format(time.RFC3339Nano))
return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
+ return nil
}
case nil:
switch d := dest.(type) {
diff --git a/libgo/go/database/sql/convert_test.go b/libgo/go/database/sql/convert_test.go
index cfe52d7f548..47098c81ec1 100644
--- a/libgo/go/database/sql/convert_test.go
+++ b/libgo/go/database/sql/convert_test.go
@@ -106,6 +106,7 @@ var conversionTests = []conversionTest{
// To RawBytes
{s: nil, d: &scanraw, wantraw: nil},
{s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
+ {s: "string", d: &scanraw, wantraw: RawBytes("string")},
{s: 123, d: &scanraw, wantraw: RawBytes("123")},
{s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
{s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
@@ -114,6 +115,9 @@ var conversionTests = []conversionTest{
{s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
{s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
{s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
+ // time.Time has been placed here to check that the RawBytes slice gets
+ // correctly reset when calling time.Time.AppendFormat.
+ {s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
// Strings to integers
{s: "255", d: &scanuint8, wantuint: 255},
@@ -222,6 +226,12 @@ func TestConversions(t *testing.T) {
if ct.wantstr != "" && ct.wantstr != scanstr {
errf("want string %q, got %q", ct.wantstr, scanstr)
}
+ if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
+ errf("want byte %q, got %q", ct.wantbytes, scanbytes)
+ }
+ if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
+ errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
+ }
if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
errf("want int %d, got %d", ct.wantint, intValue(ct.d))
}
@@ -341,6 +351,7 @@ func TestRawBytesAllocs(t *testing.T) {
{"float32", float32(1.5), "1.5"},
{"float64", float64(64), "64"},
{"bool", false, "false"},
+ {"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
}
buf := make(RawBytes, 10)
@@ -387,7 +398,7 @@ func TestRawBytesAllocs(t *testing.T) {
}
}
-// https://github.com/golang/go/issues/13905
+// https://golang.org/issues/13905
func TestUserDefinedBytes(t *testing.T) {
type userDefinedBytes []byte
var u userDefinedBytes
@@ -470,7 +481,7 @@ func TestDriverArgs(t *testing.T) {
}
for i, tt := range tests {
ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
- got, err := driverArgs(nil, ds, tt.args)
+ got, err := driverArgsConnLocked(nil, ds, tt.args)
if err != nil {
t.Errorf("test[%d]: %v", i, err)
continue
diff --git a/libgo/go/database/sql/ctxutil.go b/libgo/go/database/sql/ctxutil.go
index bd652b54625..af2afd5aa57 100644
--- a/libgo/go/database/sql/ctxutil.go
+++ b/libgo/go/database/sql/ctxutil.go
@@ -26,8 +26,8 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
return si, err
}
-func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
- if execerCtx, is := execer.(driver.ExecerContext); is {
+func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if execerCtx != nil {
return execerCtx.ExecContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
@@ -43,10 +43,9 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda
return execer.Exec(query, dargs)
}
-func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
- if queryerCtx, is := queryer.(driver.QueryerContext); is {
- ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
- return ret, err
+func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if queryerCtx != nil {
+ return queryerCtx.QueryContext(ctx, query, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
@@ -107,10 +106,6 @@ func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (drive
return ciCtx.BeginTx(ctx, dopts)
}
- if ctx.Done() == context.Background().Done() {
- return ci.Begin()
- }
-
if opts != nil {
// Check the transaction level. If the transaction level is non-default
// then return an error here as the BeginTx driver value is not supported.
@@ -125,6 +120,10 @@ func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (drive
}
}
+ if ctx.Done() == nil {
+ return ci.Begin()
+ }
+
txi, err := ci.Begin()
if err == nil {
select {
diff --git a/libgo/go/database/sql/driver/driver.go b/libgo/go/database/sql/driver/driver.go
index 0262ca24ba2..83b2b3f535e 100644
--- a/libgo/go/database/sql/driver/driver.go
+++ b/libgo/go/database/sql/driver/driver.go
@@ -42,6 +42,10 @@ type NamedValue struct {
// Driver is the interface that must be implemented by a database
// driver.
+//
+// Database drivers may implement DriverContext for access
+// to contexts and to parse the name only once for a pool of connections,
+// instead of once per connection.
type Driver interface {
// Open returns a new connection to the database.
// The name is a string in a driver-specific format.
@@ -55,6 +59,47 @@ type Driver interface {
Open(name string) (Conn, error)
}
+// If a Driver implements DriverContext, then sql.DB will call
+// OpenConnector to obtain a Connector and then invoke
+// that Connector's Conn method to obtain each needed connection,
+// instead of invoking the Driver's Open method for each connection.
+// The two-step sequence allows drivers to parse the name just once
+// and also provides access to per-Conn contexts.
+type DriverContext interface {
+ // OpenConnector must parse the name in the same format that Driver.Open
+ // parses the name parameter.
+ OpenConnector(name string) (Connector, error)
+}
+
+// A Connector represents a driver in a fixed configuration
+// and can create any number of equivalent Conns for use
+// by multiple goroutines.
+//
+// A Connector can be passed to sql.OpenDB, to allow drivers
+// to implement their own sql.DB constructors, or returned by
+// DriverContext's OpenConnector method, to allow drivers
+// access to context and to avoid repeated parsing of driver
+// configuration.
+type Connector interface {
+ // Connect returns a connection to the database.
+ // Connect may return a cached connection (one previously
+ // closed), but doing so is unnecessary; the sql package
+ // maintains a pool of idle connections for efficient re-use.
+ //
+ // The provided context.Context is for dialing purposes only
+ // (see net.DialContext) and should not be stored or used for
+ // other purposes.
+ //
+ // The returned connection is only used by one goroutine at a
+ // time.
+ Connect(context.Context) (Conn, error)
+
+ // Driver returns the underlying Driver of the Connector,
+ // mainly to maintain compatibility with the Driver method
+ // on sql.DB.
+ Driver() Driver
+}
+
// ErrSkip may be returned by some optional interfaces' methods to
// indicate at runtime that the fast path is unavailable and the sql
// package should continue as if the optional interface was not
@@ -86,22 +131,23 @@ type Pinger interface {
// Execer is an optional interface that may be implemented by a Conn.
//
-// If a Conn does not implement Execer, the sql package's DB.Exec will
-// first prepare a query, execute the statement, and then close the
-// statement.
+// If a Conn implements neither ExecerContext nor Execer Execer,
+// the sql package's DB.Exec will first prepare a query, execute the statement,
+// and then close the statement.
//
// Exec may return ErrSkip.
//
-// Deprecated: Drivers should implement ExecerContext instead (or additionally).
+// Deprecated: Drivers should implement ExecerContext instead.
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
// ExecerContext is an optional interface that may be implemented by a Conn.
//
-// If a Conn does not implement ExecerContext, the sql package's DB.Exec will
-// first prepare a query, execute the statement, and then close the
-// statement.
+// If a Conn does not implement ExecerContext, the sql package's DB.Exec
+// will fall back to Execer; if the Conn does not implement Execer either,
+// DB.Exec will first prepare a query, execute the statement, and then
+// close the statement.
//
// ExecerContext may return ErrSkip.
//
@@ -112,22 +158,23 @@ type ExecerContext interface {
// Queryer is an optional interface that may be implemented by a Conn.
//
-// If a Conn does not implement Queryer, the sql package's DB.Query will
-// first prepare a query, execute the statement, and then close the
-// statement.
+// If a Conn implements neither QueryerContext nor Queryer,
+// the sql package's DB.Query will first prepare a query, execute the statement,
+// and then close the statement.
//
// Query may return ErrSkip.
//
-// Deprecated: Drivers should implement QueryerContext instead (or additionally).
+// Deprecated: Drivers should implement QueryerContext instead.
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
// QueryerContext is an optional interface that may be implemented by a Conn.
//
-// If a Conn does not implement QueryerContext, the sql package's DB.Query will
-// first prepare a query, execute the statement, and then close the
-// statement.
+// If a Conn does not implement QueryerContext, the sql package's DB.Query
+// will fall back to Queryer; if the Conn does not implement Queryer either,
+// DB.Query will first prepare a query, execute the statement, and then
+// close the statement.
//
// QueryerContext may return ErrSkip.
//
@@ -199,6 +246,18 @@ type ConnBeginTx interface {
BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
}
+// SessionResetter may be implemented by Conn to allow drivers to reset the
+// session state associated with the connection and to signal a bad connection.
+type SessionResetter interface {
+ // ResetSession is called while a connection is in the connection
+ // pool. No queries will run on this connection until this method returns.
+ //
+ // If the connection is bad this should return driver.ErrBadConn to prevent
+ // the connection from being returned to the connection pool. Any other
+ // error will be discarded.
+ ResetSession(ctx context.Context) error
+}
+
// Result is the result of a query execution.
type Result interface {
// LastInsertId returns the database's auto-generated ID
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go
index 4dcd096ca4d..e795412de01 100644
--- a/libgo/go/database/sql/fakedb_test.go
+++ b/libgo/go/database/sql/fakedb_test.go
@@ -55,6 +55,32 @@ type fakeDriver struct {
dbs map[string]*fakeDB
}
+type fakeConnector struct {
+ name string
+
+ waiter func(context.Context)
+}
+
+func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
+ conn, err := fdriver.Open(c.name)
+ conn.(*fakeConn).waiter = c.waiter
+ return conn, err
+}
+
+func (c *fakeConnector) Driver() driver.Driver {
+ return fdriver
+}
+
+type fakeDriverCtx struct {
+ fakeDriver
+}
+
+var _ driver.DriverContext = &fakeDriverCtx{}
+
+func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
+ return &fakeConnector{name: name}, nil
+}
+
type fakeDB struct {
name string
@@ -107,6 +133,16 @@ type fakeConn struct {
// bad connection tests; see isBad()
bad bool
stickyBad bool
+
+ skipDirtySession bool // tests that use Conn should set this to true.
+
+ // dirtySession tests ResetSession, true if a query has executed
+ // until ResetSession is called.
+ dirtySession bool
+
+ // The waiter is called before each query. May be used in place of the "WAIT"
+ // directive.
+ waiter func(context.Context)
}
func (c *fakeConn) touchMem() {
@@ -298,6 +334,9 @@ func (c *fakeConn) isBad() bool {
if c.stickyBad {
return true
} else if c.bad {
+ if c.db == nil {
+ return false
+ }
// alternate between bad conn and not bad conn
c.db.badConn = !c.db.badConn
return c.db.badConn
@@ -306,6 +345,21 @@ func (c *fakeConn) isBad() bool {
}
}
+func (c *fakeConn) isDirtyAndMark() bool {
+ if c.skipDirtySession {
+ return false
+ }
+ if c.currTx != nil {
+ c.dirtySession = true
+ return false
+ }
+ if c.dirtySession {
+ return true
+ }
+ c.dirtySession = true
+ return false
+}
+
func (c *fakeConn) Begin() (driver.Tx, error) {
if c.isBad() {
return nil, driver.ErrBadConn
@@ -337,6 +391,14 @@ func setStrictFakeConnClose(t *testing.T) {
testStrictClose = t
}
+func (c *fakeConn) ResetSession(ctx context.Context) error {
+ c.dirtySession = false
+ if c.isBad() {
+ return driver.ErrBadConn
+ }
+ return nil
+}
+
func (c *fakeConn) Close() (err error) {
drv := fdriver.(*fakeDriver)
defer func() {
@@ -572,6 +634,10 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
stmt.cmd = cmd
parts = parts[1:]
+ if c.waiter != nil {
+ c.waiter(ctx)
+ }
+
if stmt.wait > 0 {
wait := time.NewTimer(stmt.wait)
select {
@@ -662,6 +728,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
return nil, driver.ErrBadConn
}
+ if s.c.isDirtyAndMark() {
+ return nil, errors.New("session is dirty")
+ }
err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
@@ -774,6 +843,9 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
return nil, driver.ErrBadConn
}
+ if s.c.isDirtyAndMark() {
+ return nil, errors.New("session is dirty")
+ }
err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil {
@@ -943,6 +1015,7 @@ type rowsCursor struct {
}
func (rc *rowsCursor) touchMem() {
+ rc.parentMem.touchMem()
rc.line++
}
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go
index c609fe4cc43..9f4fa14534d 100644
--- a/libgo/go/database/sql/sql.go
+++ b/libgo/go/database/sql/sql.go
@@ -285,7 +285,7 @@ type Scanner interface {
// Example usage:
//
// var outArg string
-// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg}))
+// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg}))
type Out struct {
_Named_Fields_Required struct{}
@@ -317,8 +317,7 @@ var ErrNoRows = errors.New("sql: no rows in result set")
// connection is returned to DB's idle connection pool. The pool size
// can be controlled with SetMaxIdleConns.
type DB struct {
- driver driver.Driver
- dsn string
+ connector driver.Connector
// numClosed is an atomic counter which represents a total number of
// closed connections. Stmt.openStmt checks it before cleaning closed
// connections in Stmt.css.
@@ -335,6 +334,7 @@ type DB struct {
// It is closed during db.Close(). The close tells the connectionOpener
// goroutine to exit.
openerCh chan struct{}
+ resetterCh chan *driverConn
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
@@ -342,6 +342,8 @@ type DB struct {
maxOpen int // <= 0 means unlimited
maxLifetime time.Duration // maximum amount of time a connection may be reused
cleanerCh chan struct{}
+
+ stop func() // stop cancels the connection opener and the session resetter.
}
// connReuseStrategy determines how (*DB).conn returns database connections.
@@ -369,6 +371,7 @@ type driverConn struct {
closed bool
finalClosed bool // ci.Close has been called
openStmt map[*driverStmt]bool
+ lastErr error // lastError captures the result of the session resetter.
// guarded by db.mu
inUse bool
@@ -377,7 +380,7 @@ type driverConn struct {
}
func (dc *driverConn) releaseConn(err error) {
- dc.db.putConn(dc, err)
+ dc.db.putConn(dc, err, true)
}
func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
@@ -418,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que
return ds, nil
}
+// resetSession resets the connection session and sets the lastErr
+// that is checked before returning the connection to another query.
+//
+// resetSession assumes that the embedded mutex is locked when the connection
+// was returned to the pool. This unlocks the mutex.
+func (dc *driverConn) resetSession(ctx context.Context) {
+ defer dc.Unlock() // In case of panic.
+ if dc.closed { // Check if the database has been closed.
+ return
+ }
+ dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx)
+}
+
// the dc.db's Mutex is held.
func (dc *driverConn) closeDBLocked() func() error {
dc.Lock()
@@ -575,6 +591,52 @@ func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
// to block until the connectionOpener can satisfy the backlog of requests.
var connectionRequestQueueSize = 1000000
+type dsnConnector struct {
+ dsn string
+ driver driver.Driver
+}
+
+func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
+ return t.driver.Open(t.dsn)
+}
+
+func (t dsnConnector) Driver() driver.Driver {
+ return t.driver
+}
+
+// OpenDB opens a database using a Connector, allowing drivers to
+// bypass a string based data source name.
+//
+// Most users will open a database via a driver-specific connection
+// helper function that returns a *DB. No database drivers are included
+// in the Go standard library. See https://golang.org/s/sqldrivers for
+// a list of third-party drivers.
+//
+// OpenDB may just validate its arguments without creating a connection
+// to the database. To verify that the data source name is valid, call
+// Ping.
+//
+// The returned DB is safe for concurrent use by multiple goroutines
+// and maintains its own pool of idle connections. Thus, the OpenDB
+// function should be called just once. It is rarely necessary to
+// close a DB.
+func OpenDB(c driver.Connector) *DB {
+ ctx, cancel := context.WithCancel(context.Background())
+ db := &DB{
+ connector: c,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ resetterCh: make(chan *driverConn, 50),
+ lastPut: make(map[*driverConn]string),
+ connRequests: make(map[uint64]chan connRequest),
+ stop: cancel,
+ }
+
+ go db.connectionOpener(ctx)
+ go db.connectionResetter(ctx)
+
+ return db
+}
+
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a
// database name and connection information.
@@ -599,15 +661,16 @@ func Open(driverName, dataSourceName string) (*DB, error) {
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
- db := &DB{
- driver: driveri,
- dsn: dataSourceName,
- openerCh: make(chan struct{}, connectionRequestQueueSize),
- lastPut: make(map[*driverConn]string),
- connRequests: make(map[uint64]chan connRequest),
+
+ if driverCtx, ok := driveri.(driver.DriverContext); ok {
+ connector, err := driverCtx.OpenConnector(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ return OpenDB(connector), nil
}
- go db.connectionOpener()
- return db, nil
+
+ return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
@@ -659,7 +722,6 @@ func (db *DB) Close() error {
db.mu.Unlock()
return nil
}
- close(db.openerCh)
if db.cleanerCh != nil {
close(db.cleanerCh)
}
@@ -680,6 +742,7 @@ func (db *DB) Close() error {
err = err1
}
}
+ db.stop()
return err
}
@@ -867,18 +930,40 @@ func (db *DB) maybeOpenNewConnections() {
}
// Runs in a separate goroutine, opens new connections when requested.
-func (db *DB) connectionOpener() {
- for range db.openerCh {
- db.openNewConnection()
+func (db *DB) connectionOpener(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-db.openerCh:
+ db.openNewConnection(ctx)
+ }
+ }
+}
+
+// connectionResetter runs in a separate goroutine to reset connections async
+// to exported API.
+func (db *DB) connectionResetter(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ close(db.resetterCh)
+ for dc := range db.resetterCh {
+ dc.Unlock()
+ }
+ return
+ case dc := <-db.resetterCh:
+ dc.resetSession(ctx)
+ }
}
}
// Open one new connection
-func (db *DB) openNewConnection() {
+func (db *DB) openNewConnection(ctx context.Context) {
// maybeOpenNewConnctions has already executed db.numOpen++ before it sent
// on db.openerCh. This function must execute db.numOpen-- if the
// connection fails or is closed before returning.
- ci, err := db.driver.Open(db.dsn)
+ ci, err := db.connector.Connect(ctx)
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
@@ -953,6 +1038,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
conn.Close()
return nil, driver.ErrBadConn
}
+ // Lock around reading lastErr to ensure the session resetter finished.
+ conn.Lock()
+ err := conn.lastErr
+ conn.Unlock()
+ if err == driver.ErrBadConn {
+ conn.Close()
+ return nil, driver.ErrBadConn
+ }
return conn, nil
}
@@ -978,7 +1071,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
default:
case ret, ok := <-req:
if ok {
- db.putConn(ret.conn, ret.err)
+ db.putConn(ret.conn, ret.err, false)
}
}
return nil, ctx.Err()
@@ -990,13 +1083,24 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
ret.conn.Close()
return nil, driver.ErrBadConn
}
+ if ret.conn == nil {
+ return nil, ret.err
+ }
+ // Lock around reading lastErr to ensure the session resetter finished.
+ ret.conn.Lock()
+ err := ret.conn.lastErr
+ ret.conn.Unlock()
+ if err == driver.ErrBadConn {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
return ret.conn, ret.err
}
}
db.numOpen++ // optimistically
db.mu.Unlock()
- ci, err := db.driver.Open(db.dsn)
+ ci, err := db.connector.Connect(ctx)
if err != nil {
db.mu.Lock()
db.numOpen-- // correct for earlier optimism
@@ -1045,7 +1149,7 @@ const debugGetPut = false
// putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection.
-func (db *DB) putConn(dc *driverConn, err error) {
+func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
db.mu.Lock()
if !dc.inUse {
if debugGetPut {
@@ -1076,11 +1180,40 @@ func (db *DB) putConn(dc *driverConn, err error) {
if putConnHook != nil {
putConnHook(db, dc)
}
+ if db.closed {
+ // Connections do not need to be reset if they will be closed.
+ // Prevents writing to resetterCh after the DB has closed.
+ resetSession = false
+ }
+ if resetSession {
+ if _, resetSession = dc.ci.(driver.SessionResetter); resetSession {
+ // Lock the driverConn here so it isn't released until
+ // the connection is reset.
+ // The lock must be taken before the connection is put into
+ // the pool to prevent it from being taken out before it is reset.
+ dc.Lock()
+ }
+ }
added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
if !added {
+ if resetSession {
+ dc.Unlock()
+ }
dc.Close()
+ return
+ }
+ if !resetSession {
+ return
+ }
+ select {
+ default:
+ // If the resetterCh is blocking then mark the connection
+ // as bad and continue on.
+ dc.lastErr = driver.ErrBadConn
+ dc.Unlock()
+ case db.resetterCh <- dc:
}
}
@@ -1242,15 +1375,20 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
defer func() {
release(err)
}()
- if execer, ok := dc.ci.(driver.Execer); ok {
- var dargs []driver.NamedValue
- dargs, err = driverArgs(dc.ci, nil, args)
- if err != nil {
- return nil, err
- }
+ execerCtx, ok := dc.ci.(driver.ExecerContext)
+ var execer driver.Execer
+ if !ok {
+ execer, ok = dc.ci.(driver.Execer)
+ }
+ if ok {
+ var nvdargs []driver.NamedValue
var resi driver.Result
withLock(dc, func() {
- resi, err = ctxDriverExec(ctx, execer, query, dargs)
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
+ resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
@@ -1309,15 +1447,21 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
// The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
- if queryer, ok := dc.ci.(driver.Queryer); ok {
- dargs, err := driverArgs(dc.ci, nil, args)
- if err != nil {
- releaseConn(err)
- return nil, err
- }
+ queryerCtx, ok := dc.ci.(driver.QueryerContext)
+ var queryer driver.Queryer
+ if !ok {
+ queryer, ok = dc.ci.(driver.Queryer)
+ }
+ if ok {
+ var nvdargs []driver.NamedValue
var rowsi driver.Rows
+ var err error
withLock(dc, func() {
- rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
+ rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
@@ -1454,11 +1598,11 @@ func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error),
// Driver returns the database's underlying driver.
func (db *DB) Driver() driver.Driver {
- return db.driver
+ return db.connector.Driver()
}
// ErrConnDone is returned by any operation that is performed on a connection
-// that has already been committed or rolled back.
+// that has already been returned to the connection pool.
var ErrConnDone = errors.New("database/sql: connection is already closed")
// Conn returns a single connection by either opening a new connection
@@ -1493,9 +1637,9 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) {
type releaseConn func(error)
-// Conn represents a single database session rather a pool of database
-// sessions. Prefer running queries from DB unless there is a specific
-// need for a continuous single database session.
+// Conn represents a single database connection rather than a pool of database
+// connections. Prefer running queries from DB unless there is a specific
+// need for a continuous single database connection.
//
// A Conn must call Close to return the connection to the database pool
// and may do so concurrently with a running query.
@@ -1769,14 +1913,20 @@ func (tx *Tx) closePrepared() {
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
- return ErrTxDone
- }
+ // Check context first to avoid transaction leak.
+ // If put it behind tx.done CompareAndSwap statement, we cant't ensure
+ // the consistency between tx.done and the real COMMIT operation.
select {
default:
case <-tx.ctx.Done():
+ if atomic.LoadInt32(&tx.done) == 1 {
+ return ErrTxDone
+ }
return tx.ctx.Err()
}
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
+ return ErrTxDone
+ }
var err error
withLock(tx.dc, func() {
err = tx.txi.Commit()
@@ -1859,6 +2009,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// ...
// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
@@ -1902,11 +2055,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.mu.Unlock()
if si == nil {
- cs, err := stmt.prepareOnConnLocked(ctx, dc)
+ withLock(dc, func() {
+ var ds *driverStmt
+ ds, err = stmt.prepareOnConnLocked(ctx, dc)
+ si = ds.si
+ })
if err != nil {
return &Stmt{stickyErr: err}
}
- si = cs.si
}
parentStmt = stmt
}
@@ -2098,13 +2254,20 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
}
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
- dargs, err := driverArgs(ci, ds, args)
+ ds.Lock()
+ defer ds.Unlock()
+
+ dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
- ds.Lock()
- defer ds.Unlock()
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ if want := ds.si.NumInput(); want >= 0 && want != len(dargs) {
+ return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs))
+ }
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
@@ -2269,25 +2432,20 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
}
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
- var want int
- withLock(ds, func() {
- want = ds.si.NumInput()
- })
-
- // -1 means the driver doesn't know how to count the number of
- // placeholders, so we won't sanity check input here and instead let the
- // driver deal with errors.
- if want != -1 && len(args) != want {
- return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
- }
+ ds.Lock()
+ defer ds.Unlock()
- dargs, err := driverArgs(ci, ds, args)
+ dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
- ds.Lock()
- defer ds.Unlock()
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ if want := ds.si.NumInput(); want >= 0 && want != len(dargs) {
+ return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs))
+ }
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
@@ -2451,9 +2609,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
+
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
@@ -2503,6 +2668,12 @@ func (rs *Rows) NextResultSet() bool {
doClose = true
return false
}
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil {
doClose = true
@@ -2534,6 +2705,9 @@ func (rs *Rows) Columns() ([]string, error) {
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
return rs.rowsi.Columns(), nil
}
@@ -2548,7 +2722,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
- return rowsColumnInfoSetup(rs.rowsi), nil
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
}
// ColumnType contains the name and type of a column.
@@ -2609,7 +2786,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType
}
-func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns()
list := make([]*ColumnType, len(names))
diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go
index c935eb43480..8137eff82b4 100644
--- a/libgo/go/database/sql/sql_test.go
+++ b/libgo/go/database/sql/sql_test.go
@@ -60,10 +60,12 @@ const fakeDBName = "foo"
var chrisBirthday = time.Unix(123456789, 0)
func newTestDB(t testing.TB, name string) *DB {
- db, err := Open("test", fakeDBName)
- if err != nil {
- t.Fatalf("Open: %v", err)
- }
+ return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
+}
+
+func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
+ fc.name = fakeDBName
+ db := OpenDB(fc)
if _, err := db.Exec("WIPE"); err != nil {
t.Fatalf("exec wipe: %v", err)
}
@@ -81,6 +83,13 @@ func newTestDB(t testing.TB, name string) *DB {
return db
}
+func TestOpenDB(t *testing.T) {
+ db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver})
+ if db.Driver() != fdriver {
+ t.Fatalf("OpenDB should return the driver of the Connector")
+ }
+}
+
func TestDriverPanic(t *testing.T) {
// Test that if driver panics, database/sql does not deadlock.
db, err := Open("test", fakeDBName)
@@ -439,6 +448,20 @@ func TestTxContextWait(t *testing.T) {
waitForFree(t, db, 5*time.Second, 0)
}
+// TestUnsupportedOptions checks that the database fails when a driver that
+// doesn't implement ConnBeginTx is used with non-default options and an
+// un-cancellable context.
+func TestUnsupportedOptions(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ _, err := db.BeginTx(context.Background(), &TxOptions{
+ Isolation: LevelSerializable, ReadOnly: true,
+ })
+ if err == nil {
+ t.Fatal("expected error when using unsupported options, got nil")
+ }
+}
+
func TestMultiResultSetQuery(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -564,24 +587,46 @@ func TestPoolExhaustOnCancel(t *testing.T) {
if testing.Short() {
t.Skip("long test")
}
- db := newTestDB(t, "people")
- defer closeDB(t, db)
max := 3
+ var saturate, saturateDone sync.WaitGroup
+ saturate.Add(max)
+ saturateDone.Add(max)
+
+ donePing := make(chan bool)
+ state := 0
+
+ // waiter will be called for all queries, including
+ // initial setup queries. The state is only assigned when no
+ // no queries are made.
+ //
+ // Only allow the first batch of queries to finish once the
+ // second batch of Ping queries have finished.
+ waiter := func(ctx context.Context) {
+ switch state {
+ case 0:
+ // Nothing. Initial database setup.
+ case 1:
+ saturate.Done()
+ select {
+ case <-ctx.Done():
+ case <-donePing:
+ }
+ case 2:
+ }
+ }
+ db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
+ defer closeDB(t, db)
db.SetMaxOpenConns(max)
// First saturate the connection pool.
// Then start new requests for a connection that is cancelled after it is requested.
- var saturate, saturateDone sync.WaitGroup
- saturate.Add(max)
- saturateDone.Add(max)
-
+ state = 1
for i := 0; i < max; i++ {
go func() {
- saturate.Done()
- rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|")
+ rows, err := db.Query("SELECT|people|name,photo|")
if err != nil {
t.Fatalf("Query: %v", err)
}
@@ -591,6 +636,7 @@ func TestPoolExhaustOnCancel(t *testing.T) {
}
saturate.Wait()
+ state = 2
// Now cancel the request while it is waiting.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
@@ -607,7 +653,7 @@ func TestPoolExhaustOnCancel(t *testing.T) {
t.Fatalf("PingContext (Exhaust): %v", err)
}
}
-
+ close(donePing)
saturateDone.Wait()
// Now try to open a normal connection.
@@ -705,15 +751,15 @@ func TestRowsColumnTypes(t *testing.T) {
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
- ct++
- if ct == 0 {
- if values[0].(string) != "Bob" {
- t.Errorf("Expected Bob, got %v", values[0])
+ if ct == 1 {
+ if age := *values[0].(*int32); age != 2 {
+ t.Errorf("Expected 2, got %v", age)
}
- if values[1].(int) != 2 {
- t.Errorf("Expected 2, got %v", values[1])
+ if name := *values[1].(*string); name != "Bob" {
+ t.Errorf("Expected Bob, got %v", name)
}
}
+ ct++
}
if ct != 3 {
t.Errorf("expected 3 rows, got %d", ct)
@@ -1311,6 +1357,7 @@ func TestConnQuery(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
defer conn.Close()
var name string
@@ -1338,6 +1385,7 @@ func TestConnTx(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
defer conn.Close()
tx, err := conn.BeginTx(ctx, nil)
@@ -1658,7 +1706,7 @@ func TestIssue4902(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
- driver := db.driver.(*fakeDriver)
+ driver := db.Driver().(*fakeDriver)
opens0 := driver.openCount
var stmt *Stmt
@@ -1751,7 +1799,7 @@ func TestMaxOpenConns(t *testing.T) {
db := newTestDB(t, "magicquery")
defer closeDB(t, db)
- driver := db.driver.(*fakeDriver)
+ driver := db.Driver().(*fakeDriver)
// Force the number of open connections to 0 so we can get an accurate
// count for the test
@@ -2043,7 +2091,7 @@ func TestConnMaxLifetime(t *testing.T) {
db := newTestDB(t, "magicquery")
defer closeDB(t, db)
- driver := db.driver.(*fakeDriver)
+ driver := db.Driver().(*fakeDriver)
// Force the number of open connections to 0 so we can get an accurate
// count for the test
@@ -2132,7 +2180,7 @@ func TestStmtCloseDeps(t *testing.T) {
db := newTestDB(t, "magicquery")
defer closeDB(t, db)
- driver := db.driver.(*fakeDriver)
+ driver := db.Driver().(*fakeDriver)
driver.mu.Lock()
opens0 := driver.openCount
@@ -2363,7 +2411,9 @@ func TestManyErrBadConn(t *testing.T) {
t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
}
for _, conn := range db.freeConn {
+ conn.Lock()
conn.ci.(*fakeConn).stickyBad = true
+ conn.Unlock()
}
return db
}
@@ -2453,6 +2503,7 @@ func TestManyErrBadConn(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
err = conn.Close()
if err != nil {
t.Fatal(err)
@@ -3057,7 +3108,7 @@ func TestIssue6081(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
- drv := db.driver.(*fakeDriver)
+ drv := db.Driver().(*fakeDriver)
drv.mu.Lock()
opens0 := drv.openCount
closes0 := drv.closeCount
@@ -3106,6 +3157,9 @@ func TestIssue6081(t *testing.T) {
// In the test, a context is canceled while the query is in process so
// the internal rollback will run concurrently with the explicitly called
// Tx.Rollback.
+//
+// The addition of calling rows.Next also tests
+// Issue 21117.
func TestIssue18429(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
@@ -3138,6 +3192,9 @@ func TestIssue18429(t *testing.T) {
// reported.
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
if rows != nil {
+ // Call Next to test Issue 21117 and check for races.
+ for rows.Next() {
+ }
rows.Close()
}
// This call will race with the context cancel rollback to complete
@@ -3217,9 +3274,8 @@ func TestIssue18719(t *testing.T) {
// This call will grab the connection and cancel the context
// after it has done so. Code after must deal with the canceled state.
- rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
+ _, err = tx.QueryContext(ctx, "SELECT|people|name|")
if err != nil {
- rows.Close()
t.Fatalf("expected error %v but got %v", nil, err)
}
@@ -3242,6 +3298,7 @@ func TestIssue20647(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
@@ -3312,7 +3369,7 @@ func TestConnectionLeak(t *testing.T) {
// Now we have defaultMaxIdleConns busy connections. Open
// a new one, but wait until the busy connections are released
// before returning control to DB.
- drv := db.driver.(*fakeDriver)
+ drv := db.Driver().(*fakeDriver)
drv.waitCh = make(chan struct{}, 1)
drv.waitingCh = make(chan struct{}, 1)
var wg sync.WaitGroup
@@ -3376,7 +3433,7 @@ func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
case Out:
switch ov := v.Dest.(type) {
default:
- return errors.New("unkown NameValueCheck OUTPUT type")
+ return errors.New("unknown NameValueCheck OUTPUT type")
case *string:
*ov = "from-server"
nv.Value = "OUT:*string"
@@ -3466,6 +3523,141 @@ func TestNamedValueCheckerSkip(t *testing.T) {
}
}
+func TestOpenConnector(t *testing.T) {
+ Register("testctx", &fakeDriverCtx{})
+ db, err := Open("testctx", "people")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ if _, is := db.connector.(*fakeConnector); !is {
+ t.Fatal("not using *fakeConnector")
+ }
+}
+
+type ctxOnlyDriver struct {
+ fakeDriver
+}
+
+func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) {
+ conn, err := d.fakeDriver.Open(dsn)
+ if err != nil {
+ return nil, err
+ }
+ return &ctxOnlyConn{fc: conn.(*fakeConn)}, nil
+}
+
+var (
+ _ driver.Conn = &ctxOnlyConn{}
+ _ driver.QueryerContext = &ctxOnlyConn{}
+ _ driver.ExecerContext = &ctxOnlyConn{}
+)
+
+type ctxOnlyConn struct {
+ fc *fakeConn
+
+ queryCtxCalled bool
+ execCtxCalled bool
+}
+
+func (c *ctxOnlyConn) Begin() (driver.Tx, error) {
+ return c.fc.Begin()
+}
+
+func (c *ctxOnlyConn) Close() error {
+ return c.fc.Close()
+}
+
+// Prepare is still part of the Conn interface, so while it isn't used
+// must be defined for compatibility.
+func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) {
+ panic("not used")
+}
+
+func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
+ return c.fc.PrepareContext(ctx, q)
+}
+
+func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
+ c.queryCtxCalled = true
+ return c.fc.QueryContext(ctx, q, args)
+}
+
+func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
+ c.execCtxCalled = true
+ return c.fc.ExecContext(ctx, q, args)
+}
+
+// TestQueryExecContextOnly ensures drivers only need to implement QueryContext
+// and ExecContext methods.
+func TestQueryExecContextOnly(t *testing.T) {
+ // Ensure connection does not implment non-context interfaces.
+ var connType driver.Conn = &ctxOnlyConn{}
+ if _, ok := connType.(driver.Execer); ok {
+ t.Fatalf("%T must not implement driver.Execer", connType)
+ }
+ if _, ok := connType.(driver.Queryer); ok {
+ t.Fatalf("%T must not implement driver.Queryer", connType)
+ }
+
+ Register("ContextOnly", &ctxOnlyDriver{})
+ db, err := Open("ContextOnly", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal("db.Conn", err)
+ }
+ defer conn.Close()
+ coc := conn.dc.ci.(*ctxOnlyConn)
+ coc.fc.skipDirtySession = true
+
+ _, err = conn.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+ expectedValue := "value1"
+ _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue)
+ if err != nil {
+ t.Fatal("exec insert", err)
+ }
+ rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|")
+ if err != nil {
+ t.Fatal("query select", err)
+ }
+ v1 := ""
+ for rows.Next() {
+ err = rows.Scan(&v1)
+ if err != nil {
+ t.Fatal("rows scan", err)
+ }
+ }
+ rows.Close()
+
+ if v1 != expectedValue {
+ t.Fatalf("expected %q, got %q", expectedValue, v1)
+ }
+
+ if !coc.execCtxCalled {
+ t.Error("ExecContext not called")
+ }
+ if !coc.queryCtxCalled {
+ t.Error("QueryContext not called")
+ }
+}
+
// badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics.
type badConn struct{}