diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-06 00:47:36 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-06 00:47:36 +0000 |
commit | a9ff52f18b3295131113ad5fb65fd3b49a23e8fe (patch) | |
tree | ecd55808a3ecc1e6837d8d5f727ed1c8f6604e0f | |
parent | 6cfea9df087faf14ea461bc214edd8ae98301201 (diff) | |
download | sqlalchemy-a9ff52f18b3295131113ad5fb65fd3b49a23e8fe.tar.gz |
- added "after_begin()" hook to Session
- Session.rollback() will rollback on a prepared session
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 12 | ||||
-rw-r--r-- | test/orm/session.py | 12 |
2 files changed, 20 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 076727486..6c27f082e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -127,6 +127,13 @@ class SessionExtension(object): state. An actual commit() may or may not have occured, depending on whether or not the flush started its own transaction or participated in a larger transaction. """ + + def after_begin(self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + `transaction` is the SessionTransaction. This method is called after an + engine level transaction is begun on a connection. + """ class SessionTransaction(object): """Represents a Session-level Transaction. @@ -214,6 +221,8 @@ class SessionTransaction(object): transaction = conn.begin() self._connections[conn] = self._connections[conn.engine] = (conn, transaction, conn is not bind) + if self.session.extension is not None: + self.session.extension.after_begin(self.session, self, conn) return conn def prepare(self): @@ -266,7 +275,7 @@ class SessionTransaction(object): for subtransaction in self.session.transaction._iterate_parents(upto=self): subtransaction.close() - if self.is_active: + if self.is_active or self._prepared: for transaction in self._iterate_parents(): if transaction._parent is None or transaction.nested: transaction._rollback_impl() @@ -274,6 +283,7 @@ class SessionTransaction(object): break else: transaction._deactivate() + self.close() return self._parent diff --git a/test/orm/session.py b/test/orm/session.py index c429add40..49932f8d9 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -881,19 +881,20 @@ class SessionTest(TestBase, AssertsExecutionResults): log.append('after_flush') def after_flush_postexec(self, session, flush_context): log.append('after_flush_postexec') + def after_begin(self, session, transaction, connection): + log.append('after_begin') sess = create_session(extension = MyExt()) u = User() sess.save(u) sess.flush() - - assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec'] + assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec'] log = [] sess = create_session(transactional=True, extension=MyExt()) u = User() sess.save(u) sess.flush() - assert log == ['before_flush', 'after_flush', 'after_flush_postexec'] + assert log == ['before_flush', 'after_begin', 'after_flush', 'after_flush_postexec'] log = [] u.user_name = 'ed' @@ -903,6 +904,11 @@ class SessionTest(TestBase, AssertsExecutionResults): log = [] sess.commit() assert log == ['before_commit', 'after_commit'] + + log = [] + sess = create_session(transactional=True, extension=MyExt(), bind=testing.db) + conn = sess.connection() + assert log == ['after_begin'] def test_pickled_update(self): mapper(User, users) |