diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-09 16:34:10 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2008-05-09 16:34:10 +0000 |
commit | 4a6afd469fad170868554bf28578849bf3dfd5dd (patch) | |
tree | b396edc33d567ae19dd244e87137296450467725 /lib/sqlalchemy/util.py | |
parent | 46b7c9dc57a38d5b9e44a4723dad2ad8ec57baca (diff) | |
download | sqlalchemy-4a6afd469fad170868554bf28578849bf3dfd5dd.tar.gz |
r4695 merged to trunk; trunk now becomes 0.5.
0.4 development continues at /sqlalchemy/branches/rel_0_4
Diffstat (limited to 'lib/sqlalchemy/util.py')
-rw-r--r-- | lib/sqlalchemy/util.py | 416 |
1 files changed, 341 insertions, 75 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index e88c4b3b9..ff1108c3b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -8,7 +8,7 @@ import inspect, itertools, new, operator, sets, sys, warnings, weakref import __builtin__ types = __import__('types') -from sqlalchemy import exceptions +from sqlalchemy import exc try: import thread, threading @@ -18,14 +18,16 @@ except ImportError: try: Set = set + FrozenSet = frozenset set_types = set, sets.Set except NameError: set_types = sets.Set, - # layer some of __builtin__.set's binop behavior onto sets.Set - class Set(sets.Set): + + def py24_style_ops(): + """Layer some of __builtin__.set's binop behavior onto sets.Set.""" + def _binary_sanity_check(self, other): pass - def issubset(self, iterable): other = type(self)(iterable) return sets.Set.issubset(self, other) @@ -38,7 +40,6 @@ except NameError: def __ge__(self, other): sets.Set._binary_sanity_check(self, other) return sets.Set.__ge__(self, other) - # lt and gt still require a BaseSet def __lt__(self, other): sets.Set._binary_sanity_check(self, other) @@ -63,6 +64,14 @@ except NameError: if not isinstance(other, sets.BaseSet): return NotImplemented return sets.Set.__isub__(self, other) + return locals() + + py24_style_ops = py24_style_ops() + Set = type('Set', (sets.Set,), py24_style_ops) + FrozenSet = type('FrozenSet', (sets.ImmutableSet,), py24_style_ops) + del py24_style_ops + +EMPTY_SET = FrozenSet() try: import cPickle as pickle @@ -96,10 +105,16 @@ except ImportError: try: from operator import attrgetter -except: +except ImportError: def attrgetter(attribute): return lambda value: getattr(value, attribute) +try: + from operator import itemgetter +except ImportError: + def itemgetter(attribute): + return lambda value: value[attribute] + if sys.version_info >= (2, 5): class PopulateDict(dict): """a dict which populates missing values via a creation function. @@ -169,17 +184,17 @@ except ImportError: class deque(list): def appendleft(self, x): self.insert(0, x) - + def extendleft(self, iterable): self[0:0] = list(iterable) def popleft(self): return self.pop(0) - + def rotate(self, n): for i in xrange(n): self.appendleft(self.pop()) - + def to_list(x, default=None): if x is None: return default @@ -188,18 +203,34 @@ def to_list(x, default=None): else: return x -def array_as_starargs_decorator(func): +def array_as_starargs_decorator(fn): """Interpret a single positional array argument as *args for the decorated method. - + """ + def starargs_as_list(self, *args, **kwargs): - if len(args) == 1: - return func(self, *to_list(args[0], []), **kwargs) + if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)): + return fn(self, *to_list(args[0], []), **kwargs) else: - return func(self, *args, **kwargs) - return starargs_as_list - + return fn(self, *args, **kwargs) + starargs_as_list.__doc__ = fn.__doc__ + return function_named(starargs_as_list, fn.__name__) + +def array_as_starargs_fn_decorator(fn): + """Interpret a single positional array argument as + *args for the decorated function. + + """ + + def starargs_as_list(*args, **kwargs): + if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)): + return fn(*to_list(args[0], []), **kwargs) + else: + return fn(*args, **kwargs) + starargs_as_list.__doc__ = fn.__doc__ + return function_named(starargs_as_list, fn.__name__) + def to_set(x): if x is None: return Set() @@ -281,14 +312,121 @@ def get_func_kwargs(func): """Return the full set of legal kwargs for the given `func`.""" return inspect.getargspec(func)[0] +def format_argspec_plus(fn, grouped=True): + """Returns a dictionary of formatted, introspected function arguments. + + A enhanced variant of inspect.formatargspec to support code generation. + + fn + An inspectable callable + grouped + Defaults to True; include (parens, around, argument) lists + + Returns: + + args + Full inspect.formatargspec for fn + self_arg + The name of the first positional argument, or None + apply_pos + args, re-written in calling rather than receiving syntax. Arguments are + passed positionally. + apply_kw + Like apply_pos, except keyword-ish args are passed as keywords. + + Example:: + + >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123) + {'args': '(self, a, b, c=3, **d)', + 'self_arg': 'self', + 'apply_kw': '(self, a, b, c=c, **d)', + 'apply_pos': '(self, a, b, c, **d)'} + + """ + spec = inspect.getargspec(fn) + args = inspect.formatargspec(*spec) + self_arg = spec[0] and spec[0][0] or None + apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2]) + defaulted_vals = spec[3] is not None and spec[0][0-len(spec[3]):] or () + apply_kw = inspect.formatargspec(spec[0], spec[1], spec[2], defaulted_vals, + formatvalue=lambda x: '=' + x) + if grouped: + return dict(args=args, self_arg=self_arg, + apply_pos=apply_pos, apply_kw=apply_kw) + else: + return dict(args=args[1:-1], self_arg=self_arg, + apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1]) + +def format_argspec_init(method, grouped=True): + """format_argspec_plus with considerations for typical __init__ methods + + Wraps format_argspec_plus with error handling strategies for typical + __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return format_argspec_plus(method, grouped=grouped) + except TypeError: + self_arg = 'self' + if method is object.__init__: + args = grouped and '(self)' or 'self' + else: + args = (grouped and '(self, *args, **kwargs)' + or 'self, *args, **kwargs') + return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args) + +def getargspec_init(method): + """inspect.getargspec with considerations for typical __init__ methods + + Wraps inspect.getargspec with error handling for typical __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return inspect.getargspec(method) + except TypeError: + if method is object.__init__: + return (['self'], None, None, None) + else: + return (['self'], 'args', 'kwargs', None) + def unbound_method_to_callable(func_or_cls): """Adjust the incoming callable such that a 'self' argument is not required.""" - + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self: return func_or_cls.im_func else: return func_or_cls +def class_hierarchy(cls): + """Return an unordered sequence of all classes related to cls. + + Traverses diamond hierarchies. + + Fibs slightly: subclasses of builtin types are not returned. Thus + class_hierarchy(class A(object)) returns (A, object), not A plus every + class systemwide that derives from object. + + """ + hier = Set([cls]) + process = list(cls.__mro__) + while process: + c = process.pop() + for b in [_ for _ in c.__bases__ if _ not in hier]: + process.append(b) + hier.add(b) + if c.__module__ == '__builtin__': + continue + for s in [_ for _ in c.__subclasses__() if _ not in hier]: + process.append(s) + hier.add(s) + return list(hier) + # from paste.deploy.converters def asbool(obj): if isinstance(obj, (str, unicode)): @@ -328,9 +466,12 @@ def duck_type_collection(specimen, default=None): return specimen.__emulates__ isa = isinstance(specimen, type) and issubclass or isinstance - if isa(specimen, list): return list - if isa(specimen, set_types): return Set - if isa(specimen, dict): return dict + if isa(specimen, list): + return list + elif isa(specimen, set_types): + return Set + elif isa(specimen, dict): + return dict if hasattr(specimen, 'append'): return list @@ -370,10 +511,23 @@ def assert_arg_type(arg, argtype, name): return arg else: if isinstance(argtype, tuple): - raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) + raise exc.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) else: - raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) + raise exc.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) +_creation_order = 1 +def set_creation_order(instance): + """assign a '_creation_order' sequence to the given instance. + + This allows multiple instances to be sorted in order of + creation (typically within a single thread; the counter is + not particularly threadsafe). + + """ + global _creation_order + instance._creation_order = _creation_order + _creation_order +=1 + def warn_exception(func, *args, **kwargs): """executes the given function, catches all exceptions and converts to a warning.""" try: @@ -430,22 +584,22 @@ class SimpleProperty(object): class NotImplProperty(object): - """a property that raises ``NotImplementedError``.""" + """a property that raises ``NotImplementedError``.""" - def __init__(self, doc): - self.__doc__ = doc + def __init__(self, doc): + self.__doc__ = doc - def __set__(self, obj, value): - raise NotImplementedError() + def __set__(self, obj, value): + raise NotImplementedError() - def __delete__(self, obj): - raise NotImplementedError() + def __delete__(self, obj): + raise NotImplementedError() - def __get__(self, obj, owner): - if obj is None: - return self - else: - raise NotImplementedError() + def __get__(self, obj, owner): + if obj is None: + return self + else: + raise NotImplementedError() class OrderedProperties(object): """An object that maintains the order in which attributes are set upon it. @@ -496,10 +650,10 @@ class OrderedProperties(object): def __contains__(self, key): return key in self._data - + def update(self, value): self._data.update(value) - + def get(self, key, default=None): if key in self: return self[key] @@ -529,7 +683,10 @@ class OrderedDict(dict): def clear(self): self._list = [] dict.clear(self) - + + def sort(self, fn=None): + self._list.sort(fn) + def update(self, ____sequence=None, **kwargs): if ____sequence is not None: if hasattr(____sequence, 'keys'): @@ -622,22 +779,24 @@ class OrderedSet(Set): if d is not None: self.update(d) - def add(self, key): - if key not in self: - self._list.append(key) - Set.add(self, key) + def add(self, element): + if element not in self: + self._list.append(element) + Set.add(self, element) def remove(self, element): Set.remove(self, element) self._list.remove(element) + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + Set.add(self, element) + def discard(self, element): - try: - Set.remove(self, element) - except KeyError: - pass - else: + if element in self: self._list.remove(element) + Set.remove(self, element) def clear(self): Set.clear(self) @@ -650,22 +809,22 @@ class OrderedSet(Set): return iter(self._list) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._list) + return '%s(%r)' % (self.__class__.__name__, self._list) __str__ = __repr__ def update(self, iterable): - add = self.add - for i in iterable: - add(i) - return self + add = self.add + for i in iterable: + add(i) + return self __ior__ = update def union(self, other): - result = self.__class__(self) - result.update(other) - return result + result = self.__class__(self) + result.update(other) + return result __or__ = union @@ -698,10 +857,10 @@ class OrderedSet(Set): __iand__ = intersection_update def symmetric_difference_update(self, other): - Set.symmetric_difference_update(self, other) - self._list = [ a for a in self._list if a in self] - self._list += [ a for a in other._list if a in self] - return self + Set.symmetric_difference_update(self, other) + self._list = [ a for a in self._list if a in self] + self._list += [ a for a in other._list if a in self] + return self __ixor__ = symmetric_difference_update @@ -1021,6 +1180,35 @@ class ScopedRegistry(object): def _get_key(self): return self.scopefunc() +class WeakCompositeKey(object): + """an weak-referencable, hashable collection which is strongly referenced + until any one of its members is garbage collected. + + """ + keys = Set() + + def __init__(self, *args): + self.args = [self.__ref(arg) for arg in args] + WeakCompositeKey.keys.add(self) + + def __ref(self, arg): + if isinstance(arg, type): + return weakref.ref(arg, self.__remover) + else: + return lambda: arg + + def __remover(self, wr): + WeakCompositeKey.keys.discard(self) + + def __hash__(self): + return hash(tuple(self)) + + def __cmp__(self, other): + return cmp(tuple(self), tuple(other)) + + def __iter__(self): + return iter([arg() for arg in self.args]) + class _symbol(object): def __init__(self, name): """Construct a new named symbol.""" @@ -1059,7 +1247,6 @@ class symbol(object): finally: symbol._lock.release() - def as_interface(obj, cls=None, methods=None, required=None): """Ensure basic interface compliance for an instance or dict of callables. @@ -1155,21 +1342,12 @@ def function_named(fn, name): fn.func_defaults, fn.func_closure) return fn -def conditional_cache_decorator(func): - """apply conditional caching to the return value of a function.""" - - return cache_decorator(func, conditional=True) - -def cache_decorator(func, conditional=False): +def cache_decorator(func): """apply caching to the return value of a function.""" name = '_cached_' + func.__name__ - + def do_with_cache(self, *args, **kwargs): - if conditional: - cache = kwargs.pop('cache', False) - if not cache: - return func(self, *args, **kwargs) try: return getattr(self, name) except AttributeError: @@ -1177,21 +1355,109 @@ def cache_decorator(func, conditional=False): setattr(self, name, value) return value return do_with_cache - + def reset_cached(instance, name): try: delattr(instance, '_cached_' + name) except AttributeError: pass +class WeakIdentityMapping(weakref.WeakKeyDictionary): + """A WeakKeyDictionary with an object identity index. + + Adds a .by_id dictionary to a regular WeakKeyDictionary. Trades + performance during mutation operations for accelerated lookups by id(). + + The usual cautions about weak dictionaries and iteration also apply to + this subclass. + + """ + _none = symbol('none') + + def __init__(self): + weakref.WeakKeyDictionary.__init__(self) + self.by_id = {} + self._weakrefs = {} + + def __setitem__(self, object, value): + oid = id(object) + self.by_id[oid] = value + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + weakref.WeakKeyDictionary.__setitem__(self, object, value) + + def __delitem__(self, object): + del self._weakrefs[id(object)] + del self.by_id[id(object)] + weakref.WeakKeyDictionary.__delitem__(self, object) + + def setdefault(self, object, default=None): + value = weakref.WeakKeyDictionary.setdefault(self, object, default) + oid = id(object) + if value is default: + self.by_id[oid] = default + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + return value + + def pop(self, object, default=_none): + if default is self._none: + value = weakref.WeakKeyDictionary.pop(self, object) + else: + value = weakref.WeakKeyDictionary.pop(self, object, default) + if id(object) in self.by_id: + del self._weakrefs[id(object)] + del self.by_id[id(object)] + return value + + def popitem(self): + item = weakref.WeakKeyDictionary.popitem(self) + oid = id(item[0]) + del self._weakrefs[oid] + del self.by_id[oid] + return item + + def clear(self): + self._weakrefs.clear() + self.by_id.clear() + weakref.WeakKeyDictionary.clear(self) + + def update(self, *a, **kw): + raise NotImplementedError + + def _cleanup(self, wr, key=None): + if key is None: + key = wr.key + try: + del self._weakrefs[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + try: + del self.by_id[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + if sys.version_info < (2, 4): # pragma: no cover + def _ref(self, object): + oid = id(object) + return weakref.ref(object, lambda wr: self._cleanup(wr, oid)) + else: + class _keyed_weakref(weakref.ref): + def __init__(self, object, callback): + weakref.ref.__init__(self, object, callback) + self.key = id(object) + + def _ref(self, object): + return self._keyed_weakref(object, self._cleanup) + + def warn(msg): if isinstance(msg, basestring): - warnings.warn(msg, exceptions.SAWarning, stacklevel=3) + warnings.warn(msg, exc.SAWarning, stacklevel=3) else: warnings.warn(msg, stacklevel=3) def warn_deprecated(msg): - warnings.warn(msg, exceptions.SADeprecationWarning, stacklevel=3) + warnings.warn(msg, exc.SADeprecationWarning, stacklevel=3) def deprecated(message=None, add_deprecation_to_docstring=True): """Decorates a function and issues a deprecation warning on use. @@ -1216,7 +1482,7 @@ def deprecated(message=None, add_deprecation_to_docstring=True): def decorate(fn): return _decorate_with_warning( - fn, exceptions.SADeprecationWarning, + fn, exc.SADeprecationWarning, message % dict(func=fn.__name__), header) return decorate @@ -1248,7 +1514,7 @@ def pending_deprecation(version, message=None, def decorate(fn): return _decorate_with_warning( - fn, exceptions.SAPendingDeprecationWarning, + fn, exc.SAPendingDeprecationWarning, message % dict(func=fn.__name__), header) return decorate |