diff options
Diffstat (limited to 'lib/sqlalchemy/util.py')
-rw-r--r-- | lib/sqlalchemy/util.py | 50 |
1 files changed, 34 insertions, 16 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 76c73ca6a..735843d2d 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -746,7 +746,6 @@ class OrderedDict(dict): self._list.remove(item[0]) return item - class OrderedSet(set): def __init__(self, d=None): set.__init__(self) @@ -1101,41 +1100,60 @@ class ScopedRegistry(object): a callable that returns a new object to be placed in the registry scopefunc - a callable that will return a key to store/retrieve an object, - defaults to ``thread.get_ident`` for thread-local objects. Use - a value like ``lambda: True`` for application scope. - """ + a callable that will return a key to store/retrieve an object. + If None, ScopedRegistry uses a threading.local object instead. - def __init__(self, createfunc, scopefunc=None): - self.createfunc = createfunc - if scopefunc is None: - self.scopefunc = thread.get_ident + """ + def __new__(cls, createfunc, scopefunc=None): + if not scopefunc: + return object.__new__(_TLocalRegistry) else: - self.scopefunc = scopefunc + return object.__new__(cls) + + def __init__(self, createfunc, scopefunc): + self.createfunc = createfunc + self.scopefunc = scopefunc self.registry = {} def __call__(self): - key = self._get_key() + key = self.scopefunc() try: return self.registry[key] except KeyError: return self.registry.setdefault(key, self.createfunc()) def has(self): - return self._get_key() in self.registry + return self.scopefunc() in self.registry def set(self, obj): - self.registry[self._get_key()] = obj + self.registry[self.scopefunc()] = obj def clear(self): try: - del self.registry[self._get_key()] + del self.registry[self.scopefunc()] except KeyError: pass - def _get_key(self): - return self.scopefunc() +class _TLocalRegistry(ScopedRegistry): + def __init__(self, createfunc, scopefunc=None): + self.createfunc = createfunc + self.registry = threading.local() + + def __call__(self): + try: + return self.registry.value + except AttributeError: + val = self.registry.value = self.createfunc() + return val + def has(self): + return hasattr(self.registry, "value") + + def set(self, obj): + self.registry.value = obj + + def clear(self): + del self.registry.value class WeakCompositeKey(object): """an weak-referencable, hashable collection which is strongly referenced |