diff options
Diffstat (limited to 'lib/sqlalchemy/engine/threadlocal.py')
-rw-r--r-- | lib/sqlalchemy/engine/threadlocal.py | 28 |
1 files changed, 18 insertions, 10 deletions
diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index 2ff498db5..06246e854 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -6,16 +6,19 @@ """Provides a thread-local transactional wrapper around the root Engine class. -The ``threadlocal`` module is invoked when using the ``strategy="threadlocal"`` flag -with :func:`~sqlalchemy.engine.create_engine`. This module is semi-private and is -invoked automatically when the threadlocal engine strategy is used. +The ``threadlocal`` module is invoked when using the +``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`. +This module is semi-private and is invoked automatically when the threadlocal +engine strategy is used. """ -from .. import util, event +from .. import util from . import base import weakref + class TLConnection(base.Connection): + def __init__(self, *arg, **kw): super(TLConnection, self).__init__(*arg, **kw) self.__opencount = 0 @@ -33,16 +36,18 @@ class TLConnection(base.Connection): self.__opencount = 0 base.Connection.close(self) + class TLEngine(base.Engine): - """An Engine that includes support for thread-local managed transactions.""" + """An Engine that includes support for thread-local managed + transactions. + """ _tl_connection_cls = TLConnection def __init__(self, *args, **kwargs): super(TLEngine, self).__init__(*args, **kwargs) self._connections = util.threading.local() - def contextual_connect(self, **kw): if not hasattr(self._connections, 'conn'): connection = None @@ -52,21 +57,24 @@ class TLEngine(base.Engine): if connection is None or connection.closed: # guards against pool-level reapers, if desired. # or not connection.connection.is_valid: - connection = self._tl_connection_cls(self, self.pool.connect(), **kw) - self._connections.conn = conn = weakref.ref(connection) + connection = self._tl_connection_cls( + self, self.pool.connect(), **kw) + self._connections.conn = weakref.ref(connection) return connection._increment_connect() def begin_twophase(self, xid=None): if not hasattr(self._connections, 'trans'): self._connections.trans = [] - self._connections.trans.append(self.contextual_connect().begin_twophase(xid=xid)) + self._connections.trans.append( + self.contextual_connect().begin_twophase(xid=xid)) return self def begin_nested(self): if not hasattr(self._connections, 'trans'): self._connections.trans = [] - self._connections.trans.append(self.contextual_connect().begin_nested()) + self._connections.trans.append( + self.contextual_connect().begin_nested()) return self def begin(self): |