diff options
author | Ian Lance Taylor <iant@golang.org> | 2018-01-09 01:23:08 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2018-01-09 01:23:08 +0000 |
commit | 1a2f01efa63036a5104f203a4789e682c0e0915d (patch) | |
tree | 373e15778dc8295354584e1f86915ae493b604ff /libgo/go/database | |
parent | 8799df67f2dab88f9fda11739c501780a85575e2 (diff) | |
download | gcc-1a2f01efa63036a5104f203a4789e682c0e0915d.tar.gz |
libgo: update to Go1.10beta1
Update the Go library to the 1.10beta1 release.
Requires a few changes to the compiler for modifications to the map
runtime code, and to handle some nowritebarrier cases in the runtime.
Reviewed-on: https://go-review.googlesource.com/86455
gotools/:
* Makefile.am (go_cmd_vet_files): New variable.
(go_cmd_buildid_files, go_cmd_test2json_files): New variables.
(s-zdefaultcc): Change from constants to functions.
(noinst_PROGRAMS): Add vet, buildid, and test2json.
(cgo$(EXEEXT)): Link against $(LIBGOTOOL).
(vet$(EXEEXT)): New target.
(buildid$(EXEEXT)): New target.
(test2json$(EXEEXT)): New target.
(install-exec-local): Install all $(noinst_PROGRAMS).
(uninstall-local): Uninstasll all $(noinst_PROGRAMS).
(check-go-tool): Depend on $(noinst_PROGRAMS). Copy down
objabi.go.
(check-runtime): Depend on $(noinst_PROGRAMS).
(check-cgo-test, check-carchive-test): Likewise.
(check-vet): New target.
(check): Depend on check-vet. Look at cmd_vet-testlog.
(.PHONY): Add check-vet.
* Makefile.in: Rebuild.
From-SVN: r256365
Diffstat (limited to 'libgo/go/database')
-rw-r--r-- | libgo/go/database/sql/convert.go | 32 | ||||
-rw-r--r-- | libgo/go/database/sql/convert_test.go | 15 | ||||
-rw-r--r-- | libgo/go/database/sql/ctxutil.go | 19 | ||||
-rw-r--r-- | libgo/go/database/sql/driver/driver.go | 87 | ||||
-rw-r--r-- | libgo/go/database/sql/fakedb_test.go | 73 | ||||
-rw-r--r-- | libgo/go/database/sql/sql.go | 305 | ||||
-rw-r--r-- | libgo/go/database/sql/sql_test.go | 248 |
7 files changed, 647 insertions, 132 deletions
diff --git a/libgo/go/database/sql/convert.go b/libgo/go/database/sql/convert.go index 4983181fe75..b79ec3f7b27 100644 --- a/libgo/go/database/sql/convert.go +++ b/libgo/go/database/sql/convert.go @@ -12,7 +12,6 @@ import ( "fmt" "reflect" "strconv" - "sync" "time" "unicode" "unicode/utf8" @@ -38,17 +37,10 @@ func validateNamedValueName(name string) error { return fmt.Errorf("name %q does not begin with a letter", name) } -func driverNumInput(ds *driverStmt) int { - ds.Lock() - defer ds.Unlock() // in case NumInput panics - return ds.si.NumInput() -} - // ccChecker wraps the driver.ColumnConverter and allows it to be used // as if it were a NamedValueChecker. If the driver ColumnConverter // is not present then the NamedValueChecker will return driver.ErrSkip. type ccChecker struct { - sync.Locker cci driver.ColumnConverter want int } @@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { // same error. var err error arg := nv.Value - c.Lock() nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) - c.Unlock() if err != nil { return err } @@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { // Stmt.Query into driver Values. // // The statement ds may be nil, if no statement is available. -func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { +func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { nvargs := make([]driver.NamedValue, len(args)) // -1 means the driver doesn't know how to count the number of @@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na var cc ccChecker if ds != nil { si = ds.si - want = driverNumInput(ds) - cc.Locker = ds.Locker + want = ds.si.NumInput() cc.want = want } @@ -204,7 +193,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na } } - // Check the length of arguments after convertion to allow for omitted + // Check the length of arguments after conversion to allow for omitted // arguments. if want != -1 && len(nvargs) != want { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) @@ -234,6 +223,12 @@ func convertAssign(dest, src interface{}) error { } *d = []byte(s) return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil } case []byte: switch d := dest.(type) { @@ -264,6 +259,9 @@ func convertAssign(dest, src interface{}) error { } case time.Time: switch d := dest.(type) { + case *time.Time: + *d = s + return nil case *string: *d = s.Format(time.RFC3339Nano) return nil @@ -273,6 +271,12 @@ func convertAssign(dest, src interface{}) error { } *d = []byte(s.Format(time.RFC3339Nano)) return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil } case nil: switch d := dest.(type) { diff --git a/libgo/go/database/sql/convert_test.go b/libgo/go/database/sql/convert_test.go index cfe52d7f548..47098c81ec1 100644 --- a/libgo/go/database/sql/convert_test.go +++ b/libgo/go/database/sql/convert_test.go @@ -106,6 +106,7 @@ var conversionTests = []conversionTest{ // To RawBytes {s: nil, d: &scanraw, wantraw: nil}, {s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")}, + {s: "string", d: &scanraw, wantraw: RawBytes("string")}, {s: 123, d: &scanraw, wantraw: RawBytes("123")}, {s: int8(123), d: &scanraw, wantraw: RawBytes("123")}, {s: int64(123), d: &scanraw, wantraw: RawBytes("123")}, @@ -114,6 +115,9 @@ var conversionTests = []conversionTest{ {s: uint32(123), d: &scanraw, wantraw: RawBytes("123")}, {s: uint64(123), d: &scanraw, wantraw: RawBytes("123")}, {s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")}, + // time.Time has been placed here to check that the RawBytes slice gets + // correctly reset when calling time.Time.AppendFormat. + {s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")}, // Strings to integers {s: "255", d: &scanuint8, wantuint: 255}, @@ -222,6 +226,12 @@ func TestConversions(t *testing.T) { if ct.wantstr != "" && ct.wantstr != scanstr { errf("want string %q, got %q", ct.wantstr, scanstr) } + if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) { + errf("want byte %q, got %q", ct.wantbytes, scanbytes) + } + if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) { + errf("want RawBytes %q, got %q", ct.wantraw, scanraw) + } if ct.wantint != 0 && ct.wantint != intValue(ct.d) { errf("want int %d, got %d", ct.wantint, intValue(ct.d)) } @@ -341,6 +351,7 @@ func TestRawBytesAllocs(t *testing.T) { {"float32", float32(1.5), "1.5"}, {"float64", float64(64), "64"}, {"bool", false, "false"}, + {"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"}, } buf := make(RawBytes, 10) @@ -387,7 +398,7 @@ func TestRawBytesAllocs(t *testing.T) { } } -// https://github.com/golang/go/issues/13905 +// https://golang.org/issues/13905 func TestUserDefinedBytes(t *testing.T) { type userDefinedBytes []byte var u userDefinedBytes @@ -470,7 +481,7 @@ func TestDriverArgs(t *testing.T) { } for i, tt := range tests { ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}} - got, err := driverArgs(nil, ds, tt.args) + got, err := driverArgsConnLocked(nil, ds, tt.args) if err != nil { t.Errorf("test[%d]: %v", i, err) continue diff --git a/libgo/go/database/sql/ctxutil.go b/libgo/go/database/sql/ctxutil.go index bd652b54625..af2afd5aa57 100644 --- a/libgo/go/database/sql/ctxutil.go +++ b/libgo/go/database/sql/ctxutil.go @@ -26,8 +26,8 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver return si, err } -func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { - if execerCtx, is := execer.(driver.ExecerContext); is { +func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { + if execerCtx != nil { return execerCtx.ExecContext(ctx, query, nvdargs) } dargs, err := namedValueToValue(nvdargs) @@ -43,10 +43,9 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvda return execer.Exec(query, dargs) } -func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { - if queryerCtx, is := queryer.(driver.QueryerContext); is { - ret, err := queryerCtx.QueryContext(ctx, query, nvdargs) - return ret, err +func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { + if queryerCtx != nil { + return queryerCtx.QueryContext(ctx, query, nvdargs) } dargs, err := namedValueToValue(nvdargs) if err != nil { @@ -107,10 +106,6 @@ func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (drive return ciCtx.BeginTx(ctx, dopts) } - if ctx.Done() == context.Background().Done() { - return ci.Begin() - } - if opts != nil { // Check the transaction level. If the transaction level is non-default // then return an error here as the BeginTx driver value is not supported. @@ -125,6 +120,10 @@ func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (drive } } + if ctx.Done() == nil { + return ci.Begin() + } + txi, err := ci.Begin() if err == nil { select { diff --git a/libgo/go/database/sql/driver/driver.go b/libgo/go/database/sql/driver/driver.go index 0262ca24ba2..83b2b3f535e 100644 --- a/libgo/go/database/sql/driver/driver.go +++ b/libgo/go/database/sql/driver/driver.go @@ -42,6 +42,10 @@ type NamedValue struct { // Driver is the interface that must be implemented by a database // driver. +// +// Database drivers may implement DriverContext for access +// to contexts and to parse the name only once for a pool of connections, +// instead of once per connection. type Driver interface { // Open returns a new connection to the database. // The name is a string in a driver-specific format. @@ -55,6 +59,47 @@ type Driver interface { Open(name string) (Conn, error) } +// If a Driver implements DriverContext, then sql.DB will call +// OpenConnector to obtain a Connector and then invoke +// that Connector's Conn method to obtain each needed connection, +// instead of invoking the Driver's Open method for each connection. +// The two-step sequence allows drivers to parse the name just once +// and also provides access to per-Conn contexts. +type DriverContext interface { + // OpenConnector must parse the name in the same format that Driver.Open + // parses the name parameter. + OpenConnector(name string) (Connector, error) +} + +// A Connector represents a driver in a fixed configuration +// and can create any number of equivalent Conns for use +// by multiple goroutines. +// +// A Connector can be passed to sql.OpenDB, to allow drivers +// to implement their own sql.DB constructors, or returned by +// DriverContext's OpenConnector method, to allow drivers +// access to context and to avoid repeated parsing of driver +// configuration. +type Connector interface { + // Connect returns a connection to the database. + // Connect may return a cached connection (one previously + // closed), but doing so is unnecessary; the sql package + // maintains a pool of idle connections for efficient re-use. + // + // The provided context.Context is for dialing purposes only + // (see net.DialContext) and should not be stored or used for + // other purposes. + // + // The returned connection is only used by one goroutine at a + // time. + Connect(context.Context) (Conn, error) + + // Driver returns the underlying Driver of the Connector, + // mainly to maintain compatibility with the Driver method + // on sql.DB. + Driver() Driver +} + // ErrSkip may be returned by some optional interfaces' methods to // indicate at runtime that the fast path is unavailable and the sql // package should continue as if the optional interface was not @@ -86,22 +131,23 @@ type Pinger interface { // Execer is an optional interface that may be implemented by a Conn. // -// If a Conn does not implement Execer, the sql package's DB.Exec will -// first prepare a query, execute the statement, and then close the -// statement. +// If a Conn implements neither ExecerContext nor Execer Execer, +// the sql package's DB.Exec will first prepare a query, execute the statement, +// and then close the statement. // // Exec may return ErrSkip. // -// Deprecated: Drivers should implement ExecerContext instead (or additionally). +// Deprecated: Drivers should implement ExecerContext instead. type Execer interface { Exec(query string, args []Value) (Result, error) } // ExecerContext is an optional interface that may be implemented by a Conn. // -// If a Conn does not implement ExecerContext, the sql package's DB.Exec will -// first prepare a query, execute the statement, and then close the -// statement. +// If a Conn does not implement ExecerContext, the sql package's DB.Exec +// will fall back to Execer; if the Conn does not implement Execer either, +// DB.Exec will first prepare a query, execute the statement, and then +// close the statement. // // ExecerContext may return ErrSkip. // @@ -112,22 +158,23 @@ type ExecerContext interface { // Queryer is an optional interface that may be implemented by a Conn. // -// If a Conn does not implement Queryer, the sql package's DB.Query will -// first prepare a query, execute the statement, and then close the -// statement. +// If a Conn implements neither QueryerContext nor Queryer, +// the sql package's DB.Query will first prepare a query, execute the statement, +// and then close the statement. // // Query may return ErrSkip. // -// Deprecated: Drivers should implement QueryerContext instead (or additionally). +// Deprecated: Drivers should implement QueryerContext instead. type Queryer interface { Query(query string, args []Value) (Rows, error) } // QueryerContext is an optional interface that may be implemented by a Conn. // -// If a Conn does not implement QueryerContext, the sql package's DB.Query will -// first prepare a query, execute the statement, and then close the -// statement. +// If a Conn does not implement QueryerContext, the sql package's DB.Query +// will fall back to Queryer; if the Conn does not implement Queryer either, +// DB.Query will first prepare a query, execute the statement, and then +// close the statement. // // QueryerContext may return ErrSkip. // @@ -199,6 +246,18 @@ type ConnBeginTx interface { BeginTx(ctx context.Context, opts TxOptions) (Tx, error) } +// SessionResetter may be implemented by Conn to allow drivers to reset the +// session state associated with the connection and to signal a bad connection. +type SessionResetter interface { + // ResetSession is called while a connection is in the connection + // pool. No queries will run on this connection until this method returns. + // + // If the connection is bad this should return driver.ErrBadConn to prevent + // the connection from being returned to the connection pool. Any other + // error will be discarded. + ResetSession(ctx context.Context) error +} + // Result is the result of a query execution. type Result interface { // LastInsertId returns the database's auto-generated ID diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go index 4dcd096ca4d..e795412de01 100644 --- a/libgo/go/database/sql/fakedb_test.go +++ b/libgo/go/database/sql/fakedb_test.go @@ -55,6 +55,32 @@ type fakeDriver struct { dbs map[string]*fakeDB } +type fakeConnector struct { + name string + + waiter func(context.Context) +} + +func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { + conn, err := fdriver.Open(c.name) + conn.(*fakeConn).waiter = c.waiter + return conn, err +} + +func (c *fakeConnector) Driver() driver.Driver { + return fdriver +} + +type fakeDriverCtx struct { + fakeDriver +} + +var _ driver.DriverContext = &fakeDriverCtx{} + +func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { + return &fakeConnector{name: name}, nil +} + type fakeDB struct { name string @@ -107,6 +133,16 @@ type fakeConn struct { // bad connection tests; see isBad() bad bool stickyBad bool + + skipDirtySession bool // tests that use Conn should set this to true. + + // dirtySession tests ResetSession, true if a query has executed + // until ResetSession is called. + dirtySession bool + + // The waiter is called before each query. May be used in place of the "WAIT" + // directive. + waiter func(context.Context) } func (c *fakeConn) touchMem() { @@ -298,6 +334,9 @@ func (c *fakeConn) isBad() bool { if c.stickyBad { return true } else if c.bad { + if c.db == nil { + return false + } // alternate between bad conn and not bad conn c.db.badConn = !c.db.badConn return c.db.badConn @@ -306,6 +345,21 @@ func (c *fakeConn) isBad() bool { } } +func (c *fakeConn) isDirtyAndMark() bool { + if c.skipDirtySession { + return false + } + if c.currTx != nil { + c.dirtySession = true + return false + } + if c.dirtySession { + return true + } + c.dirtySession = true + return false +} + func (c *fakeConn) Begin() (driver.Tx, error) { if c.isBad() { return nil, driver.ErrBadConn @@ -337,6 +391,14 @@ func setStrictFakeConnClose(t *testing.T) { testStrictClose = t } +func (c *fakeConn) ResetSession(ctx context.Context) error { + c.dirtySession = false + if c.isBad() { + return driver.ErrBadConn + } + return nil +} + func (c *fakeConn) Close() (err error) { drv := fdriver.(*fakeDriver) defer func() { @@ -572,6 +634,10 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm stmt.cmd = cmd parts = parts[1:] + if c.waiter != nil { + c.waiter(ctx) + } + if stmt.wait > 0 { wait := time.NewTimer(stmt.wait) select { @@ -662,6 +728,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { return nil, driver.ErrBadConn } + if s.c.isDirtyAndMark() { + return nil, errors.New("session is dirty") + } err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { @@ -774,6 +843,9 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { return nil, driver.ErrBadConn } + if s.c.isDirtyAndMark() { + return nil, errors.New("session is dirty") + } err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { @@ -943,6 +1015,7 @@ type rowsCursor struct { } func (rc *rowsCursor) touchMem() { + rc.parentMem.touchMem() rc.line++ } diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go index c609fe4cc43..9f4fa14534d 100644 --- a/libgo/go/database/sql/sql.go +++ b/libgo/go/database/sql/sql.go @@ -285,7 +285,7 @@ type Scanner interface { // Example usage: // // var outArg string -// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg})) +// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg})) type Out struct { _Named_Fields_Required struct{} @@ -317,8 +317,7 @@ var ErrNoRows = errors.New("sql: no rows in result set") // connection is returned to DB's idle connection pool. The pool size // can be controlled with SetMaxIdleConns. type DB struct { - driver driver.Driver - dsn string + connector driver.Connector // numClosed is an atomic counter which represents a total number of // closed connections. Stmt.openStmt checks it before cleaning closed // connections in Stmt.css. @@ -335,6 +334,7 @@ type DB struct { // It is closed during db.Close(). The close tells the connectionOpener // goroutine to exit. openerCh chan struct{} + resetterCh chan *driverConn closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only @@ -342,6 +342,8 @@ type DB struct { maxOpen int // <= 0 means unlimited maxLifetime time.Duration // maximum amount of time a connection may be reused cleanerCh chan struct{} + + stop func() // stop cancels the connection opener and the session resetter. } // connReuseStrategy determines how (*DB).conn returns database connections. @@ -369,6 +371,7 @@ type driverConn struct { closed bool finalClosed bool // ci.Close has been called openStmt map[*driverStmt]bool + lastErr error // lastError captures the result of the session resetter. // guarded by db.mu inUse bool @@ -377,7 +380,7 @@ type driverConn struct { } func (dc *driverConn) releaseConn(err error) { - dc.db.putConn(dc, err) + dc.db.putConn(dc, err, true) } func (dc *driverConn) removeOpenStmt(ds *driverStmt) { @@ -418,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que return ds, nil } +// resetSession resets the connection session and sets the lastErr +// that is checked before returning the connection to another query. +// +// resetSession assumes that the embedded mutex is locked when the connection +// was returned to the pool. This unlocks the mutex. +func (dc *driverConn) resetSession(ctx context.Context) { + defer dc.Unlock() // In case of panic. + if dc.closed { // Check if the database has been closed. + return + } + dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx) +} + // the dc.db's Mutex is held. func (dc *driverConn) closeDBLocked() func() error { dc.Lock() @@ -575,6 +591,52 @@ func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { // to block until the connectionOpener can satisfy the backlog of requests. var connectionRequestQueueSize = 1000000 +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} + +// OpenDB opens a database using a Connector, allowing drivers to +// bypass a string based data source name. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See https://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// OpenDB may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the OpenDB +// function should be called just once. It is rarely necessary to +// close a DB. +func OpenDB(c driver.Connector) *DB { + ctx, cancel := context.WithCancel(context.Background()) + db := &DB{ + connector: c, + openerCh: make(chan struct{}, connectionRequestQueueSize), + resetterCh: make(chan *driverConn, 50), + lastPut: make(map[*driverConn]string), + connRequests: make(map[uint64]chan connRequest), + stop: cancel, + } + + go db.connectionOpener(ctx) + go db.connectionResetter(ctx) + + return db +} + // Open opens a database specified by its database driver name and a // driver-specific data source name, usually consisting of at least a // database name and connection information. @@ -599,15 +661,16 @@ func Open(driverName, dataSourceName string) (*DB, error) { if !ok { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } - db := &DB{ - driver: driveri, - dsn: dataSourceName, - openerCh: make(chan struct{}, connectionRequestQueueSize), - lastPut: make(map[*driverConn]string), - connRequests: make(map[uint64]chan connRequest), + + if driverCtx, ok := driveri.(driver.DriverContext); ok { + connector, err := driverCtx.OpenConnector(dataSourceName) + if err != nil { + return nil, err + } + return OpenDB(connector), nil } - go db.connectionOpener() - return db, nil + + return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil } func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error { @@ -659,7 +722,6 @@ func (db *DB) Close() error { db.mu.Unlock() return nil } - close(db.openerCh) if db.cleanerCh != nil { close(db.cleanerCh) } @@ -680,6 +742,7 @@ func (db *DB) Close() error { err = err1 } } + db.stop() return err } @@ -867,18 +930,40 @@ func (db *DB) maybeOpenNewConnections() { } // Runs in a separate goroutine, opens new connections when requested. -func (db *DB) connectionOpener() { - for range db.openerCh { - db.openNewConnection() +func (db *DB) connectionOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-db.openerCh: + db.openNewConnection(ctx) + } + } +} + +// connectionResetter runs in a separate goroutine to reset connections async +// to exported API. +func (db *DB) connectionResetter(ctx context.Context) { + for { + select { + case <-ctx.Done(): + close(db.resetterCh) + for dc := range db.resetterCh { + dc.Unlock() + } + return + case dc := <-db.resetterCh: + dc.resetSession(ctx) + } } } // Open one new connection -func (db *DB) openNewConnection() { +func (db *DB) openNewConnection(ctx context.Context) { // maybeOpenNewConnctions has already executed db.numOpen++ before it sent // on db.openerCh. This function must execute db.numOpen-- if the // connection fails or is closed before returning. - ci, err := db.driver.Open(db.dsn) + ci, err := db.connector.Connect(ctx) db.mu.Lock() defer db.mu.Unlock() if db.closed { @@ -953,6 +1038,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn conn.Close() return nil, driver.ErrBadConn } + // Lock around reading lastErr to ensure the session resetter finished. + conn.Lock() + err := conn.lastErr + conn.Unlock() + if err == driver.ErrBadConn { + conn.Close() + return nil, driver.ErrBadConn + } return conn, nil } @@ -978,7 +1071,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn default: case ret, ok := <-req: if ok { - db.putConn(ret.conn, ret.err) + db.putConn(ret.conn, ret.err, false) } } return nil, ctx.Err() @@ -990,13 +1083,24 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn ret.conn.Close() return nil, driver.ErrBadConn } + if ret.conn == nil { + return nil, ret.err + } + // Lock around reading lastErr to ensure the session resetter finished. + ret.conn.Lock() + err := ret.conn.lastErr + ret.conn.Unlock() + if err == driver.ErrBadConn { + ret.conn.Close() + return nil, driver.ErrBadConn + } return ret.conn, ret.err } } db.numOpen++ // optimistically db.mu.Unlock() - ci, err := db.driver.Open(db.dsn) + ci, err := db.connector.Connect(ctx) if err != nil { db.mu.Lock() db.numOpen-- // correct for earlier optimism @@ -1045,7 +1149,7 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(dc *driverConn, err error) { +func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { db.mu.Lock() if !dc.inUse { if debugGetPut { @@ -1076,11 +1180,40 @@ func (db *DB) putConn(dc *driverConn, err error) { if putConnHook != nil { putConnHook(db, dc) } + if db.closed { + // Connections do not need to be reset if they will be closed. + // Prevents writing to resetterCh after the DB has closed. + resetSession = false + } + if resetSession { + if _, resetSession = dc.ci.(driver.SessionResetter); resetSession { + // Lock the driverConn here so it isn't released until + // the connection is reset. + // The lock must be taken before the connection is put into + // the pool to prevent it from being taken out before it is reset. + dc.Lock() + } + } added := db.putConnDBLocked(dc, nil) db.mu.Unlock() if !added { + if resetSession { + dc.Unlock() + } dc.Close() + return + } + if !resetSession { + return + } + select { + default: + // If the resetterCh is blocking then mark the connection + // as bad and continue on. + dc.lastErr = driver.ErrBadConn + dc.Unlock() + case db.resetterCh <- dc: } } @@ -1242,15 +1375,20 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q defer func() { release(err) }() - if execer, ok := dc.ci.(driver.Execer); ok { - var dargs []driver.NamedValue - dargs, err = driverArgs(dc.ci, nil, args) - if err != nil { - return nil, err - } + execerCtx, ok := dc.ci.(driver.ExecerContext) + var execer driver.Execer + if !ok { + execer, ok = dc.ci.(driver.Execer) + } + if ok { + var nvdargs []driver.NamedValue var resi driver.Result withLock(dc, func() { - resi, err = ctxDriverExec(ctx, execer, query, dargs) + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs) }) if err != driver.ErrSkip { if err != nil { @@ -1309,15 +1447,21 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat // The ctx context is from a query method and the txctx context is from an // optional transaction context. func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { - if queryer, ok := dc.ci.(driver.Queryer); ok { - dargs, err := driverArgs(dc.ci, nil, args) - if err != nil { - releaseConn(err) - return nil, err - } + queryerCtx, ok := dc.ci.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = dc.ci.(driver.Queryer) + } + if ok { + var nvdargs []driver.NamedValue var rowsi driver.Rows + var err error withLock(dc, func() { - rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs) + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs) }) if err != driver.ErrSkip { if err != nil { @@ -1454,11 +1598,11 @@ func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), // Driver returns the database's underlying driver. func (db *DB) Driver() driver.Driver { - return db.driver + return db.connector.Driver() } // ErrConnDone is returned by any operation that is performed on a connection -// that has already been committed or rolled back. +// that has already been returned to the connection pool. var ErrConnDone = errors.New("database/sql: connection is already closed") // Conn returns a single connection by either opening a new connection @@ -1493,9 +1637,9 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) { type releaseConn func(error) -// Conn represents a single database session rather a pool of database -// sessions. Prefer running queries from DB unless there is a specific -// need for a continuous single database session. +// Conn represents a single database connection rather than a pool of database +// connections. Prefer running queries from DB unless there is a specific +// need for a continuous single database connection. // // A Conn must call Close to return the connection to the database pool // and may do so concurrently with a running query. @@ -1769,14 +1913,20 @@ func (tx *Tx) closePrepared() { // Commit commits the transaction. func (tx *Tx) Commit() error { - if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { - return ErrTxDone - } + // Check context first to avoid transaction leak. + // If put it behind tx.done CompareAndSwap statement, we cant't ensure + // the consistency between tx.done and the real COMMIT operation. select { default: case <-tx.ctx.Done(): + if atomic.LoadInt32(&tx.done) == 1 { + return ErrTxDone + } return tx.ctx.Err() } + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { + return ErrTxDone + } var err error withLock(tx.dc, func() { err = tx.txi.Commit() @@ -1859,6 +2009,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // ... // res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) // +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +// // The returned statement operates within the transaction and will be closed // when the transaction has been committed or rolled back. func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { @@ -1902,11 +2055,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { stmt.mu.Unlock() if si == nil { - cs, err := stmt.prepareOnConnLocked(ctx, dc) + withLock(dc, func() { + var ds *driverStmt + ds, err = stmt.prepareOnConnLocked(ctx, dc) + si = ds.si + }) if err != nil { return &Stmt{stickyErr: err} } - si = cs.si } parentStmt = stmt } @@ -2098,13 +2254,20 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { } func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) { - dargs, err := driverArgs(ci, ds, args) + ds.Lock() + defer ds.Unlock() + + dargs, err := driverArgsConnLocked(ci, ds, args) if err != nil { return nil, err } - ds.Lock() - defer ds.Unlock() + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := ds.si.NumInput(); want >= 0 && want != len(dargs) { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs)) + } resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) if err != nil { @@ -2269,25 +2432,20 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { } func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) { - var want int - withLock(ds, func() { - want = ds.si.NumInput() - }) - - // -1 means the driver doesn't know how to count the number of - // placeholders, so we won't sanity check input here and instead let the - // driver deal with errors. - if want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) - } + ds.Lock() + defer ds.Unlock() - dargs, err := driverArgs(ci, ds, args) + dargs, err := driverArgsConnLocked(ci, ds, args) if err != nil { return nil, err } - ds.Lock() - defer ds.Unlock() + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := ds.si.NumInput(); want >= 0 && want != len(dargs) { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs)) + } rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) if err != nil { @@ -2451,9 +2609,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) { if rs.closed { return false, false } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + if rs.lastcols == nil { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) } + rs.lasterr = rs.rowsi.Next(rs.lastcols) if rs.lasterr != nil { // Close the connection if there is a driver error. @@ -2503,6 +2668,12 @@ func (rs *Rows) NextResultSet() bool { doClose = true return false } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + rs.lasterr = nextResultSet.NextResultSet() if rs.lasterr != nil { doClose = true @@ -2534,6 +2705,9 @@ func (rs *Rows) Columns() ([]string, error) { if rs.rowsi == nil { return nil, errors.New("sql: no Rows available") } + rs.dc.Lock() + defer rs.dc.Unlock() + return rs.rowsi.Columns(), nil } @@ -2548,7 +2722,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { if rs.rowsi == nil { return nil, errors.New("sql: no Rows available") } - return rowsColumnInfoSetup(rs.rowsi), nil + rs.dc.Lock() + defer rs.dc.Unlock() + + return rowsColumnInfoSetupConnLocked(rs.rowsi), nil } // ColumnType contains the name and type of a column. @@ -2609,7 +2786,7 @@ func (ci *ColumnType) DatabaseTypeName() string { return ci.databaseType } -func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { +func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { names := rowsi.Columns() list := make([]*ColumnType, len(names)) diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go index c935eb43480..8137eff82b4 100644 --- a/libgo/go/database/sql/sql_test.go +++ b/libgo/go/database/sql/sql_test.go @@ -60,10 +60,12 @@ const fakeDBName = "foo" var chrisBirthday = time.Unix(123456789, 0) func newTestDB(t testing.TB, name string) *DB { - db, err := Open("test", fakeDBName) - if err != nil { - t.Fatalf("Open: %v", err) - } + return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name) +} + +func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB { + fc.name = fakeDBName + db := OpenDB(fc) if _, err := db.Exec("WIPE"); err != nil { t.Fatalf("exec wipe: %v", err) } @@ -81,6 +83,13 @@ func newTestDB(t testing.TB, name string) *DB { return db } +func TestOpenDB(t *testing.T) { + db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver}) + if db.Driver() != fdriver { + t.Fatalf("OpenDB should return the driver of the Connector") + } +} + func TestDriverPanic(t *testing.T) { // Test that if driver panics, database/sql does not deadlock. db, err := Open("test", fakeDBName) @@ -439,6 +448,20 @@ func TestTxContextWait(t *testing.T) { waitForFree(t, db, 5*time.Second, 0) } +// TestUnsupportedOptions checks that the database fails when a driver that +// doesn't implement ConnBeginTx is used with non-default options and an +// un-cancellable context. +func TestUnsupportedOptions(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + _, err := db.BeginTx(context.Background(), &TxOptions{ + Isolation: LevelSerializable, ReadOnly: true, + }) + if err == nil { + t.Fatal("expected error when using unsupported options, got nil") + } +} + func TestMultiResultSetQuery(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -564,24 +587,46 @@ func TestPoolExhaustOnCancel(t *testing.T) { if testing.Short() { t.Skip("long test") } - db := newTestDB(t, "people") - defer closeDB(t, db) max := 3 + var saturate, saturateDone sync.WaitGroup + saturate.Add(max) + saturateDone.Add(max) + + donePing := make(chan bool) + state := 0 + + // waiter will be called for all queries, including + // initial setup queries. The state is only assigned when no + // no queries are made. + // + // Only allow the first batch of queries to finish once the + // second batch of Ping queries have finished. + waiter := func(ctx context.Context) { + switch state { + case 0: + // Nothing. Initial database setup. + case 1: + saturate.Done() + select { + case <-ctx.Done(): + case <-donePing: + } + case 2: + } + } + db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people") + defer closeDB(t, db) db.SetMaxOpenConns(max) // First saturate the connection pool. // Then start new requests for a connection that is cancelled after it is requested. - var saturate, saturateDone sync.WaitGroup - saturate.Add(max) - saturateDone.Add(max) - + state = 1 for i := 0; i < max; i++ { go func() { - saturate.Done() - rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|") + rows, err := db.Query("SELECT|people|name,photo|") if err != nil { t.Fatalf("Query: %v", err) } @@ -591,6 +636,7 @@ func TestPoolExhaustOnCancel(t *testing.T) { } saturate.Wait() + state = 2 // Now cancel the request while it is waiting. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -607,7 +653,7 @@ func TestPoolExhaustOnCancel(t *testing.T) { t.Fatalf("PingContext (Exhaust): %v", err) } } - + close(donePing) saturateDone.Wait() // Now try to open a normal connection. @@ -705,15 +751,15 @@ func TestRowsColumnTypes(t *testing.T) { if err != nil { t.Fatalf("failed to scan values in %v", err) } - ct++ - if ct == 0 { - if values[0].(string) != "Bob" { - t.Errorf("Expected Bob, got %v", values[0]) + if ct == 1 { + if age := *values[0].(*int32); age != 2 { + t.Errorf("Expected 2, got %v", age) } - if values[1].(int) != 2 { - t.Errorf("Expected 2, got %v", values[1]) + if name := *values[1].(*string); name != "Bob" { + t.Errorf("Expected Bob, got %v", name) } } + ct++ } if ct != 3 { t.Errorf("expected 3 rows, got %d", ct) @@ -1311,6 +1357,7 @@ func TestConnQuery(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true defer conn.Close() var name string @@ -1338,6 +1385,7 @@ func TestConnTx(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true defer conn.Close() tx, err := conn.BeginTx(ctx, nil) @@ -1658,7 +1706,7 @@ func TestIssue4902(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) - driver := db.driver.(*fakeDriver) + driver := db.Driver().(*fakeDriver) opens0 := driver.openCount var stmt *Stmt @@ -1751,7 +1799,7 @@ func TestMaxOpenConns(t *testing.T) { db := newTestDB(t, "magicquery") defer closeDB(t, db) - driver := db.driver.(*fakeDriver) + driver := db.Driver().(*fakeDriver) // Force the number of open connections to 0 so we can get an accurate // count for the test @@ -2043,7 +2091,7 @@ func TestConnMaxLifetime(t *testing.T) { db := newTestDB(t, "magicquery") defer closeDB(t, db) - driver := db.driver.(*fakeDriver) + driver := db.Driver().(*fakeDriver) // Force the number of open connections to 0 so we can get an accurate // count for the test @@ -2132,7 +2180,7 @@ func TestStmtCloseDeps(t *testing.T) { db := newTestDB(t, "magicquery") defer closeDB(t, db) - driver := db.driver.(*fakeDriver) + driver := db.Driver().(*fakeDriver) driver.mu.Lock() opens0 := driver.openCount @@ -2363,7 +2411,9 @@ func TestManyErrBadConn(t *testing.T) { t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn) } for _, conn := range db.freeConn { + conn.Lock() conn.ci.(*fakeConn).stickyBad = true + conn.Unlock() } return db } @@ -2453,6 +2503,7 @@ func TestManyErrBadConn(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true err = conn.Close() if err != nil { t.Fatal(err) @@ -3057,7 +3108,7 @@ func TestIssue6081(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) - drv := db.driver.(*fakeDriver) + drv := db.Driver().(*fakeDriver) drv.mu.Lock() opens0 := drv.openCount closes0 := drv.closeCount @@ -3106,6 +3157,9 @@ func TestIssue6081(t *testing.T) { // In the test, a context is canceled while the query is in process so // the internal rollback will run concurrently with the explicitly called // Tx.Rollback. +// +// The addition of calling rows.Next also tests +// Issue 21117. func TestIssue18429(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -3138,6 +3192,9 @@ func TestIssue18429(t *testing.T) { // reported. rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") if rows != nil { + // Call Next to test Issue 21117 and check for races. + for rows.Next() { + } rows.Close() } // This call will race with the context cancel rollback to complete @@ -3217,9 +3274,8 @@ func TestIssue18719(t *testing.T) { // This call will grab the connection and cancel the context // after it has done so. Code after must deal with the canceled state. - rows, err := tx.QueryContext(ctx, "SELECT|people|name|") + _, err = tx.QueryContext(ctx, "SELECT|people|name|") if err != nil { - rows.Close() t.Fatalf("expected error %v but got %v", nil, err) } @@ -3242,6 +3298,7 @@ func TestIssue20647(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true defer conn.Close() stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|") @@ -3312,7 +3369,7 @@ func TestConnectionLeak(t *testing.T) { // Now we have defaultMaxIdleConns busy connections. Open // a new one, but wait until the busy connections are released // before returning control to DB. - drv := db.driver.(*fakeDriver) + drv := db.Driver().(*fakeDriver) drv.waitCh = make(chan struct{}, 1) drv.waitingCh = make(chan struct{}, 1) var wg sync.WaitGroup @@ -3376,7 +3433,7 @@ func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error { case Out: switch ov := v.Dest.(type) { default: - return errors.New("unkown NameValueCheck OUTPUT type") + return errors.New("unknown NameValueCheck OUTPUT type") case *string: *ov = "from-server" nv.Value = "OUT:*string" @@ -3466,6 +3523,141 @@ func TestNamedValueCheckerSkip(t *testing.T) { } } +func TestOpenConnector(t *testing.T) { + Register("testctx", &fakeDriverCtx{}) + db, err := Open("testctx", "people") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, is := db.connector.(*fakeConnector); !is { + t.Fatal("not using *fakeConnector") + } +} + +type ctxOnlyDriver struct { + fakeDriver +} + +func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) { + conn, err := d.fakeDriver.Open(dsn) + if err != nil { + return nil, err + } + return &ctxOnlyConn{fc: conn.(*fakeConn)}, nil +} + +var ( + _ driver.Conn = &ctxOnlyConn{} + _ driver.QueryerContext = &ctxOnlyConn{} + _ driver.ExecerContext = &ctxOnlyConn{} +) + +type ctxOnlyConn struct { + fc *fakeConn + + queryCtxCalled bool + execCtxCalled bool +} + +func (c *ctxOnlyConn) Begin() (driver.Tx, error) { + return c.fc.Begin() +} + +func (c *ctxOnlyConn) Close() error { + return c.fc.Close() +} + +// Prepare is still part of the Conn interface, so while it isn't used +// must be defined for compatibility. +func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) { + panic("not used") +} + +func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) { + return c.fc.PrepareContext(ctx, q) +} + +func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) { + c.queryCtxCalled = true + return c.fc.QueryContext(ctx, q, args) +} + +func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) { + c.execCtxCalled = true + return c.fc.ExecContext(ctx, q, args) +} + +// TestQueryExecContextOnly ensures drivers only need to implement QueryContext +// and ExecContext methods. +func TestQueryExecContextOnly(t *testing.T) { + // Ensure connection does not implment non-context interfaces. + var connType driver.Conn = &ctxOnlyConn{} + if _, ok := connType.(driver.Execer); ok { + t.Fatalf("%T must not implement driver.Execer", connType) + } + if _, ok := connType.(driver.Queryer); ok { + t.Fatalf("%T must not implement driver.Queryer", connType) + } + + Register("ContextOnly", &ctxOnlyDriver{}) + db, err := Open("ContextOnly", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal("db.Conn", err) + } + defer conn.Close() + coc := conn.dc.ci.(*ctxOnlyConn) + coc.fc.skipDirtySession = true + + _, err = conn.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string") + if err != nil { + t.Fatal("exec create", err) + } + expectedValue := "value1" + _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue) + if err != nil { + t.Fatal("exec insert", err) + } + rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|") + if err != nil { + t.Fatal("query select", err) + } + v1 := "" + for rows.Next() { + err = rows.Scan(&v1) + if err != nil { + t.Fatal("rows scan", err) + } + } + rows.Close() + + if v1 != expectedValue { + t.Fatalf("expected %q, got %q", expectedValue, v1) + } + + if !coc.execCtxCalled { + t.Error("ExecContext not called") + } + if !coc.queryCtxCalled { + t.Error("QueryContext not called") + } +} + // badConn implements a bad driver.Conn, for TestBadDriver. // The Exec method panics. type badConn struct{} |