summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/session.py12
-rw-r--r--test/orm/session.py12
2 files changed, 20 insertions, 4 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 076727486..6c27f082e 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -127,6 +127,13 @@ class SessionExtension(object):
state. An actual commit() may or may not have occured, depending on whether or not
the flush started its own transaction or participated in a larger transaction.
"""
+
+ def after_begin(self, session, transaction, connection):
+ """Execute after a transaction is begun on a connection
+
+ `transaction` is the SessionTransaction. This method is called after an
+ engine level transaction is begun on a connection.
+ """
class SessionTransaction(object):
"""Represents a Session-level Transaction.
@@ -214,6 +221,8 @@ class SessionTransaction(object):
transaction = conn.begin()
self._connections[conn] = self._connections[conn.engine] = (conn, transaction, conn is not bind)
+ if self.session.extension is not None:
+ self.session.extension.after_begin(self.session, self, conn)
return conn
def prepare(self):
@@ -266,7 +275,7 @@ class SessionTransaction(object):
for subtransaction in self.session.transaction._iterate_parents(upto=self):
subtransaction.close()
- if self.is_active:
+ if self.is_active or self._prepared:
for transaction in self._iterate_parents():
if transaction._parent is None or transaction.nested:
transaction._rollback_impl()
@@ -274,6 +283,7 @@ class SessionTransaction(object):
break
else:
transaction._deactivate()
+
self.close()
return self._parent
diff --git a/test/orm/session.py b/test/orm/session.py
index c429add40..49932f8d9 100644
--- a/test/orm/session.py
+++ b/test/orm/session.py
@@ -881,19 +881,20 @@ class SessionTest(TestBase, AssertsExecutionResults):
log.append('after_flush')
def after_flush_postexec(self, session, flush_context):
log.append('after_flush_postexec')
+ def after_begin(self, session, transaction, connection):
+ log.append('after_begin')
sess = create_session(extension = MyExt())
u = User()
sess.save(u)
sess.flush()
-
- assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
+ assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
log = []
sess = create_session(transactional=True, extension=MyExt())
u = User()
sess.save(u)
sess.flush()
- assert log == ['before_flush', 'after_flush', 'after_flush_postexec']
+ assert log == ['before_flush', 'after_begin', 'after_flush', 'after_flush_postexec']
log = []
u.user_name = 'ed'
@@ -903,6 +904,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
log = []
sess.commit()
assert log == ['before_commit', 'after_commit']
+
+ log = []
+ sess = create_session(transactional=True, extension=MyExt(), bind=testing.db)
+ conn = sess.connection()
+ assert log == ['after_begin']
def test_pickled_update(self):
mapper(User, users)