summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/scoping.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/scoping.py')
-rw-r--r--lib/sqlalchemy/orm/scoping.py111
1 files changed, 111 insertions, 0 deletions
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
new file mode 100644
index 000000000..3ae63b49a
--- /dev/null
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -0,0 +1,111 @@
+from sqlalchemy.util import ScopedRegistry, warn_deprecated
+from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm.mapper import global_extensions
+from sqlalchemy import exceptions
+import types
+
+__all__ = ['ScopedSession']
+
+
+class ScopedSession(object):
+ """Provides thread-local management of Sessions.
+
+ Usage::
+
+ Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)
+
+ """
+
+ def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
+ self.session_factory = session_factory
+ self.enhance_classes = enhance_classes
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+ if self.enhance_classes:
+ global_extensions.append(_ScopedExt(self))
+
+ def __call__(self, **kwargs):
+ if len(kwargs):
+ scope = kwargs.pop('scope', False)
+ if scope is not None:
+ if self.registry.has():
+ raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
+ else:
+ sess = self.session_factory(**kwargs)
+ self.registry.set(sess)
+ return sess
+ else:
+ return self.session_factory(**kwargs)
+ else:
+ return self.registry()
+
+ def configure(self, **kwargs):
+ """reconfigure the sessionmaker used by this SessionContext"""
+ self.session_factory.configure(**kwargs)
+
+def instrument(name):
+ def do(self, *args, **kwargs):
+ return getattr(self.registry(), name)(*args, **kwargs)
+ return do
+for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
+ setattr(ScopedSession, meth, instrument(meth))
+
+def makeprop(name):
+ def set(self, attr):
+ setattr(self.registry(), name, attr)
+ def get(self):
+ return getattr(self.registry(), name)
+ return property(get, set)
+for prop in ('bind', 'dirty', 'identity_map'):
+ setattr(ScopedSession, prop, makeprop(prop))
+
+def clslevel(name):
+ def do(cls, *args,**kwargs):
+ return getattr(Session, name)(*args, **kwargs)
+ return classmethod(do)
+for prop in ('close_all',):
+ setattr(ScopedSession, prop, clslevel(prop))
+
+class _ScopedExt(MapperExtension):
+ def __init__(self, context):
+ self.context = context
+
+ def get_session(self):
+ return self.context.registry()
+
+ def instrument_class(self, mapper, class_):
+ class query(object):
+ def __getattr__(self, key):
+ return getattr(registry().query(class_), key)
+ def __call__(self):
+ return registry().query(class_)
+
+ if not hasattr(class_, 'query'):
+ class_.query = query()
+
+ def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+ session = kwargs.pop('_sa_session', self.context.registry())
+ if not isinstance(oldinit, types.MethodType):
+ for key, value in kwargs.items():
+ #if validate:
+ # if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+ # raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ setattr(instance, key, value)
+ session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ return EXT_CONTINUE
+
+ def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
+ object_session(instance).expunge(instance)
+ return EXT_CONTINUE
+
+ def dispose_class(self, mapper, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+ if hasattr(class_, 'query'):
+ delattr(class_, 'query')
+
+
+