diff options
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r-- | Lib/test/test_functools.py | 118 |
1 files changed, 114 insertions, 4 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 11e6e84420..db1e9348dd 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -246,6 +246,7 @@ class TestUpdateWrapper(unittest.TestCase): self.check_wrapper(wrapper, f) self.assertIs(wrapper.__wrapped__, f) self.assertEqual(wrapper.__name__, 'f') + self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') self.assertNotIn('b', wrapper.__annotations__) @@ -266,6 +267,7 @@ class TestUpdateWrapper(unittest.TestCase): functools.update_wrapper(wrapper, f, (), ()) self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') + self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.__annotations__, {}) self.assertFalse(hasattr(wrapper, 'attr')) @@ -283,6 +285,7 @@ class TestUpdateWrapper(unittest.TestCase): functools.update_wrapper(wrapper, f, assign, update) self.check_wrapper(wrapper, f, assign, update) self.assertEqual(wrapper.__name__, 'wrapper') + self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) @@ -330,17 +333,18 @@ class TestWraps(TestUpdateWrapper): def wrapper(): pass self.check_wrapper(wrapper, f) - return wrapper + return wrapper, f def test_default_update(self): - wrapper = self._default_update() + wrapper, f = self._default_update() self.assertEqual(wrapper.__name__, 'f') + self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") def test_default_update_doc(self): - wrapper = self._default_update() + wrapper, _ = self._default_update() self.assertEqual(wrapper.__doc__, 'This is a test') def test_no_update(self): @@ -353,6 +357,7 @@ class TestWraps(TestUpdateWrapper): pass self.check_wrapper(wrapper, f, (), ()) self.assertEqual(wrapper.__name__, 'wrapper') + self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertFalse(hasattr(wrapper, 'attr')) @@ -372,6 +377,7 @@ class TestWraps(TestUpdateWrapper): pass self.check_wrapper(wrapper, f, assign, update) self.assertEqual(wrapper.__name__, 'wrapper') + self.assertNotEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.__doc__, None) self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) @@ -457,19 +463,82 @@ class TestReduce(unittest.TestCase): self.assertEqual(self.func(add, d), "".join(d.keys())) class TestCmpToKey(unittest.TestCase): + def test_cmp_to_key(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(cmp1) + self.assertEqual(key(3), key(3)) + self.assertGreater(key(3), key(1)) + def cmp2(x, y): + return int(x) - int(y) + key = functools.cmp_to_key(cmp2) + self.assertEqual(key(4.0), key('4')) + self.assertLess(key(2), key('35')) + + def test_cmp_to_key_arguments(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(obj=3), key(obj=3)) + self.assertGreater(key(obj=3), key(obj=1)) + with self.assertRaises((TypeError, AttributeError)): + key(3) > 1 # rhs is not a K object + with self.assertRaises((TypeError, AttributeError)): + 1 < key(3) # lhs is not a K object + with self.assertRaises(TypeError): + key = functools.cmp_to_key() # too few args + with self.assertRaises(TypeError): + key = functools.cmp_to_key(cmp1, None) # too many args + key = functools.cmp_to_key(cmp1) + with self.assertRaises(TypeError): + key() # too few args + with self.assertRaises(TypeError): + key(None, None) # too many args + + def test_bad_cmp(self): + def cmp1(x, y): + raise ZeroDivisionError + key = functools.cmp_to_key(cmp1) + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + class BadCmp: + def __lt__(self, other): + raise ZeroDivisionError + def cmp1(x, y): + return BadCmp() + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + def test_obj_field(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(50).obj, 50) + + def test_sort_int(self): def mycmp(x, y): return y - x self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) + def test_sort_int_str(self): + def mycmp(x, y): + x, y = int(x), int(y) + return (x > y) - (x < y) + values = [5, '3', 7, 2, '0', '1', 4, '10', 1] + values = sorted(values, key=functools.cmp_to_key(mycmp)) + self.assertEqual([int(value) for value in values], + [0, 1, 1, 2, 3, 4, 5, 7, 10]) + def test_hash(self): def mycmp(x, y): return y - x key = functools.cmp_to_key(mycmp) k = key(10) self.assertRaises(TypeError, hash, k) - self.assertFalse(isinstance(k, collections.Hashable)) + self.assertNotIsInstance(k, collections.Hashable) class TestTotalOrdering(unittest.TestCase): @@ -692,6 +761,47 @@ class TestLRU(unittest.TestCase): with self.assertRaises(IndexError): func(15) + def test_lru_with_types(self): + for maxsize in (None, 100): + @functools.lru_cache(maxsize=maxsize, typed=True) + def square(x): + return x * x + self.assertEqual(square(3), 9) + self.assertEqual(type(square(3)), type(9)) + self.assertEqual(square(3.0), 9.0) + self.assertEqual(type(square(3.0)), type(9.0)) + self.assertEqual(square(x=3), 9) + self.assertEqual(type(square(x=3)), type(9)) + self.assertEqual(square(x=3.0), 9.0) + self.assertEqual(type(square(x=3.0)), type(9.0)) + self.assertEqual(square.cache_info().hits, 4) + self.assertEqual(square.cache_info().misses, 4) + + def test_need_for_rlock(self): + # This will deadlock on an LRU cache that uses a regular lock + + @functools.lru_cache(maxsize=10) + def test_func(x): + 'Used to demonstrate a reentrant lru_cache call within a single thread' + return x + + class DoubleEq: + 'Demonstrate a reentrant lru_cache call within a single thread' + def __init__(self, x): + self.x = x + def __hash__(self): + return self.x + def __eq__(self, other): + if self.x == 2: + test_func(DoubleEq(1)) + return self.x == other.x + + test_func(DoubleEq(1)) # Load the cache + test_func(DoubleEq(2)) # Load the cache + self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call + DoubleEq(2)) # Verify the correct return value + + def test_main(verbose=None): test_classes = ( TestPartial, |