diff options
author | Jason Kirtland <jek@discorporate.us> | 2007-10-31 09:13:12 +0000 |
---|---|---|
committer | Jason Kirtland <jek@discorporate.us> | 2007-10-31 09:13:12 +0000 |
commit | 661774055c58d096faf929f34abd947fb5931788 (patch) | |
tree | 7a241a40971b1f666d26ecd5b527c029314f6dc6 /test/base/utils.py | |
parent | 0462f95f80b815a6997cd62276f8fcbcdd431156 (diff) | |
download | sqlalchemy-661774055c58d096faf929f34abd947fb5931788.tar.gz |
Added util.IdentitySet to support [ticket:676] and [ticket:834]
Diffstat (limited to 'test/base/utils.py')
-rw-r--r-- | test/base/utils.py | 141 |
1 files changed, 139 insertions, 2 deletions
diff --git a/test/base/utils.py b/test/base/utils.py index 28258e9c3..1cfcd8fb5 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -83,7 +83,144 @@ class ArgSingletonTest(unittest.TestCase): m1 = m2 = m3 = None MyClass.dispose(MyClass) assert len(util.ArgSingleton.instances) == 0 - - + +class ImmutableSubclass(str): + pass + +class HashOverride(object): + def __init__(self, value=None): + self.value = value + def __hash__(self): + return hash(self.value) + +class EqOverride(object): + def __init__(self, value=None): + self.value = value + def __eq__(self, other): + if isinstance(other, EqOverride): + return self.value == other.value + else: + return False + def __ne__(self, other): + if isinstance(other, EqOverride): + return self.value != other.value + else: + return True + +class HashEqOverride(object): + def __init__(self, value=None): + self.value = value + def __hash__(self): + return hash(self.value) + def __eq__(self, other): + if isinstance(other, EqOverride): + return self.value == other.value + else: + return False + def __ne__(self, other): + if isinstance(other, EqOverride): + return self.value != other.value + else: + return True + + +class IdentitySetTest(unittest.TestCase): + def assert_eq(self, identityset, expected_iterable): + found = sorted(list(identityset)) + expected = sorted(expected_iterable) + self.assertEquals(found, expected) + + def test_init(self): + ids = util.IdentitySet([1,2,3,2,1]) + self.assert_eq(ids, [1,2,3]) + + ids = util.IdentitySet(ids) + self.assert_eq(ids, [1,2,3]) + + ids = util.IdentitySet() + self.assert_eq(ids, []) + + ids = util.IdentitySet([]) + self.assert_eq(ids, []) + + ids = util.IdentitySet(ids) + self.assert_eq(ids, []) + + def test_add(self): + for type_ in (object, ImmutableSubclass): + data = [type_(), type_()] + ids = util.IdentitySet() + for i in range(2) + range(2): + ids.add(data[i]) + self.assert_eq(ids, data) + + for type_ in (EqOverride, HashOverride, HashEqOverride): + data = [type_(1), type_(1), type_(2)] + ids = util.IdentitySet() + for i in range(3) + range(3): + ids.add(data[i]) + self.assert_eq(ids, data) + + def test_basic_sanity(self): + IdentitySet = util.IdentitySet + + o1, o2, o3 = object(), object(), object() + ids = IdentitySet([o1]) + ids.discard(o1) + ids.discard(o1) + ids.add(o1) + ids.remove(o1) + self.assertRaises(KeyError, ids.remove, o1) + + self.assert_(ids.copy() == ids) + self.assert_(ids != None) + self.assert_(not(ids == None)) + self.assert_(ids != IdentitySet([o1,o2,o3])) + ids.clear() + self.assert_(o1 not in ids) + ids.add(o2) + self.assert_(o2 in ids) + self.assert_(ids.pop() == o2) + ids.add(o1) + self.assert_(len(ids) == 1) + + isuper = IdentitySet([o1,o2]) + self.assert_(ids < isuper) + self.assert_(ids.issubset(isuper)) + self.assert_(isuper.issuperset(ids)) + self.assert_(isuper > ids) + + self.assert_(ids.union(isuper) == isuper) + self.assert_(ids | isuper == isuper) + self.assert_(isuper - ids == IdentitySet([o2])) + self.assert_(isuper.difference(ids) == IdentitySet([o2])) + self.assert_(ids.intersection(isuper) == IdentitySet([o1])) + self.assert_(ids & isuper == IdentitySet([o1])) + self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2])) + self.assert_(ids ^ isuper == IdentitySet([o2])) + + ids.update(isuper) + ids |= isuper + ids.difference_update(isuper) + ids -= isuper + ids.intersection_update(isuper) + ids &= isuper + ids.symmetric_difference_update(isuper) + ids ^= isuper + + ids.update('foobar') + try: + ids |= 'foobar' + self.assert_(False) + except TypeError: + self.assert_(True) + s = set([o1,o2]) + s |= ids + self.assert_(isinstance(s, IdentitySet)) + + self.assertRaises(TypeError, cmp, ids) + self.assertRaises(TypeError, hash, ids) + + if __name__ == "__main__": testbase.main() |