summaryrefslogtreecommitdiff
path: root/test/base/utils.py
diff options
context:
space:
mode:
authorJason Kirtland <jek@discorporate.us>2007-10-31 09:13:12 +0000
committerJason Kirtland <jek@discorporate.us>2007-10-31 09:13:12 +0000
commit661774055c58d096faf929f34abd947fb5931788 (patch)
tree7a241a40971b1f666d26ecd5b527c029314f6dc6 /test/base/utils.py
parent0462f95f80b815a6997cd62276f8fcbcdd431156 (diff)
downloadsqlalchemy-661774055c58d096faf929f34abd947fb5931788.tar.gz
Added util.IdentitySet to support [ticket:676] and [ticket:834]
Diffstat (limited to 'test/base/utils.py')
-rw-r--r--test/base/utils.py141
1 files changed, 139 insertions, 2 deletions
diff --git a/test/base/utils.py b/test/base/utils.py
index 28258e9c3..1cfcd8fb5 100644
--- a/test/base/utils.py
+++ b/test/base/utils.py
@@ -83,7 +83,144 @@ class ArgSingletonTest(unittest.TestCase):
m1 = m2 = m3 = None
MyClass.dispose(MyClass)
assert len(util.ArgSingleton.instances) == 0
-
-
+
+class ImmutableSubclass(str):
+ pass
+
+class HashOverride(object):
+ def __init__(self, value=None):
+ self.value = value
+ def __hash__(self):
+ return hash(self.value)
+
+class EqOverride(object):
+ def __init__(self, value=None):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value == other.value
+ else:
+ return False
+ def __ne__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value != other.value
+ else:
+ return True
+
+class HashEqOverride(object):
+ def __init__(self, value=None):
+ self.value = value
+ def __hash__(self):
+ return hash(self.value)
+ def __eq__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value == other.value
+ else:
+ return False
+ def __ne__(self, other):
+ if isinstance(other, EqOverride):
+ return self.value != other.value
+ else:
+ return True
+
+
+class IdentitySetTest(unittest.TestCase):
+ def assert_eq(self, identityset, expected_iterable):
+ found = sorted(list(identityset))
+ expected = sorted(expected_iterable)
+ self.assertEquals(found, expected)
+
+ def test_init(self):
+ ids = util.IdentitySet([1,2,3,2,1])
+ self.assert_eq(ids, [1,2,3])
+
+ ids = util.IdentitySet(ids)
+ self.assert_eq(ids, [1,2,3])
+
+ ids = util.IdentitySet()
+ self.assert_eq(ids, [])
+
+ ids = util.IdentitySet([])
+ self.assert_eq(ids, [])
+
+ ids = util.IdentitySet(ids)
+ self.assert_eq(ids, [])
+
+ def test_add(self):
+ for type_ in (object, ImmutableSubclass):
+ data = [type_(), type_()]
+ ids = util.IdentitySet()
+ for i in range(2) + range(2):
+ ids.add(data[i])
+ self.assert_eq(ids, data)
+
+ for type_ in (EqOverride, HashOverride, HashEqOverride):
+ data = [type_(1), type_(1), type_(2)]
+ ids = util.IdentitySet()
+ for i in range(3) + range(3):
+ ids.add(data[i])
+ self.assert_eq(ids, data)
+
+ def test_basic_sanity(self):
+ IdentitySet = util.IdentitySet
+
+ o1, o2, o3 = object(), object(), object()
+ ids = IdentitySet([o1])
+ ids.discard(o1)
+ ids.discard(o1)
+ ids.add(o1)
+ ids.remove(o1)
+ self.assertRaises(KeyError, ids.remove, o1)
+
+ self.assert_(ids.copy() == ids)
+ self.assert_(ids != None)
+ self.assert_(not(ids == None))
+ self.assert_(ids != IdentitySet([o1,o2,o3]))
+ ids.clear()
+ self.assert_(o1 not in ids)
+ ids.add(o2)
+ self.assert_(o2 in ids)
+ self.assert_(ids.pop() == o2)
+ ids.add(o1)
+ self.assert_(len(ids) == 1)
+
+ isuper = IdentitySet([o1,o2])
+ self.assert_(ids < isuper)
+ self.assert_(ids.issubset(isuper))
+ self.assert_(isuper.issuperset(ids))
+ self.assert_(isuper > ids)
+
+ self.assert_(ids.union(isuper) == isuper)
+ self.assert_(ids | isuper == isuper)
+ self.assert_(isuper - ids == IdentitySet([o2]))
+ self.assert_(isuper.difference(ids) == IdentitySet([o2]))
+ self.assert_(ids.intersection(isuper) == IdentitySet([o1]))
+ self.assert_(ids & isuper == IdentitySet([o1]))
+ self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2]))
+ self.assert_(ids ^ isuper == IdentitySet([o2]))
+
+ ids.update(isuper)
+ ids |= isuper
+ ids.difference_update(isuper)
+ ids -= isuper
+ ids.intersection_update(isuper)
+ ids &= isuper
+ ids.symmetric_difference_update(isuper)
+ ids ^= isuper
+
+ ids.update('foobar')
+ try:
+ ids |= 'foobar'
+ self.assert_(False)
+ except TypeError:
+ self.assert_(True)
+ s = set([o1,o2])
+ s |= ids
+ self.assert_(isinstance(s, IdentitySet))
+
+ self.assertRaises(TypeError, cmp, ids)
+ self.assertRaises(TypeError, hash, ids)
+
+
if __name__ == "__main__":
testbase.main()