diff options
Diffstat (limited to 'src/database/sql/sql.go')
-rw-r--r-- | src/database/sql/sql.go | 41 |
1 files changed, 36 insertions, 5 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 90f813d82..731b7a7f7 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1043,6 +1043,13 @@ type Tx struct { // or Rollback. once done, all operations fail with // ErrTxDone. done bool + + // All Stmts prepared for this transaction. These will be closed after the + // transaction has been committed or rolled back. + stmts struct { + sync.Mutex + v []*Stmt + } } var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") @@ -1064,6 +1071,15 @@ func (tx *Tx) grabConn() (*driverConn, error) { return tx.dc, nil } +// Closes all Stmts prepared for this transaction. +func (tx *Tx) closePrepared() { + tx.stmts.Lock() + for _, stmt := range tx.stmts.v { + stmt.Close() + } + tx.stmts.Unlock() +} + // Commit commits the transaction. func (tx *Tx) Commit() error { if tx.done { @@ -1071,8 +1087,12 @@ func (tx *Tx) Commit() error { } defer tx.close() tx.dc.Lock() - defer tx.dc.Unlock() - return tx.txi.Commit() + err := tx.txi.Commit() + tx.dc.Unlock() + if err != driver.ErrBadConn { + tx.closePrepared() + } + return err } // Rollback aborts the transaction. @@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error { } defer tx.close() tx.dc.Lock() - defer tx.dc.Unlock() - return tx.txi.Rollback() + err := tx.txi.Rollback() + tx.dc.Unlock() + if err != driver.ErrBadConn { + tx.closePrepared() + } + return err } // Prepare creates a prepared statement for use within a transaction. @@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { }, query: query, } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, stmt) + tx.stmts.Unlock() return stmt, nil } @@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { dc.Lock() si, err := dc.ci.Prepare(stmt.query) dc.Unlock() - return &Stmt{ + txs := &Stmt{ db: tx.db, tx: tx, txsi: &driverStmt{ @@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { query: stmt.query, stickyErr: err, } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, txs) + tx.stmts.Unlock() + return txs } // Exec executes a query that doesn't return rows. |