summaryrefslogtreecommitdiff
path: root/Lib/test/test_functools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r--Lib/test/test_functools.py192
1 files changed, 111 insertions, 81 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 5076644952..9ea6747188 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -8,6 +8,7 @@ import sys
from test import support
import unittest
from weakref import proxy
+import contextlib
try:
import threading
except ImportError:
@@ -20,6 +21,14 @@ c_functools = support.import_fresh_module('functools', fresh=['_functools'])
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
+@contextlib.contextmanager
+def replaced_module(name, replacement):
+ original_module = sys.modules[name]
+ sys.modules[name] = replacement
+ try:
+ yield
+ finally:
+ sys.modules[name] = original_module
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
@@ -167,58 +176,35 @@ class TestPartial:
p2.new_attr = 'spam'
self.assertEqual(p2.new_attr, 'spam')
-
-@unittest.skipUnless(c_functools, 'requires the C _functools module')
-class TestPartialC(TestPartial, unittest.TestCase):
- if c_functools:
- partial = c_functools.partial
-
- def test_attributes_unwritable(self):
- # attributes should not be writable
- p = self.partial(capture, 1, 2, a=10, b=20)
- self.assertRaises(AttributeError, setattr, p, 'func', map)
- self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
- self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
-
- p = self.partial(hex)
- try:
- del p.__dict__
- except TypeError:
- pass
- else:
- self.fail('partial object allowed __dict__ to be deleted')
-
def test_repr(self):
args = (object(), object())
args_repr = ', '.join(repr(a) for a in args)
kwargs = {'a': object(), 'b': object()}
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
'b={b!r}, a={a!r}'.format_map(kwargs)]
- if self.partial is c_functools.partial:
+ if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial'
else:
name = self.partial.__name__
f = self.partial(capture)
- self.assertEqual('{}({!r})'.format(name, capture),
- repr(f))
+ self.assertEqual(f'{name}({capture!r})', repr(f))
f = self.partial(capture, *args)
- self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
- repr(f))
+ self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
f = self.partial(capture, **kwargs)
self.assertIn(repr(f),
- ['{}({!r}, {})'.format(name, capture, kwargs_repr)
+ [f'{name}({capture!r}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs])
f = self.partial(capture, *args, **kwargs)
self.assertIn(repr(f),
- ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
+ [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs])
def test_recursive_repr(self):
- if self.partial is c_functools.partial:
+ if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial'
else:
name = self.partial.__name__
@@ -226,30 +212,31 @@ class TestPartialC(TestPartial, unittest.TestCase):
f = self.partial(capture)
f.__setstate__((f, (), {}, {}))
try:
- self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
+ self.assertEqual(repr(f), '%s(...)' % (name,))
finally:
f.__setstate__((capture, (), {}, {}))
f = self.partial(capture)
f.__setstate__((capture, (f,), {}, {}))
try:
- self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
+ self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
finally:
f.__setstate__((capture, (), {}, {}))
f = self.partial(capture)
f.__setstate__((capture, (), {'a': f}, {}))
try:
- self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
+ self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
finally:
f.__setstate__((capture, (), {}, {}))
def test_pickle(self):
- f = self.partial(signature, ['asdf'], bar=[True])
- f.attr = []
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- self.assertEqual(signature(f_copy), signature(f))
+ with self.AllowPickle():
+ f = self.partial(signature, ['asdf'], bar=[True])
+ f.attr = []
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ self.assertEqual(signature(f_copy), signature(f))
def test_copy(self):
f = self.partial(signature, ['asdf'], bar=[True])
@@ -274,11 +261,13 @@ class TestPartialC(TestPartial, unittest.TestCase):
def test_setstate(self):
f = self.partial(signature)
f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
+
self.assertEqual(signature(f),
(capture, (1,), dict(a=10), dict(attr=[])))
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
f.__setstate__((capture, (1,), dict(a=10), None))
+
self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
@@ -325,38 +314,39 @@ class TestPartialC(TestPartial, unittest.TestCase):
self.assertIs(type(r[0]), tuple)
def test_recursive_pickle(self):
- f = self.partial(capture)
- f.__setstate__((f, (), {}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- with self.assertRaises(RecursionError):
- pickle.dumps(f, proto)
- finally:
- f.__setstate__((capture, (), {}, {}))
-
- f = self.partial(capture)
- f.__setstate__((capture, (f,), {}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- try:
- self.assertIs(f_copy.args[0], f_copy)
- finally:
- f_copy.__setstate__((capture, (), {}, {}))
- finally:
- f.__setstate__((capture, (), {}, {}))
-
- f = self.partial(capture)
- f.__setstate__((capture, (), {'a': f}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- try:
- self.assertIs(f_copy.keywords['a'], f_copy)
- finally:
- f_copy.__setstate__((capture, (), {}, {}))
- finally:
- f.__setstate__((capture, (), {}, {}))
+ with self.AllowPickle():
+ f = self.partial(capture)
+ f.__setstate__((f, (), {}, {}))
+ try:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.assertRaises(RecursionError):
+ pickle.dumps(f, proto)
+ finally:
+ f.__setstate__((capture, (), {}, {}))
+
+ f = self.partial(capture)
+ f.__setstate__((capture, (f,), {}, {}))
+ try:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ try:
+ self.assertIs(f_copy.args[0], f_copy)
+ finally:
+ f_copy.__setstate__((capture, (), {}, {}))
+ finally:
+ f.__setstate__((capture, (), {}, {}))
+
+ f = self.partial(capture)
+ f.__setstate__((capture, (), {'a': f}, {}))
+ try:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ try:
+ self.assertIs(f_copy.keywords['a'], f_copy)
+ finally:
+ f_copy.__setstate__((capture, (), {}, {}))
+ finally:
+ f.__setstate__((capture, (), {}, {}))
# Issue 6083: Reference counting bug
def test_setstate_refcount(self):
@@ -375,24 +365,60 @@ class TestPartialC(TestPartial, unittest.TestCase):
f = self.partial(object)
self.assertRaises(TypeError, f.__setstate__, BadSequence())
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialC(TestPartial, unittest.TestCase):
+ if c_functools:
+ partial = c_functools.partial
+
+ class AllowPickle:
+ def __enter__(self):
+ return self
+ def __exit__(self, type, value, tb):
+ return False
+
+ def test_attributes_unwritable(self):
+ # attributes should not be writable
+ p = self.partial(capture, 1, 2, a=10, b=20)
+ self.assertRaises(AttributeError, setattr, p, 'func', map)
+ self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
+ self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
+
+ p = self.partial(hex)
+ try:
+ del p.__dict__
+ except TypeError:
+ pass
+ else:
+ self.fail('partial object allowed __dict__ to be deleted')
class TestPartialPy(TestPartial, unittest.TestCase):
- partial = staticmethod(py_functools.partial)
+ partial = py_functools.partial
+ class AllowPickle:
+ def __init__(self):
+ self._cm = replaced_module("functools", py_functools)
+ def __enter__(self):
+ return self._cm.__enter__()
+ def __exit__(self, type, value, tb):
+ return self._cm.__exit__(type, value, tb)
if c_functools:
- class PartialSubclass(c_functools.partial):
+ class CPartialSubclass(c_functools.partial):
pass
+class PyPartialSubclass(py_functools.partial):
+ pass
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialCSubclass(TestPartialC):
if c_functools:
- partial = PartialSubclass
+ partial = CPartialSubclass
# partial subclasses are not optimized for nested calls
test_nested_optimization = None
+class TestPartialPySubclass(TestPartialPy):
+ partial = PyPartialSubclass
class TestPartialMethod(unittest.TestCase):
@@ -683,9 +709,10 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
-
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestReduce(unittest.TestCase):
- func = functools.reduce
+ if c_functools:
+ func = c_functools.reduce
def test_reduce(self):
class Squares:
@@ -1548,13 +1575,15 @@ class TestSingleDispatch(unittest.TestCase):
bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
for haystack in permutations(bases):
m = mro(dict, haystack)
- self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
- c.Iterable, c.Container, object])
+ self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
+ c.Collection, c.Sized, c.Iterable,
+ c.Container, object])
bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
for haystack in permutations(bases):
m = mro(c.ChainMap, haystack)
self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
- c.Sized, c.Iterable, c.Container, object])
+ c.Collection, c.Sized, c.Iterable,
+ c.Container, object])
# If there's a generic function with implementations registered for
# both Sized and Container, passing a defaultdict to it results in an
@@ -1575,9 +1604,9 @@ class TestSingleDispatch(unittest.TestCase):
bases = [c.MutableSequence, c.MutableMapping]
for haystack in permutations(bases):
m = mro(D, bases)
- self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
- c.defaultdict, dict, c.MutableMapping,
- c.Mapping, c.Sized, c.Iterable, c.Container,
+ self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
+ c.defaultdict, dict, c.MutableMapping, c.Mapping,
+ c.Collection, c.Sized, c.Iterable, c.Container,
object])
# Container and Callable are registered on different base classes and
@@ -1590,7 +1619,8 @@ class TestSingleDispatch(unittest.TestCase):
for haystack in permutations(bases):
m = mro(C, haystack)
self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
- c.Sized, c.Iterable, c.Container, object])
+ c.Collection, c.Sized, c.Iterable,
+ c.Container, object])
def test_register_abc(self):
c = collections