summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/ext/sessioncontext.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-07-27 04:08:53 +0000
commited4fc64bb0ac61c27bc4af32962fb129e74a36bf (patch)
treec1cf2fb7b1cafced82a8898e23d2a0bf5ced8526 /lib/sqlalchemy/ext/sessioncontext.py
parent3a8e235af64e36b3b711df1f069d32359fe6c967 (diff)
downloadsqlalchemy-ed4fc64bb0ac61c27bc4af32962fb129e74a36bf.tar.gz
merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3.
Diffstat (limited to 'lib/sqlalchemy/ext/sessioncontext.py')
-rw-r--r--lib/sqlalchemy/ext/sessioncontext.py28
1 files changed, 24 insertions, 4 deletions
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
index 2f81e55d2..fcbf29c3f 100644
--- a/lib/sqlalchemy/ext/sessioncontext.py
+++ b/lib/sqlalchemy/ext/sessioncontext.py
@@ -1,5 +1,5 @@
from sqlalchemy.util import ScopedRegistry
-from sqlalchemy.orm.mapper import MapperExtension
+from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS
__all__ = ['SessionContext', 'SessionContextExt']
@@ -15,16 +15,18 @@ class SessionContext(object):
engine = create_engine(...)
def session_factory():
- return Session(bind_to=engine)
+ return Session(bind=engine)
context = SessionContext(session_factory)
s = context.current # get thread-local session
- context.current = Session(bind_to=other_engine) # set current session
+ context.current = Session(bind=other_engine) # set current session
del context.current # discard the thread-local session (a new one will
# be created on the next call to context.current)
"""
- def __init__(self, session_factory, scopefunc=None):
+ def __init__(self, session_factory=None, scopefunc=None):
+ if session_factory is None:
+ session_factory = create_session
self.registry = ScopedRegistry(session_factory, scopefunc)
super(SessionContext, self).__init__()
@@ -60,3 +62,21 @@ class SessionContextExt(MapperExtension):
def get_session(self):
return self.context.current
+
+ def init_instance(self, mapper, class_, instance, args, kwargs):
+ session = kwargs.pop('_sa_session', self.context.current)
+ session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ return EXT_PASS
+
+ def init_failed(self, mapper, class_, instance, args, kwargs):
+ object_session(instance).expunge(instance)
+ return EXT_PASS
+
+ def dispose_class(self, mapper, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+
+