diff options
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r-- | Lib/test/test_functools.py | 65 |
1 files changed, 64 insertions, 1 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 7d11b53a71..97d7524f59 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -436,19 +436,82 @@ class TestReduce(unittest.TestCase): self.assertEqual(self.func(add, d), "".join(d.keys())) class TestCmpToKey(unittest.TestCase): + def test_cmp_to_key(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(cmp1) + self.assertEqual(key(3), key(3)) + self.assertGreater(key(3), key(1)) + def cmp2(x, y): + return int(x) - int(y) + key = functools.cmp_to_key(cmp2) + self.assertEqual(key(4.0), key('4')) + self.assertLess(key(2), key('35')) + + def test_cmp_to_key_arguments(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(obj=3), key(obj=3)) + self.assertGreater(key(obj=3), key(obj=1)) + with self.assertRaises((TypeError, AttributeError)): + key(3) > 1 # rhs is not a K object + with self.assertRaises((TypeError, AttributeError)): + 1 < key(3) # lhs is not a K object + with self.assertRaises(TypeError): + key = functools.cmp_to_key() # too few args + with self.assertRaises(TypeError): + key = functools.cmp_to_key(cmp1, None) # too many args + key = functools.cmp_to_key(cmp1) + with self.assertRaises(TypeError): + key() # too few args + with self.assertRaises(TypeError): + key(None, None) # too many args + + def test_bad_cmp(self): + def cmp1(x, y): + raise ZeroDivisionError + key = functools.cmp_to_key(cmp1) + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + class BadCmp: + def __lt__(self, other): + raise ZeroDivisionError + def cmp1(x, y): + return BadCmp() + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + def test_obj_field(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(50).obj, 50) + + def test_sort_int(self): def mycmp(x, y): return y - x self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) + def test_sort_int_str(self): + def mycmp(x, y): + x, y = int(x), int(y) + return (x > y) - (x < y) + values = [5, '3', 7, 2, '0', '1', 4, '10', 1] + values = sorted(values, key=functools.cmp_to_key(mycmp)) + self.assertEqual([int(value) for value in values], + [0, 1, 1, 2, 3, 4, 5, 7, 10]) + def test_hash(self): def mycmp(x, y): return y - x key = functools.cmp_to_key(mycmp) k = key(10) self.assertRaises(TypeError, hash, k) - self.assertFalse(isinstance(k, collections.Hashable)) + self.assertNotIsInstance(k, collections.Hashable) class TestTotalOrdering(unittest.TestCase): |