diff options
Diffstat (limited to 'lib/sqlalchemy/engine/util.py')
-rw-r--r-- | lib/sqlalchemy/engine/util.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index ede263198..17e3510aa 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -153,3 +153,82 @@ def _distill_params_20(params): return (params,), _no_kw else: raise exc.ArgumentError("mapping or sequence expected for parameters") + + +class TransactionalContext(object): + """Apply Python context manager behavior to transaction objects. + + Performs validation to ensure the subject of the transaction is not + used if the transaction were ended prematurely. + + """ + + _trans_subject = None + + def _transaction_is_active(self): + raise NotImplementedError() + + def _transaction_is_closed(self): + raise NotImplementedError() + + def _get_subject(self): + raise NotImplementedError() + + @classmethod + def _trans_ctx_check(cls, subject): + trans_context = subject._trans_context_manager + if trans_context: + if not trans_context._transaction_is_active(): + raise exc.InvalidRequestError( + "Can't operate on closed transaction inside context " + "manager. Please complete the context manager " + "before emitting further commands." + ) + + def __enter__(self): + subject = self._get_subject() + + # none for outer transaction, may be non-None for nested + # savepoint, legacy nesting cases + trans_context = subject._trans_context_manager + self._outer_trans_ctx = trans_context + + self._trans_subject = subject + subject._trans_context_manager = self + return self + + def __exit__(self, type_, value, traceback): + subject = self._trans_subject + + # simplistically we could assume that + # "subject._trans_context_manager is self". However, any calling + # code that is manipulating __exit__ directly would break this + # assumption. alembic context manager + # is an example of partial use that just calls __exit__ and + # not __enter__ at the moment. it's safe to assume this is being done + # in the wild also + out_of_band_exit = ( + subject is None or subject._trans_context_manager is not self + ) + + if type_ is None and self._transaction_is_active(): + try: + self.commit() + except: + with util.safe_reraise(): + self.rollback() + finally: + if not out_of_band_exit: + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None + else: + try: + if not self._transaction_is_active(): + if not self._transaction_is_closed(): + self.close() + else: + self.rollback() + finally: + if not out_of_band_exit: + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None |