summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/util.py
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2007-10-31 09:13:12 +0000
committerJason Kirtland <jek@discorporate.us>2007-10-31 09:13:12 +0000
commit661774055c58d096faf929f34abd947fb5931788 (patch)
tree7a241a40971b1f666d26ecd5b527c029314f6dc6 /lib/sqlalchemy/util.py
parent0462f95f80b815a6997cd62276f8fcbcdd431156 (diff)
downloadsqlalchemy-661774055c58d096faf929f34abd947fb5931788.tar.gz
Added util.IdentitySet to support [ticket:676] and [ticket:834]
Diffstat (limited to 'lib/sqlalchemy/util.py')
-rw-r--r--lib/sqlalchemy/util.py193
1 files changed, 192 insertions, 1 deletions
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 6c74e115f..9317d4b9b 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -4,7 +4,7 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, warnings, sets
+import itertools, sys, warnings, sets
import __builtin__
from sqlalchemy import exceptions, logging
@@ -524,6 +524,197 @@ class OrderedSet(Set):
__isub__ = difference_update
+class IdentitySet(object):
+ """A set that considers only object id() for uniqueness.
+
+ This strategy has edge cases for builtin types- it's possible to have
+ two 'foo' strings in one of these sets, for example. Use sparingly.
+ """
+
+ class _IdentityProxy(object):
+ """Proxies an object's id() as its hash and basis for equality."""
+
+ __slots__ = ('obj',)
+
+ def __init__(self, value):
+ self.obj = value
+ def __hash__(self):
+ return id(self.obj)
+ def __eq__(self, other):
+ if isinstance(other, type(self)):
+ return id(self.obj) == id(other.obj)
+ else:
+ return id(self.obj) == id(other)
+ def __ne__(self, other):
+ if isinstance(other, type(self)):
+ return id(self.obj) != id(other.obj)
+ else:
+ return id(self.obj) != id(other)
+
+ def __init__(self, iterable=None):
+ self.set = Set()
+ if iterable:
+ for o in iterable:
+ self.add(o)
+
+ def add(self, value):
+ self.set.add(_id_proxy(value))
+
+ def remove(self, value):
+ value = _id_proxy(value)
+ if value not in self:
+ raise KeyError(value.obj)
+ self.set.remove(value)
+
+ def discard(self, value):
+ self.set.discard(_id_proxy(value))
+
+ def pop(self):
+ proxied = self.set.pop()
+ return proxied.obj
+
+ def issubset(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return self.set.issubset(iterable)
+ __le__ = issubset
+
+ def __lt__(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return len(self) < len(iterable) and self.issubset(iterable)
+
+ def issuperset(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return self.set.issuperset(iterable)
+ __ge__ = issuperset
+
+ def __gt__(self, iterable):
+ if not isinstance(iterable, type(self)):
+ iterable = type(self)(iterable)
+ return len(self) > len(iterable) and self.issuperset(iterable)
+
+ def __eq__(self, other):
+ if isinstance(other, IdentitySet):
+ return self.set == other.set
+ else:
+ return False
+
+ def __ne__(self, other):
+ if isinstance(other, IdentitySet):
+ return self.set != other.set
+ else:
+ return True
+
+ def __cmp__(self, other):
+ raise TypeError('cannot compare sets using cmp()')
+
+ def clear(self):
+ self.set.clear()
+
+ def copy(self):
+ return type(self)(self.set)
+
+ def union(self, iterable):
+ return type(self)(self.set.union(_proxyiter(iterable)))
+
+ def __or__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.union(iterable)
+ __ror__ = union
+
+ def update(self, iterable):
+ self.set.update(_proxyiter(iterable))
+
+ def __ior__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.update(iterable)
+ return self
+
+ def difference(self, iterable):
+ return type(self)(self.set.difference(_proxyiter(iterable)))
+
+ def __sub__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.difference(iterable)
+ __rsub__ = __sub__
+
+ def difference_update(self, iterable):
+ self.set.difference_update(_proxyiter(iterable))
+
+ def __isub__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.difference_update(iterable)
+ return self
+
+ def intersection(self, iterable):
+ return type(self)(self.set.intersection(_proxyiter(iterable)))
+
+ def __and__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.intersection(iterable)
+ __rand__ = __and__
+
+ def intersection_update(self, iterable):
+ self.set.intersection_update(_proxyiter(iterable))
+
+ def __iand__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.intersection_update(iterable)
+ return self
+
+ def symmetric_difference(self, iterable):
+ return type(self)(self.set.symmetric_difference(_proxyiter(iterable)))
+
+ def __xor__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ return self.symmetric_difference(iterable)
+ __rxor__ = __xor__
+
+ def symmetric_difference_update(self, iterable):
+ self.set.symmetric_difference_update(_proxyiter(iterable))
+
+ def __ixor__(self, iterable):
+ if not isinstance(iterable, set_types + (IdentitySet,)):
+ return NotImplemented
+ self.symmetric_difference_update(iterable)
+ return self
+
+ def __iter__(self):
+ for proxy in self.set:
+ assert isinstance(proxy, self._IdentityProxy)
+ yield proxy.obj
+
+ def __len__(self):
+ return len(self.set)
+
+ def __contains__(self, value):
+ return _id_proxy(value) in self.set
+
+ def __hash__(self):
+ raise TypeError('set objects are unhashable')
+
+ def __repr__(self):
+ return '%s(%r)' % (type(self).__name__, list(self))
+
+def _proxyiter(iterable):
+ return itertools.imap(_id_proxy, iterable)
+
+def _id_proxy(item):
+ if isinstance(item, IdentitySet._IdentityProxy):
+ return item
+ else:
+ return IdentitySet._IdentityProxy(item)
+
+
class UniqueAppender(object):
"""appends items to a collection such that only unique items
are added."""