summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/util.py')
-rw-r--r--lib/sqlalchemy/util.py83
1 files changed, 64 insertions, 19 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 8b68fb108..1356fa324 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import inspect, itertools, new, operator, sys, warnings, weakref
+import inspect, itertools, operator, sys, warnings, weakref
import __builtin__
types = __import__('types')
@@ -18,8 +18,13 @@ except ImportError:
import dummy_threading as threading
from dummy_threading import local as ThreadLocal
-if sys.version_info < (2, 6):
+py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
+
+if py3k:
+ set_types = set
+elif sys.version_info < (2, 6):
import sets
+ set_types = set, sets.Set
else:
# 2.6 deprecates sets.Set, but we still need to be able to detect them
# in user code and as return values from DB-APIs
@@ -32,15 +37,24 @@ else:
import sets
warnings.filters.remove(ignore)
-set_types = set, sets.Set
+ set_types = set, sets.Set
EMPTY_SET = frozenset()
-try:
- import cPickle as pickle
-except ImportError:
+if py3k:
import pickle
+else:
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle
+if py3k:
+ def buffer(x):
+ return x # no-op until we figure out what MySQLdb is going to use
+else:
+ buffer = __builtin__.buffer
+
if sys.version_info >= (2, 5):
class PopulateDict(dict):
"""A dict which populates missing values via a creation function.
@@ -70,6 +84,17 @@ else:
self[key] = value = self.creator(key)
return value
+if py3k:
+ def callable(fn):
+ return hasattr(fn, '__call__')
+else:
+ callable = __builtin__.callable
+
+if py3k:
+ from functools import reduce
+else:
+ reduce = __builtin__.reduce
+
try:
from collections import defaultdict
except ImportError:
@@ -125,6 +150,14 @@ def to_set(x):
else:
return x
+def to_column_set(x):
+ if x is None:
+ return column_set()
+ if not isinstance(x, column_set):
+ return column_set(to_list(x))
+ else:
+ return x
+
try:
from functools import update_wrapper
@@ -823,10 +856,11 @@ class IdentitySet(object):
This strategy has edge cases for builtin types- it's possible to have
two 'foo' strings in one of these sets, for example. Use sparingly.
+
"""
_working_set = set
-
+
def __init__(self, iterable=None):
self._members = dict()
if iterable:
@@ -918,7 +952,7 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).union(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).union(_iter_id(iterable)))
return result
def __or__(self, other):
@@ -939,7 +973,7 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).difference(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).difference(_iter_id(iterable)))
return result
def __sub__(self, other):
@@ -960,7 +994,7 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).intersection(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).intersection(_iter_id(iterable)))
return result
def __and__(self, other):
@@ -981,9 +1015,12 @@ class IdentitySet(object):
result = type(self)()
# testlib.pragma exempt:__hash__
result._members.update(
- self._working_set(self._members.iteritems()).symmetric_difference(_iter_id(iterable)))
+ self._working_set(self._member_id_tuples()).symmetric_difference(_iter_id(iterable)))
return result
-
+
+ def _member_id_tuples(self):
+ return ((id(v), v) for v in self._members.itervalues())
+
def __xor__(self, other):
if not isinstance(other, IdentitySet):
return NotImplemented
@@ -1016,11 +1053,6 @@ class IdentitySet(object):
return '%s(%r)' % (type(self).__name__, self._members.values())
-def _iter_id(iterable):
- """Generator: ((id(o), o) for o in iterable)."""
- for item in iterable:
- yield id(item), item
-
class OrderedIdentitySet(IdentitySet):
class _working_set(OrderedSet):
# a testing pragma: exempt the OIDS working set from the test suite's
@@ -1028,7 +1060,7 @@ class OrderedIdentitySet(IdentitySet):
# but it's safe here: IDS operates on (id, instance) tuples in the
# working set.
__sa_hash_exempt__ = True
-
+
def __init__(self, iterable=None):
IdentitySet.__init__(self)
self._members = OrderedDict()
@@ -1036,6 +1068,19 @@ class OrderedIdentitySet(IdentitySet):
for o in iterable:
self.add(o)
+def _iter_id(iterable):
+ """Generator: ((id(o), o) for o in iterable)."""
+
+ for item in iterable:
+ yield id(item), item
+
+# define collections that are capable of storing
+# ColumnElement objects as hashable keys/elements.
+column_set = set
+column_dict = dict
+ordered_column_set = OrderedSet
+populate_column_dict = PopulateDict
+
def unique_list(seq, compare_with=set):
seen = compare_with()
return [x for x in seq if x not in seen and not seen.add(x)]
@@ -1296,7 +1341,7 @@ def function_named(fn, name):
try:
fn.__name__ = name
except TypeError:
- fn = new.function(fn.func_code, fn.func_globals, name,
+ fn = types.FunctionType(fn.func_code, fn.func_globals, name,
fn.func_defaults, fn.func_closure)
return fn