summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/scoping.py
blob: 3ae63b49aebba7e53c04b53941e2e41dfa3322e7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from sqlalchemy.util import ScopedRegistry, warn_deprecated
from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.mapper import global_extensions
from sqlalchemy import exceptions
import types

__all__ = ['ScopedSession']


class ScopedSession(object):
    """Provides thread-local management of Sessions.

    Usage::

      Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)

    """

    def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
        self.session_factory = session_factory
        self.enhance_classes = enhance_classes
        self.registry = ScopedRegistry(session_factory, scopefunc)
        if self.enhance_classes:
            global_extensions.append(_ScopedExt(self))

    def __call__(self, **kwargs):
        if len(kwargs):
            scope = kwargs.pop('scope', False)
            if scope is not None:
                if self.registry.has():
                    raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
                else:
                    sess = self.session_factory(**kwargs)
                    self.registry.set(sess)
                    return sess
            else:
                return self.session_factory(**kwargs)
        else:
            return self.registry()

    def configure(self, **kwargs):
        """reconfigure the sessionmaker used by this SessionContext"""
        self.session_factory.configure(**kwargs)

def instrument(name):
    def do(self, *args, **kwargs):
        return getattr(self.registry(), name)(*args, **kwargs)
    return do
for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
    setattr(ScopedSession, meth, instrument(meth))

def makeprop(name):
    def set(self, attr):
        setattr(self.registry(), name, attr)
    def get(self):
        return getattr(self.registry(), name)
    return property(get, set)
for prop in ('bind', 'dirty', 'identity_map'):
    setattr(ScopedSession, prop, makeprop(prop))

def clslevel(name):
    def do(cls, *args,**kwargs):
        return getattr(Session, name)(*args, **kwargs)
    return classmethod(do)
for prop in ('close_all',):
    setattr(ScopedSession, prop, clslevel(prop))
    
class _ScopedExt(MapperExtension):
    def __init__(self, context):
        self.context = context
        
    def get_session(self):
        return self.context.registry()

    def instrument_class(self, mapper, class_):
        class query(object):
            def __getattr__(self, key):
                return getattr(registry().query(class_), key)
            def __call__(self):
                return registry().query(class_)

        if not hasattr(class_, 'query'): 
            class_.query = query()
        
    def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
        session = kwargs.pop('_sa_session', self.context.registry())
        if not isinstance(oldinit, types.MethodType):
            for key, value in kwargs.items():
                #if validate:
                #    if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
                #        raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
                setattr(instance, key, value)
        session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
        return EXT_CONTINUE

    def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
        object_session(instance).expunge(instance)
        return EXT_CONTINUE

    def dispose_class(self, mapper, class_):
        if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
            if class_.__init__._oldinit is not None:
                class_.__init__ = class_.__init__._oldinit
            else:
                delattr(class_, '__init__')
        if hasattr(class_, 'query'):
            delattr(class_, 'query')