diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-10-31 19:53:27 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-10-31 19:53:27 +0000 |
commit | d5aa08160a6b93dcb39343b41de78b73cb68ac7d (patch) | |
tree | 57096cee278ab6783b60e289bf6bbb2f2dd180ed /lib/sqlalchemy/util.py | |
parent | 3f4d34b42c9ad2b27b62a64d5d80e47ff3bff5d6 (diff) | |
download | sqlalchemy-d5aa08160a6b93dcb39343b41de78b73cb68ac7d.tar.gz |
- A more efficient IdentitySet
Diffstat (limited to 'lib/sqlalchemy/util.py')
-rw-r--r-- | lib/sqlalchemy/util.py | 222 |
1 files changed, 116 insertions, 106 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 9317d4b9b..a4ccaac6a 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -531,188 +531,198 @@ class IdentitySet(object): two 'foo' strings in one of these sets, for example. Use sparingly. """ - class _IdentityProxy(object): - """Proxies an object's id() as its hash and basis for equality.""" - - __slots__ = ('obj',) - - def __init__(self, value): - self.obj = value - def __hash__(self): - return id(self.obj) - def __eq__(self, other): - if isinstance(other, type(self)): - return id(self.obj) == id(other.obj) - else: - return id(self.obj) == id(other) - def __ne__(self, other): - if isinstance(other, type(self)): - return id(self.obj) != id(other.obj) - else: - return id(self.obj) != id(other) - def __init__(self, iterable=None): - self.set = Set() + self._members = {} if iterable: for o in iterable: self.add(o) def add(self, value): - self.set.add(_id_proxy(value)) + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members def remove(self, value): - value = _id_proxy(value) - if value not in self: - raise KeyError(value.obj) - self.set.remove(value) + del self._members[id(value)] def discard(self, value): - self.set.discard(_id_proxy(value)) + try: + self.remove(value) + except KeyError: + pass def pop(self): - proxied = self.set.pop() - return proxied.obj - - def issubset(self, iterable): - if not isinstance(iterable, type(self)): - iterable = type(self)(iterable) - return self.set.issubset(iterable) - __le__ = issubset - - def __lt__(self, iterable): - if not isinstance(iterable, type(self)): - iterable = type(self)(iterable) - return len(self) < len(iterable) and self.issubset(iterable) + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError('pop from an empty set') - def issuperset(self, iterable): - if not isinstance(iterable, type(self)): - iterable = type(self)(iterable) - return self.set.issuperset(iterable) - __ge__ = issuperset + def clear(self): + self._members.clear() - def __gt__(self, iterable): - if not isinstance(iterable, type(self)): - iterable = type(self)(iterable) - return len(self) > len(iterable) and self.issuperset(iterable) + def __cmp__(self, other): + raise TypeError('cannot compare sets using cmp()') def __eq__(self, other): if isinstance(other, IdentitySet): - return self.set == other.set + return self._members == other._members else: return False def __ne__(self, other): if isinstance(other, IdentitySet): - return self.set != other.set + return self._members != other._members else: return True - def __cmp__(self, other): - raise TypeError('cannot compare sets using cmp()') + def issubset(self, iterable): + other = type(self)(iterable) - def clear(self): - self.set.clear() + if len(self) > len(other): + return False + for m in itertools.ifilterfalse(other._members.has_key, + self._members.iterkeys()): + return False + return True - def copy(self): - return type(self)(self.set) + def __le__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + other = type(self)(iterable) + + if len(self) < len(other): + return False + + for m in itertools.ifilterfalse(self._members.has_key, + other._members.iterkeys()): + return False + return True + + def __ge__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) def union(self, iterable): - return type(self)(self.set.union(_proxyiter(iterable))) + result = type(self)() + result._members.update( + Set(self._members.iteritems()).union(_iter_id(iterable))) + return result - def __or__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __or__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - return self.union(iterable) - __ror__ = union + return self.union(other) + __ror__ = __or__ def update(self, iterable): - self.set.update(_proxyiter(iterable)) + self._members = self.union(iterable)._members - def __ior__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __ior__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - self.update(iterable) + self.update(other) return self def difference(self, iterable): - return type(self)(self.set.difference(_proxyiter(iterable))) + result = type(self)() + result._members.update( + Set(self._members.iteritems()).difference(_iter_id(iterable))) + return result - def __sub__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __sub__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - return self.difference(iterable) + return self.difference(other) __rsub__ = __sub__ def difference_update(self, iterable): - self.set.difference_update(_proxyiter(iterable)) + self._members = self.difference(iterable)._members - def __isub__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __isub__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - self.difference_update(iterable) + self.difference_update(other) return self def intersection(self, iterable): - return type(self)(self.set.intersection(_proxyiter(iterable))) + result = type(self)() + result._members.update( + Set(self._members.iteritems()).intersection(_iter_id(iterable))) + return result - def __and__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __and__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - return self.intersection(iterable) + return self.intersection(other) __rand__ = __and__ def intersection_update(self, iterable): - self.set.intersection_update(_proxyiter(iterable)) + self._members = self.intersection(iterable)._members - def __iand__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __iand__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - self.intersection_update(iterable) + self.intersection_update(other) return self def symmetric_difference(self, iterable): - return type(self)(self.set.symmetric_difference(_proxyiter(iterable))) + result = type(self)() + result._members.update( + Set(self._members.iteritems()).symmetric_difference(_iter_id(iterable))) + return result - def __xor__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __xor__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - return self.symmetric_difference(iterable) + return self.symmetric_difference(other) __rxor__ = __xor__ def symmetric_difference_update(self, iterable): - self.set.symmetric_difference_update(_proxyiter(iterable)) + self._members = self.symmetric_difference(iterable)._members - def __ixor__(self, iterable): - if not isinstance(iterable, set_types + (IdentitySet,)): + def __ixor__(self, other): + if not isinstance(other, set_types + (IdentitySet,)): return NotImplemented - self.symmetric_difference_update(iterable) + self.symmetric_difference(other) return self - def __iter__(self): - for proxy in self.set: - assert isinstance(proxy, self._IdentityProxy) - yield proxy.obj + def copy(self): + return type(self)(self._members.itervalues()) + + __copy__ = copy def __len__(self): - return len(self.set) + return len(self._members) - def __contains__(self, value): - return _id_proxy(value) in self.set + def __iter__(self): + return self._members.itervalues() def __hash__(self): raise TypeError('set objects are unhashable') def __repr__(self): - return '%s(%r)' % (type(self).__name__, list(self)) - -def _proxyiter(iterable): - return itertools.imap(_id_proxy, iterable) + return '%s(%r)' % (type(self).__name__, self._members.values()) -def _id_proxy(item): - if isinstance(item, IdentitySet._IdentityProxy): - return item - else: - return IdentitySet._IdentityProxy(item) +def _iter_id(iterable): + """Generator: ((id(o), o) for o in iterable).""" + for item in iterable: + yield id(item), item class UniqueAppender(object): |