diff options
Diffstat (limited to 'libgo/go/database/sql/sql.go')
-rw-r--r-- | libgo/go/database/sql/sql.go | 793 |
1 files changed, 615 insertions, 178 deletions
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go index 29aef78b240..a80782bfedc 100644 --- a/libgo/go/database/sql/sql.go +++ b/libgo/go/database/sql/sql.go @@ -4,6 +4,9 @@ // Package sql provides a generic interface around SQL (or SQL-like) // databases. +// +// The sql package must be used in conjunction with a database driver. +// See http://golang.org/s/sqldrivers for a list of drivers. package sql import ( @@ -11,6 +14,7 @@ import ( "errors" "fmt" "io" + "runtime" "sync" ) @@ -176,22 +180,197 @@ var ErrNoRows = errors.New("sql: no rows in result set") // DB is a database handle. It's safe for concurrent use by multiple // goroutines. // -// If the underlying database driver has the concept of a connection -// and per-connection session state, the sql package manages creating -// and freeing connections automatically, including maintaining a free -// pool of idle connections. If observing session state is required, -// either do not share a *DB between multiple concurrent goroutines or -// create and observe all state only within a transaction. Once -// DB.Open is called, the returned Tx is bound to a single isolated -// connection. Once Tx.Commit or Tx.Rollback is called, that -// connection is returned to DB's idle connection pool. +// The sql package creates and frees connections automatically; it +// also maintains a free pool of idle connections. If the database has +// a concept of per-connection state, such state can only be reliably +// observed within a transaction. Once DB.Begin is called, the +// returned Tx is bound to a single connection. Once Commit or +// Rollback is called on the transaction, that transaction's +// connection is returned to DB's idle connection pool. The pool size +// can be controlled with SetMaxIdleConns. type DB struct { driver driver.Driver dsn string - mu sync.Mutex // protects freeConn and closed - freeConn []driver.Conn + mu sync.Mutex // protects following fields + freeConn []*driverConn closed bool + dep map[finalCloser]depSet + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only + maxIdle int // zero means defaultMaxIdleConns; negative means 0 +} + +// driverConn wraps a driver.Conn with a mutex, to +// be held during all calls into the Conn. (including any calls onto +// interfaces returned via that Conn, such as calls on Tx, Stmt, +// Result, Rows) +type driverConn struct { + db *DB + + sync.Mutex // guards following + ci driver.Conn + closed bool + finalClosed bool // ci.Close has been called + openStmt map[driver.Stmt]bool + + // guarded by db.mu + inUse bool + onPut []func() // code (with db.mu held) run when conn is next returned + dbmuClosed bool // same as closed, but guarded by db.mu, for connIfFree +} + +func (dc *driverConn) removeOpenStmt(si driver.Stmt) { + dc.Lock() + defer dc.Unlock() + delete(dc.openStmt, si) +} + +func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { + si, err := dc.ci.Prepare(query) + if err == nil { + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // TODO(bradfitz): let drivers opt out of caring about + // stmt closes if the conn is about to close anyway? For now + // do the safe thing, in case stmts need to be closed. + // + // TODO(bradfitz): after Go 1.1, closing driver.Stmts + // should be moved to driverStmt, using unique + // *driverStmts everywhere (including from + // *Stmt.connStmt, instead of returning a + // driver.Stmt), using driverStmt as a pointer + // everywhere, and making it a finalCloser. + if dc.openStmt == nil { + dc.openStmt = make(map[driver.Stmt]bool) + } + dc.openStmt[si] = true + } + return si, err +} + +// the dc.db's Mutex is held. +func (dc *driverConn) closeDBLocked() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + return dc.db.removeDepLocked(dc, dc)() +} + +func (dc *driverConn) Close() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + + // And now updates that require holding dc.mu.Lock. + dc.db.mu.Lock() + dc.dbmuClosed = true + fn := dc.db.removeDepLocked(dc, dc) + dc.db.mu.Unlock() + return fn() +} + +func (dc *driverConn) finalClose() error { + dc.Lock() + + for si := range dc.openStmt { + si.Close() + } + dc.openStmt = nil + + err := dc.ci.Close() + dc.ci = nil + dc.finalClosed = true + + dc.Unlock() + return err +} + +// driverStmt associates a driver.Stmt with the +// *driverConn from which it came, so the driverConn's lock can be +// held during calls. +type driverStmt struct { + sync.Locker // the *driverConn + si driver.Stmt +} + +func (ds *driverStmt) Close() error { + ds.Lock() + defer ds.Unlock() + return ds.si.Close() +} + +// depSet is a finalCloser's outstanding dependencies +type depSet map[interface{}]bool // set of true bools + +// The finalCloser interface is used by (*DB).addDep and related +// dependency reference counting. +type finalCloser interface { + // finalClose is called when the reference count of an object + // goes to zero. (*DB).mu is not held while calling it. + finalClose() error +} + +// addDep notes that x now depends on dep, and x's finalClose won't be +// called until all of x's dependencies are removed with removeDep. +func (db *DB) addDep(x finalCloser, dep interface{}) { + //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep)) + db.mu.Lock() + defer db.mu.Unlock() + db.addDepLocked(x, dep) +} + +func (db *DB) addDepLocked(x finalCloser, dep interface{}) { + if db.dep == nil { + db.dep = make(map[finalCloser]depSet) + } + xdep := db.dep[x] + if xdep == nil { + xdep = make(depSet) + db.dep[x] = xdep + } + xdep[dep] = true +} + +// removeDep notes that x no longer depends on dep. +// If x still has dependencies, nil is returned. +// If x no longer has any dependencies, its finalClose method will be +// called and its error value will be returned. +func (db *DB) removeDep(x finalCloser, dep interface{}) error { + db.mu.Lock() + fn := db.removeDepLocked(x, dep) + db.mu.Unlock() + return fn() +} + +func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { + //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep)) + done := false + + xdep := db.dep[x] + if xdep != nil { + delete(xdep, dep) + if len(xdep) == 0 { + delete(db.dep, x) + done = true + } + } + + if !done { + return func() error { return nil } + } + return func() error { + //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x)) + return x.finalClose() + } } // Open opens a database specified by its database driver name and a @@ -199,13 +378,38 @@ type DB struct { // database name and connection information. // // Most users will open a database via a driver-specific connection -// helper function that returns a *DB. +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See http://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// Open may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. func Open(driverName, dataSourceName string) (*DB, error) { - driver, ok := drivers[driverName] + driveri, ok := drivers[driverName] if !ok { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } - return &DB{driver: driver, dsn: dataSourceName}, nil + db := &DB{ + driver: driveri, + dsn: dataSourceName, + lastPut: make(map[*driverConn]string), + } + return db, nil +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) Ping() error { + // TODO(bradfitz): give drivers an optional hook to implement + // this in a more efficient or more reliable way, if they + // have one. + dc, err := db.conn() + if err != nil { + return err + } + db.putConn(dc, nil) + return nil } // Close closes the database, releasing any open resources. @@ -213,8 +417,8 @@ func (db *DB) Close() error { db.mu.Lock() defer db.mu.Unlock() var err error - for _, c := range db.freeConn { - err1 := c.Close() + for _, dc := range db.freeConn { + err1 := dc.closeDBLocked() if err1 != nil { err = err1 } @@ -224,15 +428,45 @@ func (db *DB) Close() error { return err } -func (db *DB) maxIdleConns() int { - const defaultMaxIdleConns = 2 - // TODO(bradfitz): ask driver, if supported, for its default preference - // TODO(bradfitz): let users override? - return defaultMaxIdleConns +const defaultMaxIdleConns = 2 + +func (db *DB) maxIdleConnsLocked() int { + n := db.maxIdle + switch { + case n == 0: + // TODO(bradfitz): ask driver, if supported, for its default preference + return defaultMaxIdleConns + case n < 0: + return 0 + default: + return n + } +} + +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +// +// If n <= 0, no idle connections are retained. +func (db *DB) SetMaxIdleConns(n int) { + db.mu.Lock() + defer db.mu.Unlock() + if n > 0 { + db.maxIdle = n + } else { + // No idle connections. + db.maxIdle = -1 + } + for len(db.freeConn) > 0 && len(db.freeConn) > n { + nfree := len(db.freeConn) + dc := db.freeConn[nfree-1] + db.freeConn[nfree-1] = nil + db.freeConn = db.freeConn[:nfree-1] + go dc.Close() + } } -// conn returns a newly-opened or cached driver.Conn -func (db *DB) conn() (driver.Conn, error) { +// conn returns a newly-opened or cached *driverConn +func (db *DB) conn() (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() @@ -241,53 +475,132 @@ func (db *DB) conn() (driver.Conn, error) { if n := len(db.freeConn); n > 0 { conn := db.freeConn[n-1] db.freeConn = db.freeConn[:n-1] + conn.inUse = true db.mu.Unlock() return conn, nil } db.mu.Unlock() - return db.driver.Open(db.dsn) + + ci, err := db.driver.Open(db.dsn) + if err != nil { + return nil, err + } + dc := &driverConn{ + db: db, + ci: ci, + } + db.mu.Lock() + db.addDepLocked(dc, dc) + dc.inUse = true + db.mu.Unlock() + return dc, nil } -func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { +var ( + errConnClosed = errors.New("database/sql: internal sentinel error: conn is closed") + errConnBusy = errors.New("database/sql: internal sentinel error: conn is busy") +) + +// connIfFree returns (wanted, nil) if wanted is still a valid conn and +// isn't in use. +// +// The error is errConnClosed if the connection if the requested connection +// is invalid because it's been closed. +// +// The error is errConnBusy if the connection is in use. +func (db *DB) connIfFree(wanted *driverConn) (*driverConn, error) { db.mu.Lock() defer db.mu.Unlock() + if wanted.inUse { + return nil, errConnBusy + } + if wanted.dbmuClosed { + return nil, errConnClosed + } for i, conn := range db.freeConn { if conn != wanted { continue } db.freeConn[i] = db.freeConn[len(db.freeConn)-1] db.freeConn = db.freeConn[:len(db.freeConn)-1] - return wanted, true - } - return nil, false + wanted.inUse = true + return wanted, nil + } + // TODO(bradfitz): shouldn't get here. After Go 1.1, change this to: + // panic("connIfFree call requested a non-closed, non-busy, non-free conn") + // Which passes all the tests, but I'm too paranoid to include this + // late in Go 1.1. + // Instead, treat it like a busy connection: + return nil, errConnBusy } // putConnHook is a hook for testing. -var putConnHook func(*DB, driver.Conn) +var putConnHook func(*DB, *driverConn) + +// noteUnusedDriverStatement notes that si is no longer used and should +// be closed whenever possible (when c is next not in use), unless c is +// already closed. +func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { + db.mu.Lock() + defer db.mu.Unlock() + if c.inUse { + c.onPut = append(c.onPut, func() { + si.Close() + }) + } else { + c.Lock() + defer c.Unlock() + if !c.finalClosed { + si.Close() + } + } +} + +// debugGetPut determines whether getConn & putConn calls' stack traces +// are returned for more verbose crashes. +const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(c driver.Conn, err error) { +func (db *DB) putConn(dc *driverConn, err error) { + db.mu.Lock() + if !dc.inUse { + if debugGetPut { + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) + } + panic("sql: connection returned that was never out") + } + if debugGetPut { + db.lastPut[dc] = stack() + } + dc.inUse = false + + for _, fn := range dc.onPut { + fn() + } + dc.onPut = nil + if err == driver.ErrBadConn { // Don't reuse bad connections. + db.mu.Unlock() return } - db.mu.Lock() if putConnHook != nil { - putConnHook(db, c) + putConnHook(db, dc) } - if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { - db.freeConn = append(db.freeConn, c) + if n := len(db.freeConn); !db.closed && n < db.maxIdleConnsLocked() { + db.freeConn = append(db.freeConn, dc) db.mu.Unlock() return } - // TODO: check to see if we need this Conn for any prepared - // statements which are still active? db.mu.Unlock() - c.Close() + + dc.Close() } -// Prepare creates a prepared statement for later execution. +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. func (db *DB) Prepare(query string) (*Stmt, error) { var stmt *Stmt var err error @@ -300,34 +613,36 @@ func (db *DB) Prepare(query string) (*Stmt, error) { return stmt, err } -func (db *DB) prepare(query string) (stmt *Stmt, err error) { +func (db *DB) prepare(query string) (*Stmt, error) { // TODO: check if db.driver supports an optional // driver.Preparer interface and call that instead, if so, // otherwise we make a prepared statement that's bound // to a connection, and to execute this prepared statement // we either need to use this connection (if it's free), else // get a new connection + re-prepare + execute on that one. - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } - defer func() { - db.putConn(ci, err) - }() - - si, err := ci.Prepare(query) + dc.Lock() + si, err := dc.prepareLocked(query) + dc.Unlock() if err != nil { + db.putConn(dc, err) return nil, err } - stmt = &Stmt{ + stmt := &Stmt{ db: db, query: query, - css: []connStmt{{ci, si}}, + css: []connStmt{{dc, si}}, } + db.addDep(stmt, stmt) + db.putConn(dc, nil) return stmt, nil } // Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. func (db *DB) Exec(query string, args ...interface{}) (Result, error) { var res Result var err error @@ -341,50 +656,119 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { } func (db *DB) exec(query string, args []interface{}) (res Result, err error) { - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } defer func() { - db.putConn(ci, err) + db.putConn(dc, err) }() - if execer, ok := ci.(driver.Execer); ok { + if execer, ok := dc.ci.(driver.Execer); ok { dargs, err := driverArgs(nil, args) if err != nil { return nil, err } + dc.Lock() resi, err := execer.Exec(query, dargs) + dc.Unlock() if err != driver.ErrSkip { if err != nil { return nil, err } - return result{resi}, nil + return driverResult{dc, resi}, nil } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } - defer sti.Close() - - return resultFromStatement(sti, args...) + defer withLock(dc, func() { si.Close() }) + return resultFromStatement(driverStmt{dc, si}, args...) } // Query executes a query that returns rows, typically a SELECT. // The args are for any placeholder parameters in the query. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { - stmt, err := db.Prepare(query) + var rows *Rows + var err error + for i := 0; i < 10; i++ { + rows, err = db.query(query, args) + if err != driver.ErrBadConn { + break + } + } + return rows, err +} + +func (db *DB) query(query string, args []interface{}) (*Rows, error) { + ci, err := db.conn() + if err != nil { + return nil, err + } + + releaseConn := func(err error) { db.putConn(ci, err) } + + return db.queryConn(ci, releaseConn, query, args) +} + +// queryConn executes a query on the given connection. +// The connection gets released by the releaseConn function. +func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { + if queryer, ok := dc.ci.(driver.Queryer); ok { + dargs, err := driverArgs(nil, args) + if err != nil { + releaseConn(err) + return nil, err + } + dc.Lock() + rowsi, err := queryer.Query(query, dargs) + dc.Unlock() + if err != driver.ErrSkip { + if err != nil { + releaseConn(err) + return nil, err + } + // Note: ownership of dc passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + } + return rows, nil + } + } + + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { + releaseConn(err) return nil, err } - rows, err := stmt.Query(args...) + + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) if err != nil { - stmt.Close() + releaseConn(err) + dc.Lock() + si.Close() + dc.Unlock() return nil, err } - rows.closeStmt = stmt + + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + closeStmt: si, + } return rows, nil } @@ -411,18 +795,20 @@ func (db *DB) Begin() (*Tx, error) { } func (db *DB) begin() (tx *Tx, err error) { - ci, err := db.conn() + dc, err := db.conn() if err != nil { return nil, err } - txi, err := ci.Begin() + dc.Lock() + txi, err := dc.ci.Begin() + dc.Unlock() if err != nil { - db.putConn(ci, err) + db.putConn(dc, err) return nil, err } return &Tx{ db: db, - ci: ci, + dc: dc, txi: txi, }, nil } @@ -441,15 +827,11 @@ func (db *DB) Driver() driver.Driver { type Tx struct { db *DB - // ci is owned exclusively until Commit or Rollback, at which point + // dc is owned exclusively until Commit or Rollback, at which point // it's returned with putConn. - ci driver.Conn + dc *driverConn txi driver.Tx - // cimu is held while somebody is using ci (between grabConn - // and releaseConn) - cimu sync.Mutex - // done transitions from false to true exactly once, on Commit // or Rollback. once done, all operations fail with // ErrTxDone. @@ -463,21 +845,16 @@ func (tx *Tx) close() { panic("double close") // internal error } tx.done = true - tx.db.putConn(tx.ci, nil) - tx.ci = nil + tx.db.putConn(tx.dc, nil) + tx.dc = nil tx.txi = nil } -func (tx *Tx) grabConn() (driver.Conn, error) { +func (tx *Tx) grabConn() (*driverConn, error) { if tx.done { return nil, ErrTxDone } - tx.cimu.Lock() - return tx.ci, nil -} - -func (tx *Tx) releaseConn() { - tx.cimu.Unlock() + return tx.dc, nil } // Commit commits the transaction. @@ -486,6 +863,8 @@ func (tx *Tx) Commit() error { return ErrTxDone } defer tx.close() + tx.dc.Lock() + defer tx.dc.Unlock() return tx.txi.Commit() } @@ -495,6 +874,8 @@ func (tx *Tx) Rollback() error { return ErrTxDone } defer tx.close() + tx.dc.Lock() + defer tx.dc.Unlock() return tx.txi.Rollback() } @@ -518,21 +899,25 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // Perhaps just looking at the reference count (by noting // Stmt.Close) would be enough. We might also want a finalizer // on Stmt to drop the reference count. - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } - defer tx.releaseConn() - si, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } stmt := &Stmt{ - db: tx.db, - tx: tx, - txsi: si, + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, query: query, } return stmt, nil @@ -556,16 +941,20 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return &Stmt{stickyErr: err} } - defer tx.releaseConn() - si, err := ci.Prepare(stmt.query) + dc.Lock() + si, err := dc.ci.Prepare(stmt.query) + dc.Unlock() return &Stmt{ - db: tx.db, - tx: tx, - txsi: si, + db: tx.db, + tx: tx, + txsi: &driverStmt{ + Locker: dc, + si: si, + }, query: stmt.query, stickyErr: err, } @@ -574,51 +963,46 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { // Exec executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - ci, err := tx.grabConn() + dc, err := tx.grabConn() if err != nil { return nil, err } - defer tx.releaseConn() - if execer, ok := ci.(driver.Execer); ok { + if execer, ok := dc.ci.(driver.Execer); ok { dargs, err := driverArgs(nil, args) if err != nil { return nil, err } + dc.Lock() resi, err := execer.Exec(query, dargs) + dc.Unlock() if err == nil { - return result{resi}, nil + return driverResult{dc, resi}, nil } if err != driver.ErrSkip { return nil, err } } - sti, err := ci.Prepare(query) + dc.Lock() + si, err := dc.ci.Prepare(query) + dc.Unlock() if err != nil { return nil, err } - defer sti.Close() + defer withLock(dc, func() { si.Close() }) - return resultFromStatement(sti, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } // Query executes a query that returns rows, typically a SELECT. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - if tx.done { - return nil, ErrTxDone - } - stmt, err := tx.Prepare(query) + dc, err := tx.grabConn() if err != nil { return nil, err } - rows, err := stmt.Query(args...) - if err != nil { - stmt.Close() - return nil, err - } - rows.closeStmt = stmt - return rows, err + releaseConn := func(error) {} + return tx.db.queryConn(dc, releaseConn, query, args) } // QueryRow executes a query that is expected to return at most one row. @@ -631,7 +1015,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { // connStmt is a prepared statement on a particular connection. type connStmt struct { - ci driver.Conn + dc *driverConn si driver.Stmt } @@ -642,9 +1026,11 @@ type Stmt struct { query string // that created the Stmt stickyErr error // if non-nil, this error is returned for all operations + closemu sync.RWMutex // held exclusively during close, for read otherwise. + // If in a transaction, else both nil: tx *Tx - txsi driver.Stmt + txsi *driverStmt mu sync.Mutex // protects the rest of the fields closed bool @@ -659,39 +1045,47 @@ type Stmt struct { // Exec executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. func (s *Stmt) Exec(args ...interface{}) (Result, error) { - _, releaseConn, si, err := s.connStmt() + s.closemu.RLock() + defer s.closemu.RUnlock() + dc, releaseConn, si, err := s.connStmt() if err != nil { return nil, err } defer releaseConn(nil) - return resultFromStatement(si, args...) + return resultFromStatement(driverStmt{dc, si}, args...) } -func resultFromStatement(si driver.Stmt, args ...interface{}) (Result, error) { +func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. - if want := si.NumInput(); want != -1 && len(args) != want { + if want != -1 && len(args) != want { return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(si, args) + dargs, err := driverArgs(&ds, args) if err != nil { return nil, err } - resi, err := si.Exec(dargs) + ds.Lock() + resi, err := ds.si.Exec(dargs) + ds.Unlock() if err != nil { return nil, err } - return result{resi}, nil + return driverResult{ds.Locker, resi}, nil } // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { if err = s.stickyErr; err != nil { return } @@ -710,19 +1104,27 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St if err != nil { return } - releaseConn = func(error) { s.tx.releaseConn() } - return ci, releaseConn, s.txsi, nil + releaseConn = func(error) {} + return ci, releaseConn, s.txsi.si, nil } var cs connStmt match := false - for _, v := range s.css { - // TODO(bradfitz): lazily clean up entries in this - // list with dead conns while enumerating - if _, match = s.db.connIfFree(v.ci); match { + for i := 0; i < len(s.css); i++ { + v := s.css[i] + _, err := s.db.connIfFree(v.dc) + if err == nil { + match = true cs = v break } + if err == errConnClosed { + // Lazily remove dead conn from our freelist. + s.css[i] = s.css[len(s.css)-1] + s.css = s.css[:len(s.css)-1] + i-- + } + } s.mu.Unlock() @@ -730,11 +1132,13 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St // TODO(bradfitz): or wait for one? make configurable later? if !match { for i := 0; ; i++ { - ci, err := s.db.conn() + dc, err := s.db.conn() if err != nil { return nil, nil, nil, err } - si, err := ci.Prepare(s.query) + dc.Lock() + si, err := dc.prepareLocked(s.query) + dc.Unlock() if err == driver.ErrBadConn && i < 10 { continue } @@ -742,14 +1146,14 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St return nil, nil, nil, err } s.mu.Lock() - cs = connStmt{ci, si} + cs = connStmt{dc, si} s.css = append(s.css, cs) s.mu.Unlock() break } } - conn := cs.ci + conn := cs.dc releaseConn = func(err error) { s.db.putConn(conn, err) } return conn, releaseConn, cs.si, nil } @@ -757,37 +1161,60 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St // Query executes a prepared query statement with the given arguments // and returns the query results as a *Rows. func (s *Stmt) Query(args ...interface{}) (*Rows, error) { - ci, releaseConn, si, err := s.connStmt() + s.closemu.RLock() + defer s.closemu.RUnlock() + + dc, releaseConn, si, err := s.connStmt() + if err != nil { + return nil, err + } + + ds := driverStmt{dc, si} + rowsi, err := rowsiFromStatement(ds, args...) if err != nil { + releaseConn(err) return nil, err } + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + rowsi: rowsi, + // releaseConn set below + } + s.db.addDep(s, rows) + rows.releaseConn = func(err error) { + releaseConn(err) + s.db.removeDep(s, rows) + } + return rows, nil +} + +func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { + ds.Lock() + want := ds.si.NumInput() + ds.Unlock() + // -1 means the driver doesn't know how to count the number of // placeholders, so we won't sanity check input here and instead let the // driver deal with errors. - if want := si.NumInput(); want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", si.NumInput(), len(args)) + if want != -1 && len(args) != want { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(si, args) + dargs, err := driverArgs(&ds, args) if err != nil { return nil, err } - rowsi, err := si.Query(dargs) + ds.Lock() + rowsi, err := ds.si.Query(dargs) + ds.Unlock() if err != nil { - releaseConn(err) return nil, err } - // Note: ownership of ci passes to the *Rows, to be freed - // with releaseConn. - rows := &Rows{ - db: s.db, - ci: ci, - releaseConn: releaseConn, - rowsi: rowsi, - } - return rows, nil + return rowsi, nil } // QueryRow executes a prepared query statement with the given arguments. @@ -811,6 +1238,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row { // Close closes the statement. func (s *Stmt) Close() error { + s.closemu.Lock() + defer s.closemu.Unlock() + if s.stickyErr != nil { return s.stickyErr } @@ -823,18 +1253,19 @@ func (s *Stmt) Close() error { if s.tx != nil { s.txsi.Close() - } else { - for _, v := range s.css { - if ci, match := s.db.connIfFree(v.ci); match { - v.si.Close() - s.db.putConn(ci, nil) - } else { - // TODO(bradfitz): care that we can't close - // this statement because the statement's - // connection is in use? - } - } + return nil + } + + return s.db.removeDep(s, s) +} + +func (s *Stmt) finalClose() error { + for _, v := range s.css { + s.db.noteUnusedDriverStatement(v.dc, v.si) + v.dc.removeOpenStmt(v.si) + s.db.removeDep(v.dc, s) } + s.css = nil return nil } @@ -852,15 +1283,14 @@ func (s *Stmt) Close() error { // err = rows.Err() // get any error encountered during iteration // ... type Rows struct { - db *DB - ci driver.Conn // owned; must call putconn when closed to release + dc *driverConn // owned; must call releaseConn when closed to release releaseConn func(error) rowsi driver.Rows closed bool lastcols []driver.Value lasterr error - closeStmt *Stmt // if non-nil, statement to Close on close + closeStmt driver.Stmt // if non-nil, statement to Close on close } // Next prepares the next result row for reading with the Scan method. @@ -936,24 +1366,6 @@ func (rs *Rows) Scan(dest ...interface{}) error { return fmt.Errorf("sql: Scan error on column index %d: %v", i, err) } } - for _, dp := range dest { - b, ok := dp.(*[]byte) - if !ok { - continue - } - if *b == nil { - // If the []byte is now nil (for a NULL value), - // don't fall through to below which would - // turn it into a non-nil 0-length byte slice - continue - } - if _, ok = dp.(*RawBytes); ok { - continue - } - clone := make([]byte, len(*b)) - copy(clone, *b) - *b = clone - } return nil } @@ -966,10 +1378,10 @@ func (rs *Rows) Close() error { } rs.closed = true err := rs.rowsi.Close() - rs.releaseConn(err) if rs.closeStmt != nil { rs.closeStmt.Close() } + rs.releaseConn(err) return err } @@ -1026,6 +1438,31 @@ type Result interface { RowsAffected() (int64, error) } -type result struct { - driver.Result +type driverResult struct { + sync.Locker // the *driverConn + resi driver.Result +} + +func (dr driverResult) LastInsertId() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.LastInsertId() +} + +func (dr driverResult) RowsAffected() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.RowsAffected() +} + +func stack() string { + var buf [2 << 10]byte + return string(buf[:runtime.Stack(buf[:], false)]) +} + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + fn() + lk.Unlock() } |