diff options
Diffstat (limited to 'lib/sqlalchemy/util.py')
-rw-r--r-- | lib/sqlalchemy/util.py | 83 |
1 files changed, 64 insertions, 19 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 8b68fb108..1356fa324 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import inspect, itertools, new, operator, sys, warnings, weakref +import inspect, itertools, operator, sys, warnings, weakref import __builtin__ types = __import__('types') @@ -18,8 +18,13 @@ except ImportError: import dummy_threading as threading from dummy_threading import local as ThreadLocal -if sys.version_info < (2, 6): +py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) + +if py3k: + set_types = set +elif sys.version_info < (2, 6): import sets + set_types = set, sets.Set else: # 2.6 deprecates sets.Set, but we still need to be able to detect them # in user code and as return values from DB-APIs @@ -32,15 +37,24 @@ else: import sets warnings.filters.remove(ignore) -set_types = set, sets.Set + set_types = set, sets.Set EMPTY_SET = frozenset() -try: - import cPickle as pickle -except ImportError: +if py3k: import pickle +else: + try: + import cPickle as pickle + except ImportError: + import pickle +if py3k: + def buffer(x): + return x # no-op until we figure out what MySQLdb is going to use +else: + buffer = __builtin__.buffer + if sys.version_info >= (2, 5): class PopulateDict(dict): """A dict which populates missing values via a creation function. @@ -70,6 +84,17 @@ else: self[key] = value = self.creator(key) return value +if py3k: + def callable(fn): + return hasattr(fn, '__call__') +else: + callable = __builtin__.callable + +if py3k: + from functools import reduce +else: + reduce = __builtin__.reduce + try: from collections import defaultdict except ImportError: @@ -125,6 +150,14 @@ def to_set(x): else: return x +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + try: from functools import update_wrapper @@ -823,10 +856,11 @@ class IdentitySet(object): This strategy has edge cases for builtin types- it's possible to have two 'foo' strings in one of these sets, for example. Use sparingly. + """ _working_set = set - + def __init__(self, iterable=None): self._members = dict() if iterable: @@ -918,7 +952,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).union(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).union(_iter_id(iterable))) return result def __or__(self, other): @@ -939,7 +973,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).difference(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).difference(_iter_id(iterable))) return result def __sub__(self, other): @@ -960,7 +994,7 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).intersection(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable))) return result def __and__(self, other): @@ -981,9 +1015,12 @@ class IdentitySet(object): result = type(self)() # testlib.pragma exempt:__hash__ result._members.update( - self._working_set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) + self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable))) return result - + + def _member_id_tuples(self): + return ((id(v), v) for v in self._members.itervalues()) + def __xor__(self, other): if not isinstance(other, IdentitySet): return NotImplemented @@ -1016,11 +1053,6 @@ class IdentitySet(object): return '%s(%r)' % (type(self).__name__, self._members.values()) -def _iter_id(iterable): - """Generator: ((id(o), o) for o in iterable).""" - for item in iterable: - yield id(item), item - class OrderedIdentitySet(IdentitySet): class _working_set(OrderedSet): # a testing pragma: exempt the OIDS working set from the test suite's @@ -1028,7 +1060,7 @@ class OrderedIdentitySet(IdentitySet): # but it's safe here: IDS operates on (id, instance) tuples in the # working set. __sa_hash_exempt__ = True - + def __init__(self, iterable=None): IdentitySet.__init__(self) self._members = OrderedDict() @@ -1036,6 +1068,19 @@ class OrderedIdentitySet(IdentitySet): for o in iterable: self.add(o) +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + + for item in iterable: + yield id(item), item + +# define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet +populate_column_dict = PopulateDict + def unique_list(seq, compare_with=set): seen = compare_with() return [x for x in seq if x not in seen and not seen.add(x)] @@ -1296,7 +1341,7 @@ def function_named(fn, name): try: fn.__name__ = name except TypeError: - fn = new.function(fn.func_code, fn.func_globals, name, + fn = types.FunctionType(fn.func_code, fn.func_globals, name, fn.func_defaults, fn.func_closure) return fn |