From ed4fc64bb0ac61c27bc4af32962fb129e74a36bf Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 27 Jul 2007 04:08:53 +0000 Subject: merging 0.4 branch to trunk. see CHANGES for details. 0.3 moves to maintenance branch in branches/rel_0_3. --- lib/sqlalchemy/ext/sessioncontext.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) (limited to 'lib/sqlalchemy/ext/sessioncontext.py') diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py index 2f81e55d2..fcbf29c3f 100644 --- a/lib/sqlalchemy/ext/sessioncontext.py +++ b/lib/sqlalchemy/ext/sessioncontext.py @@ -1,5 +1,5 @@ from sqlalchemy.util import ScopedRegistry -from sqlalchemy.orm.mapper import MapperExtension +from sqlalchemy.orm import create_session, object_session, MapperExtension, EXT_PASS __all__ = ['SessionContext', 'SessionContextExt'] @@ -15,16 +15,18 @@ class SessionContext(object): engine = create_engine(...) def session_factory(): - return Session(bind_to=engine) + return Session(bind=engine) context = SessionContext(session_factory) s = context.current # get thread-local session - context.current = Session(bind_to=other_engine) # set current session + context.current = Session(bind=other_engine) # set current session del context.current # discard the thread-local session (a new one will # be created on the next call to context.current) """ - def __init__(self, session_factory, scopefunc=None): + def __init__(self, session_factory=None, scopefunc=None): + if session_factory is None: + session_factory = create_session self.registry = ScopedRegistry(session_factory, scopefunc) super(SessionContext, self).__init__() @@ -60,3 +62,21 @@ class SessionContextExt(MapperExtension): def get_session(self): return self.context.current + + def init_instance(self, mapper, class_, instance, args, kwargs): + session = kwargs.pop('_sa_session', self.context.current) + session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None)) + return EXT_PASS + + def init_failed(self, mapper, class_, instance, args, kwargs): + object_session(instance).expunge(instance) + return EXT_PASS + + def dispose_class(self, mapper, class_): + if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): + if class_.__init__._oldinit is not None: + class_.__init__ = class_.__init__._oldinit + else: + delattr(class_, '__init__') + + -- cgit v1.2.1