summaryrefslogtreecommitdiff
path: root/Lib/test
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2013-11-16 12:56:54 +0200
committerSerhiy Storchaka <storchaka@gmail.com>2013-11-16 12:56:54 +0200
commitf19e8586ca9bbc4480ec223245bd51fcabbbdff3 (patch)
tree14dc4b5edea247e57435d1b0a6434de6afded87b /Lib/test
parent4414544dd630838071a0fceda8b2e9295687b65d (diff)
parent49ff21dbb638497b4ea431c20b2fa73a51349dab (diff)
downloadcpython-f19e8586ca9bbc4480ec223245bd51fcabbbdff3.tar.gz
Issue #19590: Use specific asserts in email tests.
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/__main__.py14
-rw-r--r--Lib/test/_test_multiprocessing.py (renamed from Lib/test/test_multiprocessing.py)588
-rw-r--r--Lib/test/audiodata/pluck-pcm24.aubin0 -> 19866 bytes
-rw-r--r--Lib/test/audiotests.py63
-rw-r--r--Lib/test/badsyntax_future10.py3
-rw-r--r--Lib/test/bytecode_helper.py41
-rw-r--r--Lib/test/datetimetester.py4
-rw-r--r--Lib/test/final_a.py19
-rw-r--r--Lib/test/final_b.py19
-rw-r--r--Lib/test/fork_wait.py2
-rw-r--r--Lib/test/keycert3.pem73
-rw-r--r--Lib/test/keycert4.pem73
-rw-r--r--Lib/test/leakers/test_gestalt.py14
-rw-r--r--Lib/test/lock_tests.py20
-rw-r--r--Lib/test/make_ssl_certs.py112
-rw-r--r--Lib/test/mock_socket.py8
-rw-r--r--Lib/test/mp_fork_bomb.py5
-rw-r--r--Lib/test/multibytecodec_support.py2
-rw-r--r--Lib/test/pickletester.py53
-rw-r--r--Lib/test/pycacert.pem78
-rw-r--r--Lib/test/pycakey.pem28
-rwxr-xr-xLib/test/regrtest.py1151
-rw-r--r--Lib/test/script_helper.py36
-rw-r--r--Lib/test/sortperf.py6
-rw-r--r--Lib/test/ssl_servers.py10
-rw-r--r--Lib/test/subprocessdata/fd_status.py26
-rw-r--r--Lib/test/support/__init__.py219
-rw-r--r--Lib/test/test___all__.py30
-rw-r--r--Lib/test/test_abc.py16
-rw-r--r--Lib/test/test_aifc.py21
-rw-r--r--Lib/test/test_argparse.py36
-rwxr-xr-xLib/test/test_array.py4
-rw-r--r--Lib/test/test_ast.py98
-rw-r--r--Lib/test/test_asynchat.py35
-rw-r--r--Lib/test/test_asyncio/__init__.py31
-rw-r--r--Lib/test/test_asyncio/__main__.py5
-rw-r--r--Lib/test/test_asyncio/echo.py8
-rw-r--r--Lib/test/test_asyncio/echo2.py6
-rw-r--r--Lib/test/test_asyncio/echo3.py11
-rw-r--r--Lib/test/test_asyncio/sample.crt14
-rw-r--r--Lib/test/test_asyncio/sample.key15
-rw-r--r--Lib/test/test_asyncio/test_base_events.py684
-rw-r--r--Lib/test/test_asyncio/test_events.py1655
-rw-r--r--Lib/test/test_asyncio/test_futures.py329
-rw-r--r--Lib/test/test_asyncio/test_locks.py834
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py484
-rw-r--r--Lib/test/test_asyncio/test_queues.py470
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py1554
-rw-r--r--Lib/test/test_asyncio/test_streams.py364
-rw-r--r--Lib/test/test_asyncio/test_tasks.py1518
-rw-r--r--Lib/test/test_asyncio/test_transports.py59
-rw-r--r--Lib/test/test_asyncio/test_unix_events.py1488
-rw-r--r--Lib/test/test_asyncio/test_windows_events.py139
-rw-r--r--Lib/test/test_asyncio/test_windows_utils.py141
-rw-r--r--Lib/test/test_asyncio/tests.txt13
-rw-r--r--Lib/test/test_asyncore.py25
-rw-r--r--Lib/test/test_atexit.py42
-rw-r--r--Lib/test/test_audioop.py221
-rw-r--r--Lib/test/test_base64.py110
-rw-r--r--Lib/test/test_bisect.py10
-rw-r--r--Lib/test/test_buffer.py6
-rw-r--r--Lib/test/test_builtin.py73
-rw-r--r--Lib/test/test_bytes.py34
-rw-r--r--Lib/test/test_bz2.py218
-rw-r--r--Lib/test/test_capi.py118
-rw-r--r--Lib/test/test_cmd_line.py100
-rw-r--r--Lib/test/test_cmd_line_script.py10
-rw-r--r--Lib/test/test_code_module.py14
-rw-r--r--Lib/test/test_codeccallbacks.py4
-rw-r--r--Lib/test/test_codecs.py223
-rw-r--r--Lib/test/test_coding.py63
-rw-r--r--Lib/test/test_collections.py126
-rw-r--r--Lib/test/test_colorsys.py32
-rw-r--r--Lib/test/test_compileall.py52
-rw-r--r--Lib/test/test_complex.py3
-rw-r--r--Lib/test/test_concurrent_futures.py22
-rw-r--r--Lib/test/test_contextlib.py141
-rw-r--r--Lib/test/test_cprofile.py2
-rw-r--r--Lib/test/test_crypt.py6
-rw-r--r--Lib/test/test_csv.py21
-rw-r--r--Lib/test/test_dbm.py2
-rw-r--r--Lib/test/test_decimal.py1
-rw-r--r--Lib/test/test_deque.py4
-rw-r--r--Lib/test/test_descr.py22
-rw-r--r--Lib/test/test_devpoll.py38
-rw-r--r--Lib/test/test_dis.py409
-rw-r--r--Lib/test/test_doctest.py281
-rw-r--r--Lib/test/test_dynamicclassattribute.py304
-rw-r--r--Lib/test/test_email/__init__.py26
-rw-r--r--Lib/test/test_email/test_contentmanager.py796
-rw-r--r--Lib/test/test_email/test_email.py33
-rw-r--r--Lib/test/test_email/test_headerregistry.py14
-rw-r--r--Lib/test/test_email/test_message.py742
-rw-r--r--Lib/test/test_email/test_policy.py1
-rw-r--r--Lib/test/test_email/torture_test.py2
-rw-r--r--Lib/test/test_ensurepip.py128
-rw-r--r--Lib/test/test_enum.py1325
-rw-r--r--Lib/test/test_enumerate.py10
-rw-r--r--Lib/test/test_epoll.py54
-rw-r--r--Lib/test/test_exceptions.py28
-rw-r--r--Lib/test/test_faulthandler.py77
-rw-r--r--Lib/test/test_fcntl.py19
-rw-r--r--Lib/test/test_file.py6
-rw-r--r--Lib/test/test_filecmp.py23
-rw-r--r--Lib/test/test_fileinput.py22
-rw-r--r--Lib/test/test_fileio.py20
-rw-r--r--Lib/test/test_finalization.py513
-rw-r--r--Lib/test/test_float.py1
-rw-r--r--Lib/test/test_fork1.py2
-rw-r--r--Lib/test/test_format.py30
-rw-r--r--Lib/test/test_fractions.py32
-rw-r--r--Lib/test/test_frame.py116
-rw-r--r--Lib/test/test_frozen.py6
-rw-r--r--Lib/test/test_ftplib.py54
-rw-r--r--Lib/test/test_funcattrs.py7
-rw-r--r--Lib/test/test_functools.py918
-rw-r--r--Lib/test/test_future.py8
-rw-r--r--Lib/test/test_gc.py80
-rw-r--r--Lib/test/test_gdb.py11
-rw-r--r--Lib/test/test_generators.py74
-rw-r--r--Lib/test/test_genericpath.py82
-rw-r--r--Lib/test/test_getargs2.py88
-rw-r--r--Lib/test/test_getpass.py155
-rw-r--r--[-rwxr-xr-x]Lib/test/test_gzip.py51
-rw-r--r--Lib/test/test_hashlib.py267
-rw-r--r--Lib/test/test_hmac.py14
-rw-r--r--Lib/test/test_htmlparser.py17
-rw-r--r--Lib/test/test_httplib.py30
-rw-r--r--Lib/test/test_httpservers.py21
-rw-r--r--Lib/test/test_imaplib.py2
-rw-r--r--Lib/test/test_imp.py87
-rw-r--r--Lib/test/test_import.py88
-rw-r--r--Lib/test/test_importhooks.py250
-rw-r--r--Lib/test/test_importlib/__main__.py13
-rw-r--r--Lib/test/test_importlib/abc.py4
-rw-r--r--Lib/test/test_importlib/builtin/test_finder.py19
-rw-r--r--Lib/test/test_importlib/builtin/test_loader.py43
-rw-r--r--Lib/test/test_importlib/extension/test_case_sensitivity.py27
-rw-r--r--Lib/test/test_importlib/extension/test_finder.py19
-rw-r--r--Lib/test/test_importlib/extension/test_loader.py20
-rw-r--r--Lib/test/test_importlib/extension/test_path_hook.py20
-rw-r--r--Lib/test/test_importlib/extension/util.py1
-rw-r--r--Lib/test/test_importlib/frozen/test_finder.py14
-rw-r--r--Lib/test/test_importlib/frozen/test_loader.py56
-rw-r--r--Lib/test/test_importlib/import_/test___loader__.py48
-rw-r--r--Lib/test/test_importlib/import_/test___package__.py33
-rw-r--r--Lib/test/test_importlib/import_/test_api.py37
-rw-r--r--Lib/test/test_importlib/import_/test_caching.py30
-rw-r--r--Lib/test/test_importlib/import_/test_fromlist.py34
-rw-r--r--Lib/test/test_importlib/import_/test_meta_path.py24
-rw-r--r--Lib/test/test_importlib/import_/test_packages.py30
-rw-r--r--Lib/test/test_importlib/import_/test_path.py40
-rw-r--r--Lib/test/test_importlib/import_/test_relative_imports.py51
-rw-r--r--Lib/test/test_importlib/import_/util.py20
-rw-r--r--Lib/test/test_importlib/source/test_abc_loader.py906
-rw-r--r--Lib/test/test_importlib/source/test_case_sensitivity.py30
-rw-r--r--Lib/test/test_importlib/source/test_file_loader.py110
-rw-r--r--Lib/test/test_importlib/source/test_finder.py29
-rw-r--r--Lib/test/test_importlib/source/test_path_hook.py18
-rw-r--r--Lib/test/test_importlib/source/test_source_encoding.py20
-rw-r--r--Lib/test/test_importlib/source/util.py1
-rw-r--r--Lib/test/test_importlib/test_abc.py988
-rw-r--r--Lib/test/test_importlib/test_api.py281
-rw-r--r--Lib/test/test_importlib/test_locks.py71
-rw-r--r--Lib/test/test_importlib/test_util.py360
-rw-r--r--Lib/test/test_importlib/util.py24
-rw-r--r--Lib/test/test_inspect.py252
-rw-r--r--Lib/test/test_int.py41
-rw-r--r--Lib/test/test_io.py116
-rw-r--r--Lib/test/test_ioctl.py6
-rw-r--r--Lib/test/test_ipaddress.py12
-rw-r--r--Lib/test/test_iterlen.py62
-rw-r--r--Lib/test/test_itertools.py5
-rw-r--r--Lib/test/test_json/test_decode.py23
-rw-r--r--Lib/test/test_json/test_enum.py120
-rw-r--r--Lib/test/test_json/test_fail.py89
-rw-r--r--Lib/test/test_json/test_indent.py4
-rw-r--r--Lib/test/test_keyword.py138
-rw-r--r--Lib/test/test_keywordonlyarg.py12
-rw-r--r--Lib/test/test_kqueue.py29
-rw-r--r--Lib/test/test_largefile.py2
-rw-r--r--Lib/test/test_logging.py439
-rw-r--r--Lib/test/test_long.py2
-rw-r--r--Lib/test/test_lzma.py46
-rw-r--r--Lib/test/test_mailbox.py2
-rw-r--r--Lib/test/test_marshal.py163
-rw-r--r--Lib/test/test_memoryio.py12
-rw-r--r--Lib/test/test_memoryview.py9
-rw-r--r--Lib/test/test_mmap.py26
-rw-r--r--Lib/test/test_module.py52
-rw-r--r--Lib/test/test_multibytecodec.py73
-rw-r--r--Lib/test/test_multiprocessing_fork.py7
-rw-r--r--Lib/test/test_multiprocessing_forkserver.py7
-rw-r--r--Lib/test/test_multiprocessing_spawn.py7
-rw-r--r--Lib/test/test_namespace_pkgs.py29
-rw-r--r--Lib/test/test_nntplib.py2
-rw-r--r--Lib/test/test_normalization.py2
-rw-r--r--Lib/test/test_ntpath.py34
-rw-r--r--Lib/test/test_openpty.py2
-rw-r--r--Lib/test/test_operator.py122
-rw-r--r--Lib/test/test_optparse.py8
-rw-r--r--Lib/test/test_os.py329
-rw-r--r--Lib/test/test_ossaudiodev.py4
-rw-r--r--Lib/test/test_pdb.py43
-rw-r--r--Lib/test/test_peepholer.py295
-rw-r--r--Lib/test/test_pep277.py4
-rw-r--r--Lib/test/test_pep352.py5
-rw-r--r--Lib/test/test_pkgimport.py2
-rw-r--r--Lib/test/test_pkgutil.py4
-rw-r--r--Lib/test/test_poll.py20
-rw-r--r--Lib/test/test_poplib.py178
-rw-r--r--Lib/test/test_posix.py72
-rw-r--r--Lib/test/test_posixpath.py62
-rw-r--r--Lib/test/test_pprint.py50
-rw-r--r--Lib/test/test_print.py74
-rw-r--r--Lib/test/test_profile.py29
-rw-r--r--Lib/test/test_pty.py2
-rw-r--r--Lib/test/test_py_compile.py46
-rw-r--r--Lib/test/test_pyclbr.py4
-rw-r--r--Lib/test/test_pydoc.py221
-rw-r--r--Lib/test/test_random.py207
-rw-r--r--Lib/test/test_range.py2
-rw-r--r--Lib/test/test_re.py147
-rw-r--r--Lib/test/test_regrtest.py275
-rw-r--r--Lib/test/test_reprlib.py5
-rw-r--r--Lib/test/test_resource.py26
-rw-r--r--Lib/test/test_runpy.py9
-rw-r--r--Lib/test/test_sched.py5
-rw-r--r--Lib/test/test_scope.py17
-rw-r--r--Lib/test/test_select.py4
-rw-r--r--Lib/test/test_selectors.py429
-rw-r--r--Lib/test/test_set.py2
-rw-r--r--Lib/test/test_shelve.py13
-rw-r--r--Lib/test/test_shutil.py45
-rw-r--r--Lib/test/test_signal.py60
-rw-r--r--Lib/test/test_site.py39
-rw-r--r--Lib/test/test_slice.py114
-rw-r--r--Lib/test/test_smtplib.py23
-rw-r--r--Lib/test/test_sndhdr.py2
-rw-r--r--Lib/test/test_socket.py392
-rw-r--r--Lib/test/test_socketserver.py20
-rw-r--r--Lib/test/test_source_encoding.py (renamed from Lib/test/test_pep263.py)67
-rw-r--r--Lib/test/test_ssl.py431
-rw-r--r--Lib/test/test_stat.py61
-rw-r--r--Lib/test/test_statistics.py1534
-rw-r--r--Lib/test/test_strftime.py11
-rw-r--r--Lib/test/test_struct.py89
-rw-r--r--Lib/test/test_structseq.py2
-rw-r--r--Lib/test/test_subprocess.py189
-rw-r--r--Lib/test/test_sunau.py28
-rw-r--r--Lib/test/test_sundry.py39
-rw-r--r--Lib/test/test_super.py33
-rw-r--r--Lib/test/test_support.py1
-rw-r--r--Lib/test/test_syntax.py4
-rw-r--r--Lib/test/test_sys.py84
-rw-r--r--Lib/test/test_sysconfig.py44
-rw-r--r--Lib/test/test_tarfile.py10
-rw-r--r--Lib/test/test_tcl.py34
-rw-r--r--Lib/test/test_telnetlib.py101
-rw-r--r--Lib/test/test_tempfile.py29
-rw-r--r--Lib/test/test_textwrap.py159
-rw-r--r--Lib/test/test_thread.py2
-rw-r--r--Lib/test/test_threaded_import.py2
-rw-r--r--Lib/test/test_threading.py433
-rw-r--r--Lib/test/test_threadsignals.py2
-rw-r--r--Lib/test/test_timeout.py2
-rw-r--r--Lib/test/test_traceback.py82
-rw-r--r--Lib/test/test_types.py37
-rw-r--r--Lib/test/test_ucn.py2
-rw-r--r--Lib/test/test_unicode.py369
-rw-r--r--Lib/test/test_unicodedata.py4
-rw-r--r--Lib/test/test_urllib.py103
-rw-r--r--Lib/test/test_urllib2.py137
-rw-r--r--Lib/test/test_urllib2_localnet.py30
-rw-r--r--Lib/test/test_urllib2net.py43
-rw-r--r--Lib/test/test_urllibnet.py19
-rwxr-xr-xLib/test/test_urlparse.py8
-rw-r--r--Lib/test/test_venv.py59
-rw-r--r--Lib/test/test_wait3.py12
-rw-r--r--Lib/test/test_warnings.py32
-rw-r--r--Lib/test/test_weakref.py316
-rw-r--r--Lib/test/test_winreg.py28
-rw-r--r--Lib/test/test_winsound.py2
-rw-r--r--Lib/test/test_xml_etree.py194
-rw-r--r--Lib/test/test_xml_etree_c.py9
-rw-r--r--Lib/test/test_xmlrpc.py74
-rw-r--r--Lib/test/test_xmlrpc_net.py31
-rw-r--r--Lib/test/test_zipfile.py221
-rw-r--r--Lib/test/test_zipimport.py37
-rw-r--r--Lib/test/test_zipimport_support.py6
-rw-r--r--Lib/test/tf_inherit_check.py2
291 files changed, 30151 insertions, 5549 deletions
diff --git a/Lib/test/__main__.py b/Lib/test/__main__.py
index ce5615b889..d5fbe159d7 100644
--- a/Lib/test/__main__.py
+++ b/Lib/test/__main__.py
@@ -1,13 +1,3 @@
-from test import regrtest, support
+from test import regrtest
-
-TEMPDIR, TESTCWD = regrtest._make_temp_dir_for_build(regrtest.TEMPDIR)
-regrtest.TEMPDIR = TEMPDIR
-regrtest.TESTCWD = TESTCWD
-
-# Run the tests in a context manager that temporary changes the CWD to a
-# temporary and writable directory. If it's not possible to create or
-# change the CWD, the original CWD will be used. The original CWD is
-# available from support.SAVEDCWD.
-with support.temp_cwd(TESTCWD, quiet=True):
- regrtest.main()
+regrtest.main_in_temp_cwd()
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/_test_multiprocessing.py
index 86cf5c188f..ad77260c05 100644
--- a/Lib/test/test_multiprocessing.py
+++ b/Lib/test/_test_multiprocessing.py
@@ -43,7 +43,7 @@ from multiprocessing import util
try:
from multiprocessing import reduction
- HAS_REDUCTION = True
+ HAS_REDUCTION = reduction.HAVE_SEND_HANDLE
except ImportError:
HAS_REDUCTION = False
@@ -99,6 +99,9 @@ try:
except:
MAXFD = 256
+# To speed up tests when using the forkserver, we can preload these:
+PRELOAD = ['__main__', 'test.test_multiprocessing_forkserver']
+
#
# Some tests require ctypes
#
@@ -294,17 +297,22 @@ class _TestProcess(BaseTestCase):
self.assertTimingAlmostEqual(join.elapsed, 0.0)
self.assertEqual(p.is_alive(), True)
+ # XXX maybe terminating too soon causes the problems on Gentoo...
+ time.sleep(1)
+
p.terminate()
if hasattr(signal, 'alarm'):
+ # On the Gentoo buildbot waitpid() often seems to block forever.
+ # We use alarm() to interrupt it if it blocks for too long.
def handler(*args):
raise RuntimeError('join took too long: %s' % p)
old_handler = signal.signal(signal.SIGALRM, handler)
try:
signal.alarm(10)
self.assertEqual(join(), None)
- signal.alarm(0)
finally:
+ signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
else:
self.assertEqual(join(), None)
@@ -342,7 +350,6 @@ class _TestProcess(BaseTestCase):
@classmethod
def _test_recursion(cls, wconn, id):
- from multiprocessing import forking
wconn.send(id)
if len(id) < 2:
for i in range(2):
@@ -390,7 +397,7 @@ class _TestProcess(BaseTestCase):
self.assertFalse(wait_for_handle(sentinel, timeout=0.0))
event.set()
p.join()
- self.assertTrue(wait_for_handle(sentinel, timeout=DELTA))
+ self.assertTrue(wait_for_handle(sentinel, timeout=1))
#
#
@@ -1779,6 +1786,35 @@ class _TestPool(BaseTestCase):
self.assertEqual(r.get(), expected)
self.assertRaises(ValueError, p.map_async, sqr, L)
+ @classmethod
+ def _test_traceback(cls):
+ raise RuntimeError(123) # some comment
+
+ def test_traceback(self):
+ # We want ensure that the traceback from the child process is
+ # contained in the traceback raised in the main process.
+ if self.TYPE == 'processes':
+ with self.Pool(1) as p:
+ try:
+ p.apply(self._test_traceback)
+ except Exception as e:
+ exc = e
+ else:
+ raise AssertionError('expected RuntimeError')
+ self.assertIs(type(exc), RuntimeError)
+ self.assertEqual(exc.args, (123,))
+ cause = exc.__cause__
+ self.assertIs(type(cause), multiprocessing.pool.RemoteTraceback)
+ self.assertIn('raise RuntimeError(123) # some comment', cause.tb)
+
+ with test.support.captured_stderr() as f1:
+ try:
+ raise exc
+ except RuntimeError:
+ sys.excepthook(*sys.exc_info())
+ self.assertIn('raise RuntimeError(123) # some comment',
+ f1.getvalue())
+
def raising():
raise KeyError("key")
@@ -2486,7 +2522,7 @@ class _TestPicklingConnections(BaseTestCase):
@classmethod
def tearDownClass(cls):
- from multiprocessing.reduction import resource_sharer
+ from multiprocessing import resource_sharer
resource_sharer.stop(timeout=5)
@classmethod
@@ -2800,30 +2836,40 @@ class _TestFinalize(BaseTestCase):
# Test that from ... import * works for each module
#
-class _TestImportStar(BaseTestCase):
+class _TestImportStar(unittest.TestCase):
- ALLOWED_TYPES = ('processes',)
+ def get_module_names(self):
+ import glob
+ folder = os.path.dirname(multiprocessing.__file__)
+ pattern = os.path.join(folder, '*.py')
+ files = glob.glob(pattern)
+ modules = [os.path.splitext(os.path.split(f)[1])[0] for f in files]
+ modules = ['multiprocessing.' + m for m in modules]
+ modules.remove('multiprocessing.__init__')
+ modules.append('multiprocessing')
+ return modules
def test_import(self):
- modules = [
- 'multiprocessing', 'multiprocessing.connection',
- 'multiprocessing.heap', 'multiprocessing.managers',
- 'multiprocessing.pool', 'multiprocessing.process',
- 'multiprocessing.synchronize', 'multiprocessing.util'
- ]
-
- if HAS_REDUCTION:
- modules.append('multiprocessing.reduction')
+ modules = self.get_module_names()
+ if sys.platform == 'win32':
+ modules.remove('multiprocessing.popen_fork')
+ modules.remove('multiprocessing.popen_forkserver')
+ modules.remove('multiprocessing.popen_spawn_posix')
+ else:
+ modules.remove('multiprocessing.popen_spawn_win32')
+ if not HAS_REDUCTION:
+ modules.remove('multiprocessing.popen_forkserver')
- if c_int is not None:
+ if c_int is None:
# This module requires _ctypes
- modules.append('multiprocessing.sharedctypes')
+ modules.remove('multiprocessing.sharedctypes')
for name in modules:
__import__(name)
mod = sys.modules[name]
+ self.assertTrue(hasattr(mod, '__all__'), name)
- for attr in getattr(mod, '__all__', ()):
+ for attr in mod.__all__:
self.assertTrue(
hasattr(mod, attr),
'%r does not have attribute %r' % (mod, attr)
@@ -2906,7 +2952,7 @@ class _TestPollEintr(BaseTestCase):
@classmethod
def _killer(cls, pid):
- time.sleep(0.5)
+ time.sleep(0.1)
os.kill(pid, signal.SIGUSR1)
@unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1')
@@ -2919,12 +2965,14 @@ class _TestPollEintr(BaseTestCase):
try:
killer = self.Process(target=self._killer, args=(pid,))
killer.start()
- p = self.Process(target=time.sleep, args=(1,))
- p.start()
- p.join()
+ try:
+ p = self.Process(target=time.sleep, args=(2,))
+ p.start()
+ p.join()
+ finally:
+ killer.join()
self.assertTrue(got_signal[0])
self.assertEqual(p.exitcode, 0)
- killer.join()
finally:
signal.signal(signal.SIGUSR1, oldhandler)
@@ -2937,8 +2985,11 @@ class TestInvalidHandle(unittest.TestCase):
@unittest.skipIf(WIN32, "skipped on Windows")
def test_invalid_handles(self):
conn = multiprocessing.connection.Connection(44977608)
+ # check that poll() doesn't crash
try:
- self.assertRaises((ValueError, OSError), conn.poll)
+ conn.poll()
+ except (ValueError, OSError):
+ pass
finally:
# Hack private attribute _handle to avoid printing an error
# in conn.__del__
@@ -2946,131 +2997,6 @@ class TestInvalidHandle(unittest.TestCase):
self.assertRaises((ValueError, OSError),
multiprocessing.connection.Connection, -1)
-#
-# Functions used to create test cases from the base ones in this module
-#
-
-def create_test_cases(Mixin, type):
- result = {}
- glob = globals()
- Type = type.capitalize()
- ALL_TYPES = {'processes', 'threads', 'manager'}
-
- for name in list(glob.keys()):
- if name.startswith('_Test'):
- base = glob[name]
- assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES)
- if type in base.ALLOWED_TYPES:
- newname = 'With' + Type + name[1:]
- class Temp(base, Mixin, unittest.TestCase):
- pass
- result[newname] = Temp
- Temp.__name__ = Temp.__qualname__ = newname
- Temp.__module__ = Mixin.__module__
- return result
-
-#
-# Create test cases
-#
-
-class ProcessesMixin(object):
- TYPE = 'processes'
- Process = multiprocessing.Process
- connection = multiprocessing.connection
- current_process = staticmethod(multiprocessing.current_process)
- active_children = staticmethod(multiprocessing.active_children)
- Pool = staticmethod(multiprocessing.Pool)
- Pipe = staticmethod(multiprocessing.Pipe)
- Queue = staticmethod(multiprocessing.Queue)
- JoinableQueue = staticmethod(multiprocessing.JoinableQueue)
- Lock = staticmethod(multiprocessing.Lock)
- RLock = staticmethod(multiprocessing.RLock)
- Semaphore = staticmethod(multiprocessing.Semaphore)
- BoundedSemaphore = staticmethod(multiprocessing.BoundedSemaphore)
- Condition = staticmethod(multiprocessing.Condition)
- Event = staticmethod(multiprocessing.Event)
- Barrier = staticmethod(multiprocessing.Barrier)
- Value = staticmethod(multiprocessing.Value)
- Array = staticmethod(multiprocessing.Array)
- RawValue = staticmethod(multiprocessing.RawValue)
- RawArray = staticmethod(multiprocessing.RawArray)
-
-testcases_processes = create_test_cases(ProcessesMixin, type='processes')
-globals().update(testcases_processes)
-
-
-class ManagerMixin(object):
- TYPE = 'manager'
- Process = multiprocessing.Process
- Queue = property(operator.attrgetter('manager.Queue'))
- JoinableQueue = property(operator.attrgetter('manager.JoinableQueue'))
- Lock = property(operator.attrgetter('manager.Lock'))
- RLock = property(operator.attrgetter('manager.RLock'))
- Semaphore = property(operator.attrgetter('manager.Semaphore'))
- BoundedSemaphore = property(operator.attrgetter('manager.BoundedSemaphore'))
- Condition = property(operator.attrgetter('manager.Condition'))
- Event = property(operator.attrgetter('manager.Event'))
- Barrier = property(operator.attrgetter('manager.Barrier'))
- Value = property(operator.attrgetter('manager.Value'))
- Array = property(operator.attrgetter('manager.Array'))
- list = property(operator.attrgetter('manager.list'))
- dict = property(operator.attrgetter('manager.dict'))
- Namespace = property(operator.attrgetter('manager.Namespace'))
-
- @classmethod
- def Pool(cls, *args, **kwds):
- return cls.manager.Pool(*args, **kwds)
-
- @classmethod
- def setUpClass(cls):
- cls.manager = multiprocessing.Manager()
-
- @classmethod
- def tearDownClass(cls):
- # only the manager process should be returned by active_children()
- # but this can take a bit on slow machines, so wait a few seconds
- # if there are other children too (see #17395)
- t = 0.01
- while len(multiprocessing.active_children()) > 1 and t < 5:
- time.sleep(t)
- t *= 2
- gc.collect() # do garbage collection
- if cls.manager._number_of_objects() != 0:
- # This is not really an error since some tests do not
- # ensure that all processes which hold a reference to a
- # managed object have been joined.
- print('Shared objects which still exist at manager shutdown:')
- print(cls.manager._debug_info())
- cls.manager.shutdown()
- cls.manager.join()
- cls.manager = None
-
-testcases_manager = create_test_cases(ManagerMixin, type='manager')
-globals().update(testcases_manager)
-
-
-class ThreadsMixin(object):
- TYPE = 'threads'
- Process = multiprocessing.dummy.Process
- connection = multiprocessing.dummy.connection
- current_process = staticmethod(multiprocessing.dummy.current_process)
- active_children = staticmethod(multiprocessing.dummy.active_children)
- Pool = staticmethod(multiprocessing.Pool)
- Pipe = staticmethod(multiprocessing.dummy.Pipe)
- Queue = staticmethod(multiprocessing.dummy.Queue)
- JoinableQueue = staticmethod(multiprocessing.dummy.JoinableQueue)
- Lock = staticmethod(multiprocessing.dummy.Lock)
- RLock = staticmethod(multiprocessing.dummy.RLock)
- Semaphore = staticmethod(multiprocessing.dummy.Semaphore)
- BoundedSemaphore = staticmethod(multiprocessing.dummy.BoundedSemaphore)
- Condition = staticmethod(multiprocessing.dummy.Condition)
- Event = staticmethod(multiprocessing.dummy.Event)
- Barrier = staticmethod(multiprocessing.dummy.Barrier)
- Value = staticmethod(multiprocessing.dummy.Value)
- Array = staticmethod(multiprocessing.dummy.Array)
-
-testcases_threads = create_test_cases(ThreadsMixin, type='threads')
-globals().update(testcases_threads)
class OtherTest(unittest.TestCase):
@@ -3420,7 +3346,7 @@ class TestFlags(unittest.TestCase):
def test_flags(self):
import json, subprocess
# start child process using unusual flags
- prog = ('from test.test_multiprocessing import TestFlags; ' +
+ prog = ('from test._test_multiprocessing import TestFlags; ' +
'TestFlags.run_in_child()')
data = subprocess.check_output(
[sys.executable, '-E', '-S', '-O', '-c', prog])
@@ -3467,13 +3393,14 @@ class TestTimeouts(unittest.TestCase):
class TestNoForkBomb(unittest.TestCase):
def test_noforkbomb(self):
+ sm = multiprocessing.get_start_method()
name = os.path.join(os.path.dirname(__file__), 'mp_fork_bomb.py')
- if WIN32:
- rc, out, err = test.script_helper.assert_python_failure(name)
+ if sm != 'fork':
+ rc, out, err = test.script_helper.assert_python_failure(name, sm)
self.assertEqual('', out.decode('ascii'))
self.assertIn('RuntimeError', err.decode('ascii'))
else:
- rc, out, err = test.script_helper.assert_python_ok(name)
+ rc, out, err = test.script_helper.assert_python_ok(name, sm)
self.assertEqual('123', out.decode('ascii').rstrip())
self.assertEqual('', err.decode('ascii'))
@@ -3491,7 +3418,8 @@ class TestForkAwareThreadLock(unittest.TestCase):
if n > 1:
p = multiprocessing.Process(target=cls.child, args=(n-1, conn))
p.start()
- p.join()
+ conn.close()
+ p.join(timeout=5)
else:
conn.send(len(util._afterfork_registry))
conn.close()
@@ -3502,11 +3430,78 @@ class TestForkAwareThreadLock(unittest.TestCase):
old_size = len(util._afterfork_registry)
p = multiprocessing.Process(target=self.child, args=(5, w))
p.start()
+ w.close()
new_size = r.recv()
- p.join()
+ p.join(timeout=5)
self.assertLessEqual(new_size, old_size)
#
+# Check that non-forked child processes do not inherit unneeded fds/handles
+#
+
+class TestCloseFds(unittest.TestCase):
+
+ def get_high_socket_fd(self):
+ if WIN32:
+ # The child process will not have any socket handles, so
+ # calling socket.fromfd() should produce WSAENOTSOCK even
+ # if there is a handle of the same number.
+ return socket.socket().detach()
+ else:
+ # We want to produce a socket with an fd high enough that a
+ # freshly created child process will not have any fds as high.
+ fd = socket.socket().detach()
+ to_close = []
+ while fd < 50:
+ to_close.append(fd)
+ fd = os.dup(fd)
+ for x in to_close:
+ os.close(x)
+ return fd
+
+ def close(self, fd):
+ if WIN32:
+ socket.socket(fileno=fd).close()
+ else:
+ os.close(fd)
+
+ @classmethod
+ def _test_closefds(cls, conn, fd):
+ try:
+ s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
+ except Exception as e:
+ conn.send(e)
+ else:
+ s.close()
+ conn.send(None)
+
+ def test_closefd(self):
+ if not HAS_REDUCTION:
+ raise unittest.SkipTest('requires fd pickling')
+
+ reader, writer = multiprocessing.Pipe()
+ fd = self.get_high_socket_fd()
+ try:
+ p = multiprocessing.Process(target=self._test_closefds,
+ args=(writer, fd))
+ p.start()
+ writer.close()
+ e = reader.recv()
+ p.join(timeout=5)
+ finally:
+ self.close(fd)
+ writer.close()
+ reader.close()
+
+ if multiprocessing.get_start_method() == 'fork':
+ self.assertIs(e, None)
+ else:
+ WSAENOTSOCK = 10038
+ self.assertIsInstance(e, OSError)
+ self.assertTrue(e.errno == errno.EBADF or
+ e.winerror == WSAENOTSOCK, e)
+
+#
# Issue #17097: EINTR should be ignored by recv(), send(), accept() etc
#
@@ -3550,10 +3545,10 @@ class TestIgnoreEINTR(unittest.TestCase):
def handler(signum, frame):
pass
signal.signal(signal.SIGUSR1, handler)
- l = multiprocessing.connection.Listener()
- conn.send(l.address)
- a = l.accept()
- a.send('welcome')
+ with multiprocessing.connection.Listener() as l:
+ conn.send(l.address)
+ a = l.accept()
+ a.send('welcome')
@unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1')
def test_ignore_listener(self):
@@ -3574,26 +3569,267 @@ class TestIgnoreEINTR(unittest.TestCase):
finally:
conn.close()
+class TestStartMethod(unittest.TestCase):
+ @classmethod
+ def _check_context(cls, conn):
+ conn.send(multiprocessing.get_start_method())
+
+ def check_context(self, ctx):
+ r, w = ctx.Pipe(duplex=False)
+ p = ctx.Process(target=self._check_context, args=(w,))
+ p.start()
+ w.close()
+ child_method = r.recv()
+ r.close()
+ p.join()
+ self.assertEqual(child_method, ctx.get_start_method())
+
+ def test_context(self):
+ for method in ('fork', 'spawn', 'forkserver'):
+ try:
+ ctx = multiprocessing.get_context(method)
+ except ValueError:
+ continue
+ self.assertEqual(ctx.get_start_method(), method)
+ self.assertIs(ctx.get_context(), ctx)
+ self.assertRaises(ValueError, ctx.set_start_method, 'spawn')
+ self.assertRaises(ValueError, ctx.set_start_method, None)
+ self.check_context(ctx)
+
+ def test_set_get(self):
+ multiprocessing.set_forkserver_preload(PRELOAD)
+ count = 0
+ old_method = multiprocessing.get_start_method()
+ try:
+ for method in ('fork', 'spawn', 'forkserver'):
+ try:
+ multiprocessing.set_start_method(method, force=True)
+ except ValueError:
+ continue
+ self.assertEqual(multiprocessing.get_start_method(), method)
+ ctx = multiprocessing.get_context()
+ self.assertEqual(ctx.get_start_method(), method)
+ self.assertTrue(type(ctx).__name__.lower().startswith(method))
+ self.assertTrue(
+ ctx.Process.__name__.lower().startswith(method))
+ self.check_context(multiprocessing)
+ count += 1
+ finally:
+ multiprocessing.set_start_method(old_method, force=True)
+ self.assertGreaterEqual(count, 1)
+
+ def test_get_all(self):
+ methods = multiprocessing.get_all_start_methods()
+ if sys.platform == 'win32':
+ self.assertEqual(methods, ['spawn'])
+ else:
+ self.assertTrue(methods == ['fork', 'spawn'] or
+ methods == ['fork', 'spawn', 'forkserver'])
+
+#
+# Check that killing process does not leak named semaphores
+#
+
+@unittest.skipIf(sys.platform == "win32",
+ "test semantics don't make sense on Windows")
+class TestSemaphoreTracker(unittest.TestCase):
+ def test_semaphore_tracker(self):
+ import subprocess
+ cmd = '''if 1:
+ import multiprocessing as mp, time, os
+ mp.set_start_method("spawn")
+ lock1 = mp.Lock()
+ lock2 = mp.Lock()
+ os.write(%d, lock1._semlock.name.encode("ascii") + b"\\n")
+ os.write(%d, lock2._semlock.name.encode("ascii") + b"\\n")
+ time.sleep(10)
+ '''
+ r, w = os.pipe()
+ p = subprocess.Popen([sys.executable,
+ '-c', cmd % (w, w)],
+ pass_fds=[w],
+ stderr=subprocess.PIPE)
+ os.close(w)
+ with open(r, 'rb', closefd=True) as f:
+ name1 = f.readline().rstrip().decode('ascii')
+ name2 = f.readline().rstrip().decode('ascii')
+ _multiprocessing.sem_unlink(name1)
+ p.terminate()
+ p.wait()
+ time.sleep(1.0)
+ with self.assertRaises(OSError) as ctx:
+ _multiprocessing.sem_unlink(name2)
+ # docs say it should be ENOENT, but OSX seems to give EINVAL
+ self.assertIn(ctx.exception.errno, (errno.ENOENT, errno.EINVAL))
+ err = p.stderr.read().decode('utf-8')
+ p.stderr.close()
+ expected = 'semaphore_tracker: There appear to be 2 leaked semaphores'
+ self.assertRegex(err, expected)
+ self.assertRegex(err, 'semaphore_tracker: %r: \[Errno' % name1)
+
#
-#
+# Mixins
#
-def setUpModule():
- if sys.platform.startswith("linux"):
- try:
- lock = multiprocessing.RLock()
- except OSError:
- raise unittest.SkipTest("OSError raises on RLock creation, "
- "see issue 3111!")
- check_enough_semaphores()
- util.get_temp_dir() # creates temp directory for use by all processes
- multiprocessing.get_logger().setLevel(LOG_LEVEL)
+class ProcessesMixin(object):
+ TYPE = 'processes'
+ Process = multiprocessing.Process
+ connection = multiprocessing.connection
+ current_process = staticmethod(multiprocessing.current_process)
+ active_children = staticmethod(multiprocessing.active_children)
+ Pool = staticmethod(multiprocessing.Pool)
+ Pipe = staticmethod(multiprocessing.Pipe)
+ Queue = staticmethod(multiprocessing.Queue)
+ JoinableQueue = staticmethod(multiprocessing.JoinableQueue)
+ Lock = staticmethod(multiprocessing.Lock)
+ RLock = staticmethod(multiprocessing.RLock)
+ Semaphore = staticmethod(multiprocessing.Semaphore)
+ BoundedSemaphore = staticmethod(multiprocessing.BoundedSemaphore)
+ Condition = staticmethod(multiprocessing.Condition)
+ Event = staticmethod(multiprocessing.Event)
+ Barrier = staticmethod(multiprocessing.Barrier)
+ Value = staticmethod(multiprocessing.Value)
+ Array = staticmethod(multiprocessing.Array)
+ RawValue = staticmethod(multiprocessing.RawValue)
+ RawArray = staticmethod(multiprocessing.RawArray)
-def tearDownModule():
- # pause a bit so we don't get warning about dangling threads/processes
- time.sleep(0.5)
+class ManagerMixin(object):
+ TYPE = 'manager'
+ Process = multiprocessing.Process
+ Queue = property(operator.attrgetter('manager.Queue'))
+ JoinableQueue = property(operator.attrgetter('manager.JoinableQueue'))
+ Lock = property(operator.attrgetter('manager.Lock'))
+ RLock = property(operator.attrgetter('manager.RLock'))
+ Semaphore = property(operator.attrgetter('manager.Semaphore'))
+ BoundedSemaphore = property(operator.attrgetter('manager.BoundedSemaphore'))
+ Condition = property(operator.attrgetter('manager.Condition'))
+ Event = property(operator.attrgetter('manager.Event'))
+ Barrier = property(operator.attrgetter('manager.Barrier'))
+ Value = property(operator.attrgetter('manager.Value'))
+ Array = property(operator.attrgetter('manager.Array'))
+ list = property(operator.attrgetter('manager.list'))
+ dict = property(operator.attrgetter('manager.dict'))
+ Namespace = property(operator.attrgetter('manager.Namespace'))
+
+ @classmethod
+ def Pool(cls, *args, **kwds):
+ return cls.manager.Pool(*args, **kwds)
+
+ @classmethod
+ def setUpClass(cls):
+ cls.manager = multiprocessing.Manager()
+
+ @classmethod
+ def tearDownClass(cls):
+ # only the manager process should be returned by active_children()
+ # but this can take a bit on slow machines, so wait a few seconds
+ # if there are other children too (see #17395)
+ t = 0.01
+ while len(multiprocessing.active_children()) > 1 and t < 5:
+ time.sleep(t)
+ t *= 2
+ gc.collect() # do garbage collection
+ if cls.manager._number_of_objects() != 0:
+ # This is not really an error since some tests do not
+ # ensure that all processes which hold a reference to a
+ # managed object have been joined.
+ print('Shared objects which still exist at manager shutdown:')
+ print(cls.manager._debug_info())
+ cls.manager.shutdown()
+ cls.manager.join()
+ cls.manager = None
-if __name__ == '__main__':
- unittest.main()
+class ThreadsMixin(object):
+ TYPE = 'threads'
+ Process = multiprocessing.dummy.Process
+ connection = multiprocessing.dummy.connection
+ current_process = staticmethod(multiprocessing.dummy.current_process)
+ active_children = staticmethod(multiprocessing.dummy.active_children)
+ Pool = staticmethod(multiprocessing.Pool)
+ Pipe = staticmethod(multiprocessing.dummy.Pipe)
+ Queue = staticmethod(multiprocessing.dummy.Queue)
+ JoinableQueue = staticmethod(multiprocessing.dummy.JoinableQueue)
+ Lock = staticmethod(multiprocessing.dummy.Lock)
+ RLock = staticmethod(multiprocessing.dummy.RLock)
+ Semaphore = staticmethod(multiprocessing.dummy.Semaphore)
+ BoundedSemaphore = staticmethod(multiprocessing.dummy.BoundedSemaphore)
+ Condition = staticmethod(multiprocessing.dummy.Condition)
+ Event = staticmethod(multiprocessing.dummy.Event)
+ Barrier = staticmethod(multiprocessing.dummy.Barrier)
+ Value = staticmethod(multiprocessing.dummy.Value)
+ Array = staticmethod(multiprocessing.dummy.Array)
+
+#
+# Functions used to create test cases from the base ones in this module
+#
+
+def install_tests_in_module_dict(remote_globs, start_method):
+ __module__ = remote_globs['__name__']
+ local_globs = globals()
+ ALL_TYPES = {'processes', 'threads', 'manager'}
+
+ for name, base in local_globs.items():
+ if not isinstance(base, type):
+ continue
+ if issubclass(base, BaseTestCase):
+ if base is BaseTestCase:
+ continue
+ assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES
+ for type_ in base.ALLOWED_TYPES:
+ newname = 'With' + type_.capitalize() + name[1:]
+ Mixin = local_globs[type_.capitalize() + 'Mixin']
+ class Temp(base, Mixin, unittest.TestCase):
+ pass
+ Temp.__name__ = Temp.__qualname__ = newname
+ Temp.__module__ = __module__
+ remote_globs[newname] = Temp
+ elif issubclass(base, unittest.TestCase):
+ class Temp(base, object):
+ pass
+ Temp.__name__ = Temp.__qualname__ = name
+ Temp.__module__ = __module__
+ remote_globs[name] = Temp
+
+ dangling = [None, None]
+ old_start_method = [None]
+
+ def setUpModule():
+ multiprocessing.set_forkserver_preload(PRELOAD)
+ multiprocessing.process._cleanup()
+ dangling[0] = multiprocessing.process._dangling.copy()
+ dangling[1] = threading._dangling.copy()
+ old_start_method[0] = multiprocessing.get_start_method(allow_none=True)
+ try:
+ multiprocessing.set_start_method(start_method, force=True)
+ except ValueError:
+ raise unittest.SkipTest(start_method +
+ ' start method not supported')
+
+ if sys.platform.startswith("linux"):
+ try:
+ lock = multiprocessing.RLock()
+ except OSError:
+ raise unittest.SkipTest("OSError raises on RLock creation, "
+ "see issue 3111!")
+ check_enough_semaphores()
+ util.get_temp_dir() # creates temp directory
+ multiprocessing.get_logger().setLevel(LOG_LEVEL)
+
+ def tearDownModule():
+ multiprocessing.set_start_method(old_start_method[0], force=True)
+ # pause a bit so we don't get warning about dangling threads/processes
+ time.sleep(0.5)
+ multiprocessing.process._cleanup()
+ gc.collect()
+ tmp = set(multiprocessing.process._dangling) - set(dangling[0])
+ if tmp:
+ print('Dangling processes:', tmp, file=sys.stderr)
+ del tmp
+ tmp = set(threading._dangling) - set(dangling[1])
+ if tmp:
+ print('Dangling threads:', tmp, file=sys.stderr)
+
+ remote_globs['setUpModule'] = setUpModule
+ remote_globs['tearDownModule'] = tearDownModule
diff --git a/Lib/test/audiodata/pluck-pcm24.au b/Lib/test/audiodata/pluck-pcm24.au
new file mode 100644
index 0000000000..0bb230418a
--- /dev/null
+++ b/Lib/test/audiodata/pluck-pcm24.au
Binary files differ
diff --git a/Lib/test/audiotests.py b/Lib/test/audiotests.py
index 59e9928796..7b39269ec9 100644
--- a/Lib/test/audiotests.py
+++ b/Lib/test/audiotests.py
@@ -47,6 +47,12 @@ class AudioTests:
params = f.getparams()
self.assertEqual(params,
(nchannels, sampwidth, framerate, nframes, comptype, compname))
+ self.assertEqual(params.nchannels, nchannels)
+ self.assertEqual(params.sampwidth, sampwidth)
+ self.assertEqual(params.framerate, framerate)
+ self.assertEqual(params.nframes, nframes)
+ self.assertEqual(params.comptype, comptype)
+ self.assertEqual(params.compname, compname)
dump = pickle.dumps(params)
self.assertEqual(pickle.loads(dump), params)
@@ -63,15 +69,12 @@ class AudioWriteTests(AudioTests):
return f
def check_file(self, testfile, nframes, frames):
- f = self.module.open(testfile, 'rb')
- try:
+ with self.module.open(testfile, 'rb') as f:
self.assertEqual(f.getnchannels(), self.nchannels)
self.assertEqual(f.getsampwidth(), self.sampwidth)
self.assertEqual(f.getframerate(), self.framerate)
self.assertEqual(f.getnframes(), nframes)
self.assertEqual(f.readframes(nframes), frames)
- finally:
- f.close()
def test_write_params(self):
f = self.create_file(TESTFN)
@@ -81,6 +84,53 @@ class AudioWriteTests(AudioTests):
self.nframes, self.comptype, self.compname)
f.close()
+ def test_write_context_manager_calls_close(self):
+ # Close checks for a minimum header and will raise an error
+ # if it is not set, so this proves that close is called.
+ with self.assertRaises(self.module.Error):
+ with self.module.open(TESTFN, 'wb'):
+ pass
+ with self.assertRaises(self.module.Error):
+ with open(TESTFN, 'wb') as testfile:
+ with self.module.open(testfile):
+ pass
+
+ def test_context_manager_with_open_file(self):
+ with open(TESTFN, 'wb') as testfile:
+ with self.module.open(testfile) as f:
+ f.setnchannels(self.nchannels)
+ f.setsampwidth(self.sampwidth)
+ f.setframerate(self.framerate)
+ f.setcomptype(self.comptype, self.compname)
+ self.assertEqual(testfile.closed, self.close_fd)
+ with open(TESTFN, 'rb') as testfile:
+ with self.module.open(testfile) as f:
+ self.assertFalse(f.getfp().closed)
+ params = f.getparams()
+ self.assertEqual(params.nchannels, self.nchannels)
+ self.assertEqual(params.sampwidth, self.sampwidth)
+ self.assertEqual(params.framerate, self.framerate)
+ if not self.close_fd:
+ self.assertIsNone(f.getfp())
+ self.assertEqual(testfile.closed, self.close_fd)
+
+ def test_context_manager_with_filename(self):
+ # If the file doesn't get closed, this test won't fail, but it will
+ # produce a resource leak warning.
+ with self.module.open(TESTFN, 'wb') as f:
+ f.setnchannels(self.nchannels)
+ f.setsampwidth(self.sampwidth)
+ f.setframerate(self.framerate)
+ f.setcomptype(self.comptype, self.compname)
+ with self.module.open(TESTFN) as f:
+ self.assertFalse(f.getfp().closed)
+ params = f.getparams()
+ self.assertEqual(params.nchannels, self.nchannels)
+ self.assertEqual(params.sampwidth, self.sampwidth)
+ self.assertEqual(params.framerate, self.framerate)
+ if not self.close_fd:
+ self.assertIsNone(f.getfp())
+
def test_write(self):
f = self.create_file(TESTFN)
f.setnframes(self.nframes)
@@ -203,12 +253,9 @@ class AudioTestsWithSourceFile(AudioTests):
with open(TESTFN, 'rb') as testfile:
self.assertEqual(testfile.read(13), b'ababagalamaga')
- f = self.module.open(testfile, 'rb')
- try:
+ with self.module.open(testfile, 'rb') as f:
self.assertEqual(f.getnchannels(), self.nchannels)
self.assertEqual(f.getsampwidth(), self.sampwidth)
self.assertEqual(f.getframerate(), self.framerate)
self.assertEqual(f.getnframes(), self.sndfilenframes)
self.assertEqual(f.readframes(self.nframes), self.frames)
- finally:
- f.close()
diff --git a/Lib/test/badsyntax_future10.py b/Lib/test/badsyntax_future10.py
new file mode 100644
index 0000000000..fa5ab67a98
--- /dev/null
+++ b/Lib/test/badsyntax_future10.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+"spam, bar, blah"
+from __future__ import print_function
diff --git a/Lib/test/bytecode_helper.py b/Lib/test/bytecode_helper.py
new file mode 100644
index 0000000000..58b4209f55
--- /dev/null
+++ b/Lib/test/bytecode_helper.py
@@ -0,0 +1,41 @@
+"""bytecode_helper - support tools for testing correct bytecode generation"""
+
+import unittest
+import dis
+import io
+
+_UNSPECIFIED = object()
+
+class BytecodeTestCase(unittest.TestCase):
+ """Custom assertion methods for inspecting bytecode."""
+
+ def get_disassembly_as_string(self, co):
+ s = io.StringIO()
+ dis.dis(co, file=s)
+ return s.getvalue()
+
+ def assertInBytecode(self, x, opname, argval=_UNSPECIFIED):
+ """Returns instr if op is found, otherwise throws AssertionError"""
+ for instr in dis.get_instructions(x):
+ if instr.opname == opname:
+ if argval is _UNSPECIFIED or instr.argval == argval:
+ return instr
+ disassembly = self.get_disassembly_as_string(x)
+ if argval is _UNSPECIFIED:
+ msg = '%s not found in bytecode:\n%s' % (opname, disassembly)
+ else:
+ msg = '(%s,%r) not found in bytecode:\n%s'
+ msg = msg % (opname, argval, disassembly)
+ self.fail(msg)
+
+ def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED):
+ """Throws AssertionError if op is found"""
+ for instr in dis.get_instructions(x):
+ if instr.opname == opname:
+ disassembly = self.get_disassembly_as_string(co)
+ if opargval is _UNSPECIFIED:
+ msg = '%s occurs in bytecode:\n%s' % (opname, disassembly)
+ elif instr.argval == argval:
+ msg = '(%s,%r) occurs in bytecode:\n%s'
+ msg = msg % (opname, argval, disassembly)
+ self.fail(msg)
diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py
index 3226bce7da..49140a5526 100644
--- a/Lib/test/datetimetester.py
+++ b/Lib/test/datetimetester.py
@@ -619,6 +619,10 @@ class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase):
eq(td(hours=-.2/us_per_hour), td(0))
eq(td(days=-.4/us_per_day, hours=-.2/us_per_hour), td(microseconds=-1))
+ # Test for a patch in Issue 8860
+ eq(td(microseconds=0.5), 0.5*td(microseconds=1.0))
+ eq(td(microseconds=0.5)//td.resolution, 0.5*td.resolution//td.resolution)
+
def test_massive_normalization(self):
td = timedelta(microseconds=-1)
self.assertEqual((td.days, td.seconds, td.microseconds),
diff --git a/Lib/test/final_a.py b/Lib/test/final_a.py
new file mode 100644
index 0000000000..390ee8895a
--- /dev/null
+++ b/Lib/test/final_a.py
@@ -0,0 +1,19 @@
+"""
+Fodder for module finalization tests in test_module.
+"""
+
+import shutil
+import test.final_b
+
+x = 'a'
+
+class C:
+ def __del__(self):
+ # Inspect module globals and builtins
+ print("x =", x)
+ print("final_b.x =", test.final_b.x)
+ print("shutil.rmtree =", getattr(shutil.rmtree, '__name__', None))
+ print("len =", getattr(len, '__name__', None))
+
+c = C()
+_underscored = C()
diff --git a/Lib/test/final_b.py b/Lib/test/final_b.py
new file mode 100644
index 0000000000..7228d82b88
--- /dev/null
+++ b/Lib/test/final_b.py
@@ -0,0 +1,19 @@
+"""
+Fodder for module finalization tests in test_module.
+"""
+
+import shutil
+import test.final_a
+
+x = 'b'
+
+class C:
+ def __del__(self):
+ # Inspect module globals and builtins
+ print("x =", x)
+ print("final_a.x =", test.final_a.x)
+ print("shutil.rmtree =", getattr(shutil.rmtree, '__name__', None))
+ print("len =", getattr(len, '__name__', None))
+
+c = C()
+_underscored = C()
diff --git a/Lib/test/fork_wait.py b/Lib/test/fork_wait.py
index 88527df25e..19b54ec736 100644
--- a/Lib/test/fork_wait.py
+++ b/Lib/test/fork_wait.py
@@ -28,7 +28,7 @@ class ForkWait(unittest.TestCase):
self.alive[id] = os.getpid()
try:
time.sleep(SHORTSLEEP)
- except IOError:
+ except OSError:
pass
def wait_impl(self, cpid):
diff --git a/Lib/test/keycert3.pem b/Lib/test/keycert3.pem
new file mode 100644
index 0000000000..5bfa62c4ca
--- /dev/null
+++ b/Lib/test/keycert3.pem
@@ -0,0 +1,73 @@
+-----BEGIN PRIVATE KEY-----
+MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP
+jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM
+9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ
+aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe
+yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j
+y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+
+AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW
+5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL
+9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9
+1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT
+DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh
+1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m
+JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3
+RnJdHOMXWem7/w==
+-----END PRIVATE KEY-----
+Certificate:
+ Data:
+ Version: 1 (0x0)
+ Serial Number: 12723342612721443281 (0xb09264b1f2da21d1)
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Validity
+ Not Before: Jan 4 19:47:07 2013 GMT
+ Not After : Nov 13 19:47:07 2022 GMT
+ Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d:
+ 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb:
+ c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99:
+ 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c:
+ f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93:
+ 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23:
+ f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5:
+ af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6:
+ 21:82:a5:3c:88:e5:be:1b:b1
+ Exponent: 65537 (0x10001)
+ Signature Algorithm: sha1WithRSAEncryption
+ 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a:
+ e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93:
+ f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13:
+ e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92:
+ d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59:
+ 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8:
+ ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1:
+ 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75:
+ 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96:
+ 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48:
+ 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a:
+ f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6:
+ 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41:
+ a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb:
+ fc:a9:94:71
+-----BEGIN CERTIFICATE-----
+MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY
+WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV
+BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3
+WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV
+BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv
+c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C
+tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola
+N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1
+TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR
+iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG
+xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo
+5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv
+mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF
+YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh
+2EJ36/yplHE=
+-----END CERTIFICATE-----
diff --git a/Lib/test/keycert4.pem b/Lib/test/keycert4.pem
new file mode 100644
index 0000000000..53355c8a50
--- /dev/null
+++ b/Lib/test/keycert4.pem
@@ -0,0 +1,73 @@
+-----BEGIN PRIVATE KEY-----
+MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAK5UQiMI5VkNs2Qv
+L7gUaiDdFevNUXRjU4DHAe3ZzzYLZNE69h9gO9VCSS16tJ5fT5VEu0EZyGr0e3V2
+NkX0ZoU0Hc/UaY4qx7LHmn5SYZpIxhJnkf7SyHJK1zUaGlU0/LxYqIuGCtF5dqx1
+L2OQhEx1GM6RydHdgX69G64LXcY5AgMBAAECgYAhsRMfJkb9ERLMl/oG/5sLQu9L
+pWDKt6+ZwdxzlZbggQ85CMYshjLKIod2DLL/sLf2x1PRXyRG131M1E3k8zkkz6de
+R1uDrIN/x91iuYzfLQZGh8bMY7Yjd2eoroa6R/7DjpElGejLxOAaDWO0ST2IFQy9
+myTGS2jSM97wcXfsSQJBANP3jelJoS5X6BRjTSneY21wcocxVuQh8pXpErALVNsT
+drrFTeaBuZp7KvbtnIM5g2WRNvaxLZlAY/hXPJvi6ncCQQDSix1cebml6EmPlEZS
+Mm8gwI2F9ufUunwJmBJcz826Do0ZNGByWDAM/JQZH4FX4GfAFNuj8PUb+GQfadkx
+i1DPAkEA0lVsNHojvuDsIo8HGuzarNZQT2beWjJ1jdxh9t7HrTx7LIps6rb/fhOK
+Zs0R6gVAJaEbcWAPZ2tFyECInAdnsQJAUjaeXXjuxFkjOFym5PvqpvhpivEx78Bu
+JPTr3rAKXmfGMxxfuOa0xK1wSyshP6ZR/RBn/+lcXPKubhHQDOegwwJAJF1DBQnN
++/tLmOPULtDwfP4Zixn+/8GmGOahFoRcu6VIGHmRilJTn6MOButw7Glv2YdeC6l/
+e83Gq6ffLVfKNQ==
+-----END PRIVATE KEY-----
+Certificate:
+ Data:
+ Version: 1 (0x0)
+ Serial Number: 12723342612721443282 (0xb09264b1f2da21d2)
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Validity
+ Not Before: Jan 4 19:47:07 2013 GMT
+ Not After : Nov 13 19:47:07 2022 GMT
+ Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=fakehostname
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (1024 bit)
+ Modulus:
+ 00:ae:54:42:23:08:e5:59:0d:b3:64:2f:2f:b8:14:
+ 6a:20:dd:15:eb:cd:51:74:63:53:80:c7:01:ed:d9:
+ cf:36:0b:64:d1:3a:f6:1f:60:3b:d5:42:49:2d:7a:
+ b4:9e:5f:4f:95:44:bb:41:19:c8:6a:f4:7b:75:76:
+ 36:45:f4:66:85:34:1d:cf:d4:69:8e:2a:c7:b2:c7:
+ 9a:7e:52:61:9a:48:c6:12:67:91:fe:d2:c8:72:4a:
+ d7:35:1a:1a:55:34:fc:bc:58:a8:8b:86:0a:d1:79:
+ 76:ac:75:2f:63:90:84:4c:75:18:ce:91:c9:d1:dd:
+ 81:7e:bd:1b:ae:0b:5d:c6:39
+ Exponent: 65537 (0x10001)
+ Signature Algorithm: sha1WithRSAEncryption
+ ad:45:8a:8e:ef:c6:ef:04:41:5c:2c:4a:84:dc:02:76:0c:d0:
+ 66:0f:f0:16:04:58:4d:fd:68:b7:b8:d3:a8:41:a5:5c:3c:6f:
+ 65:3c:d1:f8:ce:43:35:e7:41:5f:53:3d:c9:2c:c3:7d:fc:56:
+ 4a:fa:47:77:38:9d:bb:97:28:0a:3b:91:19:7f:bc:74:ae:15:
+ 6b:bd:20:36:67:45:a5:1e:79:d7:75:e6:89:5c:6d:54:84:d1:
+ 95:d7:a7:b4:33:3c:af:37:c4:79:8f:5e:75:dc:75:c2:18:fb:
+ 61:6f:2d:dc:38:65:5b:ba:67:28:d0:88:d7:8d:b9:23:5a:8e:
+ e8:c6:bb:db:ce:d5:b8:41:2a:ce:93:08:b6:95:ad:34:20:18:
+ d5:3b:37:52:74:50:0b:07:2c:b0:6d:a4:4c:7b:f4:e0:fd:d1:
+ af:17:aa:20:cd:62:e3:f0:9d:37:69:db:41:bd:d4:1c:fb:53:
+ 20:da:88:9d:76:26:67:ce:01:90:a7:80:1d:a9:5b:39:73:68:
+ 54:0a:d1:2a:03:1b:8f:3c:43:5d:5d:c4:51:f1:a7:e7:11:da:
+ 31:2c:49:06:af:04:f4:b8:3c:99:c4:20:b9:06:36:a2:00:92:
+ 61:1d:0c:6d:24:05:e2:82:e1:47:db:a0:5f:ba:b9:fb:ba:fa:
+ 49:12:1e:ce
+-----BEGIN CERTIFICATE-----
+MIICpzCCAY8CCQCwkmSx8toh0jANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY
+WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV
+BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3
+WjBiMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV
+BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRUwEwYDVQQDEwxmYWtlaG9z
+dG5hbWUwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAK5UQiMI5VkNs2QvL7gU
+aiDdFevNUXRjU4DHAe3ZzzYLZNE69h9gO9VCSS16tJ5fT5VEu0EZyGr0e3V2NkX0
+ZoU0Hc/UaY4qx7LHmn5SYZpIxhJnkf7SyHJK1zUaGlU0/LxYqIuGCtF5dqx1L2OQ
+hEx1GM6RydHdgX69G64LXcY5AgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAK1Fio7v
+xu8EQVwsSoTcAnYM0GYP8BYEWE39aLe406hBpVw8b2U80fjOQzXnQV9TPcksw338
+Vkr6R3c4nbuXKAo7kRl/vHSuFWu9IDZnRaUeedd15olcbVSE0ZXXp7QzPK83xHmP
+XnXcdcIY+2FvLdw4ZVu6ZyjQiNeNuSNajujGu9vO1bhBKs6TCLaVrTQgGNU7N1J0
+UAsHLLBtpEx79OD90a8XqiDNYuPwnTdp20G91Bz7UyDaiJ12JmfOAZCngB2pWzlz
+aFQK0SoDG488Q11dxFHxp+cR2jEsSQavBPS4PJnEILkGNqIAkmEdDG0kBeKC4Ufb
+oF+6ufu6+kkSHs4=
+-----END CERTIFICATE-----
diff --git a/Lib/test/leakers/test_gestalt.py b/Lib/test/leakers/test_gestalt.py
deleted file mode 100644
index e0081c1fc5..0000000000
--- a/Lib/test/leakers/test_gestalt.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import sys
-
-if sys.platform != 'darwin':
- raise ValueError("This test only leaks on Mac OS X")
-
-def leak():
- # taken from platform._mac_ver_lookup()
- from gestalt import gestalt
- import MacOS
-
- try:
- gestalt('sysu')
- except MacOS.Error:
- pass
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py
index bfbf44e0db..1cbcea2343 100644
--- a/Lib/test/lock_tests.py
+++ b/Lib/test/lock_tests.py
@@ -80,6 +80,11 @@ class BaseLockTests(BaseTestCase):
lock = self.locktype()
del lock
+ def test_repr(self):
+ lock = self.locktype()
+ repr(lock)
+ del lock
+
def test_acquire_destroy(self):
lock = self.locktype()
lock.acquire()
@@ -413,6 +418,17 @@ class ConditionTests(BaseTestCase):
self.assertRaises(RuntimeError, cond.notify)
def _check_notify(self, cond):
+ # Note that this test is sensitive to timing. If the worker threads
+ # don't execute in a timely fashion, the main thread may think they
+ # are further along then they are. The main thread therefore issues
+ # _wait() statements to try to make sure that it doesn't race ahead
+ # of the workers.
+ # Secondly, this test assumes that condition variables are not subject
+ # to spurious wakeups. The absence of spurious wakeups is an implementation
+ # detail of Condition Cariables in current CPython, but in general, not
+ # a guaranteed property of condition variables as a programming
+ # construct. In particular, it is possible that this can no longer
+ # be conveniently guaranteed should their implementation ever change.
N = 5
results1 = []
results2 = []
@@ -440,6 +456,9 @@ class ConditionTests(BaseTestCase):
_wait()
self.assertEqual(results1, [(True, 1)] * 3)
self.assertEqual(results2, [])
+ # first wait, to ensure all workers settle into cond.wait() before
+ # we continue. See issue #8799
+ _wait()
# Notify 5 threads: they might be in their first or second wait
cond.acquire()
cond.notify(5)
@@ -450,6 +469,7 @@ class ConditionTests(BaseTestCase):
_wait()
self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
self.assertEqual(results2, [(True, 2)] * 3)
+ _wait() # make sure all workers settle into cond.wait()
# Notify all threads: they are all in their second wait
cond.acquire()
cond.notify_all()
diff --git a/Lib/test/make_ssl_certs.py b/Lib/test/make_ssl_certs.py
index 48d2e57f4b..f630813b2c 100644
--- a/Lib/test/make_ssl_certs.py
+++ b/Lib/test/make_ssl_certs.py
@@ -2,6 +2,7 @@
and friends."""
import os
+import shutil
import sys
import tempfile
from subprocess import *
@@ -20,11 +21,52 @@ req_template = """
[req_x509_extensions]
subjectAltName = DNS:{hostname}
+
+ [ ca ]
+ default_ca = CA_default
+
+ [ CA_default ]
+ dir = cadir
+ database = $dir/index.txt
+ default_md = sha1
+ default_days = 3600
+ certificate = pycacert.pem
+ private_key = pycakey.pem
+ serial = $dir/serial
+ RANDFILE = $dir/.rand
+
+ policy = policy_match
+
+ [ policy_match ]
+ countryName = match
+ stateOrProvinceName = optional
+ organizationName = match
+ organizationalUnitName = optional
+ commonName = supplied
+ emailAddress = optional
+
+ [ policy_anything ]
+ countryName = optional
+ stateOrProvinceName = optional
+ localityName = optional
+ organizationName = optional
+ organizationalUnitName = optional
+ commonName = supplied
+ emailAddress = optional
+
+
+ [ v3_ca ]
+
+ subjectKeyIdentifier=hash
+ authorityKeyIdentifier=keyid:always,issuer
+ basicConstraints = CA:true
+
"""
here = os.path.abspath(os.path.dirname(__file__))
-def make_cert_key(hostname):
+def make_cert_key(hostname, sign=False):
+ print("creating cert for " + hostname)
tempnames = []
for i in range(3):
with tempfile.NamedTemporaryFile(delete=False) as f:
@@ -33,10 +75,25 @@ def make_cert_key(hostname):
try:
with open(req_file, 'w') as f:
f.write(req_template.format(hostname=hostname))
- args = ['req', '-new', '-days', '3650', '-nodes', '-x509',
+ args = ['req', '-new', '-days', '3650', '-nodes',
'-newkey', 'rsa:1024', '-keyout', key_file,
- '-out', cert_file, '-config', req_file]
+ '-config', req_file]
+ if sign:
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ tempnames.append(f.name)
+ reqfile = f.name
+ args += ['-out', reqfile ]
+
+ else:
+ args += ['-x509', '-out', cert_file ]
check_call(['openssl'] + args)
+
+ if sign:
+ args = ['ca', '-config', req_file, '-out', cert_file, '-outdir', 'cadir',
+ '-policy', 'policy_anything', '-batch', '-infiles', reqfile ]
+ check_call(['openssl'] + args)
+
+
with open(cert_file, 'r') as f:
cert = f.read()
with open(key_file, 'r') as f:
@@ -46,6 +103,32 @@ def make_cert_key(hostname):
for name in tempnames:
os.remove(name)
+TMP_CADIR = 'cadir'
+
+def unmake_ca():
+ shutil.rmtree(TMP_CADIR)
+
+def make_ca():
+ os.mkdir(TMP_CADIR)
+ with open(os.path.join('cadir','index.txt'),'a+') as f:
+ pass # empty file
+ with open(os.path.join('cadir','index.txt.attr'),'w+') as f:
+ f.write('unique_subject = no')
+
+ with tempfile.NamedTemporaryFile("w") as t:
+ t.write(req_template.format(hostname='our-ca-server'))
+ t.flush()
+ with tempfile.NamedTemporaryFile() as f:
+ args = ['req', '-new', '-days', '3650', '-extensions', 'v3_ca', '-nodes',
+ '-newkey', 'rsa:2048', '-keyout', 'pycakey.pem',
+ '-out', f.name,
+ '-subj', '/C=XY/L=Castle Anthrax/O=Python Software Foundation CA/CN=our-ca-server']
+ check_call(['openssl'] + args)
+ args = ['ca', '-config', t.name, '-create_serial',
+ '-out', 'pycacert.pem', '-batch', '-outdir', TMP_CADIR,
+ '-keyfile', 'pycakey.pem', '-days', '3650',
+ '-selfsign', '-extensions', 'v3_ca', '-infiles', f.name ]
+ check_call(['openssl'] + args)
if __name__ == '__main__':
os.chdir(here)
@@ -54,11 +137,34 @@ if __name__ == '__main__':
f.write(cert)
with open('ssl_key.pem', 'w') as f:
f.write(key)
+ print("password protecting ssl_key.pem in ssl_key.passwd.pem")
+ check_call(['openssl','rsa','-in','ssl_key.pem','-out','ssl_key.passwd.pem','-des3','-passout','pass:somepass'])
+ check_call(['openssl','rsa','-in','ssl_key.pem','-out','keycert.passwd.pem','-des3','-passout','pass:somepass'])
+
with open('keycert.pem', 'w') as f:
f.write(key)
f.write(cert)
+
+ with open('keycert.passwd.pem', 'a+') as f:
+ f.write(cert)
+
# For certificate matching tests
+ make_ca()
cert, key = make_cert_key('fakehostname')
with open('keycert2.pem', 'w') as f:
f.write(key)
f.write(cert)
+
+ cert, key = make_cert_key('localhost', True)
+ with open('keycert3.pem', 'w') as f:
+ f.write(key)
+ f.write(cert)
+
+ cert, key = make_cert_key('fakehostname', True)
+ with open('keycert4.pem', 'w') as f:
+ f.write(key)
+ f.write(cert)
+
+ unmake_ca()
+ print("\n\nPlease change the values in test_ssl.py, test_parse_cert function related to notAfter,notBefore and serialNumber")
+ check_call(['openssl','x509','-in','keycert.pem','-dates','-serial','-noout'])
diff --git a/Lib/test/mock_socket.py b/Lib/test/mock_socket.py
index d09e78c1d5..8ef0ec8c8d 100644
--- a/Lib/test/mock_socket.py
+++ b/Lib/test/mock_socket.py
@@ -140,12 +140,8 @@ def gethostbyname(name):
return ""
-class gaierror(Exception):
- pass
-
-
-class error(Exception):
- pass
+gaierror = socket_module.gaierror
+error = socket_module.error
# Constants
diff --git a/Lib/test/mp_fork_bomb.py b/Lib/test/mp_fork_bomb.py
index 908afe3045..017e010ba0 100644
--- a/Lib/test/mp_fork_bomb.py
+++ b/Lib/test/mp_fork_bomb.py
@@ -7,6 +7,11 @@ def foo():
# correctly on Windows. However, we should get a RuntimeError rather
# than the Windows equivalent of a fork bomb.
+if len(sys.argv) > 1:
+ multiprocessing.set_start_method(sys.argv[1])
+else:
+ multiprocessing.set_start_method('spawn')
+
p = multiprocessing.Process(target=foo)
p.start()
p.join()
diff --git a/Lib/test/multibytecodec_support.py b/Lib/test/multibytecodec_support.py
index 26bac7be10..dcaae7b95c 100644
--- a/Lib/test/multibytecodec_support.py
+++ b/Lib/test/multibytecodec_support.py
@@ -282,7 +282,7 @@ class TestBase_Mapping(unittest.TestCase):
unittest.TestCase.__init__(self, *args, **kw)
try:
self.open_mapping_file().close() # test it to report the error early
- except (IOError, HTTPException):
+ except (OSError, HTTPException):
self.skipTest("Could not retrieve "+self.mapfileurl)
def open_mapping_file(self):
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index 052290d764..197112024b 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -601,30 +601,6 @@ class AbstractPickleTests(unittest.TestCase):
self.assertRaises(KeyError, self.loads, b'g0\np0')
self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)])
- def test_insecure_strings(self):
- # XXX Some of these tests are temporarily disabled
- insecure = [b"abc", b"2 + 2", # not quoted
- ## b"'abc' + 'def'", # not a single quoted string
- b"'abc", # quote is not closed
- b"'abc\"", # open quote and close quote don't match
- b"'abc' ?", # junk after close quote
- b"'\\'", # trailing backslash
- # Variations on issue #17710
- b"'",
- b'"',
- b"' ",
- b"' ",
- b"' ",
- b"' ",
- b'" ',
- # some tests of the quoting rules
- ## b"'abc\"\''",
- ## b"'\\\\a\'\'\'\\\'\\\\\''",
- ]
- for b in insecure:
- buf = b"S" + b + b"\012p0\012."
- self.assertRaises(ValueError, self.loads, buf)
-
def test_unicode(self):
endcases = ['', '<\\u>', '<\\\u1234>', '<\n>',
'<\\>', '<\\\U00012345>',
@@ -1214,6 +1190,35 @@ class AbstractPickleTests(unittest.TestCase):
dumped = b'\x80\x03X\x01\x00\x00\x00ar\xff\xff\xff\xff.'
self.assertRaises(ValueError, self.loads, dumped)
+ def test_badly_escaped_string(self):
+ self.assertRaises(ValueError, self.loads, b"S'\\'\n.")
+
+ def test_badly_quoted_string(self):
+ # Issue #17710
+ badpickles = [b"S'\n.",
+ b'S"\n.',
+ b'S\' \n.',
+ b'S" \n.',
+ b'S\'"\n.',
+ b'S"\'\n.',
+ b"S' ' \n.",
+ b'S" " \n.',
+ b"S ''\n.",
+ b'S ""\n.',
+ b'S \n.',
+ b'S\n.',
+ b'S.']
+ for p in badpickles:
+ self.assertRaises(pickle.UnpicklingError, self.loads, p)
+
+ def test_correctly_quoted_string(self):
+ goodpickles = [(b"S''\n.", ''),
+ (b'S""\n.', ''),
+ (b'S"\\n"\n.', '\n'),
+ (b"S'\\n'\n.", '\n')]
+ for p, expected in goodpickles:
+ self.assertEqual(self.loads(p), expected)
+
def _check_pickling_with_opcode(self, obj, opcode, proto):
pickled = self.dumps(obj, proto)
self.assertTrue(opcode_in_pickle(opcode, pickled))
diff --git a/Lib/test/pycacert.pem b/Lib/test/pycacert.pem
new file mode 100644
index 0000000000..09b1f3e08a
--- /dev/null
+++ b/Lib/test/pycacert.pem
@@ -0,0 +1,78 @@
+Certificate:
+ Data:
+ Version: 3 (0x2)
+ Serial Number: 12723342612721443280 (0xb09264b1f2da21d0)
+ Signature Algorithm: sha1WithRSAEncryption
+ Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Validity
+ Not Before: Jan 4 19:47:07 2013 GMT
+ Not After : Jan 2 19:47:07 2023 GMT
+ Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server
+ Subject Public Key Info:
+ Public Key Algorithm: rsaEncryption
+ Public-Key: (2048 bit)
+ Modulus:
+ 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2:
+ 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4:
+ e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f:
+ e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f:
+ 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf:
+ 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d:
+ a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3:
+ e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4:
+ 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf:
+ 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c:
+ e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6:
+ c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a:
+ cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01:
+ 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87:
+ 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f:
+ 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14:
+ e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4:
+ c5:4d
+ Exponent: 65537 (0x10001)
+ X509v3 extensions:
+ X509v3 Subject Key Identifier:
+ BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B
+ X509v3 Authority Key Identifier:
+ keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B
+
+ X509v3 Basic Constraints:
+ CA:TRUE
+ Signature Algorithm: sha1WithRSAEncryption
+ 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6:
+ 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d:
+ a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95:
+ 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17:
+ 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c:
+ 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4:
+ fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7:
+ 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24:
+ 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33:
+ 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61:
+ ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f:
+ 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64:
+ b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb:
+ 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3:
+ 5e:58:c8:9e
+-----BEGIN CERTIFICATE-----
+MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV
+BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW
+MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx
+OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg
+Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI
+hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV
+q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/
+AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA
+Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni
+0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx
+6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w
+HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2
+2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB
+AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4
+QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1
+Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O
+JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR
+f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf
+9mmvtk57HVjsO6lTo15YyJ4=
+-----END CERTIFICATE-----
diff --git a/Lib/test/pycakey.pem b/Lib/test/pycakey.pem
new file mode 100644
index 0000000000..fc6effefb2
--- /dev/null
+++ b/Lib/test/pycakey.pem
@@ -0,0 +1,28 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDn3unjDJ8AtqH9
+K1uW0m/M4L6GuSBe7AN6VavqpOn5SYXSZtXtx3rqVo4tj+dC4mIoqZ/WG47rtbSc
+nxSr3+aUi3YdPm0kYe0MvwCKYQzfXMg2cxYAzUe6baSkdIiDIwoZ/AmnPEpL0+cd
+LeTqTFQh8ybbiTcY1AK7QDJfpP8tHPfUu+yOz1yCrOZ8CGxIhWEHfyXgXOC8NF/g
+uQRHdchHC4281shoXzODYtIgRDWxrYEais28NbBci0fWGOmcGJfMATwpzOge5OTB
+uN7nwhEYh1qTNNimJfcUcevkIaLSDy4u1GIANdPW71xgS0ypFOLdFVhGNzMmt+cu
+Xe1C5MVNAgMBAAECggEBAJPM7QuUrPn4cLN/Ysd15lwTWn9oHDFFgkYFvCs66gXE
+ju/6Kx2BjWE4wTJby09AHM/MqB0DvguT7Mf1Q2j3tPQ1HZowg8OwRDleuwp6KIls
+jBbhL0Jdl/5HC67ktWvZ9wNvO/wFG1rQfT6FVajf9LUbWEaSZbOG2SLhHfsHorzu
+xjTJaI3bQ/0+79B1exwk5ruwhzFRd/XpY8hls7D/RfPIuHDlBghkW3N59KFWrf5h
+6bNEh2THm0+IyGcGqs0FD+QCOXyvsjwSUswqrr2ctLREOeDcd5ReUjSxYgjcJRrm
+J7ceIY/+uwDJxw/OlnmBvF6pQMkKwYW2gFztu+g2t4UCgYEA/9yo01Exz4crxXsy
+tAlnDJM++nZcm07rtFjTKHUfKY/cCgNTa8udM0svnfwlid/dpgLsI38gx04HHC1i
+EZ4acz+ToIWedLxM0nq73//xeRWEazOvCz1mMTZaMldahTWAyzN8qVK2B/625Yy4
+wNYWyweBBwEB8MzaCs73spksXOsCgYEA5/7wvhiofYGFAfMuANeJIwDL2OtBnoOv
+mVNfCmi3GC38fzwyi5ZpskWDiS2woJ+LQfs9Qu4EcZbUFLd7gbeOvb5gmFUtYope
+LitUUKunIR18MkQ+mQDBpQPQPhk4QJP5reCbWkrfTu7b5o/iS41s6fBTFmuzhLcT
+C71vFdCyeKcCgYAiCCqYeOtELDmBOeLDmaCQRqGQ1N96dOPbCBmF/xYXBCCDYG/f
+HaUaJnz96YTgstsbcrYP/p/Qgqtlbw/lQf9IpwMuzbcG1ejt8g89OyDWNyt2ytgU
+iaUnFJCos3/Byh0Iah/BsdOueo2/OJl2ZMOBW80orlSgv86cs2y037TL4wKBgQDm
+OOyW+MlbowhnIvfoBfwlLEkefnej4nKD6WRLZBcue5Qyf355X06Mhsc9foXlH+6G
+D9h/bswiHNdhp6N82rdgPGiHQx/CxiUoE/+b/nvgNO5mw6qLE2EXbG1e8pAMJcyE
+bHw+YkawggDfELI036fRj5gki8SeUz8nS1nNgElbyQKBgCRDX9Jh+MwSLu4QBWdt
+/fi+lv3K6kun/fI7EOV1vCV/j871tICu7pu5BrOLxAHqoVfU9AUX299/2KjCb5pv
+kjogiUK6qWCWBlfuqDNWGCoUGt1rhznUva0nNjSMy5rinBhhjpROZC2pw48lOluP
+UuvXsaPph7GTqPuy4Kab12YC
+-----END PRIVATE KEY-----
diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py
index ae62c6e7a0..a5d707edae 100755
--- a/Lib/test/regrtest.py
+++ b/Lib/test/regrtest.py
@@ -1,11 +1,18 @@
#! /usr/bin/env python3
"""
-Usage:
+Script to run Python regression tests.
+Run this script with -h or --help for documentation.
+"""
+
+USAGE = """\
python -m test [options] [test_name1 [test_name2 ...]]
python path/to/Lib/test/regrtest.py [options] [test_name1 [test_name2 ...]]
+"""
+DESCRIPTION = """\
+Run Python regression tests.
If no arguments or options are provided, finds all files matching
the pattern "test_*" in the Lib/test subdirectory and runs
@@ -15,63 +22,10 @@ For more rigorous testing, it is useful to use the following
command line:
python -E -Wd -m test [options] [test_name1 ...]
+"""
-
-Options:
-
--h/--help -- print this text and exit
---timeout TIMEOUT
- -- dump the traceback and exit if a test takes more
- than TIMEOUT seconds; disabled if TIMEOUT is negative
- or equals to zero
---wait -- wait for user input, e.g., allow a debugger to be attached
-
-Verbosity
-
--v/--verbose -- run tests in verbose mode with output to stdout
--w/--verbose2 -- re-run failed tests in verbose mode
--W/--verbose3 -- display test output on failure
--d/--debug -- print traceback for failed tests
--q/--quiet -- no output unless one or more tests fail
--o/--slow -- print the slowest 10 tests
- --header -- print header with interpreter info
-
-Selecting tests
-
--r/--randomize -- randomize test execution order (see below)
- --randseed -- pass a random seed to reproduce a previous random run
--f/--fromfile -- read names of tests to run from a file (see below)
--x/--exclude -- arguments are tests to *exclude*
--s/--single -- single step through a set of tests (see below)
--m/--match PAT -- match test cases and methods with glob pattern PAT
--G/--failfast -- fail as soon as a test fails (only with -v or -W)
--u/--use RES1,RES2,...
- -- specify which special resource intensive tests to run
--M/--memlimit LIMIT
- -- run very large memory-consuming tests
- --testdir DIR
- -- execute test files in the specified directory (instead
- of the Python stdlib test suite)
-
-Special runs
-
--l/--findleaks -- if GC is available detect tests that leak memory
--L/--runleaks -- run the leaks(1) command just before exit
--R/--huntrleaks RUNCOUNTS
- -- search for reference leaks (needs debug build, v. slow)
--j/--multiprocess PROCESSES
- -- run PROCESSES processes at once
--T/--coverage -- turn on code coverage tracing using the trace module
--D/--coverdir DIRECTORY
- -- Directory where coverage files are put
--N/--nocoverdir -- Put coverage files alongside modules
--t/--threshold THRESHOLD
- -- call gc.set_threshold(THRESHOLD)
--n/--nowindows -- suppress error message boxes on Windows
--F/--forever -- run the specified tests in a loop, until an error happens
-
-
-Additional Option Details:
+EPILOG = """\
+Additional option details:
-r randomizes test execution order. You can use --randseed=int to provide a
int seed value for the randomizer; this is useful for reproducing troublesome
@@ -168,11 +122,12 @@ option '-uall,-gui'.
# We import importlib *ASAP* in order to test #15386
import importlib
+import argparse
import builtins
import faulthandler
-import getopt
import io
import json
+import locale
import logging
import os
import platform
@@ -194,7 +149,7 @@ try:
except ImportError:
threading = None
try:
- import multiprocessing.process
+ import _multiprocessing, multiprocessing.process
except ImportError:
multiprocessing = None
@@ -246,20 +201,262 @@ from test import support
RESOURCE_NAMES = ('audio', 'curses', 'largefile', 'network',
'decimal', 'cpu', 'subprocess', 'urlfetch', 'gui')
-TEMPDIR = os.path.abspath(tempfile.gettempdir())
-
-def usage(msg):
- print(msg, file=sys.stderr)
- print("Use --help for usage", file=sys.stderr)
- sys.exit(2)
-
+# When tests are run from the Python build directory, it is best practice
+# to keep the test files in a subfolder. This eases the cleanup of leftover
+# files using the "make distclean" command.
+if sysconfig.is_python_build():
+ TEMPDIR = os.path.join(sysconfig.get_config_var('srcdir'), 'build')
+else:
+ TEMPDIR = tempfile.gettempdir()
+TEMPDIR = os.path.abspath(TEMPDIR)
+
+class _ArgParser(argparse.ArgumentParser):
+
+ def error(self, message):
+ super().error(message + "\nPass -h or --help for complete help.")
+
+def _create_parser():
+ # Set prog to prevent the uninformative "__main__.py" from displaying in
+ # error messages when using "python -m test ...".
+ parser = _ArgParser(prog='regrtest.py',
+ usage=USAGE,
+ description=DESCRIPTION,
+ epilog=EPILOG,
+ add_help=False,
+ formatter_class=argparse.RawDescriptionHelpFormatter)
+
+ # Arguments with this clause added to its help are described further in
+ # the epilog's "Additional option details" section.
+ more_details = ' See the section at bottom for more details.'
+
+ group = parser.add_argument_group('General options')
+ # We add help explicitly to control what argument group it renders under.
+ group.add_argument('-h', '--help', action='help',
+ help='show this help message and exit')
+ group.add_argument('--timeout', metavar='TIMEOUT', type=float,
+ help='dump the traceback and exit if a test takes '
+ 'more than TIMEOUT seconds; disabled if TIMEOUT '
+ 'is negative or equals to zero')
+ group.add_argument('--wait', action='store_true',
+ help='wait for user input, e.g., allow a debugger '
+ 'to be attached')
+ group.add_argument('--slaveargs', metavar='ARGS')
+ group.add_argument('-S', '--start', metavar='START',
+ help='the name of the test at which to start.' +
+ more_details)
+
+ group = parser.add_argument_group('Verbosity')
+ group.add_argument('-v', '--verbose', action='count',
+ help='run tests in verbose mode with output to stdout')
+ group.add_argument('-w', '--verbose2', action='store_true',
+ help='re-run failed tests in verbose mode')
+ group.add_argument('-W', '--verbose3', action='store_true',
+ help='display test output on failure')
+ group.add_argument('-q', '--quiet', action='store_true',
+ help='no output unless one or more tests fail')
+ group.add_argument('-o', '--slow', action='store_true', dest='print_slow',
+ help='print the slowest 10 tests')
+ group.add_argument('--header', action='store_true',
+ help='print header with interpreter info')
+
+ group = parser.add_argument_group('Selecting tests')
+ group.add_argument('-r', '--randomize', action='store_true',
+ help='randomize test execution order.' + more_details)
+ group.add_argument('--randseed', metavar='SEED',
+ dest='random_seed', type=int,
+ help='pass a random seed to reproduce a previous '
+ 'random run')
+ group.add_argument('-f', '--fromfile', metavar='FILE',
+ help='read names of tests to run from a file.' +
+ more_details)
+ group.add_argument('-x', '--exclude', action='store_true',
+ help='arguments are tests to *exclude*')
+ group.add_argument('-s', '--single', action='store_true',
+ help='single step through a set of tests.' +
+ more_details)
+ group.add_argument('-m', '--match', metavar='PAT',
+ dest='match_tests',
+ help='match test cases and methods with glob pattern PAT')
+ group.add_argument('-G', '--failfast', action='store_true',
+ help='fail as soon as a test fails (only with -v or -W)')
+ group.add_argument('-u', '--use', metavar='RES1,RES2,...',
+ action='append', type=resources_list,
+ help='specify which special resource intensive tests '
+ 'to run.' + more_details)
+ group.add_argument('-M', '--memlimit', metavar='LIMIT',
+ help='run very large memory-consuming tests.' +
+ more_details)
+ group.add_argument('--testdir', metavar='DIR',
+ type=relative_filename,
+ help='execute test files in the specified directory '
+ '(instead of the Python stdlib test suite)')
+
+ group = parser.add_argument_group('Special runs')
+ group.add_argument('-l', '--findleaks', action='store_true',
+ help='if GC is available detect tests that leak memory')
+ group.add_argument('-L', '--runleaks', action='store_true',
+ help='run the leaks(1) command just before exit.' +
+ more_details)
+ group.add_argument('-R', '--huntrleaks', metavar='RUNCOUNTS',
+ type=huntrleaks,
+ help='search for reference leaks (needs debug build, '
+ 'very slow).' + more_details)
+ group.add_argument('-j', '--multiprocess', metavar='PROCESSES',
+ dest='use_mp', type=int,
+ help='run PROCESSES processes at once')
+ group.add_argument('-T', '--coverage', action='store_true',
+ dest='trace',
+ help='turn on code coverage tracing using the trace '
+ 'module')
+ group.add_argument('-D', '--coverdir', metavar='DIR',
+ type=relative_filename,
+ help='directory where coverage files are put')
+ group.add_argument('-N', '--nocoverdir',
+ action='store_const', const=None, dest='coverdir',
+ help='put coverage files alongside modules')
+ group.add_argument('-t', '--threshold', metavar='THRESHOLD',
+ type=int,
+ help='call gc.set_threshold(THRESHOLD)')
+ group.add_argument('-n', '--nowindows', action='store_true',
+ help='suppress error message boxes on Windows')
+ group.add_argument('-F', '--forever', action='store_true',
+ help='run the specified tests in a loop, until an '
+ 'error happens')
+
+ parser.add_argument('args', nargs=argparse.REMAINDER,
+ help=argparse.SUPPRESS)
+
+ return parser
+
+def relative_filename(string):
+ # CWD is replaced with a temporary dir before calling main(), so we
+ # join it with the saved CWD so it ends up where the user expects.
+ return os.path.join(support.SAVEDCWD, string)
+
+def huntrleaks(string):
+ args = string.split(':')
+ if len(args) not in (2, 3):
+ raise argparse.ArgumentTypeError(
+ 'needs 2 or 3 colon-separated arguments')
+ nwarmup = int(args[0]) if args[0] else 5
+ ntracked = int(args[1]) if args[1] else 4
+ fname = args[2] if len(args) > 2 and args[2] else 'reflog.txt'
+ return nwarmup, ntracked, fname
+
+def resources_list(string):
+ u = [x.lower() for x in string.split(',')]
+ for r in u:
+ if r == 'all' or r == 'none':
+ continue
+ if r[0] == '-':
+ r = r[1:]
+ if r not in RESOURCE_NAMES:
+ raise argparse.ArgumentTypeError('invalid resource: ' + r)
+ return u
-def main(tests=None, testdir=None, verbose=0, quiet=False,
+def _parse_args(args, **kwargs):
+ # Defaults
+ ns = argparse.Namespace(testdir=None, verbose=0, quiet=False,
exclude=False, single=False, randomize=False, fromfile=None,
findleaks=False, use_resources=None, trace=False, coverdir='coverage',
runleaks=False, huntrleaks=False, verbose2=False, print_slow=False,
random_seed=None, use_mp=None, verbose3=False, forever=False,
- header=False, failfast=False, match_tests=None):
+ header=False, failfast=False, match_tests=None)
+ for k, v in kwargs.items():
+ if not hasattr(ns, k):
+ raise TypeError('%r is an invalid keyword argument '
+ 'for this function' % k)
+ setattr(ns, k, v)
+ if ns.use_resources is None:
+ ns.use_resources = []
+
+ parser = _create_parser()
+ parser.parse_args(args=args, namespace=ns)
+
+ if ns.single and ns.fromfile:
+ parser.error("-s and -f don't go together!")
+ if ns.use_mp and ns.trace:
+ parser.error("-T and -j don't go together!")
+ if ns.use_mp and ns.findleaks:
+ parser.error("-l and -j don't go together!")
+ if ns.use_mp and ns.memlimit:
+ parser.error("-M and -j don't go together!")
+ if ns.failfast and not (ns.verbose or ns.verbose3):
+ parser.error("-G/--failfast needs either -v or -W")
+
+ if ns.quiet:
+ ns.verbose = 0
+ if ns.timeout is not None:
+ if hasattr(faulthandler, 'dump_traceback_later'):
+ if ns.timeout <= 0:
+ ns.timeout = None
+ else:
+ print("Warning: The timeout option requires "
+ "faulthandler.dump_traceback_later")
+ ns.timeout = None
+ if ns.use_mp is not None:
+ if ns.use_mp <= 0:
+ # Use all cores + extras for tests that like to sleep
+ ns.use_mp = 2 + (os.cpu_count() or 1)
+ if ns.use_mp == 1:
+ ns.use_mp = None
+ if ns.use:
+ for a in ns.use:
+ for r in a:
+ if r == 'all':
+ ns.use_resources[:] = RESOURCE_NAMES
+ continue
+ if r == 'none':
+ del ns.use_resources[:]
+ continue
+ remove = False
+ if r[0] == '-':
+ remove = True
+ r = r[1:]
+ if remove:
+ if r in ns.use_resources:
+ ns.use_resources.remove(r)
+ elif r not in ns.use_resources:
+ ns.use_resources.append(r)
+ if ns.random_seed is not None:
+ ns.randomize = True
+
+ return ns
+
+
+def run_test_in_subprocess(testname, ns):
+ """Run the given test in a subprocess with --slaveargs.
+
+ ns is the option Namespace parsed from command-line arguments. regrtest
+ is invoked in a subprocess with the --slaveargs argument; when the
+ subprocess exits, its return code, stdout and stderr are returned as a
+ 3-tuple.
+ """
+ from subprocess import Popen, PIPE
+ base_cmd = ([sys.executable] + support.args_from_interpreter_flags() +
+ ['-X', 'faulthandler', '-m', 'test.regrtest'])
+
+ slaveargs = (
+ (testname, ns.verbose, ns.quiet),
+ dict(huntrleaks=ns.huntrleaks,
+ use_resources=ns.use_resources,
+ output_on_failure=ns.verbose3,
+ timeout=ns.timeout, failfast=ns.failfast,
+ match_tests=ns.match_tests))
+ # Running the child from the same working directory as regrtest's original
+ # invocation ensures that TEMPDIR for the child is the same when
+ # sysconfig.is_python_build() is true. See issue 15300.
+ popen = Popen(base_cmd + ['--slaveargs', json.dumps(slaveargs)],
+ stdout=PIPE, stderr=PIPE,
+ universal_newlines=True,
+ close_fds=(os.name != 'nt'),
+ cwd=support.SAVEDCWD)
+ stdout, stderr = popen.communicate()
+ retcode = popen.wait()
+ return retcode, stdout, stderr
+
+
+def main(tests=None, **kwargs):
"""Execute a test suite.
This also parses command-line options and modifies its behavior
@@ -282,7 +479,6 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
directly to set the values that would normally be set by flags
on the command line.
"""
-
# Display the Python traceback on fatal errors (e.g. segfault)
faulthandler.enable(all_threads=True)
@@ -298,187 +494,51 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
replace_stdout()
support.record_original_stdout(sys.stdout)
- try:
- opts, args = getopt.getopt(sys.argv[1:], 'hvqxsoS:rf:lu:t:TD:NLR:FdwWM:nj:Gm:',
- ['help', 'verbose', 'verbose2', 'verbose3', 'quiet',
- 'exclude', 'single', 'slow', 'randomize', 'fromfile=', 'findleaks',
- 'use=', 'threshold=', 'coverdir=', 'nocoverdir',
- 'runleaks', 'huntrleaks=', 'memlimit=', 'randseed=',
- 'multiprocess=', 'coverage', 'slaveargs=', 'forever', 'debug',
- 'start=', 'nowindows', 'header', 'testdir=', 'timeout=', 'wait',
- 'failfast', 'match='])
- except getopt.error as msg:
- usage(msg)
- # Defaults
- if random_seed is None:
- random_seed = random.randrange(10000000)
- if use_resources is None:
- use_resources = []
- debug = False
- start = None
- timeout = None
- for o, a in opts:
- if o in ('-h', '--help'):
- print(__doc__)
- return
- elif o in ('-v', '--verbose'):
- verbose += 1
- elif o in ('-w', '--verbose2'):
- verbose2 = True
- elif o in ('-d', '--debug'):
- debug = True
- elif o in ('-W', '--verbose3'):
- verbose3 = True
- elif o in ('-G', '--failfast'):
- failfast = True
- elif o in ('-q', '--quiet'):
- quiet = True;
- verbose = 0
- elif o in ('-x', '--exclude'):
- exclude = True
- elif o in ('-S', '--start'):
- start = a
- elif o in ('-s', '--single'):
- single = True
- elif o in ('-o', '--slow'):
- print_slow = True
- elif o in ('-r', '--randomize'):
- randomize = True
- elif o == '--randseed':
- randomize = True
- random_seed = int(a)
- elif o in ('-f', '--fromfile'):
- fromfile = a
- elif o in ('-m', '--match'):
- match_tests = a
- elif o in ('-l', '--findleaks'):
- findleaks = True
- elif o in ('-L', '--runleaks'):
- runleaks = True
- elif o in ('-t', '--threshold'):
- import gc
- gc.set_threshold(int(a))
- elif o in ('-T', '--coverage'):
- trace = True
- elif o in ('-D', '--coverdir'):
- # CWD is replaced with a temporary dir before calling main(), so we
- # need join it with the saved CWD so it goes where the user expects.
- coverdir = os.path.join(support.SAVEDCWD, a)
- elif o in ('-N', '--nocoverdir'):
- coverdir = None
- elif o in ('-R', '--huntrleaks'):
- huntrleaks = a.split(':')
- if len(huntrleaks) not in (2, 3):
- print(a, huntrleaks)
- usage('-R takes 2 or 3 colon-separated arguments')
- if not huntrleaks[0]:
- huntrleaks[0] = 5
- else:
- huntrleaks[0] = int(huntrleaks[0])
- if not huntrleaks[1]:
- huntrleaks[1] = 4
- else:
- huntrleaks[1] = int(huntrleaks[1])
- if len(huntrleaks) == 2 or not huntrleaks[2]:
- huntrleaks[2:] = ["reflog.txt"]
- # Avoid false positives due to various caches
- # filling slowly with random data:
- warm_caches()
- elif o in ('-M', '--memlimit'):
- support.set_memlimit(a)
- elif o in ('-u', '--use'):
- u = [x.lower() for x in a.split(',')]
- for r in u:
- if r == 'all':
- use_resources[:] = RESOURCE_NAMES
- continue
- if r == 'none':
- del use_resources[:]
- continue
- remove = False
- if r[0] == '-':
- remove = True
- r = r[1:]
- if r not in RESOURCE_NAMES:
- usage('Invalid -u/--use option: ' + a)
- if remove:
- if r in use_resources:
- use_resources.remove(r)
- elif r not in use_resources:
- use_resources.append(r)
- elif o in ('-n', '--nowindows'):
- import msvcrt
- msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS|
- msvcrt.SEM_NOALIGNMENTFAULTEXCEPT|
- msvcrt.SEM_NOGPFAULTERRORBOX|
- msvcrt.SEM_NOOPENFILEERRORBOX)
- try:
- msvcrt.CrtSetReportMode
- except AttributeError:
- # release build
- pass
- else:
- for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]:
- msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE)
- msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR)
- elif o in ('-F', '--forever'):
- forever = True
- elif o in ('-j', '--multiprocess'):
- use_mp = int(a)
- if use_mp <= 0:
- try:
- import multiprocessing
- # Use all cores + extras for tests that like to sleep
- use_mp = 2 + multiprocessing.cpu_count()
- except (ImportError, NotImplementedError):
- use_mp = 3
- if use_mp == 1:
- use_mp = None
- elif o == '--header':
- header = True
- elif o == '--slaveargs':
- args, kwargs = json.loads(a)
- try:
- result = runtest(*args, **kwargs)
- except KeyboardInterrupt:
- result = INTERRUPTED, ''
- except BaseException as e:
- traceback.print_exc()
- result = CHILD_ERROR, str(e)
- sys.stdout.flush()
- print() # Force a newline (just in case)
- print(json.dumps(result))
- sys.exit(0)
- elif o == '--testdir':
- # CWD is replaced with a temporary dir before calling main(), so we
- # join it with the saved CWD so it ends up where the user expects.
- testdir = os.path.join(support.SAVEDCWD, a)
- elif o == '--timeout':
- if hasattr(faulthandler, 'dump_traceback_later'):
- timeout = float(a)
- if timeout <= 0:
- timeout = None
- else:
- print("Warning: The timeout option requires "
- "faulthandler.dump_traceback_later")
- timeout = None
- elif o == '--wait':
- input("Press any key to continue...")
+ ns = _parse_args(sys.argv[1:], **kwargs)
+
+ if ns.huntrleaks:
+ # Avoid false positives due to various caches
+ # filling slowly with random data:
+ warm_caches()
+ if ns.memlimit is not None:
+ support.set_memlimit(ns.memlimit)
+ if ns.threshold is not None:
+ import gc
+ gc.set_threshold(ns.threshold)
+ if ns.nowindows:
+ import msvcrt
+ msvcrt.SetErrorMode(msvcrt.SEM_FAILCRITICALERRORS|
+ msvcrt.SEM_NOALIGNMENTFAULTEXCEPT|
+ msvcrt.SEM_NOGPFAULTERRORBOX|
+ msvcrt.SEM_NOOPENFILEERRORBOX)
+ try:
+ msvcrt.CrtSetReportMode
+ except AttributeError:
+ # release build
+ pass
else:
- print(("No handler for option {}. Please report this as a bug "
- "at http://bugs.python.org.").format(o), file=sys.stderr)
- sys.exit(1)
- if single and fromfile:
- usage("-s and -f don't go together!")
- if use_mp and trace:
- usage("-T and -j don't go together!")
- if use_mp and findleaks:
- usage("-l and -j don't go together!")
- if use_mp and support.max_memuse:
- usage("-M and -j don't go together!")
- if failfast and not (verbose or verbose3):
- usage("-G/--failfast needs either -v or -W")
+ for m in [msvcrt.CRT_WARN, msvcrt.CRT_ERROR, msvcrt.CRT_ASSERT]:
+ msvcrt.CrtSetReportMode(m, msvcrt.CRTDBG_MODE_FILE)
+ msvcrt.CrtSetReportFile(m, msvcrt.CRTDBG_FILE_STDERR)
+ if ns.wait:
+ input("Press any key to continue...")
+
+ if ns.slaveargs is not None:
+ args, kwargs = json.loads(ns.slaveargs)
+ if kwargs.get('huntrleaks'):
+ unittest.BaseTestSuite._cleanup = False
+ try:
+ result = runtest(*args, **kwargs)
+ except KeyboardInterrupt:
+ result = INTERRUPTED, ''
+ except BaseException as e:
+ traceback.print_exc()
+ result = CHILD_ERROR, str(e)
+ sys.stdout.flush()
+ print() # Force a newline (just in case)
+ print(json.dumps(result))
+ sys.exit(0)
good = []
bad = []
@@ -487,12 +547,12 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
environment_changed = []
interrupted = False
- if findleaks:
+ if ns.findleaks:
try:
import gc
except ImportError:
print('No GC available, disabling findleaks.')
- findleaks = False
+ ns.findleaks = False
else:
# Uncomment the line below to report garbage that is not
# freeable by reference counting alone. By default only
@@ -500,42 +560,43 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
#gc.set_debug(gc.DEBUG_SAVEALL)
found_garbage = []
- if single:
+ if ns.huntrleaks:
+ unittest.BaseTestSuite._cleanup = False
+
+ if ns.single:
filename = os.path.join(TEMPDIR, 'pynexttest')
try:
- fp = open(filename, 'r')
- next_test = fp.read().strip()
- tests = [next_test]
- fp.close()
- except IOError:
+ with open(filename, 'r') as fp:
+ next_test = fp.read().strip()
+ tests = [next_test]
+ except OSError:
pass
- if fromfile:
+ if ns.fromfile:
tests = []
- fp = open(os.path.join(support.SAVEDCWD, fromfile))
- count_pat = re.compile(r'\[\s*\d+/\s*\d+\]')
- for line in fp:
- line = count_pat.sub('', line)
- guts = line.split() # assuming no test has whitespace in its name
- if guts and not guts[0].startswith('#'):
- tests.extend(guts)
- fp.close()
+ with open(os.path.join(support.SAVEDCWD, ns.fromfile)) as fp:
+ count_pat = re.compile(r'\[\s*\d+/\s*\d+\]')
+ for line in fp:
+ line = count_pat.sub('', line)
+ guts = line.split() # assuming no test has whitespace in its name
+ if guts and not guts[0].startswith('#'):
+ tests.extend(guts)
# Strip .py extensions.
- removepy(args)
+ removepy(ns.args)
removepy(tests)
stdtests = STDTESTS[:]
nottests = NOTTESTS.copy()
- if exclude:
- for arg in args:
+ if ns.exclude:
+ for arg in ns.args:
if arg in stdtests:
stdtests.remove(arg)
nottests.add(arg)
- args = []
+ ns.args = []
# For a partial run, we do not need to clutter the output.
- if verbose or header or not (quiet or single or tests or args):
+ if ns.verbose or ns.header or not (ns.quiet or ns.single or tests or ns.args):
# Print basic platform information
print("==", platform.python_implementation(), *sys.version.split())
print("== ", platform.platform(aliased=True),
@@ -545,37 +606,39 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
# if testdir is set, then we are not running the python tests suite, so
# don't add default tests to be executed or skipped (pass empty values)
- if testdir:
- alltests = findtests(testdir, list(), set())
+ if ns.testdir:
+ alltests = findtests(ns.testdir, list(), set())
else:
- alltests = findtests(testdir, stdtests, nottests)
+ alltests = findtests(ns.testdir, stdtests, nottests)
- selected = tests or args or alltests
- if single:
+ selected = tests or ns.args or alltests
+ if ns.single:
selected = selected[:1]
try:
next_single_test = alltests[alltests.index(selected[0])+1]
except IndexError:
next_single_test = None
# Remove all the selected tests that precede start if it's set.
- if start:
+ if ns.start:
try:
- del selected[:selected.index(start)]
+ del selected[:selected.index(ns.start)]
except ValueError:
- print("Couldn't find starting test (%s), using all tests" % start)
- if randomize:
- random.seed(random_seed)
- print("Using random seed", random_seed)
+ print("Couldn't find starting test (%s), using all tests" % ns.start)
+ if ns.randomize:
+ if ns.random_seed is None:
+ ns.random_seed = random.randrange(10000000)
+ random.seed(ns.random_seed)
+ print("Using random seed", ns.random_seed)
random.shuffle(selected)
- if trace:
+ if ns.trace:
import trace, tempfile
tracer = trace.Trace(ignoredirs=[sys.base_prefix, sys.base_exec_prefix,
tempfile.gettempdir()],
trace=False, count=True)
test_times = []
- support.verbose = verbose # Tell tests to be moderately quiet
- support.use_resources = use_resources
+ support.verbose = ns.verbose # Tell tests to be moderately quiet
+ support.use_resources = ns.use_resources
save_modules = sys.modules.keys()
def accumulate_result(test, result):
@@ -593,7 +656,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
skipped.append(test)
resource_denieds.append(test)
- if forever:
+ if ns.forever:
def test_forever(tests=list(selected)):
while True:
for test in tests:
@@ -608,19 +671,16 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
test_count = '/{}'.format(len(selected))
test_count_width = len(test_count) - 1
- if use_mp:
+ if ns.use_mp:
try:
from threading import Thread
except ImportError:
print("Multiprocess option requires thread support")
sys.exit(2)
from queue import Queue
- from subprocess import Popen, PIPE
- debug_output_pat = re.compile(r"\[\d+ refs\]$")
+ debug_output_pat = re.compile(r"\[\d+ refs, \d+ blocks\]$")
output = Queue()
pending = MultiprocessTests(tests)
- opt_args = support.args_from_interpreter_flags()
- base_cmd = [sys.executable] + opt_args + ['-m', 'test.regrtest']
def work():
# A worker thread.
try:
@@ -630,24 +690,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
except StopIteration:
output.put((None, None, None, None))
return
- args_tuple = (
- (test, verbose, quiet),
- dict(huntrleaks=huntrleaks, use_resources=use_resources,
- debug=debug, output_on_failure=verbose3,
- timeout=timeout, failfast=failfast,
- match_tests=match_tests)
- )
- # -E is needed by some tests, e.g. test_import
- # Running the child from the same working directory ensures
- # that TEMPDIR for the child is the same when
- # sysconfig.is_python_build() is true. See issue 15300.
- popen = Popen(base_cmd + ['--slaveargs', json.dumps(args_tuple)],
- stdout=PIPE, stderr=PIPE,
- universal_newlines=True,
- close_fds=(os.name != 'nt'),
- cwd=support.SAVEDCWD)
- stdout, stderr = popen.communicate()
- retcode = popen.wait()
+ retcode, stdout, stderr = run_test_in_subprocess(test, ns)
# Strip last refcount output line if it exists, since it
# comes from the shutdown of the interpreter in the subcommand.
stderr = debug_output_pat.sub("", stderr)
@@ -664,19 +707,19 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
except BaseException:
output.put((None, None, None, None))
raise
- workers = [Thread(target=work) for i in range(use_mp)]
+ workers = [Thread(target=work) for i in range(ns.use_mp)]
for worker in workers:
worker.start()
finished = 0
test_index = 1
try:
- while finished < use_mp:
+ while finished < ns.use_mp:
test, stdout, stderr, result = output.get()
if test is None:
finished += 1
continue
accumulate_result(test, result)
- if not quiet:
+ if not ns.quiet:
fmt = "[{1:{0}}{2}/{3}] {4}" if bad else "[{1:{0}}{2}] {4}"
print(fmt.format(
test_count_width, test_index, test_count,
@@ -699,29 +742,30 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
worker.join()
else:
for test_index, test in enumerate(tests, 1):
- if not quiet:
+ if not ns.quiet:
fmt = "[{1:{0}}{2}/{3}] {4}" if bad else "[{1:{0}}{2}] {4}"
print(fmt.format(
test_count_width, test_index, test_count, len(bad), test))
sys.stdout.flush()
- if trace:
+ if ns.trace:
# If we're tracing code coverage, then we don't exit with status
# if on a false return value from main.
- tracer.runctx('runtest(test, verbose, quiet, timeout=timeout)',
+ tracer.runctx('runtest(test, ns.verbose, ns.quiet, timeout=ns.timeout)',
globals=globals(), locals=vars())
else:
try:
- result = runtest(test, verbose, quiet, huntrleaks, debug,
- output_on_failure=verbose3,
- timeout=timeout, failfast=failfast,
- match_tests=match_tests)
+ result = runtest(test, ns.verbose, ns.quiet,
+ ns.huntrleaks,
+ output_on_failure=ns.verbose3,
+ timeout=ns.timeout, failfast=ns.failfast,
+ match_tests=ns.match_tests)
accumulate_result(test, result)
except KeyboardInterrupt:
interrupted = True
break
except:
raise
- if findleaks:
+ if ns.findleaks:
gc.collect()
if gc.garbage:
print("Warning: test created", len(gc.garbage), end=' ')
@@ -742,11 +786,11 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
omitted = set(selected) - set(good) - set(bad) - set(skipped)
print(count(len(omitted), "test"), "omitted:")
printlist(omitted)
- if good and not quiet:
+ if good and not ns.quiet:
if not bad and not skipped and not interrupted and len(good) > 1:
print("All", end=' ')
print(count(len(good), "test"), "OK.")
- if print_slow:
+ if ns.print_slow:
test_times.sort(reverse=True)
print("10 slowest tests:")
for time, test in test_times[:10]:
@@ -760,32 +804,19 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
print("{} altered the execution environment:".format(
count(len(environment_changed), "test")))
printlist(environment_changed)
- if skipped and not quiet:
+ if skipped and not ns.quiet:
print(count(len(skipped), "test"), "skipped:")
printlist(skipped)
- e = _ExpectedSkips()
- plat = sys.platform
- if e.isvalid():
- surprise = set(skipped) - e.getexpected() - set(resource_denieds)
- if surprise:
- print(count(len(surprise), "skip"), \
- "unexpected on", plat + ":")
- printlist(surprise)
- else:
- print("Those skips are all expected on", plat + ".")
- else:
- print("Ask someone to teach regrtest.py about which tests are")
- print("expected to get skipped on", plat + ".")
-
- if verbose2 and bad:
+ if ns.verbose2 and bad:
print("Re-running failed tests in verbose mode")
for test in bad:
print("Re-running test %r in verbose mode" % test)
sys.stdout.flush()
try:
- verbose = True
- ok = runtest(test, True, quiet, huntrleaks, debug, timeout=timeout)
+ ns.verbose = True
+ ok = runtest(test, True, ns.quiet, ns.huntrleaks,
+ timeout=ns.timeout)
except KeyboardInterrupt:
# print a newline separate from the ^C
print()
@@ -793,18 +824,18 @@ def main(tests=None, testdir=None, verbose=0, quiet=False,
except:
raise
- if single:
+ if ns.single:
if next_single_test:
with open(filename, 'w') as fp:
fp.write(next_single_test + '\n')
else:
os.unlink(filename)
- if trace:
+ if ns.trace:
r = tracer.results()
- r.write_results(show_missing=True, summary=True, coverdir=coverdir)
+ r.write_results(show_missing=True, summary=True, coverdir=ns.coverdir)
- if runleaks:
+ if ns.runleaks:
os.system("leaks %d" % os.getpid())
sys.exit(len(bad) > 0 or interrupted)
@@ -877,7 +908,7 @@ def replace_stdout():
atexit.register(restore_stdout)
def runtest(test, verbose, quiet,
- huntrleaks=False, debug=False, use_resources=None,
+ huntrleaks=False, use_resources=None,
output_on_failure=False, failfast=False, match_tests=None,
timeout=None):
"""Run a single test.
@@ -885,14 +916,15 @@ def runtest(test, verbose, quiet,
test -- the name of the test
verbose -- if true, print more messages
quiet -- if true, don't print 'skipped' messages (probably redundant)
- test_times -- a list of (time, test_name) pairs
huntrleaks -- run multiple times to test for leaks; requires a debug
build; a triple corresponding to -R's three arguments
+ use_resources -- list of extra resources to use
output_on_failure -- if true, display test output on failure
timeout -- dump the traceback and exit if a test takes more than
timeout seconds
+ failfast, match_tests -- See regrtest command-line flags for these.
- Returns one of the test result constants:
+ Returns the tuple result, test_time, where result is one of the constants:
INTERRUPTED KeyboardInterrupt when run under -j
RESOURCE_DENIED test skipped because resource denied
SKIPPED test skipped for some other reason
@@ -930,7 +962,7 @@ def runtest(test, verbose, quiet,
sys.stdout = stream
sys.stderr = stream
result = runtest_inner(test, verbose, quiet, huntrleaks,
- debug, display_failure=False)
+ display_failure=False)
if result[0] == FAILED:
output = stream.getvalue()
orig_stderr.write(output)
@@ -940,7 +972,7 @@ def runtest(test, verbose, quiet,
sys.stderr = orig_stderr
else:
support.verbose = verbose # Tell tests to be moderately quiet
- result = runtest_inner(test, verbose, quiet, huntrleaks, debug,
+ result = runtest_inner(test, verbose, quiet, huntrleaks,
display_failure=not verbose)
return result
finally:
@@ -992,10 +1024,12 @@ class saved_test_environment:
'os.environ', 'sys.path', 'sys.path_hooks', '__import__',
'warnings.filters', 'asyncore.socket_map',
'logging._handlers', 'logging._handlerList', 'sys.gettrace',
- 'sys.warnoptions', 'threading._dangling',
- 'multiprocessing.process._dangling',
+ 'sys.warnoptions',
+ # multiprocessing.process._cleanup() may release ref
+ # to a thread, so check processes first.
+ 'multiprocessing.process._dangling', 'threading._dangling',
'sysconfig._CONFIG_VARS', 'sysconfig._INSTALL_SCHEMES',
- 'support.TESTFN',
+ 'support.TESTFN', 'locale', 'warnings.showwarning',
)
def get_sys_argv(self):
@@ -1123,6 +1157,8 @@ class saved_test_environment:
def get_multiprocessing_process__dangling(self):
if not multiprocessing:
return None
+ # Unjoined process objects can survive after process exits
+ multiprocessing.process._cleanup()
# This copies the weakrefs without making any strong reference
return multiprocessing.process._dangling.copy()
def restore_multiprocessing_process__dangling(self, saved):
@@ -1164,6 +1200,25 @@ class saved_test_environment:
elif os.path.isdir(support.TESTFN):
shutil.rmtree(support.TESTFN)
+ _lc = [getattr(locale, lc) for lc in dir(locale)
+ if lc.startswith('LC_')]
+ def get_locale(self):
+ pairings = []
+ for lc in self._lc:
+ try:
+ pairings.append((lc, locale.setlocale(lc, None)))
+ except (TypeError, ValueError):
+ continue
+ return pairings
+ def restore_locale(self, saved):
+ for lc, setting in saved:
+ locale.setlocale(lc, setting)
+
+ def get_warnings_showwarning(self):
+ return warnings.showwarning
+ def restore_warnings_showwarning(self, fxn):
+ warnings.showwarning = fxn
+
def resource_info(self):
for name in self.resources:
method_suffix = name.replace('.', '_')
@@ -1198,7 +1253,7 @@ class saved_test_environment:
def runtest_inner(test, verbose, quiet,
- huntrleaks=False, debug=False, display_failure=True):
+ huntrleaks=False, display_failure=True):
support.unload(test)
test_time = 0.0
@@ -1211,8 +1266,7 @@ def runtest_inner(test, verbose, quiet,
abstest = 'test.' + test
with saved_test_environment(test, verbose, quiet) as environment:
start_time = time.time()
- the_package = __import__(abstest, globals(), locals(), [])
- the_module = getattr(the_package, test)
+ the_module = importlib.import_module(abstest)
# If the test has a test_main, that will run the appropriate
# tests. If not, use normal unittest test loading.
test_runner = getattr(the_module, "test_main", None)
@@ -1221,8 +1275,7 @@ def runtest_inner(test, verbose, quiet,
test_runner = lambda: support.run_unittest(tests)
test_runner()
if huntrleaks:
- refleak = dash_R(the_module, test, test_runner,
- huntrleaks)
+ refleak = dash_R(the_module, test, test_runner, huntrleaks)
test_time = time.time() - start_time
except support.ResourceDenied as msg:
if not quiet:
@@ -1328,41 +1381,50 @@ def dash_R(the_module, test, indirect_test, huntrleaks):
for obj in abc.__subclasses__() + [abc]:
abcs[obj] = obj._abc_registry.copy()
- if indirect_test:
- def run_the_test():
- indirect_test()
- else:
- def run_the_test():
- del sys.modules[the_module.__name__]
- exec('import ' + the_module.__name__)
-
- deltas = []
nwarmup, ntracked, fname = huntrleaks
fname = os.path.join(support.SAVEDCWD, fname)
repcount = nwarmup + ntracked
+ rc_deltas = [0] * repcount
+ alloc_deltas = [0] * repcount
+
print("beginning", repcount, "repetitions", file=sys.stderr)
print(("1234567890"*(repcount//10 + 1))[:repcount], file=sys.stderr)
sys.stderr.flush()
- dash_R_cleanup(fs, ps, pic, zdc, abcs)
for i in range(repcount):
- rc_before = sys.gettotalrefcount()
- run_the_test()
+ indirect_test()
+ alloc_after, rc_after = dash_R_cleanup(fs, ps, pic, zdc, abcs)
sys.stderr.write('.')
sys.stderr.flush()
- dash_R_cleanup(fs, ps, pic, zdc, abcs)
- rc_after = sys.gettotalrefcount()
if i >= nwarmup:
- deltas.append(rc_after - rc_before)
+ rc_deltas[i] = rc_after - rc_before
+ alloc_deltas[i] = alloc_after - alloc_before
+ alloc_before, rc_before = alloc_after, rc_after
print(file=sys.stderr)
- if any(deltas):
- msg = '%s leaked %s references, sum=%s' % (test, deltas, sum(deltas))
- print(msg, file=sys.stderr)
- sys.stderr.flush()
- with open(fname, "a") as refrep:
- print(msg, file=refrep)
- refrep.flush()
- return True
- return False
+ # These checkers return False on success, True on failure
+ def check_rc_deltas(deltas):
+ return any(deltas)
+ def check_alloc_deltas(deltas):
+ # At least 1/3rd of 0s
+ if 3 * deltas.count(0) < len(deltas):
+ return True
+ # Nothing else than 1s, 0s and -1s
+ if not set(deltas) <= {1,0,-1}:
+ return True
+ return False
+ failed = False
+ for deltas, item_name, checker in [
+ (rc_deltas, 'references', check_rc_deltas),
+ (alloc_deltas, 'memory blocks', check_alloc_deltas)]:
+ if checker(deltas):
+ msg = '%s leaked %s %s, sum=%s' % (
+ test, deltas[nwarmup:], item_name, sum(deltas))
+ print(msg, file=sys.stderr)
+ sys.stderr.flush()
+ with open(fname, "a") as refrep:
+ print(msg, file=refrep)
+ refrep.flush()
+ failed = True
+ return failed
def dash_R_cleanup(fs, ps, pic, zdc, abcs):
import gc, copyreg
@@ -1428,8 +1490,11 @@ def dash_R_cleanup(fs, ps, pic, zdc, abcs):
else:
ctypes._reset_cache()
- # Collect cyclic trash.
+ # Collect cyclic trash and read memory statistics immediately after.
+ func1 = sys.getallocatedblocks
+ func2 = sys.gettotalrefcount
gc.collect()
+ return func1(), func2()
def warm_caches():
# char cache
@@ -1472,307 +1537,10 @@ def printlist(x, width=70, indent=4):
print(fill(' '.join(str(elt) for elt in sorted(x)), width,
initial_indent=blanks, subsequent_indent=blanks))
-# Map sys.platform to a string containing the basenames of tests
-# expected to be skipped on that platform.
-#
-# Special cases:
-# test_pep277
-# The _ExpectedSkips constructor adds this to the set of expected
-# skips if not os.path.supports_unicode_filenames.
-# test_timeout
-# Controlled by test_timeout.skip_expected. Requires the network
-# resource and a socket module.
-#
-# Tests that are expected to be skipped everywhere except on one platform
-# are also handled separately.
-
-_expectations = (
- ('win32',
- """
- test__locale
- test_crypt
- test_curses
- test_dbm
- test_devpoll
- test_fcntl
- test_fork1
- test_epoll
- test_dbm_gnu
- test_dbm_ndbm
- test_grp
- test_ioctl
- test_largefile
- test_kqueue
- test_openpty
- test_ossaudiodev
- test_pipes
- test_poll
- test_posix
- test_pty
- test_pwd
- test_resource
- test_signal
- test_syslog
- test_threadsignals
- test_wait3
- test_wait4
- """),
- ('linux',
- """
- test_curses
- test_devpoll
- test_largefile
- test_kqueue
- test_ossaudiodev
- """),
- ('unixware',
- """
- test_epoll
- test_largefile
- test_kqueue
- test_minidom
- test_openpty
- test_pyexpat
- test_sax
- test_sundry
- """),
- ('openunix',
- """
- test_epoll
- test_largefile
- test_kqueue
- test_minidom
- test_openpty
- test_pyexpat
- test_sax
- test_sundry
- """),
- ('sco_sv',
- """
- test_asynchat
- test_fork1
- test_epoll
- test_gettext
- test_largefile
- test_locale
- test_kqueue
- test_minidom
- test_openpty
- test_pyexpat
- test_queue
- test_sax
- test_sundry
- test_thread
- test_threaded_import
- test_threadedtempfile
- test_threading
- """),
- ('darwin',
- """
- test__locale
- test_curses
- test_devpoll
- test_epoll
- test_dbm_gnu
- test_gdb
- test_largefile
- test_locale
- test_minidom
- test_ossaudiodev
- test_poll
- """),
- ('sunos',
- """
- test_curses
- test_dbm
- test_epoll
- test_kqueue
- test_dbm_gnu
- test_gzip
- test_openpty
- test_zipfile
- test_zlib
- """),
- ('hp-ux',
- """
- test_curses
- test_epoll
- test_dbm_gnu
- test_gzip
- test_largefile
- test_locale
- test_kqueue
- test_minidom
- test_openpty
- test_pyexpat
- test_sax
- test_zipfile
- test_zlib
- """),
- ('cygwin',
- """
- test_curses
- test_dbm
- test_devpoll
- test_epoll
- test_ioctl
- test_kqueue
- test_largefile
- test_locale
- test_ossaudiodev
- test_socketserver
- """),
- ('os2emx',
- """
- test_audioop
- test_curses
- test_epoll
- test_kqueue
- test_largefile
- test_mmap
- test_openpty
- test_ossaudiodev
- test_pty
- test_resource
- test_signal
- """),
- ('freebsd',
- """
- test_devpoll
- test_epoll
- test_dbm_gnu
- test_locale
- test_ossaudiodev
- test_pep277
- test_pty
- test_socketserver
- test_tcl
- test_tk
- test_ttk_guionly
- test_ttk_textonly
- test_timeout
- test_urllibnet
- test_multiprocessing
- """),
- ('aix',
- """
- test_bz2
- test_epoll
- test_dbm_gnu
- test_gzip
- test_kqueue
- test_ossaudiodev
- test_tcl
- test_tk
- test_ttk_guionly
- test_ttk_textonly
- test_zipimport
- test_zlib
- """),
- ('openbsd',
- """
- test_ctypes
- test_devpoll
- test_epoll
- test_dbm_gnu
- test_locale
- test_normalization
- test_ossaudiodev
- test_pep277
- test_tcl
- test_tk
- test_ttk_guionly
- test_ttk_textonly
- test_multiprocessing
- """),
- ('netbsd',
- """
- test_ctypes
- test_curses
- test_devpoll
- test_epoll
- test_dbm_gnu
- test_locale
- test_ossaudiodev
- test_pep277
- test_tcl
- test_tk
- test_ttk_guionly
- test_ttk_textonly
- test_multiprocessing
- """),
-)
-
-class _ExpectedSkips:
- def __init__(self):
- import os.path
- from test import test_timeout
-
- self.valid = False
- expected = None
- for item in _expectations:
- if sys.platform.startswith(item[0]):
- expected = item[1]
- break
- if expected is not None:
- self.expected = set(expected.split())
-
- # These are broken tests, for now skipped on every platform.
- # XXX Fix these!
- self.expected.add('test_nis')
-
- # expected to be skipped on every platform, even Linux
- if not os.path.supports_unicode_filenames:
- self.expected.add('test_pep277')
-
- # doctest, profile and cProfile tests fail when the codec for the
- # fs encoding isn't built in because PyUnicode_Decode() adds two
- # calls into Python.
- encs = ("utf-8", "latin-1", "ascii", "mbcs", "utf-16", "utf-32")
- if sys.getfilesystemencoding().lower() not in encs:
- self.expected.add('test_profile')
- self.expected.add('test_cProfile')
- self.expected.add('test_doctest')
-
- if test_timeout.skip_expected:
- self.expected.add('test_timeout')
-
- if sys.platform != "win32":
- # test_sqlite is only reliable on Windows where the library
- # is distributed with Python
- WIN_ONLY = {"test_unicode_file", "test_winreg",
- "test_winsound", "test_startfile",
- "test_sqlite", "test_msilib"}
- self.expected |= WIN_ONLY
-
- if sys.platform != 'sunos5':
- self.expected.add('test_nis')
-
- if support.python_is_optimized():
- self.expected.add("test_gdb")
-
- self.valid = True
-
- def isvalid(self):
- "Return true iff _ExpectedSkips knows about the current platform."
- return self.valid
-
- def getexpected(self):
- """Return set of test names we expect to skip on current platform.
-
- self.isvalid() must be true.
- """
-
- assert self.isvalid()
- return self.expected
-
-def _make_temp_dir_for_build(TEMPDIR):
- # When tests are run from the Python build directory, it is best practice
- # to keep the test files in a subfolder. It eases the cleanup of leftover
- # files using command "make distclean".
+
+def main_in_temp_cwd():
+ """Run main() in a temporary working directory."""
if sysconfig.is_python_build():
- TEMPDIR = os.path.join(sysconfig.get_config_var('srcdir'), 'build')
- TEMPDIR = os.path.abspath(TEMPDIR)
try:
os.mkdir(TEMPDIR)
except FileExistsError:
@@ -1781,10 +1549,16 @@ def _make_temp_dir_for_build(TEMPDIR):
# Define a writable temp dir that will be used as cwd while running
# the tests. The name of the dir includes the pid to allow parallel
# testing (see the -j option).
- TESTCWD = 'test_python_{}'.format(os.getpid())
+ test_cwd = 'test_python_{}'.format(os.getpid())
+ test_cwd = os.path.join(TEMPDIR, test_cwd)
+
+ # Run the tests in a context manager that temporarily changes the CWD to a
+ # temporary and writable directory. If it's not possible to create or
+ # change the CWD, the original CWD will be used. The original CWD is
+ # available from support.SAVEDCWD.
+ with support.temp_cwd(test_cwd, quiet=True):
+ main()
- TESTCWD = os.path.join(TEMPDIR, TESTCWD)
- return TEMPDIR, TESTCWD
if __name__ == '__main__':
# Remove regrtest.py's own directory from the module search path. Despite
@@ -1808,11 +1582,4 @@ if __name__ == '__main__':
# sanity check
assert __file__ == os.path.abspath(sys.argv[0])
- TEMPDIR, TESTCWD = _make_temp_dir_for_build(TEMPDIR)
-
- # Run the tests in a context manager that temporary changes the CWD to a
- # temporary and writable directory. If it's not possible to create or
- # change the CWD, the original CWD will be used. The original CWD is
- # available from support.SAVEDCWD.
- with support.temp_cwd(TESTCWD, quiet=True):
- main()
+ main_in_temp_cwd()
diff --git a/Lib/test/script_helper.py b/Lib/test/script_helper.py
index ab201649e5..4d5c1f120a 100644
--- a/Lib/test/script_helper.py
+++ b/Lib/test/script_helper.py
@@ -12,13 +12,22 @@ import contextlib
import shutil
import zipfile
-from imp import source_from_cache
+from importlib.util import source_from_cache
from test.support import make_legacy_pyc, strip_python_stderr, temp_dir
# Executing the interpreter in a subprocess
def _assert_python(expected_success, *args, **env_vars):
- cmd_line = [sys.executable]
- if not env_vars:
+ if '__isolated' in env_vars:
+ isolated = env_vars.pop('__isolated')
+ else:
+ isolated = not env_vars
+ cmd_line = [sys.executable, '-X', 'faulthandler']
+ if isolated:
+ # isolated mode: ignore Python environment variables, ignore user
+ # site-packages, and don't add the current directory to sys.path
+ cmd_line.append('-I')
+ elif not env_vars:
+ # ignore Python environment variables
cmd_line.append('-E')
# Need to preserve the original environment, for in-place testing of
# shared library builds.
@@ -39,7 +48,7 @@ def _assert_python(expected_success, *args, **env_vars):
p.stdout.close()
p.stderr.close()
rc = p.returncode
- err = strip_python_stderr(err)
+ err = strip_python_stderr(err)
if (rc and expected_success) or (not rc and not expected_success):
raise AssertionError(
"Process return code is %d, "
@@ -49,18 +58,32 @@ def _assert_python(expected_success, *args, **env_vars):
def assert_python_ok(*args, **env_vars):
"""
Assert that running the interpreter with `args` and optional environment
- variables `env_vars` is ok and return a (return code, stdout, stderr) tuple.
+ variables `env_vars` succeeds (rc == 0) and return a (return code, stdout,
+ stderr) tuple.
+
+ If the __cleanenv keyword is set, env_vars is used a fresh environment.
+
+ Python is started in isolated mode (command line option -I),
+ except if the __isolated keyword is set to False.
"""
return _assert_python(True, *args, **env_vars)
def assert_python_failure(*args, **env_vars):
"""
Assert that running the interpreter with `args` and optional environment
- variables `env_vars` fails and return a (return code, stdout, stderr) tuple.
+ variables `env_vars` fails (rc != 0) and return a (return code, stdout,
+ stderr) tuple.
+
+ See assert_python_ok() for more options.
"""
return _assert_python(False, *args, **env_vars)
def spawn_python(*args, **kw):
+ """Run a Python subprocess with the given arguments.
+
+ kw is extra keyword args to pass to subprocess.Popen. Returns a Popen
+ object.
+ """
cmd_line = [sys.executable, '-E']
cmd_line.extend(args)
return subprocess.Popen(cmd_line, stdin=subprocess.PIPE,
@@ -68,6 +91,7 @@ def spawn_python(*args, **kw):
**kw)
def kill_python(p):
+ """Run the given Popen process until completion and return stdout."""
p.stdin.close()
data = p.stdout.read()
p.stdout.close()
diff --git a/Lib/test/sortperf.py b/Lib/test/sortperf.py
index af7c0b4bd1..90722f7abc 100644
--- a/Lib/test/sortperf.py
+++ b/Lib/test/sortperf.py
@@ -22,7 +22,7 @@ def randfloats(n):
fn = os.path.join(td, "rr%06d" % n)
try:
fp = open(fn, "rb")
- except IOError:
+ except OSError:
r = random.random
result = [r() for i in range(n)]
try:
@@ -35,9 +35,9 @@ def randfloats(n):
if fp:
try:
os.unlink(fn)
- except os.error:
+ except OSError:
pass
- except IOError as msg:
+ except OSError as msg:
print("can't write", fn, ":", msg)
else:
result = marshal.load(fp)
diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py
index 8686153a17..759b3f487e 100644
--- a/Lib/test/ssl_servers.py
+++ b/Lib/test/ssl_servers.py
@@ -35,7 +35,7 @@ class HTTPSServer(_HTTPServer):
try:
sock, addr = self.socket.accept()
sslconn = self.context.wrap_socket(sock, server_side=True)
- except socket.error as e:
+ except OSError as e:
# socket errors are silenced by the caller, print them here
if support.verbose:
sys.stderr.write("Got an error:\n%s\n" % e)
@@ -147,9 +147,11 @@ class HTTPSServerThread(threading.Thread):
self.server.shutdown()
-def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None):
- # we assume the certfile contains both private key and certificate
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+def make_https_server(case, *, context=None, certfile=CERTFILE,
+ host=HOST, handler_class=None):
+ if context is None:
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ # We assume the certfile contains both private key and certificate
context.load_cert_chain(certfile)
server = HTTPSServerThread(context, host, handler_class)
flag = threading.Event()
diff --git a/Lib/test/subprocessdata/fd_status.py b/Lib/test/subprocessdata/fd_status.py
index 1f61e13a34..d12bd95abe 100644
--- a/Lib/test/subprocessdata/fd_status.py
+++ b/Lib/test/subprocessdata/fd_status.py
@@ -1,17 +1,27 @@
"""When called as a script, print a comma-separated list of the open
-file descriptors on stdout."""
+file descriptors on stdout.
+
+Usage:
+fd_stats.py: check all file descriptors
+fd_status.py fd1 fd2 ...: check only specified file descriptors
+"""
import errno
import os
-
-try:
- _MAXFD = os.sysconf("SC_OPEN_MAX")
-except:
- _MAXFD = 256
+import stat
+import sys
if __name__ == "__main__":
fds = []
- for fd in range(0, _MAXFD):
+ if len(sys.argv) == 1:
+ try:
+ _MAXFD = os.sysconf("SC_OPEN_MAX")
+ except:
+ _MAXFD = 256
+ test_fds = range(0, _MAXFD)
+ else:
+ test_fds = map(int, sys.argv[1:])
+ for fd in test_fds:
try:
st = os.fstat(fd)
except OSError as e:
@@ -19,6 +29,6 @@ if __name__ == "__main__":
continue
raise
# Ignore Solaris door files
- if st.st_mode & 0xF000 != 0xd000:
+ if not stat.S_ISDOOR(st.st_mode):
fds.append(fd)
print(','.join(map(str, fds)))
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index dbd7846737..20a5e85fdb 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -15,10 +15,10 @@ import shutil
import warnings
import unittest
import importlib
+import importlib.util
import collections.abc
import re
import subprocess
-import imp
import time
import sysconfig
import fnmatch
@@ -57,26 +57,49 @@ try:
except ImportError:
lzma = None
+try:
+ import resource
+except ImportError:
+ resource = None
+
__all__ = [
- "Error", "TestFailed", "ResourceDenied", "import_module", "verbose",
- "use_resources", "max_memuse", "record_original_stdout",
- "get_original_stdout", "unload", "unlink", "rmtree", "forget",
+ # globals
+ "PIPE_MAX_SIZE", "verbose", "max_memuse", "use_resources", "failfast",
+ # exceptions
+ "Error", "TestFailed", "ResourceDenied",
+ # imports
+ "import_module", "import_fresh_module", "CleanImport",
+ # modules
+ "unload", "forget",
+ # io
+ "record_original_stdout", "get_original_stdout", "captured_stdout",
+ "captured_stdin", "captured_stderr",
+ # filesystem
+ "TESTFN", "SAVEDCWD", "unlink", "rmtree", "temp_cwd", "findfile",
+ "create_empty_file", "can_symlink",
+ # unittest
"is_resource_enabled", "requires", "requires_freebsd_version",
- "requires_linux_version", "requires_mac_ver", "find_unused_port",
- "bind_port", "IPV6_ENABLED", "is_jython", "TESTFN", "HOST", "SAVEDCWD",
- "temp_cwd", "findfile", "create_empty_file", "sortdict",
- "check_syntax_error", "open_urlresource", "check_warnings", "CleanImport",
- "EnvironmentVarGuard", "TransientResource", "captured_stdout",
- "captured_stdin", "captured_stderr", "time_out", "socket_peer_reset",
- "ioerror_peer_reset", "run_with_locale", 'temp_umask',
- "transient_internet", "set_memlimit", "bigmemtest", "bigaddrspacetest",
- "BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
- "threading_cleanup", "reap_children", "cpython_only", "check_impl_detail",
- "get_attribute", "swap_item", "swap_attr", "requires_IEEE_754",
- "TestHandler", "Matcher", "can_symlink", "skip_unless_symlink",
- "skip_unless_xattr", "import_fresh_module", "requires_zlib",
- "PIPE_MAX_SIZE", "failfast", "anticipate_failure", "run_with_tz",
- "requires_gzip", "requires_bz2", "requires_lzma", "suppress_crash_popup",
+ "requires_linux_version", "requires_mac_ver", "check_syntax_error",
+ "TransientResource", "time_out", "socket_peer_reset", "ioerror_peer_reset",
+ "transient_internet", "BasicTestRunner", "run_unittest", "run_doctest",
+ "skip_unless_symlink", "requires_gzip", "requires_bz2", "requires_lzma",
+ "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute",
+ "requires_IEEE_754", "skip_unless_xattr", "requires_zlib",
+ "anticipate_failure",
+ # sys
+ "is_jython", "check_impl_detail",
+ # network
+ "HOST", "IPV6_ENABLED", "find_unused_port", "bind_port", "open_urlresource",
+ # processes
+ 'temp_umask', "reap_children",
+ # logging
+ "TestHandler",
+ # threads
+ "threading_setup", "threading_cleanup",
+ # miscellaneous
+ "check_warnings", "EnvironmentVarGuard", "run_with_locale", "swap_item",
+ "swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict",
+ "run_with_tz",
]
class Error(Exception):
@@ -98,7 +121,8 @@ def _ignore_deprecated_imports(ignore=True):
"""Context manager to suppress package and module deprecation
warnings when importing them.
- If ignore is False, this context manager has no effect."""
+ If ignore is False, this context manager has no effect.
+ """
if ignore:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".+ (module|package)",
@@ -108,16 +132,21 @@ def _ignore_deprecated_imports(ignore=True):
yield
-def import_module(name, deprecated=False):
+def import_module(name, deprecated=False, *, required_on=()):
"""Import and return the module to be tested, raising SkipTest if
it is not available.
If deprecated is True, any module or package deprecation messages
- will be suppressed."""
+ will be suppressed. If a module is required on a platform but optional for
+ others, set required_on to an iterable of platform prefixes which will be
+ compared against sys.platform.
+ """
with _ignore_deprecated_imports(deprecated):
try:
return importlib.import_module(name)
except ImportError as msg:
+ if sys.platform.startswith(tuple(required_on)):
+ raise
raise unittest.SkipTest(str(msg))
@@ -303,25 +332,20 @@ else:
def unlink(filename):
try:
_unlink(filename)
- except OSError as error:
- # The filename need not exist.
- if error.errno not in (errno.ENOENT, errno.ENOTDIR):
- raise
+ except (FileNotFoundError, NotADirectoryError):
+ pass
def rmdir(dirname):
try:
_rmdir(dirname)
- except OSError as error:
- # The directory need not exist.
- if error.errno != errno.ENOENT:
- raise
+ except FileNotFoundError:
+ pass
def rmtree(path):
try:
_rmtree(path)
- except OSError as error:
- if error.errno != errno.ENOENT:
- raise
+ except FileNotFoundError:
+ pass
def make_legacy_pyc(source):
"""Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location.
@@ -333,7 +357,7 @@ def make_legacy_pyc(source):
does not need to exist, however the PEP 3147 pyc file must exist.
:return: The file system path to the legacy pyc file.
"""
- pyc_file = imp.cache_from_source(source)
+ pyc_file = importlib.util.cache_from_source(source)
up_one = os.path.dirname(os.path.abspath(source))
legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o'))
os.rename(pyc_file, legacy_pyc)
@@ -352,8 +376,8 @@ def forget(modname):
# combinations of PEP 3147 and legacy pyc and pyo files.
unlink(source + 'c')
unlink(source + 'o')
- unlink(imp.cache_from_source(source, debug_override=True))
- unlink(imp.cache_from_source(source, debug_override=False))
+ unlink(importlib.util.cache_from_source(source, debug_override=True))
+ unlink(importlib.util.cache_from_source(source, debug_override=False))
# On some platforms, should not run gui test even if it is allowed
# in `use_resources'.
@@ -514,7 +538,7 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
the SO_REUSEADDR socket option having different semantics on Windows versus
Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind,
listen and then accept connections on identical host/ports. An EADDRINUSE
- socket.error will be raised at some point (depending on the platform and
+ OSError will be raised at some point (depending on the platform and
the order bind and listen were called on each socket).
However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE
@@ -586,9 +610,9 @@ def _is_ipv6_enabled():
sock = None
try:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
- sock.bind(('::1', 0))
+ sock.bind((HOSTv6, 0))
return True
- except (socket.error, socket.gaierror):
+ except OSError:
pass
finally:
if sock:
@@ -1175,9 +1199,9 @@ class TransientResource(object):
# Context managers that raise ResourceDenied when various issues
# with the Internet connection manifest themselves as exceptions.
# XXX deprecate these and use transient_internet() instead
-time_out = TransientResource(IOError, errno=errno.ETIMEDOUT)
-socket_peer_reset = TransientResource(socket.error, errno=errno.ECONNRESET)
-ioerror_peer_reset = TransientResource(IOError, errno=errno.ECONNRESET)
+time_out = TransientResource(OSError, errno=errno.ETIMEDOUT)
+socket_peer_reset = TransientResource(OSError, errno=errno.ECONNRESET)
+ioerror_peer_reset = TransientResource(OSError, errno=errno.ECONNRESET)
@contextlib.contextmanager
@@ -1223,17 +1247,17 @@ def transient_internet(resource_name, *, timeout=30.0, errnos=()):
if timeout is not None:
socket.setdefaulttimeout(timeout)
yield
- except IOError as err:
+ except OSError as err:
# urllib can wrap original socket errors multiple times (!), we must
# unwrap to get at the original error.
while True:
a = err.args
- if len(a) >= 1 and isinstance(a[0], IOError):
+ if len(a) >= 1 and isinstance(a[0], OSError):
err = a[0]
# The error can also be wrapped as args[1]:
# except socket.error as msg:
- # raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
- elif len(a) >= 2 and isinstance(a[1], IOError):
+ # raise OSError('socket error', msg).with_traceback(sys.exc_info()[2])
+ elif len(a) >= 2 and isinstance(a[1], OSError):
err = a[1]
else:
break
@@ -1691,9 +1715,18 @@ def run_unittest(*classes):
#=======================================================================
# Check for the presence of docstrings.
-HAVE_DOCSTRINGS = (check_impl_detail(cpython=False) or
- sys.platform == 'win32' or
- sysconfig.get_config_var('WITH_DOC_STRINGS'))
+# Rather than trying to enumerate all the cases where docstrings may be
+# disabled, we just check for that directly
+
+def _check_docstrings():
+ """Just used to check if docstrings are enabled"""
+
+MISSING_C_DOCSTRINGS = (check_impl_detail() and
+ sys.platform != 'win32' and
+ not sysconfig.get_config_var('WITH_DOC_STRINGS'))
+
+HAVE_DOCSTRINGS = (_check_docstrings.__doc__ is not None and
+ not MISSING_C_DOCSTRINGS)
requires_docstrings = unittest.skipUnless(HAVE_DOCSTRINGS,
"test requires docstrings")
@@ -1768,12 +1801,12 @@ def threading_setup():
def threading_cleanup(*original_values):
if not _thread:
return
- _MAX_COUNT = 10
+ _MAX_COUNT = 100
for count in range(_MAX_COUNT):
values = _thread._count(), threading._dangling
if values == original_values:
break
- time.sleep(0.1)
+ time.sleep(0.01)
gc_collect()
# XXX print a warning in case of failure?
@@ -1875,7 +1908,7 @@ def strip_python_stderr(stderr):
This will typically be run on the result of the communicate() method
of a subprocess.Popen object.
"""
- stderr = re.sub(br"\[\d+ refs\]\r?\n?", b"", stderr).strip()
+ stderr = re.sub(br"\[\d+ refs, \d+ blocks\]\r?\n?", b"", stderr).strip()
return stderr
def args_from_interpreter_flags():
@@ -2006,27 +2039,67 @@ def skip_unless_xattr(test):
return test if ok else unittest.skip(msg)(test)
-if sys.platform.startswith('win'):
- @contextlib.contextmanager
- def suppress_crash_popup():
- """Disable Windows Error Reporting dialogs using SetErrorMode."""
- # see http://msdn.microsoft.com/en-us/library/windows/desktop/ms680621%28v=vs.85%29.aspx
- # GetErrorMode is not available on Windows XP and Windows Server 2003,
- # but SetErrorMode returns the previous value, so we can use that
- import ctypes
- k32 = ctypes.windll.kernel32
- SEM_NOGPFAULTERRORBOX = 0x02
- old_error_mode = k32.SetErrorMode(SEM_NOGPFAULTERRORBOX)
- k32.SetErrorMode(old_error_mode | SEM_NOGPFAULTERRORBOX)
- try:
- yield
- finally:
- k32.SetErrorMode(old_error_mode)
-else:
- # this is a no-op for other platforms
- @contextlib.contextmanager
- def suppress_crash_popup():
- yield
+class SuppressCrashReport:
+ """Try to prevent a crash report from popping up.
+
+ On Windows, don't display the Windows Error Reporting dialog. On UNIX,
+ disable the creation of coredump file.
+ """
+ old_value = None
+
+ def __enter__(self):
+ """On Windows, disable Windows Error Reporting dialogs using
+ SetErrorMode.
+
+ On UNIX, try to save the previous core file size limit, then set
+ soft limit to 0.
+ """
+ if sys.platform.startswith('win'):
+ # see http://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
+ # GetErrorMode is not available on Windows XP and Windows Server 2003,
+ # but SetErrorMode returns the previous value, so we can use that
+ import ctypes
+ self._k32 = ctypes.windll.kernel32
+ SEM_NOGPFAULTERRORBOX = 0x02
+ self.old_value = self._k32.SetErrorMode(SEM_NOGPFAULTERRORBOX)
+ self._k32.SetErrorMode(self.old_value | SEM_NOGPFAULTERRORBOX)
+ else:
+ if resource is not None:
+ try:
+ self.old_value = resource.getrlimit(resource.RLIMIT_CORE)
+ resource.setrlimit(resource.RLIMIT_CORE,
+ (0, self.old_value[1]))
+ except (ValueError, OSError):
+ pass
+ if sys.platform == 'darwin':
+ # Check if the 'Crash Reporter' on OSX was configured
+ # in 'Developer' mode and warn that it will get triggered
+ # when it is.
+ #
+ # This assumes that this context manager is used in tests
+ # that might trigger the next manager.
+ value = subprocess.Popen(['/usr/bin/defaults', 'read',
+ 'com.apple.CrashReporter', 'DialogType'],
+ stdout=subprocess.PIPE).communicate()[0]
+ if value.strip() == b'developer':
+ print("this test triggers the Crash Reporter, "
+ "that is intentional", end='', flush=True)
+
+ return self
+
+ def __exit__(self, *ignore_exc):
+ """Restore Windows ErrorMode or core file behavior to initial value."""
+ if self.old_value is None:
+ return
+
+ if sys.platform.startswith('win'):
+ self._k32.SetErrorMode(self.old_value)
+ else:
+ if resource is not None:
+ try:
+ resource.setrlimit(resource.RLIMIT_CORE, self.old_value)
+ except (ValueError, OSError):
+ pass
def patch(test_instance, object_to_patch, attr_name, new_value):
diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py
index 608ec01f14..8cc285f70a 100644
--- a/Lib/test/test___all__.py
+++ b/Lib/test/test___all__.py
@@ -29,17 +29,20 @@ class AllTest(unittest.TestCase):
if not hasattr(sys.modules[modname], "__all__"):
raise NoAll(modname)
names = {}
- try:
- exec("from %s import *" % modname, names)
- except Exception as e:
- # Include the module name in the exception string
- self.fail("__all__ failure in {}: {}: {}".format(
- modname, e.__class__.__name__, e))
- if "__builtins__" in names:
- del names["__builtins__"]
- keys = set(names)
- all = set(sys.modules[modname].__all__)
- self.assertEqual(keys, all)
+ with self.subTest(module=modname):
+ try:
+ exec("from %s import *" % modname, names)
+ except Exception as e:
+ # Include the module name in the exception string
+ self.fail("__all__ failure in {}: {}: {}".format(
+ modname, e.__class__.__name__, e))
+ if "__builtins__" in names:
+ del names["__builtins__"]
+ keys = set(names)
+ all_list = sys.modules[modname].__all__
+ all_set = set(all_list)
+ self.assertCountEqual(all_set, all_list, "in module {}".format(modname))
+ self.assertEqual(keys, all_set, "in module {}".format(modname))
def walk_modules(self, basedir, modpath):
for fn in sorted(os.listdir(basedir)):
@@ -110,8 +113,5 @@ class AllTest(unittest.TestCase):
print('Following modules failed to be imported:', failed_imports)
-def test_main():
- support.run_unittest(AllTest)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py
index d4d7556518..93f9dae215 100644
--- a/Lib/test/test_abc.py
+++ b/Lib/test/test_abc.py
@@ -68,6 +68,19 @@ class TestLegacyAPI(unittest.TestCase):
class TestABC(unittest.TestCase):
+ def test_ABC_helper(self):
+ # create an ABC using the helper class and perform basic checks
+ class C(abc.ABC):
+ @classmethod
+ @abc.abstractmethod
+ def foo(cls): return cls.__name__
+ self.assertEqual(type(C), abc.ABCMeta)
+ self.assertRaises(TypeError, C)
+ class D(C):
+ @classmethod
+ def foo(cls): return super().foo()
+ self.assertEqual(D.foo(), 'D')
+
def test_abstractmethod_basics(self):
@abc.abstractmethod
def foo(self): pass
@@ -288,7 +301,10 @@ class TestABC(unittest.TestCase):
b = B()
self.assertFalse(isinstance(b, A))
self.assertFalse(isinstance(b, (A,)))
+ token_old = abc.get_cache_token()
A.register(B)
+ token_new = abc.get_cache_token()
+ self.assertNotEqual(token_old, token_new)
self.assertTrue(isinstance(b, A))
self.assertTrue(isinstance(b, (A,)))
diff --git a/Lib/test/test_aifc.py b/Lib/test/test_aifc.py
index b77354bb07..3e9a1d13c2 100644
--- a/Lib/test/test_aifc.py
+++ b/Lib/test/test_aifc.py
@@ -166,6 +166,21 @@ class AifcMiscTest(audiotests.AudioTests, unittest.TestCase):
#This file contains chunk types aifc doesn't recognize.
self.f = aifc.open(findfile('Sine-1000Hz-300ms.aif'))
+ def test_params_added(self):
+ f = self.f = aifc.open(TESTFN, 'wb')
+ f.aiff()
+ f.setparams((1, 1, 1, 1, b'NONE', b''))
+ f.close()
+
+ f = self.f = aifc.open(TESTFN, 'rb')
+ params = f.getparams()
+ self.assertEqual(params.nchannels, f.getnchannels())
+ self.assertEqual(params.sampwidth, f.getsampwidth())
+ self.assertEqual(params.framerate, f.getframerate())
+ self.assertEqual(params.nframes, f.getnframes())
+ self.assertEqual(params.comptype, f.getcomptype())
+ self.assertEqual(params.compname, f.getcompname())
+
def test_write_header_comptype_sampwidth(self):
for comptype in (b'ULAW', b'ulaw', b'ALAW', b'alaw', b'G722'):
fout = aifc.open(io.BytesIO(), 'wb')
@@ -369,12 +384,14 @@ class AIFCLowLevelTest(unittest.TestCase):
def test_write_aiff_by_extension(self):
sampwidth = 2
- fout = self.fout = aifc.open(TESTFN + '.aiff', 'wb')
+ filename = TESTFN + '.aiff'
+ fout = self.fout = aifc.open(filename, 'wb')
+ self.addCleanup(unlink, filename)
fout.setparams((1, sampwidth, 1, 1, b'ULAW', b''))
frames = b'\x00' * fout.getnchannels() * sampwidth
fout.writeframes(frames)
fout.close()
- f = self.f = aifc.open(TESTFN + '.aiff', 'rb')
+ f = self.f = aifc.open(filename, 'rb')
self.assertEqual(f.getcomptype(), b'NONE')
f.close()
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index c06c940bf2..c10c5909bf 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -14,6 +14,7 @@ import argparse
from io import StringIO
from test import support
+from unittest import mock
class StdIOBuffer(StringIO):
pass
@@ -1421,6 +1422,19 @@ class TestFileTypeRepr(TestCase):
type = argparse.FileType('wb', 1)
self.assertEqual("FileType('wb', 1)", repr(type))
+ def test_r_latin(self):
+ type = argparse.FileType('r', encoding='latin_1')
+ self.assertEqual("FileType('r', encoding='latin_1')", repr(type))
+
+ def test_w_big5_ignore(self):
+ type = argparse.FileType('w', encoding='big5', errors='ignore')
+ self.assertEqual("FileType('w', encoding='big5', errors='ignore')",
+ repr(type))
+
+ def test_r_1_replace(self):
+ type = argparse.FileType('r', 1, errors='replace')
+ self.assertEqual("FileType('r', 1, errors='replace')", repr(type))
+
class RFile(object):
seen = {}
@@ -1557,6 +1571,24 @@ class TestFileTypeWB(TempDirMixin, ParserTestCase):
]
+class TestFileTypeOpenArgs(TestCase):
+ """Test that open (the builtin) is correctly called"""
+
+ def test_open_args(self):
+ FT = argparse.FileType
+ cases = [
+ (FT('rb'), ('rb', -1, None, None)),
+ (FT('w', 1), ('w', 1, None, None)),
+ (FT('w', errors='replace'), ('w', -1, None, 'replace')),
+ (FT('wb', encoding='big5'), ('wb', -1, 'big5', None)),
+ (FT('w', 0, 'l1', 'strict'), ('w', 0, 'l1', 'strict')),
+ ]
+ with mock.patch('builtins.open') as m:
+ for type, args in cases:
+ type('foo')
+ m.assert_called_with('foo', *args)
+
+
class TestTypeCallable(ParserTestCase):
"""Test some callables as option/argument types"""
@@ -4327,7 +4359,7 @@ class TestOptionalsHelpVersionActions(TestCase):
def test_version_format(self):
parser = ErrorRaisingArgumentParser(prog='PPP')
parser.add_argument('-v', '--version', action='version', version='%(prog)s 3.5')
- msg = self._get_error(parser.parse_args, ['-v']).stderr
+ msg = self._get_error(parser.parse_args, ['-v']).stdout
self.assertEqual('PPP 3.5\n', msg)
def test_version_no_help(self):
@@ -4340,7 +4372,7 @@ class TestOptionalsHelpVersionActions(TestCase):
def test_version_action(self):
parser = ErrorRaisingArgumentParser(prog='XXX')
parser.add_argument('-V', action='version', version='%(prog)s 3.7')
- msg = self._get_error(parser.parse_args, ['-V']).stderr
+ msg = self._get_error(parser.parse_args, ['-V']).stdout
self.assertEqual('XXX 3.7\n', msg)
def test_no_help(self):
diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py
index e57ff24774..a5c4f23181 100755
--- a/Lib/test/test_array.py
+++ b/Lib/test/test_array.py
@@ -354,12 +354,12 @@ class BaseTest:
support.unlink(support.TESTFN)
def test_fromfile_ioerror(self):
- # Issue #5395: Check if fromfile raises a proper IOError
+ # Issue #5395: Check if fromfile raises a proper OSError
# instead of EOFError.
a = array.array(self.typecode)
f = open(support.TESTFN, 'wb')
try:
- self.assertRaises(IOError, a.fromfile, f, len(self.example))
+ self.assertRaises(OSError, a.fromfile, f, len(self.example))
finally:
f.close()
support.unlink(support.TESTFN)
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 63528885e9..e69422a292 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -180,20 +180,36 @@ eval_tests = [
class AST_Tests(unittest.TestCase):
- def _assertTrueorder(self, ast_node, parent_pos):
+ def _assertTrueorder(self, ast_node, parent_pos, reverse_check = False):
+ def should_reverse_check(parent, child):
+ # In some situations, the children of nodes occur before
+ # their parents, for example in a.b.c, a occurs before b
+ # but a is a child of b.
+ if isinstance(parent, ast.Call):
+ if parent.func == child:
+ return True
+ if isinstance(parent, (ast.Attribute, ast.Subscript)):
+ return True
+ return False
+
if not isinstance(ast_node, ast.AST) or ast_node._fields is None:
return
if isinstance(ast_node, (ast.expr, ast.stmt, ast.excepthandler)):
node_pos = (ast_node.lineno, ast_node.col_offset)
- self.assertTrue(node_pos >= parent_pos)
+ if reverse_check:
+ self.assertTrue(node_pos <= parent_pos)
+ else:
+ self.assertTrue(node_pos >= parent_pos)
parent_pos = (ast_node.lineno, ast_node.col_offset)
for name in ast_node._fields:
value = getattr(ast_node, name)
if isinstance(value, list):
for child in value:
- self._assertTrueorder(child, parent_pos)
+ self._assertTrueorder(child, parent_pos,
+ should_reverse_check(ast_node, child))
elif value is not None:
- self._assertTrueorder(value, parent_pos)
+ self._assertTrueorder(value, parent_pos,
+ should_reverse_check(ast_node, value))
def test_AST_objects(self):
x = ast.AST()
@@ -262,14 +278,14 @@ class AST_Tests(unittest.TestCase):
def test_arguments(self):
x = ast.arguments()
- self.assertEqual(x._fields, ('args', 'vararg', 'varargannotation',
- 'kwonlyargs', 'kwarg', 'kwargannotation',
- 'defaults', 'kw_defaults'))
+ self.assertEqual(x._fields, ('args', 'vararg',
+ 'kwonlyargs', 'kw_defaults',
+ 'kwarg', 'defaults'))
with self.assertRaises(AttributeError):
x.vararg
- x = ast.arguments(*range(1, 9))
+ x = ast.arguments(*range(1, 7))
self.assertEqual(x.vararg, 2)
def test_field_attr_writable(self):
@@ -439,7 +455,7 @@ class ASTHelpers_Test(unittest.TestCase):
"lineno=1, col_offset=0), args=[Name(id='eggs', ctx=Load(), "
"lineno=1, col_offset=5), Str(s='and cheese', lineno=1, "
"col_offset=11)], keywords=[], starargs=None, kwargs=None, "
- "lineno=1, col_offset=0), lineno=1, col_offset=0)])"
+ "lineno=1, col_offset=4), lineno=1, col_offset=0)])"
)
def test_copy_location(self):
@@ -460,7 +476,7 @@ class ASTHelpers_Test(unittest.TestCase):
"Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), "
"lineno=1, col_offset=0), args=[Str(s='spam', lineno=1, "
"col_offset=6)], keywords=[], starargs=None, kwargs=None, "
- "lineno=1, col_offset=0), lineno=1, col_offset=0), "
+ "lineno=1, col_offset=5), lineno=1, col_offset=0), "
"Expr(value=Call(func=Name(id='spam', ctx=Load(), lineno=1, "
"col_offset=0), args=[Str(s='eggs', lineno=1, col_offset=0)], "
"keywords=[], starargs=None, kwargs=None, lineno=1, "
@@ -560,8 +576,8 @@ class ASTValidatorTests(unittest.TestCase):
self.mod(m, "must have Load context", "eval")
def _check_arguments(self, fac, check):
- def arguments(args=None, vararg=None, varargannotation=None,
- kwonlyargs=None, kwarg=None, kwargannotation=None,
+ def arguments(args=None, vararg=None,
+ kwonlyargs=None, kwarg=None,
defaults=None, kw_defaults=None):
if args is None:
args = []
@@ -571,20 +587,12 @@ class ASTValidatorTests(unittest.TestCase):
defaults = []
if kw_defaults is None:
kw_defaults = []
- args = ast.arguments(args, vararg, varargannotation, kwonlyargs,
- kwarg, kwargannotation, defaults, kw_defaults)
+ args = ast.arguments(args, vararg, kwonlyargs, kw_defaults,
+ kwarg, defaults)
return fac(args)
args = [ast.arg("x", ast.Name("x", ast.Store()))]
check(arguments(args=args), "must have Load context")
- check(arguments(varargannotation=ast.Num(3)),
- "varargannotation but no vararg")
- check(arguments(varargannotation=ast.Name("x", ast.Store()), vararg="x"),
- "must have Load context")
check(arguments(kwonlyargs=args), "must have Load context")
- check(arguments(kwargannotation=ast.Num(42)),
- "kwargannotation but no kwarg")
- check(arguments(kwargannotation=ast.Name("x", ast.Store()),
- kwarg="x"), "must have Load context")
check(arguments(defaults=[ast.Num(3)]),
"more positional defaults than args")
check(arguments(kw_defaults=[ast.Num(4)]),
@@ -599,7 +607,7 @@ class ASTValidatorTests(unittest.TestCase):
"must have Load context")
def test_funcdef(self):
- a = ast.arguments([], None, None, [], None, None, [], [])
+ a = ast.arguments([], None, [], [], None, [])
f = ast.FunctionDef("x", a, [], [], None)
self.stmt(f, "empty body on FunctionDef")
f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())],
@@ -770,7 +778,7 @@ class ASTValidatorTests(unittest.TestCase):
self.expr(u, "must have Load context")
def test_lambda(self):
- a = ast.arguments([], None, None, [], None, None, [], [])
+ a = ast.arguments([], None, [], [], None, [])
self.expr(ast.Lambda(a, ast.Name("x", ast.Store())),
"must have Load context")
def fac(args):
@@ -928,6 +936,9 @@ class ASTValidatorTests(unittest.TestCase):
def test_tuple(self):
self._sequence(ast.Tuple)
+ def test_nameconstant(self):
+ self.expr(ast.NameConstant(4), "singleton must be True, False, or None")
+
def test_stdlib_validates(self):
stdlib = os.path.dirname(ast.__file__)
tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
@@ -936,13 +947,10 @@ class ASTValidatorTests(unittest.TestCase):
fn = os.path.join(stdlib, module)
with open(fn, "r", encoding="utf-8") as fp:
source = fp.read()
- mod = ast.parse(source)
+ mod = ast.parse(source, fn)
compile(mod, fn, "exec")
-def test_main():
- support.run_unittest(AST_Tests, ASTHelpers_Test, ASTValidatorTests)
-
def main():
if __name__ != '__main__':
return
@@ -955,20 +963,20 @@ def main():
print("]")
print("main()")
raise SystemExit
- test_main()
+ unittest.main()
#### EVERYTHING BELOW IS GENERATED #####
exec_results = [
-('Module', [('Expr', (1, 0), ('Name', (1, 0), 'None', ('Load',)))]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, None, [], None, None, [], []), [('Pass', (1, 9))], [], None)]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', 'a', None)], None, None, [], None, None, [], []), [('Pass', (1, 10))], [], None)]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', 'a', None)], None, None, [], None, None, [('Num', (1, 8), 0)], []), [('Pass', (1, 12))], [], None)]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], 'args', None, [], None, None, [], []), [('Pass', (1, 14))], [], None)]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, None, [], 'kwargs', None, [], []), [('Pass', (1, 17))], [], None)]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', 'a', None), ('arg', 'b', None), ('arg', 'c', None), ('arg', 'd', None), ('arg', 'e', None)], 'args', None, [], 'kwargs', None, [('Num', (1, 11), 1), ('Name', (1, 16), 'None', ('Load',)), ('List', (1, 24), [], ('Load',)), ('Dict', (1, 30), [], [])], []), [('Pass', (1, 52))], [], None)]),
+('Module', [('Expr', (1, 0), ('NameConstant', (1, 0), None))]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], None, []), [('Pass', (1, 9))], [], None)]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None)], None, [], [], None, []), [('Pass', (1, 10))], [], None)]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None)], None, [], [], None, [('Num', (1, 8), 0)]), [('Pass', (1, 12))], [], None)]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], ('arg', (1, 7), 'args', None), [], [], None, []), [('Pass', (1, 14))], [], None)]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], ('arg', (1, 8), 'kwargs', None), []), [('Pass', (1, 17))], [], None)]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None), ('arg', (1, 9), 'b', None), ('arg', (1, 14), 'c', None), ('arg', (1, 22), 'd', None), ('arg', (1, 28), 'e', None)], ('arg', (1, 35), 'args', None), [], [], ('arg', (1, 43), 'kwargs', None), [('Num', (1, 11), 1), ('NameConstant', (1, 16), None), ('List', (1, 24), [], ('Load',)), ('Dict', (1, 30), [], [])]), [('Pass', (1, 52))], [], None)]),
('Module', [('ClassDef', (1, 0), 'C', [], [], None, None, [('Pass', (1, 8))], [])]),
('Module', [('ClassDef', (1, 0), 'C', [('Name', (1, 8), 'object', ('Load',))], [], None, None, [('Pass', (1, 17))], [])]),
-('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, None, [], None, None, [], []), [('Return', (1, 8), ('Num', (1, 15), 1))], [], None)]),
+('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], None, []), [('Return', (1, 8), ('Num', (1, 15), 1))], [], None)]),
('Module', [('Delete', (1, 0), [('Name', (1, 4), 'v', ('Del',))])]),
('Module', [('Assign', (1, 0), [('Name', (1, 0), 'v', ('Store',))], ('Num', (1, 4), 1))]),
('Module', [('AugAssign', (1, 0), ('Name', (1, 0), 'v', ('Store',)), ('Add',), ('Num', (1, 5), 1))]),
@@ -977,7 +985,7 @@ exec_results = [
('Module', [('If', (1, 0), ('Name', (1, 3), 'v', ('Load',)), [('Pass', (1, 5))], [])]),
('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',)))], [('Pass', (1, 13))])]),
('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',))), ('withitem', ('Name', (1, 13), 'z', ('Load',)), ('Name', (1, 18), 'q', ('Store',)))], [('Pass', (1, 21))])]),
-('Module', [('Raise', (1, 0), ('Call', (1, 6), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], [], None, None), None)]),
+('Module', [('Raise', (1, 0), ('Call', (1, 15), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], [], None, None), None)]),
('Module', [('Try', (1, 0), [('Pass', (2, 2))], [('ExceptHandler', (3, 0), ('Name', (3, 7), 'Exception', ('Load',)), None, [('Pass', (4, 2))])], [], [])]),
('Module', [('Try', (1, 0), [('Pass', (2, 2))], [], [], [('Pass', (4, 2))])]),
('Module', [('Assert', (1, 0), ('Name', (1, 7), 'v', ('Load',)), None)]),
@@ -1002,29 +1010,29 @@ single_results = [
('Interactive', [('Expr', (1, 0), ('BinOp', (1, 0), ('Num', (1, 0), 1), ('Add',), ('Num', (1, 2), 2)))]),
]
eval_results = [
-('Expression', ('Name', (1, 0), 'None', ('Load',))),
+('Expression', ('NameConstant', (1, 0), None)),
('Expression', ('BoolOp', (1, 0), ('And',), [('Name', (1, 0), 'a', ('Load',)), ('Name', (1, 6), 'b', ('Load',))])),
('Expression', ('BinOp', (1, 0), ('Name', (1, 0), 'a', ('Load',)), ('Add',), ('Name', (1, 4), 'b', ('Load',)))),
('Expression', ('UnaryOp', (1, 0), ('Not',), ('Name', (1, 4), 'v', ('Load',)))),
-('Expression', ('Lambda', (1, 0), ('arguments', [], None, None, [], None, None, [], []), ('Name', (1, 7), 'None', ('Load',)))),
+('Expression', ('Lambda', (1, 0), ('arguments', [], None, [], [], None, []), ('NameConstant', (1, 7), None))),
('Expression', ('Dict', (1, 0), [('Num', (1, 2), 1)], [('Num', (1, 4), 2)])),
('Expression', ('Dict', (1, 0), [], [])),
-('Expression', ('Set', (1, 0), [('Name', (1, 1), 'None', ('Load',))])),
+('Expression', ('Set', (1, 0), [('NameConstant', (1, 1), None)])),
('Expression', ('Dict', (1, 0), [('Num', (2, 6), 1)], [('Num', (4, 10), 2)])),
('Expression', ('ListComp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), [('comprehension', ('Name', (1, 7), 'b', ('Store',)), ('Name', (1, 12), 'c', ('Load',)), [('Name', (1, 17), 'd', ('Load',))])])),
('Expression', ('GeneratorExp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), [('comprehension', ('Name', (1, 7), 'b', ('Store',)), ('Name', (1, 12), 'c', ('Load',)), [('Name', (1, 17), 'd', ('Load',))])])),
('Expression', ('Compare', (1, 0), ('Num', (1, 0), 1), [('Lt',), ('Lt',)], [('Num', (1, 4), 2), ('Num', (1, 8), 3)])),
-('Expression', ('Call', (1, 0), ('Name', (1, 0), 'f', ('Load',)), [('Num', (1, 2), 1), ('Num', (1, 4), 2)], [('keyword', 'c', ('Num', (1, 8), 3))], ('Name', (1, 11), 'd', ('Load',)), ('Name', (1, 15), 'e', ('Load',)))),
+('Expression', ('Call', (1, 1), ('Name', (1, 0), 'f', ('Load',)), [('Num', (1, 2), 1), ('Num', (1, 4), 2)], [('keyword', 'c', ('Num', (1, 8), 3))], ('Name', (1, 11), 'd', ('Load',)), ('Name', (1, 15), 'e', ('Load',)))),
('Expression', ('Num', (1, 0), 10)),
('Expression', ('Str', (1, 0), 'string')),
-('Expression', ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',))),
-('Expression', ('Subscript', (1, 0), ('Name', (1, 0), 'a', ('Load',)), ('Slice', ('Name', (1, 2), 'b', ('Load',)), ('Name', (1, 4), 'c', ('Load',)), None), ('Load',))),
+('Expression', ('Attribute', (1, 2), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',))),
+('Expression', ('Subscript', (1, 2), ('Name', (1, 0), 'a', ('Load',)), ('Slice', ('Name', (1, 2), 'b', ('Load',)), ('Name', (1, 4), 'c', ('Load',)), None), ('Load',))),
('Expression', ('Name', (1, 0), 'v', ('Load',))),
('Expression', ('List', (1, 0), [('Num', (1, 1), 1), ('Num', (1, 3), 2), ('Num', (1, 5), 3)], ('Load',))),
('Expression', ('List', (1, 0), [], ('Load',))),
('Expression', ('Tuple', (1, 0), [('Num', (1, 0), 1), ('Num', (1, 2), 2), ('Num', (1, 4), 3)], ('Load',))),
('Expression', ('Tuple', (1, 1), [('Num', (1, 1), 1), ('Num', (1, 3), 2), ('Num', (1, 5), 3)], ('Load',))),
('Expression', ('Tuple', (1, 0), [], ('Load',))),
-('Expression', ('Call', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 8), ('Attribute', (1, 8), ('Name', (1, 8), 'a', ('Load',)), 'b', ('Load',)), ('Slice', ('Num', (1, 12), 1), ('Num', (1, 14), 2), None), ('Load',))], [], None, None)),
+('Expression', ('Call', (1, 7), ('Attribute', (1, 6), ('Attribute', (1, 4), ('Attribute', (1, 2), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 12), ('Attribute', (1, 10), ('Name', (1, 8), 'a', ('Load',)), 'b', ('Load',)), ('Slice', ('Num', (1, 12), 1), ('Num', (1, 14), 2), None), ('Load',))], [], None, None)),
]
main()
diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py
index c79fe6f613..f93a52d8c9 100644
--- a/Lib/test/test_asynchat.py
+++ b/Lib/test/test_asynchat.py
@@ -15,6 +15,7 @@ except ImportError:
HOST = support.HOST
SERVER_QUIT = b'QUIT\n'
+TIMEOUT = 3.0
if threading:
class echo_server(threading.Thread):
@@ -123,7 +124,9 @@ class TestAsynchat(unittest.TestCase):
c.push(b"I'm not dead yet!" + term)
c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"])
@@ -154,7 +157,9 @@ class TestAsynchat(unittest.TestCase):
c.push(data)
c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents, [data[:termlen]])
@@ -174,7 +179,9 @@ class TestAsynchat(unittest.TestCase):
c.push(data)
c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents, [])
self.assertEqual(c.buffer, data)
@@ -186,7 +193,9 @@ class TestAsynchat(unittest.TestCase):
p = asynchat.simple_producer(data+SERVER_QUIT, buffer_size=8)
c.push_with_producer(p)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"])
@@ -196,7 +205,9 @@ class TestAsynchat(unittest.TestCase):
data = b"hello world\nI'm not dead yet!\n"
c.push_with_producer(data+SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents, [b"hello world", b"I'm not dead yet!"])
@@ -207,7 +218,9 @@ class TestAsynchat(unittest.TestCase):
c.push(b"hello world\n\nI'm not dead yet!\n")
c.push(SERVER_QUIT)
asyncore.loop(use_poll=self.usepoll, count=300, timeout=.01)
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents,
[b"hello world", b"", b"I'm not dead yet!"])
@@ -226,7 +239,9 @@ class TestAsynchat(unittest.TestCase):
# where the server echoes all of its data before we can check that it
# got any down below.
s.start_resend_event.set()
- s.join()
+ s.join(timeout=TIMEOUT)
+ if s.is_alive():
+ self.fail("join() timed out")
self.assertEqual(c.contents, [])
# the server might have been able to send a byte or two back, but this
@@ -268,9 +283,5 @@ class TestFifo(unittest.TestCase):
self.assertEqual(f.pop(), (0, None))
-def test_main(verbose=None):
- support.run_unittest(TestAsynchat, TestAsynchat_WithPoll,
- TestHelperFunctions, TestFifo)
-
if __name__ == "__main__":
- test_main(verbose=True)
+ unittest.main()
diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py
new file mode 100644
index 0000000000..024d3e1743
--- /dev/null
+++ b/Lib/test/test_asyncio/__init__.py
@@ -0,0 +1,31 @@
+import os
+import sys
+import unittest
+from test.support import run_unittest
+
+try:
+ import threading
+except ImportError:
+ raise unittest.SkipTest("No module named '_thread'")
+
+
+def suite():
+ tests_file = os.path.join(os.path.dirname(__file__), 'tests.txt')
+ with open(tests_file) as fp:
+ test_names = fp.read().splitlines()
+ tests = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ for test_name in test_names:
+ mod_name = 'test.' + test_name
+ try:
+ __import__(mod_name)
+ except unittest.SkipTest:
+ pass
+ else:
+ mod = sys.modules[mod_name]
+ tests.addTests(loader.loadTestsFromModule(mod))
+ return tests
+
+
+def test_main():
+ run_unittest(suite())
diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py
new file mode 100644
index 0000000000..b549492038
--- /dev/null
+++ b/Lib/test/test_asyncio/__main__.py
@@ -0,0 +1,5 @@
+from . import test_main
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py
new file mode 100644
index 0000000000..006364bb00
--- /dev/null
+++ b/Lib/test/test_asyncio/echo.py
@@ -0,0 +1,8 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ if not buf:
+ break
+ os.write(1, buf)
diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py
new file mode 100644
index 0000000000..e83ca09fb7
--- /dev/null
+++ b/Lib/test/test_asyncio/echo2.py
@@ -0,0 +1,6 @@
+import os
+
+if __name__ == '__main__':
+ buf = os.read(0, 1024)
+ os.write(1, b'OUT:'+buf)
+ os.write(2, b'ERR:'+buf)
diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py
new file mode 100644
index 0000000000..064496736b
--- /dev/null
+++ b/Lib/test/test_asyncio/echo3.py
@@ -0,0 +1,11 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ if not buf:
+ break
+ try:
+ os.write(1, b'OUT:'+buf)
+ except OSError as ex:
+ os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii'))
diff --git a/Lib/test/test_asyncio/sample.crt b/Lib/test/test_asyncio/sample.crt
new file mode 100644
index 0000000000..6a1e3f3c2e
--- /dev/null
+++ b/Lib/test/test_asyncio/sample.crt
@@ -0,0 +1,14 @@
+-----BEGIN CERTIFICATE-----
+MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV
+UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi
+MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3
+MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp
+Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g
+U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn
+t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7
+gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg
+Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB
+BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB
+MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I
+I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB
+-----END CERTIFICATE-----
diff --git a/Lib/test/test_asyncio/sample.key b/Lib/test/test_asyncio/sample.key
new file mode 100644
index 0000000000..edfea8dcab
--- /dev/null
+++ b/Lib/test/test_asyncio/sample.key
@@ -0,0 +1,15 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx
+/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi
+qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB
+AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y
+bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1
+iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb
+DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP
+lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL
+21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF
+ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0
+zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u
+GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq
+V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof
+-----END RSA PRIVATE KEY-----
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
new file mode 100644
index 0000000000..ff537ab2bd
--- /dev/null
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -0,0 +1,684 @@
+"""Tests for base_events.py"""
+
+import errno
+import logging
+import socket
+import time
+import unittest
+import unittest.mock
+from test.support import find_unused_port, IPV6_ENABLED
+
+from asyncio import base_events
+from asyncio import constants
+from asyncio import events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class BaseEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = base_events.BaseEventLoop()
+ self.loop._selector = unittest.mock.Mock()
+ events.set_event_loop(None)
+
+ def test_not_implemented(self):
+ m = unittest.mock.Mock()
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_socket_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_ssl_transport, m, m, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_datagram_transport, m, m)
+ self.assertRaises(
+ NotImplementedError, self.loop._process_events, [])
+ self.assertRaises(
+ NotImplementedError, self.loop._write_to_self)
+ self.assertRaises(
+ NotImplementedError, self.loop._read_from_self)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_read_pipe_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_write_pipe_transport, m, m)
+ gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m)
+ self.assertRaises(NotImplementedError, next, iter(gen))
+
+ def test__add_callback_handle(self):
+ h = events.Handle(lambda: False, ())
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertIn(h, self.loop._ready)
+
+ def test__add_callback_timer(self):
+ h = events.TimerHandle(time.monotonic()+10, lambda: False, ())
+
+ self.loop._add_callback(h)
+ self.assertIn(h, self.loop._scheduled)
+
+ def test__add_callback_cancelled_handle(self):
+ h = events.Handle(lambda: False, ())
+ h.cancel()
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertFalse(self.loop._ready)
+
+ def test_set_default_executor(self):
+ executor = unittest.mock.Mock()
+ self.loop.set_default_executor(executor)
+ self.assertIs(executor, self.loop._default_executor)
+
+ def test_getnameinfo(self):
+ sockaddr = unittest.mock.Mock()
+ self.loop.run_in_executor = unittest.mock.Mock()
+ self.loop.getnameinfo(sockaddr)
+ self.assertEqual(
+ (None, socket.getnameinfo, sockaddr, 0),
+ self.loop.run_in_executor.call_args[0])
+
+ def test_call_soon(self):
+ def cb():
+ pass
+
+ h = self.loop.call_soon(cb)
+ self.assertEqual(h._callback, cb)
+ self.assertIsInstance(h, events.Handle)
+ self.assertIn(h, self.loop._ready)
+
+ def test_call_later(self):
+ def cb():
+ pass
+
+ h = self.loop.call_later(10.0, cb)
+ self.assertIsInstance(h, events.TimerHandle)
+ self.assertIn(h, self.loop._scheduled)
+ self.assertNotIn(h, self.loop._ready)
+
+ def test_call_later_negative_delays(self):
+ calls = []
+
+ def cb(arg):
+ calls.append(arg)
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop.call_later(-1, cb, 'a')
+ self.loop.call_later(-2, cb, 'b')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(calls, ['b', 'a'])
+
+ def test_time_and_call_at(self):
+ def cb():
+ self.loop.stop()
+
+ self.loop._process_events = unittest.mock.Mock()
+ when = self.loop.time() + 0.1
+ self.loop.call_at(when, cb)
+ t0 = self.loop.time()
+ self.loop.run_forever()
+ t1 = self.loop.time()
+ self.assertTrue(0.09 <= t1-t0 <= 0.9, t1-t0)
+
+ def test_run_once_in_executor_handle(self):
+ def cb():
+ pass
+
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, events.Handle(cb, ()), ('',))
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, events.TimerHandle(10, cb, ()))
+
+ def test_run_once_in_executor_cancelled(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ h.cancel()
+
+ f = self.loop.run_in_executor(None, h)
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test_run_once_in_executor_plain(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ f = futures.Future(loop=self.loop)
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+
+ self.loop.set_default_executor(executor)
+
+ res = self.loop.run_in_executor(None, h)
+ self.assertIs(f, res)
+
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+ res = self.loop.run_in_executor(executor, h)
+ self.assertIs(f, res)
+ self.assertTrue(executor.submit.called)
+
+ f.cancel() # Don't complain about abandoned Future.
+
+ def test__run_once(self):
+ h1 = events.TimerHandle(time.monotonic() + 5.0, lambda: True, ())
+ h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ())
+
+ h1.cancel()
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._scheduled.append(h1)
+ self.loop._scheduled.append(h2)
+ self.loop._run_once()
+
+ t = self.loop._selector.select.call_args[0][0]
+ self.assertTrue(9.9 < t < 10.1, t)
+ self.assertEqual([h2], self.loop._scheduled)
+ self.assertTrue(self.loop._process_events.called)
+
+ @unittest.mock.patch('asyncio.base_events.time')
+ @unittest.mock.patch('asyncio.base_events.logger')
+ def test__run_once_logging(self, m_logging, m_time):
+ # Log to INFO level if timeout > 1.0 sec.
+ idx = -1
+ data = [10.0, 10.0, 12.0, 13.0]
+
+ def monotonic():
+ nonlocal data, idx
+ idx += 1
+ return data[idx]
+
+ m_time.monotonic = monotonic
+ m_logging.INFO = logging.INFO
+ m_logging.DEBUG = logging.DEBUG
+
+ self.loop._scheduled.append(
+ events.TimerHandle(11.0, lambda: True, ()))
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._run_once()
+ self.assertEqual(logging.INFO, m_logging.log.call_args[0][0])
+
+ idx = -1
+ data = [10.0, 10.0, 10.3, 13.0]
+ self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())]
+ self.loop._run_once()
+ self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0])
+
+ def test__run_once_schedule_handle(self):
+ handle = None
+ processed = False
+
+ def cb(loop):
+ nonlocal processed, handle
+ processed = True
+ handle = loop.call_soon(lambda: True)
+
+ h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,))
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._scheduled.append(h)
+ self.loop._run_once()
+
+ self.assertTrue(processed)
+ self.assertEqual([handle], list(self.loop._ready))
+
+ def test_run_until_complete_type_error(self):
+ self.assertRaises(
+ TypeError, self.loop.run_until_complete, 'blah')
+
+
+class MyProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = futures.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = futures.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def error_received(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class BaseEventLoopWithSelectorTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = events.new_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_connection_multiple_errors(self, m_socket):
+
+ class MyProto(protocols.Protocol):
+ pass
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ idx = -1
+ errors = ['err1', 'err2']
+
+ def _socket(*args, **kw):
+ nonlocal idx, errors
+ idx += 1
+ raise OSError(errors[idx])
+
+ m_socket.socket = _socket
+
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
+
+ def test_create_connection_host_port_sock(self):
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_host_port_sock(self):
+ coro = self.loop.create_connection(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_getaddrinfo(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_connect_err(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_multiple(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET)
+ with self.assertRaises(OSError):
+ self.loop.run_until_complete(coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_connection_multiple_errors_local_addr(self, m_socket):
+
+ def bind(addr):
+ if addr[0] == '0.0.0.1':
+ err = OSError('Err')
+ err.strerror = 'Err'
+ raise err
+
+ m_socket.socket.return_value.bind = bind
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError('Err2')
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
+ self.assertTrue(m_socket.socket.return_value.close.called)
+
+ def test_create_connection_no_local_addr(self):
+ @tasks.coroutine
+ def getaddrinfo(host, *args, **kw):
+ if host == 'example.com':
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+ else:
+ return []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_ssl_server_hostname_default(self):
+ self.loop.getaddrinfo = unittest.mock.Mock()
+
+ def mock_getaddrinfo(*args, **kwds):
+ f = futures.Future(loop=self.loop)
+ f.set_result([(socket.AF_INET, socket.SOCK_STREAM,
+ socket.SOL_TCP, '', ('1.2.3.4', 80))])
+ return f
+
+ self.loop.getaddrinfo.side_effect = mock_getaddrinfo
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.return_value = ()
+ self.loop._make_ssl_transport = unittest.mock.Mock()
+
+ class _SelectorTransportMock:
+ _sock = None
+
+ def close(self):
+ self._sock.close()
+
+ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
+ **kwds):
+ waiter.set_result(None)
+ transport = _SelectorTransportMock()
+ transport._sock = sock
+ return transport
+
+ self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
+ ANY = unittest.mock.ANY
+ # First try the default server_hostname.
+ self.loop._make_ssl_transport.reset_mock()
+ coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
+ transport, _ = self.loop.run_until_complete(coro)
+ transport.close()
+ self.loop._make_ssl_transport.assert_called_with(
+ ANY, ANY, ANY, ANY,
+ server_side=False,
+ server_hostname='python.org')
+ # Next try an explicit server_hostname.
+ self.loop._make_ssl_transport.reset_mock()
+ coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
+ server_hostname='perl.com')
+ transport, _ = self.loop.run_until_complete(coro)
+ transport.close()
+ self.loop._make_ssl_transport.assert_called_with(
+ ANY, ANY, ANY, ANY,
+ server_side=False,
+ server_hostname='perl.com')
+ # Finally try an explicit empty server_hostname.
+ self.loop._make_ssl_transport.reset_mock()
+ coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
+ server_hostname='')
+ transport, _ = self.loop.run_until_complete(coro)
+ transport.close()
+ self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
+ server_side=False,
+ server_hostname='')
+
+ def test_create_connection_no_ssl_server_hostname_errors(self):
+ # When not using ssl, server_hostname must be None.
+ coro = self.loop.create_connection(MyProto, 'python.org', 80,
+ server_hostname='')
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_connection(MyProto, 'python.org', 80,
+ server_hostname='python.org')
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_ssl_server_hostname_errors(self):
+ # When using ssl, server_hostname may be None if host is non-empty.
+ coro = self.loop.create_connection(MyProto, '', 80, ssl=True)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_connection(MyProto, None, 80, ssl=True)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+ sock = socket.socket()
+ coro = self.loop.create_connection(MyProto, None, None,
+ ssl=True, sock=sock)
+ self.addCleanup(sock.close)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_server_empty_host(self):
+ # if host is empty string use None instead
+ host = object()
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ nonlocal host
+ host = args[0]
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ fut = self.loop.create_server(MyProto, '', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertIsNone(host)
+
+ def test_create_server_host_port_sock(self):
+ fut = self.loop.create_server(
+ MyProto, '0.0.0.0', 0, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_host_port_sock(self):
+ fut = self.loop.create_server(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_getaddrinfo(self):
+ getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock()
+ getaddrinfo.return_value = []
+
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, f)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_server_cant_bind(self, m_socket):
+
+ class Err(OSError):
+ strerror = 'error'
+
+ m_socket.getaddrinfo.return_value = [
+ (2, 1, 6, '', ('127.0.0.1', 10100))]
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
+ m_socket.getaddrinfo.return_value = []
+
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_addr_error(self):
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr='localhost')
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 1, 2, 3))
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_connect_err(self):
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_socket_err(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.socket.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, local_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
+ def test_create_datagram_endpoint_no_matching_family(self):
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol,
+ remote_addr=('127.0.0.1', 0), local_addr=('::1', 0))
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_setblk_err(self, m_socket):
+ m_socket.socket.return_value.setblocking.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+ self.assertTrue(
+ m_socket.socket.return_value.close.called)
+
+ def test_create_datagram_endpoint_noaddr_nofamily(self):
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_cant_bind(self, m_socket):
+ class Err(OSError):
+ pass
+
+ m_socket.AF_INET6 = socket.AF_INET6
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto,
+ local_addr=('127.0.0.1', 0), family=socket.AF_INET)
+ self.assertRaises(Err, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ def test_accept_connection_retry(self):
+ sock = unittest.mock.Mock()
+ sock.accept.side_effect = BlockingIOError()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertFalse(sock.close.called)
+
+ @unittest.mock.patch('asyncio.selector_events.logger')
+ def test_accept_connection_exception(self, m_log):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files')
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop.call_later = unittest.mock.Mock()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertTrue(m_log.exception.called)
+ self.assertFalse(sock.close.called)
+ self.loop.remove_reader.assert_called_with(10)
+ self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY,
+ # self.loop._start_serving
+ unittest.mock.ANY,
+ MyProto, sock, None, None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
new file mode 100644
index 0000000000..3a2dece031
--- /dev/null
+++ b/Lib/test/test_asyncio/test_events.py
@@ -0,0 +1,1655 @@
+"""Tests for events.py."""
+
+import functools
+import gc
+import io
+import os
+import signal
+import socket
+try:
+ import ssl
+except ImportError:
+ ssl = None
+import subprocess
+import sys
+import threading
+import time
+import errno
+import unittest
+import unittest.mock
+from test.support import find_unused_port, IPV6_ENABLED
+
+
+from asyncio import futures
+from asyncio import events
+from asyncio import transports
+from asyncio import protocols
+from asyncio import selector_events
+from asyncio import tasks
+from asyncio import test_utils
+from asyncio import locks
+
+
+class MyProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def error_received(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyReadPipeProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = ['INITIAL']
+ self.nbytes = 0
+ self.transport = None
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == ['INITIAL'], self.state
+ self.state.append('CONNECTED')
+
+ def data_received(self, data):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.state.append('EOF')
+
+ def connection_lost(self, exc):
+ assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state
+ self.state.append('CLOSED')
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyWritePipeProto(protocols.BaseProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.transport = None
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MySubprocessProtocol(protocols.SubprocessProtocol):
+
+ def __init__(self, loop):
+ self.state = 'INITIAL'
+ self.transport = None
+ self.connected = futures.Future(loop=loop)
+ self.completed = futures.Future(loop=loop)
+ self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)}
+ self.data = {1: b'', 2: b''}
+ self.returncode = None
+ self.got_data = {1: locks.Event(loop=loop),
+ 2: locks.Event(loop=loop)}
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ self.connected.set_result(None)
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ self.completed.set_result(None)
+
+ def pipe_data_received(self, fd, data):
+ assert self.state == 'CONNECTED', self.state
+ self.data[fd] += data
+ self.got_data[fd].set()
+
+ def pipe_connection_lost(self, fd, exc):
+ assert self.state == 'CONNECTED', self.state
+ if exc:
+ self.disconnects[fd].set_exception(exc)
+ else:
+ self.disconnects[fd].set_result(exc)
+
+ def process_exited(self):
+ assert self.state == 'CONNECTED', self.state
+ self.returncode = self.transport.get_returncode()
+
+
+class EventLoopTestsMixin:
+
+ def setUp(self):
+ super().setUp()
+ self.loop = self.create_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+ super().tearDown()
+
+ def test_run_until_complete_nesting(self):
+ @tasks.coroutine
+ def coro1():
+ yield
+
+ @tasks.coroutine
+ def coro2():
+ self.assertTrue(self.loop.is_running())
+ self.loop.run_until_complete(coro1())
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, coro2())
+
+ # Note: because of the default Windows timing granularity of
+ # 15.6 msec, we use fairly long sleep times here (~100 msec).
+
+ def test_run_until_complete(self):
+ t0 = self.loop.time()
+ self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop))
+ t1 = self.loop.time()
+ self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
+
+ def test_run_until_complete_stopped(self):
+ @tasks.coroutine
+ def cb():
+ self.loop.stop()
+ yield from tasks.sleep(0.1, loop=self.loop)
+ task = cb()
+ self.assertRaises(RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_call_later(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ self.loop.stop()
+
+ self.loop.call_later(0.1, callback, 'hello world')
+ t0 = time.monotonic()
+ self.loop.run_forever()
+ t1 = time.monotonic()
+ self.assertEqual(results, ['hello world'])
+ self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0)
+
+ def test_call_soon(self):
+ results = []
+
+ def callback(arg1, arg2):
+ results.append((arg1, arg2))
+ self.loop.stop()
+
+ self.loop.call_soon(callback, 'hello', 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, [('hello', 'world')])
+
+ def test_call_soon_threadsafe(self):
+ results = []
+ lock = threading.Lock()
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ def run_in_thread():
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ lock.release()
+
+ lock.acquire()
+ t = threading.Thread(target=run_in_thread)
+ t.start()
+
+ with lock:
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ t.join()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_call_soon_threadsafe_same_thread(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_run_in_executor(self):
+ def run(arg):
+ return (arg, threading.get_ident())
+ f2 = self.loop.run_in_executor(None, run, 'yo')
+ res, thread_id = self.loop.run_until_complete(f2)
+ self.assertEqual(res, 'yo')
+ self.assertNotEqual(thread_id, threading.get_ident())
+
+ def test_reader_callback(self):
+ r, w = test_utils.socketpair()
+ bytes_read = []
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ # Spurious readiness notifications are possible
+ # at least on Linux -- see man select.
+ return
+ if data:
+ bytes_read.append(data)
+ else:
+ self.assertTrue(self.loop.remove_reader(r.fileno()))
+ r.close()
+
+ self.loop.add_reader(r.fileno(), reader)
+ self.loop.call_soon(w.send, b'abc')
+ test_utils.run_briefly(self.loop)
+ self.loop.call_soon(w.send, b'def')
+ test_utils.run_briefly(self.loop)
+ self.loop.call_soon(w.close)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(b''.join(bytes_read), b'abcdef')
+
+ def test_writer_callback(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+ self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024))
+ test_utils.run_briefly(self.loop)
+
+ def remove_writer():
+ self.assertTrue(self.loop.remove_writer(w.fileno()))
+
+ self.loop.call_soon(remove_writer)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ w.close()
+ data = r.recv(256*1024)
+ r.close()
+ self.assertGreaterEqual(len(data), 200)
+
+ def test_sock_client_ops(self):
+ with test_utils.run_test_server() as httpd:
+ sock = socket.socket()
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ data = self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ # consume data
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ sock.close()
+
+ self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+
+ def test_sock_client_fail(self):
+ # Make sure that we will get an unused port
+ address = None
+ try:
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ address = s.getsockname()
+ finally:
+ s.close()
+
+ sock = socket.socket()
+ sock.setblocking(False)
+ with self.assertRaises(ConnectionRefusedError):
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ sock.close()
+
+ def test_sock_accept(self):
+ listener = socket.socket()
+ listener.setblocking(False)
+ listener.bind(('127.0.0.1', 0))
+ listener.listen(1)
+ client = socket.socket()
+ client.connect(listener.getsockname())
+
+ f = self.loop.sock_accept(listener)
+ conn, addr = self.loop.run_until_complete(f)
+ self.assertEqual(conn.gettimeout(), 0)
+ self.assertEqual(addr, client.getsockname())
+ self.assertEqual(client.getpeername(), listener.getsockname())
+ client.close()
+ conn.close()
+ listener.close()
+
+ @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
+ def test_add_signal_handler(self):
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ # Check error behavior first.
+ self.assertRaises(
+ TypeError, self.loop.add_signal_handler, 'boom', my_handler)
+ self.assertRaises(
+ TypeError, self.loop.remove_signal_handler, 'boom')
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, signal.NSIG+1,
+ my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, signal.NSIG+1)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, 0, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, 0)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, -1, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, -1)
+ self.assertRaises(
+ RuntimeError, self.loop.add_signal_handler, signal.SIGKILL,
+ my_handler)
+ # Removing SIGKILL doesn't raise, since we don't call signal().
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL))
+ # Now set a handler and handle it.
+ self.loop.add_signal_handler(signal.SIGINT, my_handler)
+ test_utils.run_briefly(self.loop)
+ os.kill(os.getpid(), signal.SIGINT)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(caught, 1)
+ # Removing it should restore the default handler.
+ self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertEqual(signal.getsignal(signal.SIGINT),
+ signal.default_int_handler)
+ # Removing again returns False.
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT))
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_while_selecting(self):
+ # Test with a signal actually arriving during a select() call.
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+ self.loop.stop()
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once.
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_args(self):
+ some_args = (42,)
+ caught = 0
+
+ def my_handler(*args):
+ nonlocal caught
+ caught += 1
+ self.assertEqual(args, some_args)
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once.
+ self.loop.call_later(0.5, self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ def test_create_connection(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertIsInstance(tr, transports.Transport)
+ self.assertIsInstance(pr, protocols.Protocol)
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection_sock(self):
+ with test_utils.run_test_server() as httpd:
+ sock = None
+ infos = self.loop.run_until_complete(
+ self.loop.getaddrinfo(
+ *httpd.address, type=socket.SOCK_STREAM))
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ except:
+ pass
+ else:
+ break
+ else:
+ assert False, 'Can not create socket.'
+
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), sock=sock)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertIsInstance(tr, transports.Transport)
+ self.assertIsInstance(pr, protocols.Protocol)
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_ssl_connection(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address,
+ ssl=test_utils.dummy_ssl_context())
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertIsInstance(tr, transports.Transport)
+ self.assertIsInstance(pr, protocols.Protocol)
+ self.assertTrue('ssl' in tr.__class__.__name__.lower())
+ self.assertIsNotNone(tr.get_extra_info('sockname'))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection_local_addr(self):
+ with test_utils.run_test_server() as httpd:
+ port = find_unused_port()
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=(httpd.address[0], port))
+ tr, pr = self.loop.run_until_complete(f)
+ expected = pr.transport.get_extra_info('sockname')[1]
+ self.assertEqual(port, expected)
+ tr.close()
+
+ def test_create_connection_local_addr_in_use(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=httpd.address)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+ self.assertIn(str(httpd.address), cm.exception.strerror)
+
+ def test_create_server(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyProto()
+ return proto
+
+ f = self.loop.create_server(factory, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.sendall(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ self.assertEqual('INITIAL', proto.state)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
+ timeout=10)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ test_utils.run_briefly(self.loop) # windows iocp
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # close server
+ server.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl(self):
+ proto = None
+
+ class ClientMyProto(MyProto):
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def factory():
+ nonlocal proto
+ proto = MyProto(loop=self.loop)
+ return proto
+
+ here = os.path.dirname(__file__)
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.load_cert_chain(
+ certfile=os.path.join(here, 'sample.crt'),
+ keyfile=os.path.join(here, 'sample.key'))
+
+ f = self.loop.create_server(
+ factory, '127.0.0.1', 0, ssl=sslcontext)
+
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '127.0.0.1')
+
+ f_c = self.loop.create_connection(ClientMyProto, host, port,
+ ssl=test_utils.dummy_ssl_context())
+ client, pr = self.loop.run_until_complete(f_c)
+
+ client.write(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
+ timeout=10)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # stop serving
+ server.close()
+
+ def test_create_server_sock(self):
+ proto = futures.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ proto.set_result(self)
+
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(TestMyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ self.assertIs(sock, sock_ob)
+
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+ server.close()
+
+ def test_create_server_addr_in_use(self):
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(MyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ f = self.loop.create_server(MyProto, host=host, port=port)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+
+ server.close()
+
+ @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
+ def test_create_server_dual_stack(self):
+ f_proto = futures.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ f_proto.set_result(self)
+
+ try_count = 0
+ while True:
+ try:
+ port = find_unused_port()
+ f = self.loop.create_server(TestMyProto, host=None, port=port)
+ server = self.loop.run_until_complete(f)
+ except OSError as ex:
+ if ex.errno == errno.EADDRINUSE:
+ try_count += 1
+ self.assertGreaterEqual(5, try_count)
+ continue
+ else:
+ raise
+ else:
+ break
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ f_proto = futures.Future(loop=self.loop)
+ client = socket.socket(socket.AF_INET6)
+ client.connect(('::1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ server.close()
+
+ def test_server_close(self):
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+
+ server.close()
+
+ client = socket.socket()
+ self.assertRaises(
+ ConnectionRefusedError, client.connect, ('127.0.0.1', port))
+ client.close()
+
+ def test_create_datagram_endpoint(self):
+ class TestMyDatagramProto(MyDatagramProto):
+ def __init__(inner_self):
+ super().__init__(loop=self.loop)
+
+ def datagram_received(self, data, addr):
+ super().datagram_received(data, addr)
+ self.transport.sendto(b'resp:'+data, addr)
+
+ coro = self.loop.create_datagram_endpoint(
+ TestMyDatagramProto, local_addr=('127.0.0.1', 0))
+ s_transport, server = self.loop.run_until_complete(coro)
+ host, port = s_transport.get_extra_info('sockname')
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ remote_addr=(host, port))
+ transport, client = self.loop.run_until_complete(coro)
+
+ self.assertEqual('INITIALIZED', client.state)
+ transport.sendto(b'xxx')
+ for _ in range(1000):
+ if server.nbytes:
+ break
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(3, server.nbytes)
+ for _ in range(1000):
+ if client.nbytes:
+ break
+ test_utils.run_briefly(self.loop)
+
+ # received
+ self.assertEqual(8, client.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(transport.get_extra_info('sockname'))
+
+ # close connection
+ transport.close()
+ self.loop.run_until_complete(client.done)
+ self.assertEqual('CLOSED', client.state)
+ server.transport.close()
+
+ def test_internal_fds(self):
+ loop = self.create_event_loop()
+ if not isinstance(loop, selector_events.BaseSelectorEventLoop):
+ return
+
+ self.assertEqual(1, loop._internal_fds)
+ loop.close()
+ self.assertEqual(0, loop._internal_fds)
+ self.assertIsNone(loop._csock)
+ self.assertIsNone(loop._ssock)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_read_pipe(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyReadPipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(rpipe, 'rb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(0, proto.nbytes)
+
+ self.loop.run_until_complete(connect())
+
+ os.write(wpipe, b'1')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(1, proto.nbytes)
+
+ os.write(wpipe, b'2345')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(5, proto.nbytes)
+
+ os.close(wpipe)
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual(
+ ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ nonlocal transport
+ t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.loop.run_until_complete(connect())
+
+ transport.write(b'1')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'1', data)
+
+ transport.write(b'2345')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'2345', data)
+ self.assertEqual('CONNECTED', proto.state)
+
+ os.close(rpipe)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe_disconnect_on_close(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto(loop=self.loop)
+ return proto
+
+ rsock, wsock = self.loop._socketpair()
+ pipeobj = io.open(wsock.detach(), 'wb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ nonlocal transport
+ t, p = yield from self.loop.connect_write_pipe(factory,
+ pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.loop.run_until_complete(connect())
+ self.assertEqual('CONNECTED', proto.state)
+
+ transport.write(b'1')
+ data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024))
+ self.assertEqual(b'1', data)
+
+ rsock.close()
+
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ def test_prompt_cancellation(self):
+ r, w = test_utils.socketpair()
+ r.setblocking(False)
+ f = self.loop.sock_recv(r, 1)
+ ov = getattr(f, 'ov', None)
+ if ov is not None:
+ self.assertTrue(ov.pending)
+
+ @tasks.coroutine
+ def main():
+ try:
+ self.loop.call_soon(f.cancel)
+ yield from f
+ except futures.CancelledError:
+ res = 'cancelled'
+ else:
+ res = None
+ finally:
+ self.loop.stop()
+ return res
+
+ start = time.monotonic()
+ t = tasks.Task(main(), loop=self.loop)
+ self.loop.run_forever()
+ elapsed = time.monotonic() - start
+
+ self.assertLess(elapsed, 0.1)
+ self.assertEqual(t.result(), 'cancelled')
+ self.assertRaises(futures.CancelledError, f.result)
+ if ov is not None:
+ self.assertFalse(ov.pending)
+ self.loop._stop_serving(r)
+
+ r.close()
+ w.close()
+
+
+class SubprocessTestsMixin:
+
+ def check_terminated(self, returncode):
+ if sys.platform == 'win32':
+ self.assertIsInstance(returncode, int)
+ # expect 1 but sometimes get 0
+ else:
+ self.assertEqual(-signal.SIGTERM, returncode)
+
+ def check_killed(self, returncode):
+ if sys.platform == 'win32':
+ self.assertIsInstance(returncode, int)
+ # expect 1 but sometimes get 0
+ else:
+ self.assertEqual(-signal.SIGKILL, returncode)
+
+ def test_subprocess_exec(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.check_terminated(proto.returncode)
+ self.assertEqual(b'Python The Winner', proto.data[1])
+
+ def test_subprocess_interactive(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ try:
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python ')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ proto.got_data[1].clear()
+ self.assertEqual(b'Python ', proto.data[1])
+
+ stdin.write(b'The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'Python The Winner', proto.data[1])
+ finally:
+ transp.close()
+
+ self.loop.run_until_complete(proto.completed)
+ self.check_terminated(proto.returncode)
+
+ def test_subprocess_shell(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'echo Python')
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.get_pipe_transport(0).close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(0, proto.returncode)
+ self.assertTrue(all(f.done() for f in proto.disconnects.values()))
+ self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python')
+ self.assertEqual(proto.data[2], b'')
+
+ def test_subprocess_exitcode(self):
+ proto = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+
+ def test_subprocess_close_after_finish(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.assertIsNone(transp.get_pipe_transport(0))
+ self.assertIsNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+ self.assertIsNone(transp.close())
+
+ def test_subprocess_kill(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.kill()
+ self.loop.run_until_complete(proto.completed)
+ self.check_killed(proto.returncode)
+
+ def test_subprocess_terminate(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.terminate()
+ self.loop.run_until_complete(proto.completed)
+ self.check_terminated(proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP")
+ def test_subprocess_send_signal(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.send_signal(signal.SIGHUP)
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGHUP, proto.returncode)
+
+ def test_subprocess_stderr(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'test')
+
+ self.loop.run_until_complete(proto.completed)
+
+ transp.close()
+ self.assertEqual(b'OUT:test', proto.data[1])
+ self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2])
+ self.assertEqual(0, proto.returncode)
+
+ def test_subprocess_stderr_redirect_to_stdout(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog, stderr=subprocess.STDOUT)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ self.assertIsNotNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.completed)
+ self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'),
+ proto.data[1])
+ self.assertEqual(b'', proto.data[2])
+
+ transp.close()
+ self.assertEqual(0, proto.returncode)
+
+ def test_subprocess_close_client_stream(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdout = transp.get_pipe_transport(1)
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'OUT:test', proto.data[1])
+
+ stdout.close()
+ self.loop.run_until_complete(proto.disconnects[1])
+ stdin.write(b'xxx')
+ self.loop.run_until_complete(proto.got_data[2].wait())
+ if sys.platform != 'win32':
+ self.assertEqual(b'ERR:BrokenPipeError', proto.data[2])
+ else:
+ # After closing the read-end of a pipe, writing to the
+ # write-end using os.write() fails with errno==EINVAL and
+ # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using
+ # WriteFile() we get ERROR_BROKEN_PIPE as expected.)
+ self.assertEqual(b'ERR:OSError', proto.data[2])
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.check_terminated(proto.returncode)
+
+ def test_subprocess_wait_no_same_group(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto
+ # start the new process in a new session
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None,
+ start_new_session=True)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+
+
+if sys.platform == 'win32':
+ from asyncio import windows_events
+
+ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return windows_events.SelectorEventLoop()
+
+ class ProactorEventLoopTests(EventLoopTestsMixin,
+ SubprocessTestsMixin,
+ unittest.TestCase):
+
+ def create_event_loop(self):
+ return windows_events.ProactorEventLoop()
+
+ def test_create_ssl_connection(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+
+ def test_create_server_ssl(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+
+ def test_reader_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_reader_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_writer_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_writer_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_create_datagram_endpoint(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+else:
+ from asyncio import selectors
+ from asyncio import unix_events
+
+ class UnixEventLoopTestsMixin(EventLoopTestsMixin):
+ def setUp(self):
+ super().setUp()
+ watcher = unix_events.SafeChildWatcher()
+ watcher.attach_loop(self.loop)
+ events.set_child_watcher(watcher)
+
+ def tearDown(self):
+ events.set_child_watcher(None)
+ super().tearDown()
+
+ if hasattr(selectors, 'KqueueSelector'):
+ class KqueueEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(
+ selectors.KqueueSelector())
+
+ if hasattr(selectors, 'EpollSelector'):
+ class EPollEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.EpollSelector())
+
+ if hasattr(selectors, 'PollSelector'):
+ class PollEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.PollSelector())
+
+ # Should always exist.
+ class SelectEventLoopTests(UnixEventLoopTestsMixin,
+ SubprocessTestsMixin,
+ unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.SelectSelector())
+
+
+class HandleTests(unittest.TestCase):
+
+ def test_handle(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ h = events.Handle(callback, args)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())<cancelled>'), r)
+
+ def test_make_handle(self):
+ def callback(*args):
+ return args
+ h1 = events.Handle(callback, ())
+ self.assertRaises(
+ AssertionError, events.make_handle, h1, ())
+
+ @unittest.mock.patch('asyncio.events.logger')
+ def test_callback_with_exception(self, log):
+ def callback():
+ raise ValueError()
+
+ h = events.Handle(callback, ())
+ h._run()
+ self.assertTrue(log.exception.called)
+
+
+class TimerTests(unittest.TestCase):
+
+ def test_hash(self):
+ when = time.monotonic()
+ h = events.TimerHandle(when, lambda: False, ())
+ self.assertEqual(hash(h), hash(when))
+
+ def test_timer(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ when = time.monotonic()
+ h = events.TimerHandle(when, callback, args)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())<cancelled>'), r)
+
+ self.assertRaises(AssertionError,
+ events.TimerHandle, None, callback, args)
+
+ def test_timer_comparison(self):
+ def callback(*args):
+ return args
+
+ when = time.monotonic()
+
+ h1 = events.TimerHandle(when, callback, ())
+ h2 = events.TimerHandle(when, callback, ())
+ # TODO: Use assertLess etc.
+ self.assertFalse(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertTrue(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertFalse(h2 > h1)
+ self.assertTrue(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertTrue(h1 == h2)
+ self.assertFalse(h1 != h2)
+
+ h2.cancel()
+ self.assertFalse(h1 == h2)
+
+ h1 = events.TimerHandle(when, callback, ())
+ h2 = events.TimerHandle(when + 10.0, callback, ())
+ self.assertTrue(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertFalse(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertTrue(h2 > h1)
+ self.assertFalse(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h3 = events.Handle(callback, ())
+ self.assertIs(NotImplemented, h1.__eq__(h3))
+ self.assertIs(NotImplemented, h1.__ne__(h3))
+
+
+class AbstractEventLoopTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = unittest.mock.Mock()
+ loop = events.AbstractEventLoop()
+ self.assertRaises(
+ NotImplementedError, loop.run_forever)
+ self.assertRaises(
+ NotImplementedError, loop.run_until_complete, None)
+ self.assertRaises(
+ NotImplementedError, loop.stop)
+ self.assertRaises(
+ NotImplementedError, loop.is_running)
+ self.assertRaises(
+ NotImplementedError, loop.close)
+ self.assertRaises(
+ NotImplementedError, loop.call_later, None, None)
+ self.assertRaises(
+ NotImplementedError, loop.call_at, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon, None)
+ self.assertRaises(
+ NotImplementedError, loop.time)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon_threadsafe, None)
+ self.assertRaises(
+ NotImplementedError, loop.run_in_executor, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.set_default_executor, f)
+ self.assertRaises(
+ NotImplementedError, loop.getaddrinfo, 'localhost', 8080)
+ self.assertRaises(
+ NotImplementedError, loop.getnameinfo, ('localhost', 8080))
+ self.assertRaises(
+ NotImplementedError, loop.create_connection, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_server, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_datagram_endpoint, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_reader, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_reader, 1)
+ self.assertRaises(
+ NotImplementedError, loop.add_writer, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_writer, 1)
+ self.assertRaises(
+ NotImplementedError, loop.sock_recv, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_sendall, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_connect, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.sock_accept, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_signal_handler, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.connect_read_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.connect_write_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_shell, f,
+ unittest.mock.sentinel)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_exec, f)
+
+
+class ProtocolsAbsTests(unittest.TestCase):
+
+ def test_empty(self):
+ f = unittest.mock.Mock()
+ p = protocols.Protocol()
+ self.assertIsNone(p.connection_made(f))
+ self.assertIsNone(p.connection_lost(f))
+ self.assertIsNone(p.data_received(f))
+ self.assertIsNone(p.eof_received())
+
+ dp = protocols.DatagramProtocol()
+ self.assertIsNone(dp.connection_made(f))
+ self.assertIsNone(dp.connection_lost(f))
+ self.assertIsNone(dp.error_received(f))
+ self.assertIsNone(dp.datagram_received(f, f))
+
+ sp = protocols.SubprocessProtocol()
+ self.assertIsNone(sp.connection_made(f))
+ self.assertIsNone(sp.connection_lost(f))
+ self.assertIsNone(sp.pipe_data_received(1, f))
+ self.assertIsNone(sp.pipe_connection_lost(1, f))
+ self.assertIsNone(sp.process_exited())
+
+
+class PolicyTests(unittest.TestCase):
+
+ def create_policy(self):
+ if sys.platform == "win32":
+ from asyncio import windows_events
+ return windows_events.DefaultEventLoopPolicy()
+ else:
+ from asyncio import unix_events
+ return unix_events.DefaultEventLoopPolicy()
+
+ def test_event_loop_policy(self):
+ policy = events.AbstractEventLoopPolicy()
+ self.assertRaises(NotImplementedError, policy.get_event_loop)
+ self.assertRaises(NotImplementedError, policy.set_event_loop, object())
+ self.assertRaises(NotImplementedError, policy.new_event_loop)
+ self.assertRaises(NotImplementedError, policy.get_child_watcher)
+ self.assertRaises(NotImplementedError, policy.set_child_watcher,
+ object())
+
+ def test_get_event_loop(self):
+ policy = self.create_policy()
+ self.assertIsNone(policy._local._loop)
+
+ loop = policy.get_event_loop()
+ self.assertIsInstance(loop, events.AbstractEventLoop)
+
+ self.assertIs(policy._local._loop, loop)
+ self.assertIs(loop, policy.get_event_loop())
+ loop.close()
+
+ def test_get_event_loop_after_set_none(self):
+ policy = self.create_policy()
+ policy.set_event_loop(None)
+ self.assertRaises(AssertionError, policy.get_event_loop)
+
+ @unittest.mock.patch('asyncio.events.threading.current_thread')
+ def test_get_event_loop_thread(self, m_current_thread):
+
+ def f():
+ policy = self.create_policy()
+ self.assertRaises(AssertionError, policy.get_event_loop)
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ def test_new_event_loop(self):
+ policy = self.create_policy()
+
+ loop = policy.new_event_loop()
+ self.assertIsInstance(loop, events.AbstractEventLoop)
+ loop.close()
+
+ def test_set_event_loop(self):
+ policy = self.create_policy()
+ old_loop = policy.get_event_loop()
+
+ self.assertRaises(AssertionError, policy.set_event_loop, object())
+
+ loop = policy.new_event_loop()
+ policy.set_event_loop(loop)
+ self.assertIs(loop, policy.get_event_loop())
+ self.assertIsNot(old_loop, policy.get_event_loop())
+ loop.close()
+ old_loop.close()
+
+ def test_get_event_loop_policy(self):
+ policy = events.get_event_loop_policy()
+ self.assertIsInstance(policy, events.AbstractEventLoopPolicy)
+ self.assertIs(policy, events.get_event_loop_policy())
+
+ def test_set_event_loop_policy(self):
+ self.assertRaises(
+ AssertionError, events.set_event_loop_policy, object())
+
+ old_policy = events.get_event_loop_policy()
+
+ policy = self.create_policy()
+ events.set_event_loop_policy(policy)
+ self.assertIs(policy, events.get_event_loop_policy())
+ self.assertIsNot(policy, old_policy)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
new file mode 100644
index 0000000000..ccea2ffded
--- /dev/null
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -0,0 +1,329 @@
+"""Tests for futures.py."""
+
+import concurrent.futures
+import threading
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import test_utils
+
+
+def _fakefunc(f):
+ return f
+
+
+class FutureTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_initial_state(self):
+ f = futures.Future(loop=self.loop)
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.done())
+ f.cancel()
+ self.assertTrue(f.cancelled())
+
+ def test_init_constructor_default_loop(self):
+ try:
+ events.set_event_loop(self.loop)
+ f = futures.Future()
+ self.assertIs(f._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_constructor_positional(self):
+ # Make sure Future does't accept a positional argument
+ self.assertRaises(TypeError, futures.Future, 42)
+
+ def test_cancel(self):
+ f = futures.Future(loop=self.loop)
+ self.assertTrue(f.cancel())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(futures.CancelledError, f.result)
+ self.assertRaises(futures.CancelledError, f.exception)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_result(self):
+ f = futures.Future(loop=self.loop)
+ self.assertRaises(futures.InvalidStateError, f.result)
+
+ f.set_result(42)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 42)
+ self.assertEqual(f.exception(), None)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_exception(self):
+ exc = RuntimeError()
+ f = futures.Future(loop=self.loop)
+ self.assertRaises(futures.InvalidStateError, f.exception)
+
+ f.set_exception(exc)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(RuntimeError, f.result)
+ self.assertEqual(f.exception(), exc)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_yield_from_twice(self):
+ f = futures.Future(loop=self.loop)
+
+ def fixture():
+ yield 'A'
+ x = yield from f
+ yield 'B', x
+ y = yield from f
+ yield 'C', y
+
+ g = fixture()
+ self.assertEqual(next(g), 'A') # yield 'A'.
+ self.assertEqual(next(g), f) # First yield from f.
+ f.set_result(42)
+ self.assertEqual(next(g), ('B', 42)) # yield 'B', x.
+ # The second "yield from f" does not yield f.
+ self.assertEqual(next(g), ('C', 42)) # yield 'C', y.
+
+ def test_repr(self):
+ f_pending = futures.Future(loop=self.loop)
+ self.assertEqual(repr(f_pending), 'Future<PENDING>')
+ f_pending.cancel()
+
+ f_cancelled = futures.Future(loop=self.loop)
+ f_cancelled.cancel()
+ self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>')
+
+ f_result = futures.Future(loop=self.loop)
+ f_result.set_result(4)
+ self.assertEqual(repr(f_result), 'Future<result=4>')
+ self.assertEqual(f_result.result(), 4)
+
+ exc = RuntimeError()
+ f_exception = futures.Future(loop=self.loop)
+ f_exception.set_exception(exc)
+ self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>')
+ self.assertIs(f_exception.exception(), exc)
+
+ f_few_callbacks = futures.Future(loop=self.loop)
+ f_few_callbacks.add_done_callback(_fakefunc)
+ self.assertIn('Future<PENDING, [<function _fakefunc',
+ repr(f_few_callbacks))
+ f_few_callbacks.cancel()
+
+ f_many_callbacks = futures.Future(loop=self.loop)
+ for i in range(20):
+ f_many_callbacks.add_done_callback(_fakefunc)
+ r = repr(f_many_callbacks)
+ self.assertIn('Future<PENDING, [<function _fakefunc', r)
+ self.assertIn('<18 more>', r)
+ f_many_callbacks.cancel()
+
+ def test_copy_state(self):
+ # Test the internal _copy_state method since it's being directly
+ # invoked in other modules.
+ f = futures.Future(loop=self.loop)
+ f.set_result(10)
+
+ newf = futures.Future(loop=self.loop)
+ newf._copy_state(f)
+ self.assertTrue(newf.done())
+ self.assertEqual(newf.result(), 10)
+
+ f_exception = futures.Future(loop=self.loop)
+ f_exception.set_exception(RuntimeError())
+
+ newf_exception = futures.Future(loop=self.loop)
+ newf_exception._copy_state(f_exception)
+ self.assertTrue(newf_exception.done())
+ self.assertRaises(RuntimeError, newf_exception.result)
+
+ f_cancelled = futures.Future(loop=self.loop)
+ f_cancelled.cancel()
+
+ newf_cancelled = futures.Future(loop=self.loop)
+ newf_cancelled._copy_state(f_cancelled)
+ self.assertTrue(newf_cancelled.cancelled())
+
+ def test_iter(self):
+ fut = futures.Future(loop=self.loop)
+
+ def coro():
+ yield from fut
+
+ def test():
+ arg1, arg2 = coro()
+
+ self.assertRaises(AssertionError, test)
+ fut.cancel()
+
+ @unittest.mock.patch('asyncio.futures.logger')
+ def test_tb_logger_abandoned(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.logger')
+ def test_tb_logger_result_unretrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.logger')
+ def test_tb_logger_result_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ fut.result()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.logger')
+ def test_tb_logger_exception_unretrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ del fut
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.logger')
+ def test_tb_logger_exception_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ fut.exception()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.logger')
+ def test_tb_logger_exception_result_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ self.assertRaises(RuntimeError, fut.result)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ def test_wrap_future(self):
+
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = futures.wrap_future(f1, loop=self.loop)
+ res, ident = self.loop.run_until_complete(f2)
+ self.assertIsInstance(f2, futures.Future)
+ self.assertEqual(res, 'oi')
+ self.assertNotEqual(ident, threading.get_ident())
+
+ def test_wrap_future_future(self):
+ f1 = futures.Future(loop=self.loop)
+ f2 = futures.wrap_future(f1)
+ self.assertIs(f1, f2)
+
+ @unittest.mock.patch('asyncio.futures.events')
+ def test_wrap_future_use_global_loop(self, m_events):
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = futures.wrap_future(f1)
+ self.assertIs(m_events.get_event_loop.return_value, f2._loop)
+
+
+class FutureDoneCallbackTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def run_briefly(self):
+ test_utils.run_briefly(self.loop)
+
+ def _make_callback(self, bag, thing):
+ # Create a callback function that appends thing to bag.
+ def bag_appender(future):
+ bag.append(thing)
+ return bag_appender
+
+ def _new_future(self):
+ return futures.Future(loop=self.loop)
+
+ def test_callbacks_invoked_on_set_result(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 42))
+ f.add_done_callback(self._make_callback(bag, 17))
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [42, 17])
+ self.assertEqual(f.result(), 'foo')
+
+ def test_callbacks_invoked_on_set_exception(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 100))
+
+ self.assertEqual(bag, [])
+ exc = RuntimeError()
+ f.set_exception(exc)
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [100])
+ self.assertEqual(f.exception(), exc)
+
+ def test_remove_done_callback(self):
+ bag = []
+ f = self._new_future()
+ cb1 = self._make_callback(bag, 1)
+ cb2 = self._make_callback(bag, 2)
+ cb3 = self._make_callback(bag, 3)
+
+ # Add one cb1 and one cb2.
+ f.add_done_callback(cb1)
+ f.add_done_callback(cb2)
+
+ # One instance of cb2 removed. Now there's only one cb1.
+ self.assertEqual(f.remove_done_callback(cb2), 1)
+
+ # Never had any cb3 in there.
+ self.assertEqual(f.remove_done_callback(cb3), 0)
+
+ # After this there will be 6 instances of cb1 and one of cb2.
+ f.add_done_callback(cb2)
+ for i in range(5):
+ f.add_done_callback(cb1)
+
+ # Remove all instances of cb1. One cb2 remains.
+ self.assertEqual(f.remove_done_callback(cb1), 6)
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [2])
+ self.assertEqual(f.result(), 'foo')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
new file mode 100644
index 0000000000..19ef877af0
--- /dev/null
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -0,0 +1,834 @@
+"""Tests for lock.py"""
+
+import unittest
+import unittest.mock
+import re
+
+from asyncio import events
+from asyncio import futures
+from asyncio import locks
+from asyncio import tasks
+from asyncio import test_utils
+
+
+STR_RGX_REPR = (
+ r'^<(?P<class>.*?) object at (?P<address>.*?)'
+ r'\[(?P<extras>'
+ r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?'
+ r')\]>\Z'
+)
+RGX_REPR = re.compile(STR_RGX_REPR)
+
+
+class LockTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ lock = locks.Lock(loop=loop)
+ self.assertIs(lock._loop, loop)
+
+ lock = locks.Lock(loop=self.loop)
+ self.assertIs(lock._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ lock = locks.Lock()
+ self.assertIs(lock._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ lock = locks.Lock(loop=self.loop)
+ self.assertTrue(repr(lock).endswith('[unlocked]>'))
+ self.assertTrue(RGX_REPR.match(repr(lock)))
+
+ @tasks.coroutine
+ def acquire_lock():
+ yield from lock
+
+ self.loop.run_until_complete(acquire_lock())
+ self.assertTrue(repr(lock).endswith('[locked]>'))
+ self.assertTrue(RGX_REPR.match(repr(lock)))
+
+ def test_lock(self):
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_acquire(self):
+ lock = locks.Lock(loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from lock.acquire()):
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from lock.acquire()):
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from lock.acquire()):
+ result.append(3)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_acquire_cancel(self):
+ lock = locks.Lock(loop=self.loop)
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ task = tasks.Task(lock.acquire(), loop=self.loop)
+ self.loop.call_soon(task.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, task)
+ self.assertFalse(lock._waiters)
+
+ def test_cancel_race(self):
+ # Several tasks:
+ # - A acquires the lock
+ # - B is blocked in aqcuire()
+ # - C is blocked in aqcuire()
+ #
+ # Now, concurrently:
+ # - B is cancelled
+ # - A releases the lock
+ #
+ # If B's waiter is marked cancelled but not yet removed from
+ # _waiters, A's release() call will crash when trying to set
+ # B's waiter; instead, it should move on to C's waiter.
+
+ # Setup: A has the lock, b and c are waiting.
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def lockit(name, blocker):
+ yield from lock.acquire()
+ try:
+ if blocker is not None:
+ yield from blocker
+ finally:
+ lock.release()
+
+ fa = futures.Future(loop=self.loop)
+ ta = tasks.Task(lockit('A', fa), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(lock.locked())
+ tb = tasks.Task(lockit('B', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 1)
+ tc = tasks.Task(lockit('C', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 2)
+
+ # Create the race and check.
+ # Without the fix this failed at the last assert.
+ fa.set_result(None)
+ tb.cancel()
+ self.assertTrue(lock._waiters[0].cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(lock.locked())
+ self.assertTrue(ta.done())
+ self.assertTrue(tb.cancelled())
+ self.assertTrue(tc.done())
+
+ def test_release_not_acquired(self):
+ lock = locks.Lock(loop=self.loop)
+
+ self.assertRaises(RuntimeError, lock.release)
+
+ def test_release_no_waiters(self):
+ lock = locks.Lock(loop=self.loop)
+ self.loop.run_until_complete(lock.acquire())
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_context_manager(self):
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ def test_context_manager_no_yield(self):
+ lock = locks.Lock(loop=self.loop)
+
+ try:
+ with lock:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+
+class EventTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ ev = locks.Event(loop=loop)
+ self.assertIs(ev._loop, loop)
+
+ ev = locks.Event(loop=self.loop)
+ self.assertIs(ev._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ ev = locks.Event()
+ self.assertIs(ev._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertTrue(repr(ev).endswith('[unset]>'))
+ match = RGX_REPR.match(repr(ev))
+ self.assertEqual(match.group('extras'), 'unset')
+
+ ev.set()
+ self.assertTrue(repr(ev).endswith('[set]>'))
+ self.assertTrue(RGX_REPR.match(repr(ev)))
+
+ ev._waiters.append(unittest.mock.Mock())
+ self.assertTrue('waiters:1' in repr(ev))
+ self.assertTrue(RGX_REPR.match(repr(ev)))
+
+ def test_wait(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from ev.wait()):
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from ev.wait()):
+ result.append(3)
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ ev.set()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([3, 1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertIsNone(t1.result())
+ self.assertTrue(t2.done())
+ self.assertIsNone(t2.result())
+ self.assertTrue(t3.done())
+ self.assertIsNone(t3.result())
+
+ def test_wait_on_set(self):
+ ev = locks.Event(loop=self.loop)
+ ev.set()
+
+ res = self.loop.run_until_complete(ev.wait())
+ self.assertTrue(res)
+
+ def test_wait_cancel(self):
+ ev = locks.Event(loop=self.loop)
+
+ wait = tasks.Task(ev.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(ev._waiters)
+
+ def test_clear(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ self.assertTrue(ev.is_set())
+
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ def test_clear_with_waiters(self):
+ ev = locks.Event(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+ return True
+
+ t = tasks.Task(c1(result), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ ev.set()
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ ev.set()
+ self.assertEqual(1, len(ev._waiters))
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertEqual(0, len(ev._waiters))
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+
+class ConditionTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ cond = locks.Condition(loop=loop)
+ self.assertIs(cond._loop, loop)
+
+ cond = locks.Condition(loop=self.loop)
+ self.assertIs(cond._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ cond = locks.Condition()
+ self.assertIs(cond._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_wait(self):
+ cond = locks.Condition(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertFalse(cond.locked())
+
+ self.assertTrue(self.loop.run_until_complete(cond.acquire()))
+ cond.notify()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.notify(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(cond.locked())
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_wait_cancel(self):
+ cond = locks.Condition(loop=self.loop)
+ self.loop.run_until_complete(cond.acquire())
+
+ wait = tasks.Task(cond.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(cond._waiters)
+ self.assertTrue(cond.locked())
+
+ def test_wait_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, cond.wait())
+
+ def test_wait_for(self):
+ cond = locks.Condition(loop=self.loop)
+ presult = False
+
+ def predicate():
+ return presult
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait_for(predicate)):
+ result.append(1)
+ cond.release()
+ return True
+
+ t = tasks.Task(c1(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ presult = True
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_wait_for_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+
+ # predicate can return true immediately
+ res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3]))
+ self.assertEqual([1, 2, 3], res)
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete,
+ cond.wait_for(lambda: False))
+
+ def test_notify(self):
+ cond = locks.Condition(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ cond.release()
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.notify(2048)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_notify_all(self):
+ cond = locks.Condition(loop=self.loop)
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify_all()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+
+ def test_notify_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify)
+
+ def test_notify_all_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify_all)
+
+ def test_repr(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertTrue('unlocked' in repr(cond))
+ self.assertTrue(RGX_REPR.match(repr(cond)))
+
+ self.loop.run_until_complete(cond.acquire())
+ self.assertTrue('locked' in repr(cond))
+
+ cond._waiters.append(unittest.mock.Mock())
+ self.assertTrue('waiters:1' in repr(cond))
+ self.assertTrue(RGX_REPR.match(repr(cond)))
+
+ cond._waiters.append(unittest.mock.Mock())
+ self.assertTrue('waiters:2' in repr(cond))
+ self.assertTrue(RGX_REPR.match(repr(cond)))
+
+ def test_context_manager(self):
+ cond = locks.Condition(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_cond():
+ return (yield from cond)
+
+ with self.loop.run_until_complete(acquire_cond()):
+ self.assertTrue(cond.locked())
+
+ self.assertFalse(cond.locked())
+
+ def test_context_manager_no_yield(self):
+ cond = locks.Condition(loop=self.loop)
+
+ try:
+ with cond:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+
+class SemaphoreTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ sem = locks.Semaphore(loop=loop)
+ self.assertIs(sem._loop, loop)
+
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertIs(sem._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ sem = locks.Semaphore()
+ self.assertIs(sem._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(repr(sem).endswith('[locked]>'))
+ self.assertTrue('waiters' not in repr(sem))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ sem._waiters.append(unittest.mock.Mock())
+ self.assertTrue('waiters:1' in repr(sem))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ sem._waiters.append(unittest.mock.Mock())
+ self.assertTrue('waiters:2' in repr(sem))
+ self.assertTrue(RGX_REPR.match(repr(sem)))
+
+ def test_semaphore(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertEqual(1, sem._value)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(sem.locked())
+ self.assertEqual(0, sem._value)
+
+ sem.release()
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ def test_semaphore_value(self):
+ self.assertRaises(ValueError, locks.Semaphore, -1)
+
+ def test_acquire(self):
+ sem = locks.Semaphore(3, loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertFalse(sem.locked())
+
+ @tasks.coroutine
+ def c1(result):
+ yield from sem.acquire()
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from sem.acquire()
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from sem.acquire()
+ result.append(3)
+ return True
+
+ @tasks.coroutine
+ def c4(result):
+ yield from sem.acquire()
+ result.append(4)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(2, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ t4 = tasks.Task(c4(result), loop=self.loop)
+
+ sem.release()
+ sem.release()
+ self.assertEqual(2, sem._value)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(0, sem._value)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(1, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+ self.assertFalse(t4.done())
+
+ # cleanup locked semaphore
+ sem.release()
+
+ def test_acquire_cancel(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+
+ acquire = tasks.Task(sem.acquire(), loop=self.loop)
+ self.loop.call_soon(acquire.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, acquire)
+ self.assertFalse(sem._waiters)
+
+ def test_release_not_acquired(self):
+ sem = locks.Semaphore(bound=True, loop=self.loop)
+
+ self.assertRaises(ValueError, sem.release)
+
+ def test_release_no_waiters(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(sem.locked())
+
+ sem.release()
+ self.assertFalse(sem.locked())
+
+ def test_context_manager(self):
+ sem = locks.Semaphore(2, loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(sem.locked())
+
+ self.assertEqual(2, sem._value)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
new file mode 100644
index 0000000000..5a2a51c42e
--- /dev/null
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -0,0 +1,484 @@
+"""Tests for proactor_events.py"""
+
+import socket
+import unittest
+import unittest.mock
+
+import asyncio
+from asyncio.proactor_events import BaseProactorEventLoop
+from asyncio.proactor_events import _ProactorSocketTransport
+from asyncio.proactor_events import _ProactorWritePipeTransport
+from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio import test_utils
+
+
+class ProactorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.proactor = unittest.mock.Mock()
+ self.loop._proactor = self.proactor
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+
+ def test_ctor(self):
+ fut = asyncio.Future(loop=self.loop)
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+ self.protocol.connection_made(tr)
+ self.proactor.recv.assert_called_with(self.sock, 4096)
+
+ def test_loop_reading(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._loop_reading()
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertFalse(self.protocol.eof_received.called)
+
+ def test_loop_reading_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'data')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_loop_reading_no_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+
+ self.assertRaises(AssertionError, tr._loop_reading, res)
+
+ tr.close = unittest.mock.Mock()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.assertFalse(self.loop._proactor.recv.called)
+ self.assertTrue(self.protocol.eof_received.called)
+ self.assertTrue(tr.close.called)
+
+ def test_loop_reading_aborted(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(err)
+
+ def test_loop_reading_aborted_closing(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = True
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+
+ def test_loop_reading_aborted_is_fatal(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = False
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertTrue(tr._fatal_error.called)
+
+ def test_loop_reading_conn_reset_lost(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionResetError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = False
+ tr._fatal_error = unittest.mock.Mock()
+ tr._force_close = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+ tr._force_close.assert_called_with(err)
+
+ def test_loop_reading_exception(self):
+ err = self.loop._proactor.recv.side_effect = (OSError())
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(err)
+
+ def test_write(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._loop_writing = unittest.mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [b'data'])
+ self.assertTrue(tr._loop_writing.called)
+
+ def test_write_no_data(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr.write(b'')
+ self.assertFalse(tr._buffer)
+
+ def test_write_more(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = unittest.mock.Mock()
+ tr._loop_writing = unittest.mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [b'data'])
+ self.assertFalse(tr._loop_writing.called)
+
+ def test_loop_writing(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ self.loop._proactor.send.assert_called_with(self.sock, b'data')
+ self.loop._proactor.send.return_value.add_done_callback.\
+ assert_called_with(tr._loop_writing)
+
+ @unittest.mock.patch('asyncio.proactor_events.logger')
+ def test_loop_writing_err(self, m_log):
+ err = self.loop._proactor.send.side_effect = OSError()
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ tr._fatal_error.assert_called_with(err)
+ tr._conn_lost = 1
+
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [])
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_loop_writing_stop(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(b'data')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = fut
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+
+ def test_loop_writing_closing(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(1)
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = fut
+ tr.close()
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_abort(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._force_close = unittest.mock.Mock()
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertTrue(tr._closing)
+ self.assertEqual(tr._conn_lost, 1)
+
+ self.protocol.connection_lost.reset_mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_write_fut(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = unittest.mock.Mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_buffer(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ @unittest.mock.patch('asyncio.proactor_events.logger')
+ def test_fatal_error(self, m_logging):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._force_close = unittest.mock.Mock()
+ tr._fatal_error(None)
+ self.assertTrue(tr._force_close.called)
+ self.assertTrue(m_logging.exception.called)
+
+ def test_force_close(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ read_fut = tr._read_fut = unittest.mock.Mock()
+ write_fut = tr._write_fut = unittest.mock.Mock()
+ tr._force_close(None)
+
+ read_fut.cancel.assert_called_with()
+ write_fut.cancel.assert_called_with()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual([], tr._buffer)
+ self.assertEqual(tr._conn_lost, 1)
+
+ def test_force_close_idempotent(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = True
+ tr._force_close(None)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_fatal_error_2(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ tr._force_close(None)
+
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual([], tr._buffer)
+
+ def test_call_connection_lost(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._call_connection_lost(None)
+ self.assertTrue(self.protocol.connection_lost.called)
+ self.assertTrue(self.sock.close.called)
+
+ def test_write_eof(self):
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._eof_written)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+ def test_write_eof_write_pipe(self):
+ tr = _ProactorWritePipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_buffer_write_pipe(self):
+ tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_duplex_pipe(self):
+ tr = _ProactorDuplexPipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr.can_write_eof())
+ with self.assertRaises(NotImplementedError):
+ tr.write_eof()
+ tr.close()
+
+ def test_pause_resume_reading(self):
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ futures = []
+ for msg in [b'data1', b'data2', b'data3', b'data4', b'']:
+ f = asyncio.Future(loop=self.loop)
+ f.set_result(msg)
+ futures.append(f)
+ self.loop._proactor.recv.side_effect = futures
+ self.loop._run_once()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data1')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.pause_reading()
+ self.assertTrue(tr._paused)
+ for i in range(10):
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.resume_reading()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data3')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data4')
+ tr.close()
+
+
+class BaseProactorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.proactor = unittest.mock.Mock()
+
+ self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock()
+
+ class EventLoop(BaseProactorEventLoop):
+ def _socketpair(s):
+ return (self.ssock, self.csock)
+
+ self.loop = EventLoop(self.proactor)
+
+ @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon')
+ @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair')
+ def test_ctor(self, socketpair, call_soon):
+ ssock, csock = socketpair.return_value = (
+ unittest.mock.Mock(), unittest.mock.Mock())
+ loop = BaseProactorEventLoop(self.proactor)
+ self.assertIs(loop._ssock, ssock)
+ self.assertIs(loop._csock, csock)
+ self.assertEqual(loop._internal_fds, 1)
+ call_soon.assert_called_with(loop._loop_self_reading)
+
+ def test_close_self_pipe(self):
+ self.loop._close_self_pipe()
+ self.assertEqual(self.loop._internal_fds, 0)
+ self.assertTrue(self.ssock.close.called)
+ self.assertTrue(self.csock.close.called)
+ self.assertIsNone(self.loop._ssock)
+ self.assertIsNone(self.loop._csock)
+
+ def test_close(self):
+ self.loop._close_self_pipe = unittest.mock.Mock()
+ self.loop.close()
+ self.assertTrue(self.loop._close_self_pipe.called)
+ self.assertTrue(self.proactor.close.called)
+ self.assertIsNone(self.loop._proactor)
+
+ self.loop._close_self_pipe.reset_mock()
+ self.loop.close()
+ self.assertFalse(self.loop._close_self_pipe.called)
+
+ def test_sock_recv(self):
+ self.loop.sock_recv(self.sock, 1024)
+ self.proactor.recv.assert_called_with(self.sock, 1024)
+
+ def test_sock_sendall(self):
+ self.loop.sock_sendall(self.sock, b'data')
+ self.proactor.send.assert_called_with(self.sock, b'data')
+
+ def test_sock_connect(self):
+ self.loop.sock_connect(self.sock, 123)
+ self.proactor.connect.assert_called_with(self.sock, 123)
+
+ def test_sock_accept(self):
+ self.loop.sock_accept(self.sock)
+ self.proactor.accept.assert_called_with(self.sock)
+
+ def test_socketpair(self):
+ self.assertRaises(
+ NotImplementedError, BaseProactorEventLoop, self.proactor)
+
+ def test_make_socket_transport(self):
+ tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
+ self.assertIsInstance(tr, _ProactorSocketTransport)
+
+ def test_loop_self_reading(self):
+ self.loop._loop_self_reading()
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_fut(self):
+ fut = unittest.mock.Mock()
+ self.loop._loop_self_reading(fut)
+ self.assertTrue(fut.result.called)
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_exception(self):
+ self.loop.close = unittest.mock.Mock()
+ self.proactor.recv.side_effect = OSError()
+ self.assertRaises(OSError, self.loop._loop_self_reading)
+ self.assertTrue(self.loop.close.called)
+
+ def test_write_to_self(self):
+ self.loop._write_to_self()
+ self.csock.send.assert_called_with(b'x')
+
+ def test_process_events(self):
+ self.loop._process_events([])
+
+ @unittest.mock.patch('asyncio.proactor_events.logger')
+ def test_create_server(self, m_log):
+ pf = unittest.mock.Mock()
+ call_soon = self.loop.call_soon = unittest.mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ self.assertTrue(call_soon.called)
+
+ # callback
+ loop = call_soon.call_args[0][0]
+ loop()
+ self.proactor.accept.assert_called_with(self.sock)
+
+ # conn
+ fut = unittest.mock.Mock()
+ fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock())
+
+ make_tr = self.loop._make_socket_transport = unittest.mock.Mock()
+ loop(fut)
+ self.assertTrue(fut.result.called)
+ self.assertTrue(make_tr.called)
+
+ # exception
+ fut.result.side_effect = OSError()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+ self.assertTrue(m_log.exception.called)
+
+ def test_create_server_cancel(self):
+ pf = unittest.mock.Mock()
+ call_soon = self.loop.call_soon = unittest.mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ loop = call_soon.call_args[0][0]
+
+ # cancelled
+ fut = asyncio.Future(loop=self.loop)
+ fut.cancel()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+
+ def test_stop_serving(self):
+ sock = unittest.mock.Mock()
+ self.loop._stop_serving(sock)
+ self.assertTrue(sock.close.called)
+ self.proactor._stop_serving.assert_called_with(sock)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py
new file mode 100644
index 0000000000..8af4ee7f9b
--- /dev/null
+++ b/Lib/test/test_asyncio/test_queues.py
@@ -0,0 +1,470 @@
+"""Tests for queues.py"""
+
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import locks
+from asyncio import queues
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class _QueueTestBase(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+
+class QueueBasicTests(_QueueTestBase):
+
+ def _test_repr_or_str(self, fn, expect_id):
+ """Test Queue's repr or str.
+
+ fn is repr or str. expect_id is True if we expect the Queue's id to
+ appear in fn(Queue()).
+ """
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+ self.assertTrue(fn(q).startswith('<Queue'), fn(q))
+ id_is_present = hex(id(q)) in fn(q)
+ self.assertEqual(expect_id, id_is_present)
+
+ @tasks.coroutine
+ def add_getter():
+ q = queues.Queue(loop=loop)
+ # Start a task that waits to get.
+ tasks.Task(q.get(), loop=loop)
+ # Let it start waiting.
+ yield from tasks.sleep(0.1, loop=loop)
+ self.assertTrue('_getters[1]' in fn(q))
+ # resume q.get coroutine to finish generator
+ q.put_nowait(0)
+
+ loop.run_until_complete(add_getter())
+
+ @tasks.coroutine
+ def add_putter():
+ q = queues.Queue(maxsize=1, loop=loop)
+ q.put_nowait(1)
+ # Start a task that waits to put.
+ tasks.Task(q.put(2), loop=loop)
+ # Let it start waiting.
+ yield from tasks.sleep(0.1, loop=loop)
+ self.assertTrue('_putters[1]' in fn(q))
+ # resume q.put coroutine to finish generator
+ q.get_nowait()
+
+ loop.run_until_complete(add_putter())
+
+ q = queues.Queue(loop=loop)
+ q.put_nowait(1)
+ self.assertTrue('_queue=[1]' in fn(q))
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ q = queues.Queue(loop=loop)
+ self.assertIs(q._loop, loop)
+
+ q = queues.Queue(loop=self.loop)
+ self.assertIs(q._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ q = queues.Queue()
+ self.assertIs(q._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ self._test_repr_or_str(repr, True)
+
+ def test_str(self):
+ self._test_repr_or_str(str, False)
+
+ def test_empty(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertTrue(q.empty())
+ q.put_nowait(1)
+ self.assertFalse(q.empty())
+ self.assertEqual(1, q.get_nowait())
+ self.assertTrue(q.empty())
+
+ def test_full(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertFalse(q.full())
+
+ q = queues.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertTrue(q.full())
+
+ def test_order(self):
+ q = queues.Queue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 3, 2], items)
+
+ def test_maxsize(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.02, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(maxsize=2, loop=loop)
+ self.assertEqual(2, q.maxsize)
+ have_been_put = []
+
+ @tasks.coroutine
+ def putter():
+ for i in range(3):
+ yield from q.put(i)
+ have_been_put.append(i)
+ return True
+
+ @tasks.coroutine
+ def test():
+ t = tasks.Task(putter(), loop=loop)
+ yield from tasks.sleep(0.01, loop=loop)
+
+ # The putter is blocked after putting two items.
+ self.assertEqual([0, 1], have_been_put)
+ self.assertEqual(0, q.get_nowait())
+
+ # Let the putter resume and put last item.
+ yield from tasks.sleep(0.01, loop=loop)
+ self.assertEqual([0, 1, 2], have_been_put)
+ self.assertEqual(1, q.get_nowait())
+ self.assertEqual(2, q.get_nowait())
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ loop.run_until_complete(test())
+ self.assertAlmostEqual(0.02, loop.time())
+
+
+class QueueGetTests(_QueueTestBase):
+
+ def test_blocking_get(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from q.get())
+
+ res = self.loop.run_until_complete(queue_get())
+ self.assertEqual(1, res)
+
+ def test_get_with_putters(self):
+ q = queues.Queue(1, loop=self.loop)
+ q.put_nowait(1)
+
+ waiter = futures.Future(loop=self.loop)
+ q._putters.append((2, waiter))
+
+ res = self.loop.run_until_complete(q.get())
+ self.assertEqual(1, res)
+ self.assertTrue(waiter.done())
+ self.assertIsNone(waiter.result())
+
+ def test_blocking_get_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+ started = locks.Event(loop=loop)
+ finished = False
+
+ @tasks.coroutine
+ def queue_get():
+ nonlocal finished
+ started.set()
+ res = yield from q.get()
+ finished = True
+ return res
+
+ @tasks.coroutine
+ def queue_put():
+ loop.call_later(0.01, q.put_nowait, 1)
+ queue_get_task = tasks.Task(queue_get(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ res = yield from queue_get_task
+ self.assertTrue(finished)
+ return res
+
+ res = loop.run_until_complete(queue_put())
+ self.assertEqual(1, res)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_get(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_get_exception(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertRaises(queues.Empty, q.get_nowait)
+
+ def test_get_cancelled(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.061, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from tasks.wait_for(q.get(), 0.051, loop=loop))
+
+ @tasks.coroutine
+ def test():
+ get_task = tasks.Task(queue_get(), loop=loop)
+ yield from tasks.sleep(0.01, loop=loop) # let the task start
+ q.put_nowait(1)
+ return (yield from get_task)
+
+ self.assertEqual(1, loop.run_until_complete(test()))
+ self.assertAlmostEqual(0.06, loop.time())
+
+ def test_get_cancelled_race(self):
+ q = queues.Queue(loop=self.loop)
+
+ t1 = tasks.Task(q.get(), loop=self.loop)
+ t2 = tasks.Task(q.get(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t1.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t1.done())
+ q.put_nowait('a')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(t2.result(), 'a')
+
+ def test_get_with_waiting_putters(self):
+ q = queues.Queue(loop=self.loop, maxsize=1)
+ tasks.Task(q.put('a'), loop=self.loop)
+ tasks.Task(q.put('b'), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'b')
+
+
+class QueuePutTests(_QueueTestBase):
+
+ def test_blocking_put(self):
+ q = queues.Queue(loop=self.loop)
+
+ @tasks.coroutine
+ def queue_put():
+ # No maxsize, won't block.
+ yield from q.put(1)
+
+ self.loop.run_until_complete(queue_put())
+
+ def test_blocking_put_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(maxsize=1, loop=loop)
+ started = locks.Event(loop=loop)
+ finished = False
+
+ @tasks.coroutine
+ def queue_put():
+ nonlocal finished
+ started.set()
+ yield from q.put(1)
+ yield from q.put(2)
+ finished = True
+
+ @tasks.coroutine
+ def queue_get():
+ loop.call_later(0.01, q.get_nowait)
+ queue_put_task = tasks.Task(queue_put(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ yield from queue_put_task
+ self.assertTrue(finished)
+
+ loop.run_until_complete(queue_get())
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_put(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_put_exception(self):
+ q = queues.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertRaises(queues.Full, q.put_nowait, 2)
+
+ def test_put_cancelled(self):
+ q = queues.Queue(loop=self.loop)
+
+ @tasks.coroutine
+ def queue_put():
+ yield from q.put(1)
+ return True
+
+ @tasks.coroutine
+ def test():
+ return (yield from q.get())
+
+ t = tasks.Task(queue_put(), loop=self.loop)
+ self.assertEqual(1, self.loop.run_until_complete(test()))
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_put_cancelled_race(self):
+ q = queues.Queue(loop=self.loop, maxsize=1)
+
+ tasks.Task(q.put('a'), loop=self.loop)
+ tasks.Task(q.put('c'), loop=self.loop)
+ t = tasks.Task(q.put('b'), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t.done())
+ self.assertEqual(q.get_nowait(), 'a')
+ self.assertEqual(q.get_nowait(), 'c')
+
+ def test_put_with_waiting_getters(self):
+ q = queues.Queue(loop=self.loop)
+ t = tasks.Task(q.get(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.loop.run_until_complete(q.put('a'))
+ self.assertEqual(self.loop.run_until_complete(t), 'a')
+
+
+class LifoQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.LifoQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([2, 3, 1], items)
+
+
+class PriorityQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.PriorityQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 2, 3], items)
+
+
+class JoinableQueueTests(_QueueTestBase):
+
+ def test_task_done_underflow(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ self.assertRaises(ValueError, q.task_done)
+
+ def test_task_done(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ for i in range(100):
+ q.put_nowait(i)
+
+ accumulator = 0
+
+ # Two workers get items from the queue and call task_done after each.
+ # Join the queue and assert all items have been processed.
+ running = True
+
+ @tasks.coroutine
+ def worker():
+ nonlocal accumulator
+
+ while running:
+ item = yield from q.get()
+ accumulator += item
+ q.task_done()
+
+ @tasks.coroutine
+ def test():
+ for _ in range(2):
+ tasks.Task(worker(), loop=self.loop)
+
+ yield from q.join()
+
+ self.loop.run_until_complete(test())
+ self.assertEqual(sum(range(100)), accumulator)
+
+ # close running generators
+ running = False
+ for i in range(2):
+ q.put_nowait(0)
+
+ def test_join_empty_queue(self):
+ q = queues.JoinableQueue(loop=self.loop)
+
+ # Test that a queue join()s successfully, and before anything else
+ # (done twice for insurance).
+
+ @tasks.coroutine
+ def join():
+ yield from q.join()
+ yield from q.join()
+
+ self.loop.run_until_complete(join())
+
+ def test_format(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ self.assertEqual(q._format(), 'maxsize=0')
+
+ q._unfinished_tasks = 2
+ self.assertEqual(q._format(), 'maxsize=0 tasks=2')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
new file mode 100644
index 0000000000..4aef2fde88
--- /dev/null
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -0,0 +1,1554 @@
+"""Tests for selector_events.py"""
+
+import collections
+import errno
+import gc
+import pprint
+import socket
+import sys
+import unittest
+import unittest.mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from asyncio import futures
+from asyncio import selectors
+from asyncio import test_utils
+from asyncio.protocols import DatagramProtocol, Protocol
+from asyncio.selector_events import BaseSelectorEventLoop
+from asyncio.selector_events import _SelectorTransport
+from asyncio.selector_events import _SelectorSslTransport
+from asyncio.selector_events import _SelectorSocketTransport
+from asyncio.selector_events import _SelectorDatagramTransport
+
+
+class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
+
+ def _make_self_pipe(self):
+ self._ssock = unittest.mock.Mock()
+ self._csock = unittest.mock.Mock()
+ self._internal_fds += 1
+
+
+class BaseSelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock())
+
+ def test_make_socket_transport(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_make_ssl_transport(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
+
+ @unittest.mock.patch('asyncio.selector_events.ssl', None)
+ def test_make_ssl_transport_without_ssl_error(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+ with self.assertRaises(RuntimeError):
+ self.loop._make_ssl_transport(m, m, m, m)
+
+ def test_close(self):
+ ssock = self.loop._ssock
+ ssock.fileno.return_value = 7
+ csock = self.loop._csock
+ csock.fileno.return_value = 1
+ remove_reader = self.loop.remove_reader = unittest.mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = selector = unittest.mock.Mock()
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+ self.assertIsNone(self.loop._csock)
+ self.assertIsNone(self.loop._ssock)
+ selector.close.assert_called_with()
+ ssock.close.assert_called_with()
+ csock.close.assert_called_with()
+ remove_reader.assert_called_with(7)
+
+ self.loop.close()
+ self.loop.close()
+
+ def test_close_no_selector(self):
+ ssock = self.loop._ssock
+ csock = self.loop._csock
+ remove_reader = self.loop.remove_reader = unittest.mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = None
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+ self.assertFalse(ssock.close.called)
+ self.assertFalse(csock.close.called)
+ self.assertFalse(remove_reader.called)
+
+ def test_socketpair(self):
+ self.assertRaises(NotImplementedError, self.loop._socketpair)
+
+ def test_read_from_self_tryagain(self):
+ self.loop._ssock.recv.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._read_from_self())
+
+ def test_read_from_self_exception(self):
+ self.loop._ssock.recv.side_effect = OSError
+ self.assertRaises(OSError, self.loop._read_from_self)
+
+ def test_write_to_self_tryagain(self):
+ self.loop._csock.send.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._write_to_self())
+
+ def test_write_to_self_exception(self):
+ self.loop._csock.send.side_effect = OSError()
+ self.assertRaises(OSError, self.loop._write_to_self)
+
+ def test_sock_recv(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_recv = unittest.mock.Mock()
+
+ f = self.loop.sock_recv(sock, 1024)
+ self.assertIsInstance(f, futures.Future)
+ self.loop._sock_recv.assert_called_with(f, False, sock, 1024)
+
+ def test__sock_recv_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertFalse(sock.recv.called)
+
+ def test__sock_recv_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._sock_recv(f, True, sock, 1024)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_recv_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.recv.side_effect = BlockingIOError
+
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_recv_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.recv.side_effect = OSError()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertIs(err, f.exception())
+
+ def test_sock_sendall(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'data')
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, b'data'),
+ self.loop._sock_sendall.call_args[0])
+
+ def test_sock_sendall_nodata(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'')
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertFalse(self.loop._sock_sendall.called)
+
+ def test__sock_sendall_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(sock.send.called)
+
+ def test__sock_sendall_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, True, sock, b'data')
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_sendall_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = BlockingIOError
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_interrupted(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = InterruptedError
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.send.side_effect = OSError()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertIs(f.exception(), err)
+
+ def test__sock_sendall(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 4
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test__sock_sendall_partial(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 2
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'ta'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_none(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 0
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test_sock_connect(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_connect = unittest.mock.Mock()
+
+ f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, ('127.0.0.1', 8080)),
+ self.loop._sock_connect.call_args[0])
+
+ def test__sock_connect(self):
+ f = futures.Future(loop=self.loop)
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertTrue(sock.connect.called)
+
+ def test__sock_connect_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertFalse(sock.connect.called)
+
+ def test__sock_connect_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_connect_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.EAGAIN
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual(
+ (10, self.loop._sock_connect, f,
+ True, sock, ('127.0.0.1', 8080)),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_connect_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.ENOTCONN
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f.exception(), OSError)
+
+ def test_sock_accept(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_accept = unittest.mock.Mock()
+
+ f = self.loop.sock_accept(sock)
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock), self.loop._sock_accept.call_args[0])
+
+ def test__sock_accept(self):
+ f = futures.Future(loop=self.loop)
+
+ conn = unittest.mock.Mock()
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.return_value = conn, ('127.0.0.1', 1000)
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertTrue(f.done())
+ self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
+ self.assertEqual((False,), conn.setblocking.call_args[0])
+
+ def test__sock_accept_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertFalse(sock.accept.called)
+
+ def test__sock_accept_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._sock_accept(f, True, sock)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_accept_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = BlockingIOError
+
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop._sock_accept(f, False, sock)
+ self.assertEqual(
+ (10, self.loop._sock_accept, f, True, sock),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_accept_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.accept.side_effect = OSError()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertIs(err, f.exception())
+
+ def test_add_reader(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertIsNone(w)
+
+ def test_add_reader_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (reader, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(reader.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_add_reader_existing_writer(self):
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_remove_reader(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (None, None))
+ self.assertFalse(self.loop.remove_reader(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_reader_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_reader(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE, (None, writer)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_reader_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_reader(1))
+
+ def test_add_writer(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE, mask)
+ self.assertIsNone(r)
+ self.assertEqual(cb, w._callback)
+
+ def test_add_writer_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, writer))
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(writer.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(reader, r)
+ self.assertEqual(cb, w._callback)
+
+ def test_remove_writer(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, None))
+ self.assertFalse(self.loop.remove_writer(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_writer_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_writer(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_READ, (reader, None)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_writer_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_writer(1))
+
+ def test_process_events_read(self):
+ reader = unittest.mock.Mock()
+ reader._cancelled = False
+
+ self.loop._add_callback = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.assertTrue(self.loop._add_callback.called)
+ self.loop._add_callback.assert_called_with(reader)
+
+ def test_process_events_read_cancelled(self):
+ reader = unittest.mock.Mock()
+ reader.cancelled = True
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.loop.remove_reader.assert_called_with(1)
+
+ def test_process_events_write(self):
+ writer = unittest.mock.Mock()
+ writer._cancelled = False
+
+ self.loop._add_callback = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop._add_callback.assert_called_with(writer)
+
+ def test_process_events_write_cancelled(self):
+ writer = unittest.mock.Mock()
+ writer.cancelled = True
+ self.loop.remove_writer = unittest.mock.Mock()
+
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop.remove_writer.assert_called_with(1)
+
+
+class SelectorTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def test_ctor(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ self.assertIs(tr._loop, self.loop)
+ self.assertIs(tr._sock, self.sock)
+ self.assertIs(tr._sock_fd, 7)
+
+ def test_abort(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._force_close = unittest.mock.Mock()
+
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr.close()
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+ self.protocol.connection_lost(None)
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ def test_close_write_buffer(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._buffer.append(b'data')
+ tr.close()
+
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_force_close(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._buffer.append(b'1')
+ self.loop.add_reader(7, unittest.mock.sentinel)
+ self.loop.add_writer(7, unittest.mock.sentinel)
+ tr._force_close(None)
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(tr._buffer, collections.deque())
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+
+ # second close should not remove reader
+ tr._force_close(None)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ @unittest.mock.patch('asyncio.log.logger.exception')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._force_close = unittest.mock.Mock()
+ tr._fatal_error(exc)
+
+ m_exc.assert_called_with('Fatal error for %s', tr)
+ tr._force_close.assert_called_with(exc)
+
+ def test_connection_lost(self):
+ exc = OSError()
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._call_connection_lost(exc)
+
+ self.protocol.connection_lost.assert_called_with(exc)
+ self.sock.close.assert_called_with()
+ self.assertIsNone(tr._sock)
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+
+class SelectorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock_fd = self.sock.fileno.return_value = 7
+
+ def test_ctor(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.loop.assert_reader(7, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+
+ _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+
+ def test_pause_resume_reading(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+ tr.pause_reading()
+ self.assertTrue(tr._paused)
+ self.assertFalse(7 in self.loop.readers)
+ tr.resume_reading()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+
+ def test_read_ready(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recv.return_value = b'data'
+ transport._read_ready()
+
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_read_ready_eof(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close = unittest.mock.Mock()
+
+ self.sock.recv.return_value = b''
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ transport.close.assert_called_with()
+
+ def test_read_ready_eof_keep_open(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close = unittest.mock.Mock()
+
+ self.sock.recv.return_value = b''
+ self.protocol.eof_received.return_value = True
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ self.assertFalse(transport.close.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain(self, m_exc):
+ self.sock.recv.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain_interrupted(self, m_exc):
+ self.sock.recv.side_effect = InterruptedError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_conn_reset(self, m_exc):
+ err = self.sock.recv.side_effect = ConnectionResetError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._force_close = unittest.mock.Mock()
+ transport._read_ready()
+ transport._force_close.assert_called_with(err)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_err(self, m_exc):
+ err = self.sock.recv.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_write(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+
+ def test_write_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_buffer(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(b'data1')
+ transport.write(b'data2')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(collections.deque([b'data1', b'data2']),
+ transport._buffer)
+
+ def test_write_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'ta']), transport._buffer)
+
+ def test_write_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+ self.sock.fileno.return_value = 7
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ @unittest.mock.patch('asyncio.selector_events.logger')
+ def test_write_exception(self, m_log):
+ err = self.sock.send.side_effect = OSError()
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.write(data)
+ transport._fatal_error.assert_called_with(err)
+ transport._conn_lost = 1
+
+ self.sock.reset_mock()
+ transport.write(data)
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(transport._conn_lost, 2)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_write_str(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_write_ready(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.assertEqual(self.sock.send.call_args[0], (data,))
+ self.assertFalse(self.loop.writers)
+
+ def test_write_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.sock.send.assert_called_with(data)
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport._write_ready)
+
+ def test_write_ready_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'ta']), transport._buffer)
+
+ def test_write_ready_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_ready_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+
+ def test_write_ready_exception(self):
+ err = self.sock.send.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ @unittest.mock.patch('asyncio.selector_events.logger')
+ def test_write_ready_exception_and_close(self, m_log):
+ self.sock.send.side_effect = OSError()
+ remove_writer = self.loop.remove_writer = unittest.mock.Mock()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ remove_writer.assert_called_with(self.sock_fd)
+
+ def test_write_eof(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.sock.send.side_effect = BlockingIOError
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertEqual(tr._buffer, collections.deque([b'data']))
+ self.assertTrue(tr._eof)
+ self.assertFalse(self.sock.shutdown.called)
+ self.sock.send.side_effect = lambda _: 4
+ tr._write_ready()
+ self.sock.send.assert_called_with(b'data')
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorSslTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+ self.sslsock = unittest.mock.Mock()
+ self.sslsock.fileno.return_value = 1
+ self.sslcontext = unittest.mock.Mock()
+ self.sslcontext.wrap_socket.return_value = self.sslsock
+
+ def _make_one(self, create_waiter=None):
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ self.sock.reset_mock()
+ self.sslsock.reset_mock()
+ self.sslcontext.reset_mock()
+ self.loop.reset_counters()
+ return transport
+
+ def test_on_handshake(self):
+ waiter = futures.Future(loop=self.loop)
+ tr = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ waiter=waiter)
+ self.assertTrue(self.sslsock.do_handshake.called)
+ self.loop.assert_reader(1, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(waiter.result())
+
+ def test_on_handshake_reader_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._on_handshake()
+ self.loop.assert_reader(1, transport._on_handshake)
+
+ def test_on_handshake_writer_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._on_handshake()
+ self.loop.assert_writer(1, transport._on_handshake)
+
+ def test_on_handshake_exc(self):
+ exc = ValueError()
+ self.sslsock.do_handshake.side_effect = exc
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._waiter = futures.Future(loop=self.loop)
+ transport._on_handshake()
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(transport._waiter.done())
+ self.assertIs(exc, transport._waiter.exception())
+
+ def test_on_handshake_base_exc(self):
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._waiter = futures.Future(loop=self.loop)
+ exc = BaseException()
+ self.sslsock.do_handshake.side_effect = exc
+ self.assertRaises(BaseException, transport._on_handshake)
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(transport._waiter.done())
+ self.assertIs(exc, transport._waiter.exception())
+
+ def test_pause_resume_reading(self):
+ tr = self._make_one()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._read_ready)
+ tr.pause_reading()
+ self.assertTrue(tr._paused)
+ self.assertFalse(1 in self.loop.readers)
+ tr.resume_reading()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._read_ready)
+
+ def test_write_no_data(self):
+ transport = self._make_one()
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_str(self):
+ transport = self._make_one()
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = self._make_one()
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ @unittest.mock.patch('asyncio.selector_events.logger')
+ def test_write_exception(self, m_log):
+ transport = self._make_one()
+ transport._conn_lost = 1
+ transport.write(b'data')
+ self.assertEqual(transport._buffer, collections.deque())
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_read_ready_recv(self):
+ self.sslsock.recv.return_value = b'data'
+ transport = self._make_one()
+ transport._read_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
+
+ def test_read_ready_write_wants_read(self):
+ self.loop.add_writer = unittest.mock.Mock()
+ self.sslsock.recv.side_effect = BlockingIOError
+ transport = self._make_one()
+ transport._write_wants_read = True
+ transport._write_ready = unittest.mock.Mock()
+ transport._buffer.append(b'data')
+ transport._read_ready()
+
+ self.assertFalse(transport._write_wants_read)
+ transport._write_ready.assert_called_with()
+ self.loop.add_writer.assert_called_with(
+ transport._sock_fd, transport._write_ready)
+
+ def test_read_ready_recv_eof(self):
+ self.sslsock.recv.return_value = b''
+ transport = self._make_one()
+ transport.close = unittest.mock.Mock()
+ transport._read_ready()
+ transport.close.assert_called_with()
+ self.protocol.eof_received.assert_called_with()
+
+ def test_read_ready_recv_conn_reset(self):
+ err = self.sslsock.recv.side_effect = ConnectionResetError()
+ transport = self._make_one()
+ transport._force_close = unittest.mock.Mock()
+ transport._read_ready()
+ transport._force_close.assert_called_with(err)
+
+ def test_read_ready_recv_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ transport = self._make_one()
+ transport._read_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = BlockingIOError
+ transport._read_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = InterruptedError
+ transport._read_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ def test_read_ready_recv_write(self):
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop.add_writer = unittest.mock.Mock()
+ self.sslsock.recv.side_effect = ssl.SSLWantWriteError
+ transport = self._make_one()
+ transport._read_ready()
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertTrue(transport._read_wants_write)
+
+ self.loop.remove_reader.assert_called_with(transport._sock_fd)
+ self.loop.add_writer.assert_called_with(
+ transport._sock_fd, transport._write_ready)
+
+ def test_read_ready_recv_exc(self):
+ err = self.sslsock.recv.side_effect = OSError()
+ transport = self._make_one()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ def test_write_ready_send(self):
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+ transport._write_ready()
+ self.assertEqual(collections.deque(), transport._buffer)
+ self.assertTrue(self.sslsock.send.called)
+
+ def test_write_ready_send_none(self):
+ self.sslsock.send.return_value = 0
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._write_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+
+ def test_write_ready_send_partial(self):
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._write_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'ta1data2']), transport._buffer)
+
+ def test_write_ready_send_closing_partial(self):
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._write_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertFalse(self.sslsock.close.called)
+
+ def test_write_ready_send_closing(self):
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = collections.deque([b'data'])
+ transport._write_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_send_closing_empty_buffer(self):
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = collections.deque()
+ transport._write_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_send_retry(self):
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+
+ self.sslsock.send.side_effect = ssl.SSLWantWriteError
+ transport._write_ready()
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ self.sslsock.send.side_effect = BlockingIOError()
+ transport._write_ready()
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_ready_send_read(self):
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.sslsock.send.side_effect = ssl.SSLWantReadError
+ transport._write_ready()
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertTrue(transport._write_wants_read)
+ self.loop.remove_writer.assert_called_with(transport._sock_fd)
+
+ def test_write_ready_send_exc(self):
+ err = self.sslsock.send.side_effect = OSError()
+
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+ transport._fatal_error = unittest.mock.Mock()
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(err)
+ self.assertEqual(collections.deque(), transport._buffer)
+
+ def test_write_ready_read_wants_write(self):
+ self.loop.add_reader = unittest.mock.Mock()
+ self.sslsock.send.side_effect = BlockingIOError
+ transport = self._make_one()
+ transport._read_wants_write = True
+ transport._read_ready = unittest.mock.Mock()
+ transport._write_ready()
+
+ self.assertFalse(transport._read_wants_write)
+ transport._read_ready.assert_called_with()
+ self.loop.add_reader.assert_called_with(
+ transport._sock_fd, transport._read_ready)
+
+ def test_write_eof(self):
+ tr = self._make_one()
+ self.assertFalse(tr.can_write_eof())
+ self.assertRaises(NotImplementedError, tr.write_eof)
+
+ def test_close(self):
+ tr = self._make_one()
+ tr.close()
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+
+ @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support')
+ def test_server_hostname(self):
+ _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ server_hostname='localhost')
+ self.sslcontext.wrap_socket.assert_called_with(
+ self.sock, do_handshake_on_connect=False, server_side=False,
+ server_hostname='localhost')
+
+
+class SelectorSslWithoutSslTransportTests(unittest.TestCase):
+
+ @unittest.mock.patch('asyncio.selector_events.ssl', None)
+ def test_ssl_transport_requires_ssl_module(self):
+ Mock = unittest.mock.Mock
+ with self.assertRaises(RuntimeError):
+ transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
+
+
+class SelectorDatagramTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(DatagramProtocol)
+ self.sock = unittest.mock.Mock(spec_set=socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def test_read_ready(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234))
+ transport._read_ready()
+
+ self.protocol.datagram_received.assert_called_with(
+ b'data', ('0.0.0.0', 1234))
+
+ def test_read_ready_tryagain(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.side_effect = BlockingIOError
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_read_ready_err(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ err = self.sock.recvfrom.side_effect = RuntimeError()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_read_ready_oserr(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ err = self.sock.recvfrom.side_effect = OSError()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+ self.protocol.error_received.assert_called_with(err)
+
+ def test_sendto(self):
+ data = b'data'
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+ transport.sendto(b'', ())
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_buffer(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(b'data2', ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+
+ def test_sendto_tryagain(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 12345))
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ @unittest.mock.patch('asyncio.selector_events.logger')
+ def test_sendto_exception(self, m_log):
+ data = b'data'
+ err = self.sock.sendto.side_effect = RuntimeError()
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertTrue(transport._fatal_error.called)
+ transport._fatal_error.assert_called_with(err)
+ transport._conn_lost = 1
+
+ transport._address = ('123',)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_sendto_error_received(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertEqual(transport._conn_lost, 0)
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_error_received_connected(self):
+ data = b'data'
+
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data)
+
+ self.assertFalse(transport._fatal_error.called)
+ self.assertTrue(self.protocol.error_received.called)
+
+ def test_sendto_str(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.sendto, 'str', ())
+
+ def test_sendto_connected_addr(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ self.assertRaises(
+ AssertionError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+ def test_sendto_closing(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, address=(1,))
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.sendto(b'data', (1,))
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_sendto_ready(self):
+ data = b'data'
+ self.sock.sendto.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((data, ('0.0.0.0', 12345)))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345)))
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append((data, ()))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.sock.sendto.assert_called_with(data, ())
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_sendto_ready_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertFalse(self.sock.sendto.called)
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_tryagain(self):
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.extend([(b'data1', ()), (b'data2', ())])
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data1', ()), (b'data2', ())],
+ list(transport._buffer))
+
+ def test_sendto_ready_exception(self):
+ err = self.sock.sendto.side_effect = RuntimeError()
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto_ready_error_received(self):
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_ready_error_received_connection(self):
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+ self.assertTrue(self.protocol.error_received.called)
+
+ @unittest.mock.patch('asyncio.log.logger.exception')
+ def test_fatal_error_connected(self, m_exc):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ err = ConnectionRefusedError()
+ transport._fatal_error(err)
+ self.assertFalse(self.protocol.error_received.called)
+ m_exc.assert_called_with('Fatal error for %s', transport)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
new file mode 100644
index 0000000000..69e2246f44
--- /dev/null
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -0,0 +1,364 @@
+"""Tests for streams.py."""
+
+import gc
+import unittest
+import unittest.mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from asyncio import events
+from asyncio import streams
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class StreamReaderTests(unittest.TestCase):
+
+ DATA = b'line1\nline2\nline3\n'
+
+ def setUp(self):
+ self.loop = events.new_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+
+ @unittest.mock.patch('asyncio.streams.events')
+ def test_ctor_global_loop(self, m_events):
+ stream = streams.StreamReader()
+ self.assertIs(stream._loop, m_events.get_event_loop.return_value)
+
+ def test_open_connection(self):
+ with test_utils.run_test_server() as httpd:
+ f = streams.open_connection(*httpd.address, loop=self.loop)
+ reader, writer = self.loop.run_until_complete(f)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_open_connection_no_loop_ssl(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ try:
+ events.set_event_loop(self.loop)
+ f = streams.open_connection(*httpd.address,
+ ssl=test_utils.dummy_ssl_context())
+ reader, writer = self.loop.run_until_complete(f)
+ finally:
+ events.set_event_loop(None)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ def test_open_connection_error(self):
+ with test_utils.run_test_server() as httpd:
+ f = streams.open_connection(*httpd.address, loop=self.loop)
+ reader, writer = self.loop.run_until_complete(f)
+ writer._protocol.connection_lost(ZeroDivisionError())
+ f = reader.read()
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(f)
+
+ writer.close()
+ test_utils.run_briefly(self.loop)
+
+ def test_feed_empty_data(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ stream.feed_data(b'')
+ self.assertEqual(0, stream._byte_count)
+
+ def test_feed_data_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ stream.feed_data(self.DATA)
+ self.assertEqual(len(self.DATA), stream._byte_count)
+
+ def test_read_zero(self):
+ # Read zero bytes.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.read(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream._byte_count)
+
+ def test_read(self):
+ # Read bytes.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(30), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream._byte_count)
+
+ def test_read_line_breaks(self):
+ # Read bytes without line breaks.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line1')
+ stream.feed_data(b'line2')
+
+ data = self.loop.run_until_complete(stream.read(5))
+
+ self.assertEqual(b'line1', data)
+ self.assertEqual(5, stream._byte_count)
+
+ def test_read_eof(self):
+ # Read bytes, stop at eof.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(1024), loop=self.loop)
+
+ def cb():
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'', data)
+ self.assertFalse(stream._byte_count)
+
+ def test_read_until_eof(self):
+ # Read all bytes until eof.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(-1), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1\n')
+ stream.feed_data(b'chunk2')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+
+ self.assertEqual(b'chunk1\nchunk2', data)
+ self.assertFalse(stream._byte_count)
+
+ def test_read_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.read(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.read(2))
+
+ def test_readline(self):
+ # Read one line.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'chunk1 ')
+ read_task = tasks.Task(stream.readline(), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk2 ')
+ stream.feed_data(b'chunk3 ')
+ stream.feed_data(b'\n chunk4')
+ self.loop.call_soon(cb)
+
+ line = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
+ self.assertEqual(len(b'\n chunk4')-1, stream._byte_count)
+
+ def test_readline_limit_with_existing_data(self):
+ stream = streams.StreamReader(3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1\nline2\n')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'line2\n'], list(stream._buffer))
+
+ stream = streams.StreamReader(3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1')
+ stream.feed_data(b'li')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'li'], list(stream._buffer))
+ self.assertEqual(2, stream._byte_count)
+
+ def test_readline_limit(self):
+ stream = streams.StreamReader(7, loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1')
+ stream.feed_data(b'chunk2')
+ stream.feed_data(b'chunk3\n')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'chunk3\n'], list(stream._buffer))
+ self.assertEqual(7, stream._byte_count)
+
+ def test_readline_line_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA[:6])
+ stream.feed_data(self.DATA[6:])
+
+ line = self.loop.run_until_complete(stream.readline())
+
+ self.assertEqual(b'line1\n', line)
+ self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count)
+
+ def test_readline_eof(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'some data')
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'some data', line)
+
+ def test_readline_empty_eof(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', line)
+
+ def test_readline_read_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ self.loop.run_until_complete(stream.readline())
+
+ data = self.loop.run_until_complete(stream.read(7))
+
+ self.assertEqual(b'line2\nl', data)
+ self.assertEqual(
+ len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
+ stream._byte_count)
+
+ def test_readline_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'line\n', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+
+ def test_readexactly_zero_or_less(self):
+ # Read exact number of bytes (zero or less).
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.readexactly(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream._byte_count)
+
+ data = self.loop.run_until_complete(stream.readexactly(-1))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream._byte_count)
+
+ def test_readexactly(self):
+ # Read exact number of bytes.
+ stream = streams.StreamReader(loop=self.loop)
+
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA + self.DATA, data)
+ self.assertEqual(len(self.DATA), stream._byte_count)
+
+ def test_readexactly_eof(self):
+ # Read exact number of bytes (eof).
+ stream = streams.StreamReader(loop=self.loop)
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream._byte_count)
+
+ def test_readexactly_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readexactly(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readexactly(2))
+
+ def test_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ self.assertIsNone(stream.exception())
+
+ exc = ValueError()
+ stream.set_exception(exc)
+ self.assertIs(stream.exception(), exc)
+
+ def test_exception_waiter(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ @tasks.coroutine
+ def set_err():
+ stream.set_exception(ValueError())
+
+ @tasks.coroutine
+ def readline():
+ yield from stream.readline()
+
+ t1 = tasks.Task(stream.readline(), loop=self.loop)
+ t2 = tasks.Task(set_err(), loop=self.loop)
+
+ self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop))
+
+ self.assertRaises(ValueError, t1.result)
+
+ def test_exception_cancel(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ @tasks.coroutine
+ def read_a_line():
+ yield from stream.readline()
+
+ t = tasks.Task(read_a_line(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ # The following line fails if set_exception() isn't careful.
+ stream.set_exception(RuntimeError('message'))
+ test_utils.run_briefly(self.loop)
+ self.assertIs(stream._waiter, None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
new file mode 100644
index 0000000000..8f0d081554
--- /dev/null
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -0,0 +1,1518 @@
+"""Tests for tasks.py."""
+
+import gc
+import unittest
+import unittest.mock
+from unittest.mock import Mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class Dummy:
+
+ def __repr__(self):
+ return 'Dummy()'
+
+ def __call__(self, *args):
+ pass
+
+
+class TaskTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ gc.collect()
+
+ def test_task_class(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t = tasks.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = events.new_event_loop()
+ t = tasks.Task(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.close()
+
+ def test_async_coroutine(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t = tasks.async(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = events.new_event_loop()
+ t = tasks.async(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.close()
+
+ def test_async_future(self):
+ f_orig = futures.Future(loop=self.loop)
+ f_orig.set_result('ko')
+
+ f = tasks.async(f_orig)
+ self.loop.run_until_complete(f)
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 'ko')
+ self.assertIs(f, f_orig)
+
+ loop = events.new_event_loop()
+
+ with self.assertRaises(ValueError):
+ f = tasks.async(f_orig, loop=loop)
+
+ loop.close()
+
+ f = tasks.async(f_orig, loop=self.loop)
+ self.assertIs(f, f_orig)
+
+ def test_async_task(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t_orig = tasks.Task(notmuch(), loop=self.loop)
+ t = tasks.async(t_orig)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t, t_orig)
+
+ loop = events.new_event_loop()
+
+ with self.assertRaises(ValueError):
+ t = tasks.async(t_orig, loop=loop)
+
+ loop.close()
+
+ t = tasks.async(t_orig, loop=self.loop)
+ self.assertIs(t, t_orig)
+
+ def test_async_neither(self):
+ with self.assertRaises(TypeError):
+ tasks.async('ok')
+
+ def test_task_repr(self):
+ @tasks.coroutine
+ def notmuch():
+ yield from []
+ return 'abc'
+
+ t = tasks.Task(notmuch(), loop=self.loop)
+ t.add_done_callback(Dummy())
+ self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>')
+ t.cancel() # Does not take immediate effect!
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>')
+ self.assertRaises(futures.CancelledError,
+ self.loop.run_until_complete, t)
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>')
+ t = tasks.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>")
+
+ def test_task_repr_custom(self):
+ @tasks.coroutine
+ def coro():
+ pass
+
+ class T(futures.Future):
+ def __repr__(self):
+ return 'T[]'
+
+ class MyTask(tasks.Task, T):
+ def __repr__(self):
+ return super().__repr__()
+
+ gen = coro()
+ t = MyTask(gen, loop=self.loop)
+ self.assertEqual(repr(t), 'T[](<coro>)')
+ gen.close()
+
+ def test_task_basics(self):
+ @tasks.coroutine
+ def outer():
+ a = yield from inner1()
+ b = yield from inner2()
+ return a+b
+
+ @tasks.coroutine
+ def inner1():
+ return 42
+
+ @tasks.coroutine
+ def inner2():
+ return 1000
+
+ t = outer()
+ self.assertEqual(self.loop.run_until_complete(t), 1042)
+
+ def test_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def task():
+ yield from tasks.sleep(10.0, loop=loop)
+ return 12
+
+ t = tasks.Task(task(), loop=loop)
+ loop.call_soon(t.cancel)
+ with self.assertRaises(futures.CancelledError):
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_yield(self):
+ @tasks.coroutine
+ def task():
+ yield
+ yield
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start coro
+ t.cancel()
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_inner_future(self):
+ f = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start task
+ f.cancel()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(t)
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_both_task_and_inner_future(self):
+ f = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+
+ f.cancel()
+ t.cancel()
+
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(t)
+
+ self.assertTrue(t.done())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_task_catching(self):
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except futures.CancelledError:
+ return 42
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_task_ignoring(self):
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+ fut3 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except futures.CancelledError:
+ pass
+ res = yield from fut3
+ return res
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut3) # White-box test.
+ fut3.set_result(42)
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(fut3.cancelled())
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_current_task(self):
+ loop = events.new_event_loop()
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def task():
+ t.cancel()
+ self.assertTrue(t._must_cancel) # White-box test.
+ # The sleep should be cancelled immediately.
+ yield from tasks.sleep(100, loop=loop)
+ return 12
+
+ t = tasks.Task(task(), loop=loop)
+ self.assertRaises(
+ futures.CancelledError, loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t._must_cancel) # White-box test.
+ self.assertFalse(t.cancel())
+
+ def test_stop_while_run_in_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.3, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ x = 0
+ waiters = []
+
+ @tasks.coroutine
+ def task():
+ nonlocal x
+ while x < 10:
+ waiters.append(tasks.sleep(0.1, loop=loop))
+ yield from waiters[-1]
+ x += 1
+ if x == 2:
+ loop.stop()
+
+ t = tasks.Task(task(), loop=loop)
+ self.assertRaises(
+ RuntimeError, loop.run_until_complete, t)
+ self.assertFalse(t.done())
+ self.assertEqual(x, 2)
+ self.assertAlmostEqual(0.3, loop.time())
+
+ # close generators
+ for w in waiters:
+ w.close()
+
+ def test_wait_for(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.4, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def foo():
+ yield from tasks.sleep(0.2, loop=loop)
+ return 'done'
+
+ fut = tasks.Task(foo(), loop=loop)
+
+ with self.assertRaises(futures.TimeoutError):
+ loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop))
+
+ self.assertFalse(fut.done())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # wait for result
+ res = loop.run_until_complete(
+ tasks.wait_for(fut, 0.3, loop=loop))
+ self.assertEqual(res, 'done')
+ self.assertAlmostEqual(0.2, loop.time())
+
+ def test_wait_for_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def foo():
+ yield from tasks.sleep(0.2, loop=loop)
+ return 'done'
+
+ events.set_event_loop(loop)
+ try:
+ fut = tasks.Task(foo(), loop=loop)
+ with self.assertRaises(futures.TimeoutError):
+ loop.run_until_complete(tasks.wait_for(fut, 0.01))
+ finally:
+ events.set_event_loop(None)
+
+ self.assertAlmostEqual(0.01, loop.time())
+ self.assertFalse(fut.done())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(fut)
+
+ def test_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], loop=loop)
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertEqual(res, 42)
+ self.assertAlmostEqual(0.15, loop.time())
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertEqual(res, 42)
+
+ def test_wait_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0
+ self.assertAlmostEqual(0.015, when)
+ yield 0.015
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a])
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ events.set_event_loop(loop)
+ try:
+ res = loop.run_until_complete(
+ tasks.Task(foo(), loop=loop))
+ finally:
+ events.set_event_loop(None)
+
+ self.assertEqual(res, 42)
+
+ def test_wait_errors(self):
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ tasks.wait(set(), loop=self.loop))
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ tasks.wait([tasks.sleep(10.0, loop=self.loop)],
+ return_when=-1, loop=self.loop))
+
+ def test_wait_first_completed(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertFalse(a.done())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_really_done(self):
+ # there is possibility that some tasks in the pending list
+ # became done but their callbacks haven't all been called yet
+
+ @tasks.coroutine
+ def coro1():
+ yield
+
+ @tasks.coroutine
+ def coro2():
+ yield
+ yield
+
+ a = tasks.Task(coro1(), loop=self.loop)
+ b = tasks.Task(coro2(), loop=self.loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED,
+ loop=self.loop),
+ loop=self.loop)
+
+ done, pending = self.loop.run_until_complete(task)
+ self.assertEqual({a, b}, done)
+ self.assertTrue(a.done())
+ self.assertIsNone(a.result())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+
+ def test_wait_first_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ # first_exception, task already has exception
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def exc():
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc(), loop=loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_first_exception_in_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ # first_exception, exception during waiting
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def exc():
+ yield from tasks.sleep(0.01, loop=loop)
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc(), loop=loop)
+ task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION,
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_with_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(0.15, loop=loop)
+ raise ZeroDivisionError('really')
+
+ b = tasks.Task(sleeper(), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], loop=loop)
+ self.assertEqual(len(done), 2)
+ self.assertEqual(pending, set())
+ errors = set(f for f in done if f.exception() is not None)
+ self.assertEqual(len(errors), 1)
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_wait_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.11, when)
+ yield 0.11
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], timeout=0.11,
+ loop=loop)
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.11, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_concurrent_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ done, pending = loop.run_until_complete(
+ tasks.wait([b, a], timeout=0.1, loop=loop))
+
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_as_completed(self):
+
+ def gen():
+ yield 0
+ yield 0
+ yield 0.01
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+ completed = set()
+ time_shifted = False
+
+ @tasks.coroutine
+ def sleeper(dt, x):
+ nonlocal time_shifted
+ yield from tasks.sleep(dt, loop=loop)
+ completed.add(x)
+ if not time_shifted and 'a' in completed and 'b' in completed:
+ time_shifted = True
+ loop.advance_time(0.14)
+ return x
+
+ a = sleeper(0.01, 'a')
+ b = sleeper(0.01, 'b')
+ c = sleeper(0.15, 'c')
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([b, c, a], loop=loop):
+ values.append((yield from f))
+ return values
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertTrue('a' in res[:2])
+ self.assertTrue('b' in res[:2])
+ self.assertEqual(res[2], 'c')
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_as_completed_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.12, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.12, when)
+ yield 0.02
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.1, 'a', loop=loop)
+ b = tasks.sleep(0.15, 'b', loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([a, b], timeout=0.12, loop=loop):
+ try:
+ v = yield from f
+ values.append((1, v))
+ except futures.TimeoutError as exc:
+ values.append((2, exc))
+ return values
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertEqual(len(res), 2, res)
+ self.assertEqual(res[0], (1, 'a'))
+ self.assertEqual(res[1][0], 2)
+ self.assertIsInstance(res[1][1], futures.TimeoutError)
+ self.assertAlmostEqual(0.12, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_as_completed_reverse_wait(self):
+
+ def gen():
+ yield 0
+ yield 0.05
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.05, 'a', loop=loop)
+ b = tasks.sleep(0.10, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(tasks.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+
+ x = loop.run_until_complete(futs[1])
+ self.assertEqual(x, 'a')
+ self.assertAlmostEqual(0.05, loop.time())
+ loop.advance_time(0.05)
+ y = loop.run_until_complete(futs[0])
+ self.assertEqual(y, 'b')
+ self.assertAlmostEqual(0.10, loop.time())
+
+ def test_as_completed_concurrent(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0
+ self.assertAlmostEqual(0.05, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.05, 'a', loop=loop)
+ b = tasks.sleep(0.05, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(tasks.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+ waiter = tasks.wait(futs, loop=loop)
+ done, pending = loop.run_until_complete(waiter)
+ self.assertEqual(set(f.result() for f in done), {'a', 'b'})
+
+ def test_sleep(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0.05
+ self.assertAlmostEqual(0.1, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def sleeper(dt, arg):
+ yield from tasks.sleep(dt/2, loop=loop)
+ res = yield from tasks.sleep(dt/2, arg, loop=loop)
+ return res
+
+ t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop)
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'yeah')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_sleep_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop),
+ loop=loop)
+
+ handle = None
+ orig_call_later = loop.call_later
+
+ def call_later(self, delay, callback, *args):
+ nonlocal handle
+ handle = orig_call_later(self, delay, callback, *args)
+ return handle
+
+ loop.call_later = call_later
+ test_utils.run_briefly(loop)
+
+ self.assertFalse(handle._cancelled)
+
+ t.cancel()
+ test_utils.run_briefly(loop)
+ self.assertTrue(handle._cancelled)
+
+ def test_task_cancel_sleeping_task(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(5000, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ sleepfut = None
+
+ @tasks.coroutine
+ def sleep(dt):
+ nonlocal sleepfut
+ sleepfut = tasks.sleep(dt, loop=loop)
+ yield from sleepfut
+
+ @tasks.coroutine
+ def doit():
+ sleeper = tasks.Task(sleep(5000), loop=loop)
+ loop.call_later(0.1, sleeper.cancel)
+ try:
+ yield from sleeper
+ except futures.CancelledError:
+ return 'cancelled'
+ else:
+ return 'slept in'
+
+ doer = doit()
+ self.assertEqual(loop.run_until_complete(doer), 'cancelled')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_task_cancel_waiter_future(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def coro():
+ yield from fut
+
+ task = tasks.Task(coro(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(task._fut_waiter, fut)
+
+ task.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, task)
+ self.assertIsNone(task._fut_waiter)
+ self.assertTrue(fut.cancelled())
+
+ def test_step_in_completed_task(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ko'
+
+ gen = notmuch()
+ task = tasks.Task(gen, loop=self.loop)
+ task.set_result('ok')
+
+ self.assertRaises(AssertionError, task._step)
+ gen.close()
+
+ def test_step_result(self):
+ @tasks.coroutine
+ def notmuch():
+ yield None
+ yield 1
+ return 'ko'
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, notmuch())
+
+ def test_step_result_future(self):
+ # If coroutine returns future, task waits on this future.
+
+ class Fut(futures.Future):
+ def __init__(self, *args, **kwds):
+ self.cb_added = False
+ super().__init__(*args, **kwds)
+
+ def add_done_callback(self, fn):
+ self.cb_added = True
+ super().add_done_callback(fn)
+
+ fut = Fut(loop=self.loop)
+ result = None
+
+ @tasks.coroutine
+ def wait_for_future():
+ nonlocal result
+ result = yield from fut
+
+ t = tasks.Task(wait_for_future(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(fut.cb_added)
+
+ res = object()
+ fut.set_result(res)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(res, result)
+ self.assertTrue(t.done())
+ self.assertIsNone(t.result())
+
+ def test_step_with_baseexception(self):
+ @tasks.coroutine
+ def notmutch():
+ raise BaseException()
+
+ task = tasks.Task(notmutch(), loop=self.loop)
+ self.assertRaises(BaseException, task._step)
+
+ self.assertTrue(task.done())
+ self.assertIsInstance(task.exception(), BaseException)
+
+ def test_baseexception_during_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(10, loop=loop)
+
+ base_exc = BaseException()
+
+ @tasks.coroutine
+ def notmutch():
+ try:
+ yield from sleeper()
+ except futures.CancelledError:
+ raise base_exc
+
+ task = tasks.Task(notmutch(), loop=loop)
+ test_utils.run_briefly(loop)
+
+ task.cancel()
+ self.assertFalse(task.done())
+
+ self.assertRaises(BaseException, test_utils.run_briefly, loop)
+
+ self.assertTrue(task.done())
+ self.assertFalse(task.cancelled())
+ self.assertIs(task.exception(), base_exc)
+
+ def test_iscoroutinefunction(self):
+ def fn():
+ pass
+
+ self.assertFalse(tasks.iscoroutinefunction(fn))
+
+ def fn1():
+ yield
+ self.assertFalse(tasks.iscoroutinefunction(fn1))
+
+ @tasks.coroutine
+ def fn2():
+ yield
+ self.assertTrue(tasks.iscoroutinefunction(fn2))
+
+ def test_yield_vs_yield_from(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def wait_for_future():
+ yield fut
+
+ task = wait_for_future()
+ with self.assertRaises(RuntimeError):
+ self.loop.run_until_complete(task)
+
+ self.assertFalse(fut.done())
+
+ def test_yield_vs_yield_from_generator(self):
+ @tasks.coroutine
+ def coro():
+ yield
+
+ @tasks.coroutine
+ def wait_for_future():
+ gen = coro()
+ try:
+ yield gen
+ finally:
+ gen.close()
+
+ task = wait_for_future()
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_coroutine_non_gen_function(self):
+ @tasks.coroutine
+ def func():
+ return 'test'
+
+ self.assertTrue(tasks.iscoroutinefunction(func))
+
+ coro = func()
+ self.assertTrue(tasks.iscoroutine(coro))
+
+ res = self.loop.run_until_complete(coro)
+ self.assertEqual(res, 'test')
+
+ def test_coroutine_non_gen_function_return_future(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def func():
+ return fut
+
+ @tasks.coroutine
+ def coro():
+ fut.set_result('test')
+
+ t1 = tasks.Task(func(), loop=self.loop)
+ t2 = tasks.Task(coro(), loop=self.loop)
+ res = self.loop.run_until_complete(t1)
+ self.assertEqual(res, 'test')
+ self.assertIsNone(t2.result())
+
+ # Some thorough tests for cancellation propagation through
+ # coroutines, tasks and wait().
+
+ def test_yield_future_passes_cancel(self):
+ # Cancelling outer() cancels inner() cancels waiter.
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ try:
+ yield from waiter
+ except futures.CancelledError:
+ proof += 1
+ raise
+ else:
+ self.fail('got past sleep() in inner()')
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ try:
+ yield from inner()
+ except futures.CancelledError:
+ proof += 100 # Expect this path.
+ else:
+ proof += 10
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.loop.run_until_complete(f)
+ self.assertEqual(proof, 101)
+ self.assertTrue(waiter.cancelled())
+
+ def test_yield_wait_does_not_shield_cancel(self):
+ # Cancelling outer() makes wait() return early, leaves inner()
+ # running.
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ d, p = yield from tasks.wait([inner()], loop=self.loop)
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_result(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ inner.set_result(42)
+ res = self.loop.run_until_complete(outer)
+ self.assertEqual(res, 42)
+
+ def test_shield_exception(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ test_utils.run_briefly(self.loop)
+ exc = RuntimeError('expected')
+ inner.set_exception(exc)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(outer.exception(), exc)
+
+ def test_shield_cancel(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ test_utils.run_briefly(self.loop)
+ inner.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+
+ def test_shield_shortcut(self):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ res = self.loop.run_until_complete(tasks.shield(fut))
+ self.assertEqual(res, 42)
+
+ def test_shield_effect(self):
+ # Cancelling outer() does not affect inner().
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ yield from tasks.shield(inner(), loop=self.loop)
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_gather(self):
+ child1 = futures.Future(loop=self.loop)
+ child2 = futures.Future(loop=self.loop)
+ parent = tasks.gather(child1, child2, loop=self.loop)
+ outer = tasks.shield(parent, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(parent.result(), [1, 2])
+
+ def test_gather_shield(self):
+ child1 = futures.Future(loop=self.loop)
+ child2 = futures.Future(loop=self.loop)
+ inner1 = tasks.shield(child1, loop=self.loop)
+ inner2 = tasks.shield(child2, loop=self.loop)
+ parent = tasks.gather(inner1, inner2, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ parent.cancel()
+ # This should cancel inner1 and inner2 but bot child1 and child2.
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(parent.exception(), futures.CancelledError)
+ self.assertTrue(inner1.cancelled())
+ self.assertTrue(inner2.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+
+
+class GatherTestsBase:
+
+ def setUp(self):
+ self.one_loop = test_utils.TestLoop()
+ self.other_loop = test_utils.TestLoop()
+
+ def tearDown(self):
+ self.one_loop.close()
+ self.other_loop.close()
+
+ def _run_loop(self, loop):
+ while loop._ready:
+ test_utils.run_briefly(loop)
+
+ def _check_success(self, **kwargs):
+ a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ b.set_result(1)
+ a.set_result(2)
+ self._run_loop(self.one_loop)
+ self.assertEqual(cb.called, False)
+ self.assertFalse(fut.done())
+ c.set_result(3)
+ self._run_loop(self.one_loop)
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [2, 1, 3])
+
+ def test_success(self):
+ self._check_success()
+ self._check_success(return_exceptions=False)
+
+ def test_result_exception_success(self):
+ self._check_success(return_exceptions=True)
+
+ def test_one_exception(self):
+ a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c, d, e))
+ cb = Mock()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ a.set_result(1)
+ b.set_exception(exc)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertIs(fut.exception(), exc)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+
+ def test_return_exceptions(self):
+ a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c, d),
+ return_exceptions=True)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ exc2 = RuntimeError()
+ b.set_result(1)
+ c.set_exception(exc)
+ a.set_result(3)
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_exception(exc2)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [3, 1, exc, exc2])
+
+
+class FutureGatherTests(GatherTestsBase, unittest.TestCase):
+
+ def wrap_futures(self, *futures):
+ return futures
+
+ def _check_empty_sequence(self, seq_or_iter):
+ events.set_event_loop(self.one_loop)
+ self.addCleanup(events.set_event_loop, None)
+ fut = tasks.gather(*seq_or_iter)
+ self.assertIsInstance(fut, futures.Future)
+ self.assertIs(fut._loop, self.one_loop)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ self.assertEqual(fut.result(), [])
+ fut = tasks.gather(*seq_or_iter, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+
+ def test_constructor_empty_sequence(self):
+ self._check_empty_sequence([])
+ self._check_empty_sequence(())
+ self._check_empty_sequence(set())
+ self._check_empty_sequence(iter(""))
+
+ def test_constructor_heterogenous_futures(self):
+ fut1 = futures.Future(loop=self.one_loop)
+ fut2 = futures.Future(loop=self.other_loop)
+ with self.assertRaises(ValueError):
+ tasks.gather(fut1, fut2)
+ with self.assertRaises(ValueError):
+ tasks.gather(fut1, loop=self.other_loop)
+
+ def test_constructor_homogenous_futures(self):
+ children = [futures.Future(loop=self.other_loop) for i in range(3)]
+ fut = tasks.gather(*children)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+ fut = tasks.gather(*children, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+
+ def test_one_cancellation(self):
+ a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)]
+ fut = tasks.gather(a, b, c, d, e)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ b.cancel()
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertFalse(fut.cancelled())
+ self.assertIsInstance(fut.exception(), futures.CancelledError)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+
+ def test_result_exception_one_cancellation(self):
+ a, b, c, d, e, f = [futures.Future(loop=self.one_loop)
+ for i in range(6)]
+ fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ zde = ZeroDivisionError()
+ b.set_exception(zde)
+ c.cancel()
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_result(3)
+ e.cancel()
+ rte = RuntimeError()
+ f.set_exception(rte)
+ res = self.one_loop.run_until_complete(fut)
+ self.assertIsInstance(res[2], futures.CancelledError)
+ self.assertIsInstance(res[4], futures.CancelledError)
+ res[2] = res[4] = None
+ self.assertEqual(res, [1, zde, None, 3, None, rte])
+ cb.assert_called_once_with(fut)
+
+
+class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ events.set_event_loop(self.one_loop)
+
+ def tearDown(self):
+ events.set_event_loop(None)
+ super().tearDown()
+
+ def wrap_futures(self, *futures):
+ coros = []
+ for fut in futures:
+ @tasks.coroutine
+ def coro(fut=fut):
+ return (yield from fut)
+ coros.append(coro())
+ return coros
+
+ def test_constructor_loop_selection(self):
+ @tasks.coroutine
+ def coro():
+ return 'abc'
+ gen1 = coro()
+ gen2 = coro()
+ fut = tasks.gather(gen1, gen2)
+ self.assertIs(fut._loop, self.one_loop)
+ gen1.close()
+ gen2.close()
+ gen3 = coro()
+ gen4 = coro()
+ fut = tasks.gather(gen3, gen4, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ gen3.close()
+ gen4.close()
+
+ def test_cancellation_broadcast(self):
+ # Cancelling outer() cancels all children.
+ proof = 0
+ waiter = futures.Future(loop=self.one_loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ child1 = tasks.async(inner(), loop=self.one_loop)
+ child2 = tasks.async(inner(), loop=self.one_loop)
+ gatherer = None
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof, gatherer
+ gatherer = tasks.gather(child1, child2, loop=self.one_loop)
+ yield from gatherer
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ self.assertTrue(f.cancel())
+ with self.assertRaises(futures.CancelledError):
+ self.one_loop.run_until_complete(f)
+ self.assertFalse(gatherer.cancel())
+ self.assertTrue(waiter.cancelled())
+ self.assertTrue(child1.cancelled())
+ self.assertTrue(child2.cancelled())
+ test_utils.run_briefly(self.one_loop)
+ self.assertEqual(proof, 0)
+
+ def test_exception_marking(self):
+ # Test for the first line marked "Mark exception retrieved."
+
+ @tasks.coroutine
+ def inner(f):
+ yield from f
+ raise RuntimeError('should not be ignored')
+
+ a = futures.Future(loop=self.one_loop)
+ b = futures.Future(loop=self.one_loop)
+
+ @tasks.coroutine
+ def outer():
+ yield from tasks.gather(inner(a), inner(b), loop=self.one_loop)
+
+ f = tasks.async(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ a.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ b.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ self.assertIsInstance(f.exception(), RuntimeError)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py
new file mode 100644
index 0000000000..f96445c19c
--- /dev/null
+++ b/Lib/test/test_asyncio/test_transports.py
@@ -0,0 +1,59 @@
+"""Tests for transports.py."""
+
+import unittest
+import unittest.mock
+
+from asyncio import transports
+
+
+class TransportTests(unittest.TestCase):
+
+ def test_ctor_extra_is_none(self):
+ transport = transports.Transport()
+ self.assertEqual(transport._extra, {})
+
+ def test_get_extra_info(self):
+ transport = transports.Transport({'extra': 'info'})
+ self.assertEqual('info', transport.get_extra_info('extra'))
+ self.assertIsNone(transport.get_extra_info('unknown'))
+
+ default = object()
+ self.assertIs(default, transport.get_extra_info('unknown', default))
+
+ def test_writelines(self):
+ transport = transports.Transport()
+ transport.write = unittest.mock.Mock()
+
+ transport.writelines(['line1', 'line2', 'line3'])
+ self.assertEqual(3, transport.write.call_count)
+
+ def test_not_implemented(self):
+ transport = transports.Transport()
+
+ self.assertRaises(NotImplementedError, transport.write, 'data')
+ self.assertRaises(NotImplementedError, transport.write_eof)
+ self.assertRaises(NotImplementedError, transport.can_write_eof)
+ self.assertRaises(NotImplementedError, transport.pause_reading)
+ self.assertRaises(NotImplementedError, transport.resume_reading)
+ self.assertRaises(NotImplementedError, transport.close)
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_dgram_not_implemented(self):
+ transport = transports.DatagramTransport()
+
+ self.assertRaises(NotImplementedError, transport.sendto, 'data')
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_subprocess_transport_not_implemented(self):
+ transport = transports.SubprocessTransport()
+
+ self.assertRaises(NotImplementedError, transport.get_pid)
+ self.assertRaises(NotImplementedError, transport.get_returncode)
+ self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
+ self.assertRaises(NotImplementedError, transport.send_signal, 1)
+ self.assertRaises(NotImplementedError, transport.terminate)
+ self.assertRaises(NotImplementedError, transport.kill)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py
new file mode 100644
index 0000000000..fdd904955d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_unix_events.py
@@ -0,0 +1,1488 @@
+"""Tests for unix_events.py."""
+
+import collections
+import gc
+import errno
+import io
+import os
+import pprint
+import signal
+import stat
+import sys
+import threading
+import unittest
+import unittest.mock
+
+if sys.platform == 'win32':
+ raise unittest.SkipTest('UNIX only')
+
+
+from asyncio import events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import test_utils
+from asyncio import unix_events
+
+
+@unittest.skipUnless(signal, 'Signals are not supported')
+class SelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = unix_events.SelectorEventLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_check_signal(self):
+ self.assertRaises(
+ TypeError, self.loop._check_signal, '1')
+ self.assertRaises(
+ ValueError, self.loop._check_signal, signal.NSIG + 1)
+
+ def test_handle_signal_no_handler(self):
+ self.loop._handle_signal(signal.NSIG + 1, ())
+
+ def test_handle_signal_cancelled_handler(self):
+ h = events.Handle(unittest.mock.Mock(), ())
+ h.cancel()
+ self.loop._signal_handlers[signal.NSIG + 1] = h
+ self.loop.remove_signal_handler = unittest.mock.Mock()
+ self.loop._handle_signal(signal.NSIG + 1, ())
+ self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_setup_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ cb = lambda: True
+ self.loop.add_signal_handler(signal.SIGHUP, cb)
+ h = self.loop._signal_handlers.get(signal.SIGHUP)
+ self.assertIsInstance(h, events.Handle)
+ self.assertEqual(h._callback, cb)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_install_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ def set_wakeup_fd(fd):
+ if fd == -1:
+ raise ValueError()
+ m_signal.set_wakeup_fd = set_wakeup_fd
+
+ class Err(OSError):
+ errno = errno.EFAULT
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ Err,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.logger')
+ def test_add_signal_handler_install_error2(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.loop._signal_handlers[signal.SIGHUP] = lambda: True
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.logger')
+ def test_add_signal_handler_install_error3(self, m_logging, m_signal):
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+ m_signal.NSIG = signal.NSIG
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGHUP))
+ self.assertTrue(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.SIGINT = signal.SIGINT
+
+ self.loop.add_signal_handler(signal.SIGINT, lambda: True)
+ self.loop._signal_handlers[signal.SIGHUP] = object()
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertFalse(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGINT, m_signal.default_int_handler),
+ m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.logger')
+ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.loop.remove_signal_handler(signal.SIGHUP)
+ self.assertTrue(m_logging.info)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.signal.side_effect = OSError
+
+ self.assertRaises(
+ OSError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_close(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+ self.loop.add_signal_handler(signal.SIGCHLD, lambda: True)
+
+ self.assertEqual(len(self.loop._signal_handlers), 2)
+
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.loop.close()
+
+ self.assertEqual(len(self.loop._signal_handlers), 0)
+ m_signal.set_wakeup_fd.assert_called_once_with(-1)
+
+
+class UnixReadPipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(protocols.Protocol)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ fcntl_patcher = unittest.mock.patch('fcntl.fcntl')
+ fcntl_patcher.start()
+ self.addCleanup(fcntl_patcher.stop)
+
+ fstat_patcher = unittest.mock.patch('os.fstat')
+ m_fstat = fstat_patcher.start()
+ st = unittest.mock.Mock()
+ st.st_mode = stat.S_IFIFO
+ m_fstat.return_value = st
+ self.addCleanup(fstat_patcher.stop)
+
+ def test_ctor(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+ unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.return_value = b'data'
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready_eof(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.return_value = b''
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.eof_received.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready_blocked(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.side_effect = BlockingIOError
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.data_received.called)
+
+ @unittest.mock.patch('asyncio.log.logger.exception')
+ @unittest.mock.patch('os.read')
+ def test__read_ready_error(self, m_read, m_logexc):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ err = OSError()
+ m_read.side_effect = err
+ tr._close = unittest.mock.Mock()
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ tr._close.assert_called_with(err)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+
+ @unittest.mock.patch('os.read')
+ def test_pause_reading(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m = unittest.mock.Mock()
+ self.loop.add_reader(5, m)
+ tr.pause_reading()
+ self.assertFalse(self.loop.readers)
+
+ @unittest.mock.patch('os.read')
+ def test_resume_reading(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.resume_reading()
+ self.loop.assert_reader(5, tr._read_ready)
+
+ @unittest.mock.patch('os.read')
+ def test_close(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ tr._close.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ def test_close_already_closing(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr._closing = True
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ self.assertFalse(tr._close.called)
+
+ @unittest.mock.patch('os.read')
+ def test__close(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = object()
+ tr._close(err)
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ def test__call_connection_lost(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test__call_connection_lost_with_err(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+
+class UnixWritePipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ fcntl_patcher = unittest.mock.patch('fcntl.fcntl')
+ fcntl_patcher.start()
+ self.addCleanup(fcntl_patcher.stop)
+
+ fstat_patcher = unittest.mock.patch('os.fstat')
+ m_fstat = fstat_patcher.start()
+ st = unittest.mock.Mock()
+ st.st_mode = stat.S_IFIFO
+ m_fstat.return_value = st
+ self.addCleanup(fstat_patcher.stop)
+
+ def test_ctor(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol, fut)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(None, fut.result())
+
+ def test_can_write_eof(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+
+ @unittest.mock.patch('os.write')
+ def test_write(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.return_value = 4
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_no_data(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write(b'')
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_partial(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.return_value = 2
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'ta'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_buffer(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'previous']
+ tr.write(b'data')
+ self.assertFalse(m_write.called)
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'previous', b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_again(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.side_effect = BlockingIOError()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('asyncio.unix_events.logger')
+ @unittest.mock.patch('os.write')
+ def test_write_err(self, m_write, m_log):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ m_write.side_effect = err
+ tr._fatal_error = unittest.mock.Mock()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ tr._fatal_error.assert_called_with(err)
+ self.assertEqual(1, tr._conn_lost)
+
+ tr.write(b'data')
+ self.assertEqual(2, tr._conn_lost)
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ # This is a bit overspecified. :-(
+ m_log.warning.assert_called_with(
+ 'pipe closed by peer or os.write(pipe, data) raised exception.')
+
+ @unittest.mock.patch('os.write')
+ def test_write_close(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ tr._read_ready() # pipe was closed by peer
+
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 1)
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 2)
+
+ def test__read_ready(self):
+ tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
+ self.protocol)
+ tr._read_ready()
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertTrue(tr._closing)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_partial(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 3
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'a'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_again(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = BlockingIOError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_empty(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 0
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('asyncio.log.logger.exception')
+ @unittest.mock.patch('os.write')
+ def test__write_ready_err(self, m_write, m_logexc):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = err = OSError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+ self.assertEqual(1, tr._conn_lost)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_closing(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._closing = True
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.pipe.close.assert_called_with()
+
+ @unittest.mock.patch('os.write')
+ def test_abort(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ self.loop.add_reader(5, tr._read_ready)
+ tr._buffer = [b'da', b'ta']
+ tr.abort()
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test__call_connection_lost(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test__call_connection_lost_with_err(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test_close(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr.close()
+ tr.write_eof.assert_called_with()
+
+ def test_close_closing(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr._closing = True
+ tr.close()
+ self.assertFalse(tr.write_eof.called)
+
+ def test_write_eof(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_eof_pending(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ tr._buffer = [b'data']
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+
+class AbstractChildWatcherTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = unittest.mock.Mock()
+ watcher = unix_events.AbstractChildWatcher()
+ self.assertRaises(
+ NotImplementedError, watcher.add_child_handler, f, f)
+ self.assertRaises(
+ NotImplementedError, watcher.remove_child_handler, f)
+ self.assertRaises(
+ NotImplementedError, watcher.attach_loop, f)
+ self.assertRaises(
+ NotImplementedError, watcher.close)
+ self.assertRaises(
+ NotImplementedError, watcher.__enter__)
+ self.assertRaises(
+ NotImplementedError, watcher.__exit__, f, f, f)
+
+
+class BaseChildWatcherTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = unittest.mock.Mock()
+ watcher = unix_events.BaseChildWatcher()
+ self.assertRaises(
+ NotImplementedError, watcher._do_waitpid, f)
+
+
+WaitPidMocks = collections.namedtuple("WaitPidMocks",
+ ("waitpid",
+ "WIFEXITED",
+ "WIFSIGNALED",
+ "WEXITSTATUS",
+ "WTERMSIG",
+ ))
+
+
+class ChildWatcherTestsMixin:
+
+ ignore_warnings = unittest.mock.patch.object(unix_events.logger, "warning")
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.running = False
+ self.zombies = {}
+
+ with unittest.mock.patch.object(
+ self.loop, "add_signal_handler") as self.m_add_signal_handler:
+ self.watcher = self.create_watcher()
+ self.watcher.attach_loop(self.loop)
+
+ def waitpid(self, pid, flags):
+ if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1:
+ self.assertGreater(pid, 0)
+ try:
+ if pid < 0:
+ return self.zombies.popitem()
+ else:
+ return pid, self.zombies.pop(pid)
+ except KeyError:
+ pass
+ if self.running:
+ return 0, 0
+ else:
+ raise ChildProcessError()
+
+ def add_zombie(self, pid, returncode):
+ self.zombies[pid] = returncode + 32768
+
+ def WIFEXITED(self, status):
+ return status >= 32768
+
+ def WIFSIGNALED(self, status):
+ return 32700 < status < 32768
+
+ def WEXITSTATUS(self, status):
+ self.assertTrue(self.WIFEXITED(status))
+ return status - 32768
+
+ def WTERMSIG(self, status):
+ self.assertTrue(self.WIFSIGNALED(status))
+ return 32768 - status
+
+ def test_create_watcher(self):
+ self.m_add_signal_handler.assert_called_once_with(
+ signal.SIGCHLD, self.watcher._sig_chld)
+
+ def waitpid_mocks(func):
+ def wrapped_func(self):
+ def patch(target, wrapper):
+ return unittest.mock.patch(target, wraps=wrapper,
+ new_callable=unittest.mock.Mock)
+
+ with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \
+ patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \
+ patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \
+ patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \
+ patch('os.waitpid', self.waitpid) as m_waitpid:
+ func(self, WaitPidMocks(m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG,
+ ))
+ return wrapped_func
+
+ @waitpid_mocks
+ def test_sigchld(self, m):
+ # register a child
+ callback = unittest.mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(42, callback, 9, 10, 14)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child is running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates (returncode 12)
+ self.running = False
+ self.add_zombie(42, 12)
+ self.watcher._sig_chld()
+
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+ callback.assert_called_once_with(42, 12, 9, 10, 14)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+ callback.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(42, 13)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+
+ # sigchld called again
+ self.zombies.clear()
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_two_children(self, m):
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+
+ # register child 1
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(43, callback1, 7, 8)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # register child 2
+ with self.watcher:
+ self.watcher.add_child_handler(44, callback2, 147, 18)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # childen are running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child 1 terminates (signal 3)
+ self.add_zombie(43, -3)
+ self.watcher._sig_chld()
+
+ callback1.assert_called_once_with(43, -3, 7, 8)
+ self.assertFalse(callback2.called)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertTrue(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WTERMSIG.reset_mock()
+ callback1.reset_mock()
+
+ # child 2 still running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child 2 terminates (code 108)
+ self.add_zombie(44, 108)
+ self.running = False
+ self.watcher._sig_chld()
+
+ callback2.assert_called_once_with(44, 108, 147, 18)
+ self.assertFalse(callback1.called)
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+ callback2.reset_mock()
+
+ # ensure that the children are effectively reaped
+ self.add_zombie(43, 14)
+ self.add_zombie(44, 15)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+
+ # sigchld called again
+ self.zombies.clear()
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_two_children_terminating_together(self, m):
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+
+ # register child 1
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(45, callback1, 17, 8)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # register child 2
+ with self.watcher:
+ self.watcher.add_child_handler(46, callback2, 1147, 18)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # childen are running
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child 1 terminates (code 78)
+ # child 2 terminates (signal 5)
+ self.add_zombie(45, 78)
+ self.add_zombie(46, -5)
+ self.running = False
+ self.watcher._sig_chld()
+
+ callback1.assert_called_once_with(45, 78, 17, 8)
+ callback2.assert_called_once_with(46, -5, 1147, 18)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WEXITSTATUS.called)
+ self.assertTrue(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WTERMSIG.reset_mock()
+ m.WEXITSTATUS.reset_mock()
+ callback1.reset_mock()
+ callback2.reset_mock()
+
+ # ensure that the children are effectively reaped
+ self.add_zombie(45, 14)
+ self.add_zombie(46, 15)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_race_condition(self, m):
+ # register a child
+ callback = unittest.mock.Mock()
+
+ with self.watcher:
+ # child terminates before being registered
+ self.add_zombie(50, 4)
+ self.watcher._sig_chld()
+
+ self.watcher.add_child_handler(50, callback, 1, 12)
+
+ callback.assert_called_once_with(50, 4, 1, 12)
+ callback.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(50, -1)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+
+ @waitpid_mocks
+ def test_sigchld_replace_handler(self, m):
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(51, callback1, 19)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # register the same child again
+ with self.watcher:
+ self.watcher.add_child_handler(51, callback2, 21)
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates (signal 8)
+ self.running = False
+ self.add_zombie(51, -8)
+ self.watcher._sig_chld()
+
+ callback2.assert_called_once_with(51, -8, 21)
+ self.assertFalse(callback1.called)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertTrue(m.WTERMSIG.called)
+
+ m.WIFSIGNALED.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WTERMSIG.reset_mock()
+ callback2.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(51, 13)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ @waitpid_mocks
+ def test_sigchld_remove_handler(self, m):
+ callback = unittest.mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(52, callback, 1984)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # unregister the child
+ self.watcher.remove_child_handler(52)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates (code 99)
+ self.running = False
+ self.add_zombie(52, 99)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+
+ @waitpid_mocks
+ def test_sigchld_unknown_status(self, m):
+ callback = unittest.mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(53, callback, -19)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # terminate with unknown status
+ self.zombies[53] = 1178
+ self.running = False
+ self.watcher._sig_chld()
+
+ callback.assert_called_once_with(53, 1178, -19)
+ self.assertTrue(m.WIFEXITED.called)
+ self.assertTrue(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ callback.reset_mock()
+ m.WIFEXITED.reset_mock()
+ m.WIFSIGNALED.reset_mock()
+
+ # ensure that the child is effectively reaped
+ self.add_zombie(53, 101)
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback.called)
+
+ @waitpid_mocks
+ def test_remove_child_handler(self, m):
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+ callback3 = unittest.mock.Mock()
+
+ # register children
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(54, callback1, 1)
+ self.watcher.add_child_handler(55, callback2, 2)
+ self.watcher.add_child_handler(56, callback3, 3)
+
+ # remove child handler 1
+ self.assertTrue(self.watcher.remove_child_handler(54))
+
+ # remove child handler 2 multiple times
+ self.assertTrue(self.watcher.remove_child_handler(55))
+ self.assertFalse(self.watcher.remove_child_handler(55))
+ self.assertFalse(self.watcher.remove_child_handler(55))
+
+ # all children terminate
+ self.add_zombie(54, 0)
+ self.add_zombie(55, 1)
+ self.add_zombie(56, 2)
+ self.running = False
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ callback3.assert_called_once_with(56, 2, 3)
+
+ @waitpid_mocks
+ def test_sigchld_unhandled_exception(self, m):
+ callback = unittest.mock.Mock()
+
+ # register a child
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(57, callback)
+
+ # raise an exception
+ m.waitpid.side_effect = ValueError
+
+ with unittest.mock.patch.object(unix_events.logger,
+ "exception") as m_exception:
+
+ self.assertEqual(self.watcher._sig_chld(), None)
+ self.assertTrue(m_exception.called)
+
+ @waitpid_mocks
+ def test_sigchld_child_reaped_elsewhere(self, m):
+ # register a child
+ callback = unittest.mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(58, callback)
+
+ self.assertFalse(callback.called)
+ self.assertFalse(m.WIFEXITED.called)
+ self.assertFalse(m.WIFSIGNALED.called)
+ self.assertFalse(m.WEXITSTATUS.called)
+ self.assertFalse(m.WTERMSIG.called)
+
+ # child terminates
+ self.running = False
+ self.add_zombie(58, 4)
+
+ # waitpid is called elsewhere
+ os.waitpid(58, os.WNOHANG)
+
+ m.waitpid.reset_mock()
+
+ # sigchld
+ with self.ignore_warnings:
+ self.watcher._sig_chld()
+
+ callback.assert_called(m.waitpid)
+ if isinstance(self.watcher, unix_events.FastChildWatcher):
+ # here the FastChildWatche enters a deadlock
+ # (there is no way to prevent it)
+ self.assertFalse(callback.called)
+ else:
+ callback.assert_called_once_with(58, 255)
+
+ @waitpid_mocks
+ def test_sigchld_unknown_pid_during_registration(self, m):
+ # register two children
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+
+ with self.ignore_warnings, self.watcher:
+ self.running = True
+ # child 1 terminates
+ self.add_zombie(591, 7)
+ # an unknown child terminates
+ self.add_zombie(593, 17)
+
+ self.watcher._sig_chld()
+
+ self.watcher.add_child_handler(591, callback1)
+ self.watcher.add_child_handler(592, callback2)
+
+ callback1.assert_called_once_with(591, 7)
+ self.assertFalse(callback2.called)
+
+ @waitpid_mocks
+ def test_set_loop(self, m):
+ # register a child
+ callback = unittest.mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(60, callback)
+
+ # attach a new loop
+ old_loop = self.loop
+ self.loop = test_utils.TestLoop()
+
+ with unittest.mock.patch.object(
+ old_loop,
+ "remove_signal_handler") as m_old_remove_signal_handler, \
+ unittest.mock.patch.object(
+ self.loop,
+ "add_signal_handler") as m_new_add_signal_handler:
+
+ self.watcher.attach_loop(self.loop)
+
+ m_old_remove_signal_handler.assert_called_once_with(
+ signal.SIGCHLD)
+ m_new_add_signal_handler.assert_called_once_with(
+ signal.SIGCHLD, self.watcher._sig_chld)
+
+ # child terminates
+ self.running = False
+ self.add_zombie(60, 9)
+ self.watcher._sig_chld()
+
+ callback.assert_called_once_with(60, 9)
+
+ @waitpid_mocks
+ def test_set_loop_race_condition(self, m):
+ # register 3 children
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+ callback3 = unittest.mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ self.watcher.add_child_handler(61, callback1)
+ self.watcher.add_child_handler(62, callback2)
+ self.watcher.add_child_handler(622, callback3)
+
+ # detach the loop
+ old_loop = self.loop
+ self.loop = None
+
+ with unittest.mock.patch.object(
+ old_loop, "remove_signal_handler") as m_remove_signal_handler:
+
+ self.watcher.attach_loop(None)
+
+ m_remove_signal_handler.assert_called_once_with(
+ signal.SIGCHLD)
+
+ # child 1 & 2 terminate
+ self.add_zombie(61, 11)
+ self.add_zombie(62, -5)
+
+ # SIGCHLD was not catched
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ self.assertFalse(callback3.called)
+
+ # attach a new loop
+ self.loop = test_utils.TestLoop()
+
+ with unittest.mock.patch.object(
+ self.loop, "add_signal_handler") as m_add_signal_handler:
+
+ self.watcher.attach_loop(self.loop)
+
+ m_add_signal_handler.assert_called_once_with(
+ signal.SIGCHLD, self.watcher._sig_chld)
+ callback1.assert_called_once_with(61, 11) # race condition!
+ callback2.assert_called_once_with(62, -5) # race condition!
+ self.assertFalse(callback3.called)
+
+ callback1.reset_mock()
+ callback2.reset_mock()
+
+ # child 3 terminates
+ self.running = False
+ self.add_zombie(622, 19)
+ self.watcher._sig_chld()
+
+ self.assertFalse(callback1.called)
+ self.assertFalse(callback2.called)
+ callback3.assert_called_once_with(622, 19)
+
+ @waitpid_mocks
+ def test_close(self, m):
+ # register two children
+ callback1 = unittest.mock.Mock()
+ callback2 = unittest.mock.Mock()
+
+ with self.watcher:
+ self.running = True
+ # child 1 terminates
+ self.add_zombie(63, 9)
+ # other child terminates
+ self.add_zombie(65, 18)
+ self.watcher._sig_chld()
+
+ self.watcher.add_child_handler(63, callback1)
+ self.watcher.add_child_handler(64, callback1)
+
+ self.assertEqual(len(self.watcher._callbacks), 1)
+ if isinstance(self.watcher, unix_events.FastChildWatcher):
+ self.assertEqual(len(self.watcher._zombies), 1)
+
+ with unittest.mock.patch.object(
+ self.loop,
+ "remove_signal_handler") as m_remove_signal_handler:
+
+ self.watcher.close()
+
+ m_remove_signal_handler.assert_called_once_with(
+ signal.SIGCHLD)
+ self.assertFalse(self.watcher._callbacks)
+ if isinstance(self.watcher, unix_events.FastChildWatcher):
+ self.assertFalse(self.watcher._zombies)
+
+
+class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
+ def create_watcher(self):
+ return unix_events.SafeChildWatcher()
+
+
+class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
+ def create_watcher(self):
+ return unix_events.FastChildWatcher()
+
+
+class PolicyTests(unittest.TestCase):
+
+ def create_policy(self):
+ return unix_events.DefaultEventLoopPolicy()
+
+ def test_get_child_watcher(self):
+ policy = self.create_policy()
+ self.assertIsNone(policy._watcher)
+
+ watcher = policy.get_child_watcher()
+ self.assertIsInstance(watcher, unix_events.SafeChildWatcher)
+
+ self.assertIs(policy._watcher, watcher)
+
+ self.assertIs(watcher, policy.get_child_watcher())
+ self.assertIsNone(watcher._loop)
+
+ def test_get_child_watcher_after_set(self):
+ policy = self.create_policy()
+ watcher = unix_events.FastChildWatcher()
+
+ policy.set_child_watcher(watcher)
+ self.assertIs(policy._watcher, watcher)
+ self.assertIs(watcher, policy.get_child_watcher())
+
+ def test_get_child_watcher_with_mainloop_existing(self):
+ policy = self.create_policy()
+ loop = policy.get_event_loop()
+
+ self.assertIsNone(policy._watcher)
+ watcher = policy.get_child_watcher()
+
+ self.assertIsInstance(watcher, unix_events.SafeChildWatcher)
+ self.assertIs(watcher._loop, loop)
+
+ loop.close()
+
+ def test_get_child_watcher_thread(self):
+
+ def f():
+ policy.set_event_loop(policy.new_event_loop())
+
+ self.assertIsInstance(policy.get_event_loop(),
+ events.AbstractEventLoop)
+ watcher = policy.get_child_watcher()
+
+ self.assertIsInstance(watcher, unix_events.SafeChildWatcher)
+ self.assertIsNone(watcher._loop)
+
+ policy.get_event_loop().close()
+
+ policy = self.create_policy()
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ def test_child_watcher_replace_mainloop_existing(self):
+ policy = self.create_policy()
+ loop = policy.get_event_loop()
+
+ watcher = policy.get_child_watcher()
+
+ self.assertIs(watcher._loop, loop)
+
+ new_loop = policy.new_event_loop()
+ policy.set_event_loop(new_loop)
+
+ self.assertIs(watcher._loop, new_loop)
+
+ policy.set_event_loop(None)
+
+ self.assertIs(watcher._loop, None)
+
+ loop.close()
+ new_loop.close()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py
new file mode 100644
index 0000000000..f5147de25d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_events.py
@@ -0,0 +1,139 @@
+import os
+import sys
+import unittest
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import _winapi
+
+import asyncio
+
+from asyncio import windows_events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import streams
+from asyncio import transports
+from asyncio import test_utils
+from asyncio import _overlapped
+
+
+class UpperProto(protocols.Protocol):
+ def __init__(self):
+ self.buf = []
+
+ def connection_made(self, trans):
+ self.trans = trans
+
+ def data_received(self, data):
+ self.buf.append(data)
+ if b'\n' in data:
+ self.trans.write(b''.join(self.buf).upper())
+ self.trans.close()
+
+
+class ProactorTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = windows_events.ProactorEventLoop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ self.loop = None
+
+ def test_close(self):
+ a, b = self.loop._socketpair()
+ trans = self.loop._make_socket_transport(a, protocols.Protocol())
+ f = asyncio.async(self.loop.sock_recv(b, 100))
+ trans.close()
+ self.loop.run_until_complete(f)
+ self.assertEqual(f.result(), b'')
+
+ def test_double_bind(self):
+ ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid()
+ server1 = windows_events.PipeServer(ADDRESS)
+ with self.assertRaises(PermissionError):
+ server2 = windows_events.PipeServer(ADDRESS)
+ server1.close()
+
+ def test_pipe(self):
+ res = self.loop.run_until_complete(self._test_pipe())
+ self.assertEqual(res, 'done')
+
+ def _test_pipe(self):
+ ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ protocols.Protocol, ADDRESS)
+
+ [server] = yield from self.loop.start_serving_pipe(
+ UpperProto, ADDRESS)
+ self.assertIsInstance(server, windows_events.PipeServer)
+
+ clients = []
+ for i in range(5):
+ stream_reader = streams.StreamReader(loop=self.loop)
+ protocol = streams.StreamReaderProtocol(stream_reader)
+ trans, proto = yield from self.loop.create_pipe_connection(
+ lambda: protocol, ADDRESS)
+ self.assertIsInstance(trans, transports.Transport)
+ self.assertEqual(protocol, proto)
+ clients.append((stream_reader, trans))
+
+ for i, (r, w) in enumerate(clients):
+ w.write('lower-{}\n'.format(i).encode())
+
+ for i, (r, w) in enumerate(clients):
+ response = yield from r.readline()
+ self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
+ w.close()
+
+ server.close()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ protocols.Protocol, ADDRESS)
+
+ return 'done'
+
+ def test_wait_for_handle(self):
+ event = _overlapped.CreateEvent(None, True, False, None)
+ self.addCleanup(_winapi.CloseHandle, event)
+
+ # Wait for unset event with 0.2s timeout;
+ # result should be False at timeout
+ f = self.loop._proactor.wait_for_handle(event, 0.2)
+ start = self.loop.time()
+ self.loop.run_until_complete(f)
+ elapsed = self.loop.time() - start
+ self.assertFalse(f.result())
+ self.assertTrue(0.18 < elapsed < 0.22, elapsed)
+
+ _overlapped.SetEvent(event)
+
+ # Wait for for set event;
+ # result should be True immediately
+ f = self.loop._proactor.wait_for_handle(event, 10)
+ start = self.loop.time()
+ self.loop.run_until_complete(f)
+ elapsed = self.loop.time() - start
+ self.assertTrue(f.result())
+ self.assertTrue(0 <= elapsed < 0.02, elapsed)
+
+ _overlapped.ResetEvent(event)
+
+ # Wait for unset event with a cancelled future;
+ # CancelledError should be raised immediately
+ f = self.loop._proactor.wait_for_handle(event, 10)
+ f.cancel()
+ start = self.loop.time()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(f)
+ elapsed = self.loop.time() - start
+ self.assertTrue(0 <= elapsed < 0.02, elapsed)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py
new file mode 100644
index 0000000000..fa9d66c021
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_utils.py
@@ -0,0 +1,141 @@
+"""Tests for window_utils"""
+
+import sys
+import test.support
+import unittest
+import unittest.mock
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import _winapi
+
+from asyncio import windows_utils
+from asyncio import _overlapped
+
+
+class WinsocketpairTests(unittest.TestCase):
+
+ def test_winsocketpair(self):
+ ssock, csock = windows_utils.socketpair()
+
+ csock.send(b'xxx')
+ self.assertEqual(b'xxx', ssock.recv(1024))
+
+ csock.close()
+ ssock.close()
+
+ @unittest.mock.patch('asyncio.windows_utils.socket')
+ def test_winsocketpair_exc(self, m_socket):
+ m_socket.socket.return_value.getsockname.return_value = ('', 12345)
+ m_socket.socket.return_value.accept.return_value = object(), object()
+ m_socket.socket.return_value.connect.side_effect = OSError()
+
+ self.assertRaises(OSError, windows_utils.socketpair)
+
+
+class PipeTests(unittest.TestCase):
+
+ def test_pipe_overlapped(self):
+ h1, h2 = windows_utils.pipe(overlapped=(True, True))
+ try:
+ ov1 = _overlapped.Overlapped()
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, 0)
+
+ ov1.ReadFile(h1, 100)
+ self.assertTrue(ov1.pending)
+ self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING)
+ ERROR_IO_INCOMPLETE = 996
+ try:
+ ov1.getresult()
+ except OSError as e:
+ self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE)
+ else:
+ raise RuntimeError('expected ERROR_IO_INCOMPLETE')
+
+ ov2 = _overlapped.Overlapped()
+ self.assertFalse(ov2.pending)
+ self.assertEqual(ov2.error, 0)
+
+ ov2.WriteFile(h2, b"hello")
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+
+ res = _winapi.WaitForMultipleObjects([ov2.event], False, 100)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE)
+ self.assertFalse(ov2.pending)
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+ self.assertEqual(ov1.getresult(), b"hello")
+ finally:
+ _winapi.CloseHandle(h1)
+ _winapi.CloseHandle(h2)
+
+ def test_pipe_handle(self):
+ h, _ = windows_utils.pipe(overlapped=(True, True))
+ _winapi.CloseHandle(_)
+ p = windows_utils.PipeHandle(h)
+ self.assertEqual(p.fileno(), h)
+ self.assertEqual(p.handle, h)
+
+ # check garbage collection of p closes handle
+ del p
+ test.support.gc_collect()
+ try:
+ _winapi.CloseHandle(h)
+ except OSError as e:
+ self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE
+ else:
+ raise RuntimeError('expected ERROR_INVALID_HANDLE')
+
+
+class PopenTests(unittest.TestCase):
+
+ def test_popen(self):
+ command = r"""if 1:
+ import sys
+ s = sys.stdin.readline()
+ sys.stdout.write(s.upper())
+ sys.stderr.write('stderr')
+ """
+ msg = b"blah\n"
+
+ p = windows_utils.Popen([sys.executable, '-c', command],
+ stdin=windows_utils.PIPE,
+ stdout=windows_utils.PIPE,
+ stderr=windows_utils.PIPE)
+
+ for f in [p.stdin, p.stdout, p.stderr]:
+ self.assertIsInstance(f, windows_utils.PipeHandle)
+
+ ovin = _overlapped.Overlapped()
+ ovout = _overlapped.Overlapped()
+ overr = _overlapped.Overlapped()
+
+ ovin.WriteFile(p.stdin.handle, msg)
+ ovout.ReadFile(p.stdout.handle, 100)
+ overr.ReadFile(p.stderr.handle, 100)
+
+ events = [ovin.event, ovout.event, overr.event]
+ # Super-long timeout for slow buildbots.
+ res = _winapi.WaitForMultipleObjects(events, True, 10000)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+ self.assertFalse(ovout.pending)
+ self.assertFalse(overr.pending)
+ self.assertFalse(ovin.pending)
+
+ self.assertEqual(ovin.getresult(), len(msg))
+ out = ovout.getresult().rstrip()
+ err = overr.getresult().rstrip()
+
+ self.assertGreater(len(out), 0)
+ self.assertGreater(len(err), 0)
+ # allow for partial reads...
+ self.assertTrue(msg.upper().rstrip().startswith(out))
+ self.assertTrue(b"stderr".startswith(err))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/tests.txt b/Lib/test/test_asyncio/tests.txt
new file mode 100644
index 0000000000..30609cd5fe
--- /dev/null
+++ b/Lib/test/test_asyncio/tests.txt
@@ -0,0 +1,13 @@
+test_asyncio.test_base_events
+test_asyncio.test_events
+test_asyncio.test_futures
+test_asyncio.test_locks
+test_asyncio.test_proactor_events
+test_asyncio.test_queues
+test_asyncio.test_selector_events
+test_asyncio.test_streams
+test_asyncio.test_tasks
+test_asyncio.test_transports
+test_asyncio.test_unix_events
+test_asyncio.test_windows_events
+test_asyncio.test_windows_utils
diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py
index 5d0632ef26..084d247295 100644
--- a/Lib/test/test_asyncore.py
+++ b/Lib/test/test_asyncore.py
@@ -19,6 +19,7 @@ try:
except ImportError:
threading = None
+TIMEOUT = 3
HAS_UNIX_SOCKETS = hasattr(socket, 'AF_UNIX')
class dummysocket:
@@ -395,7 +396,10 @@ class DispatcherWithSendTests(unittest.TestCase):
self.assertEqual(cap.getvalue(), data*2)
finally:
- t.join()
+ t.join(timeout=TIMEOUT)
+ if t.is_alive():
+ self.fail("join() timed out")
+
class DispatcherWithSendTests_UsePoll(DispatcherWithSendTests):
@@ -740,7 +744,12 @@ class BaseTestAPI:
s.create_socket(self.family)
self.assertEqual(s.socket.family, self.family)
SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
- self.assertEqual(s.socket.type, socket.SOCK_STREAM | SOCK_NONBLOCK)
+ sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK
+ if hasattr(socket, 'SOCK_CLOEXEC'):
+ self.assertIn(s.socket.type,
+ (sock_type | socket.SOCK_CLOEXEC, sock_type))
+ else:
+ self.assertEqual(s.socket.type, sock_type)
def test_bind(self):
if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
@@ -754,7 +763,7 @@ class BaseTestAPI:
s2 = asyncore.dispatcher()
s2.create_socket(self.family)
# EADDRINUSE indicates the socket was correctly bound
- self.assertRaises(socket.error, s2.bind, (self.addr[0], port))
+ self.assertRaises(OSError, s2.bind, (self.addr[0], port))
def test_set_reuse_addr(self):
if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
@@ -762,7 +771,7 @@ class BaseTestAPI:
sock = socket.socket(self.family)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- except socket.error:
+ except OSError:
unittest.skip("SO_REUSEADDR not supported on this platform")
else:
# if SO_REUSEADDR succeeded for sock we expect asyncore
@@ -787,7 +796,11 @@ class BaseTestAPI:
t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1,
count=500))
t.start()
- self.addCleanup(t.join)
+ def cleanup():
+ t.join(timeout=TIMEOUT)
+ if t.is_alive():
+ self.fail("join() timed out")
+ self.addCleanup(cleanup)
s = socket.socket(self.family, socket.SOCK_STREAM)
s.settimeout(.2)
@@ -795,7 +808,7 @@ class BaseTestAPI:
struct.pack('ii', 1, 0))
try:
s.connect(server.address)
- except socket.error:
+ except OSError:
pass
finally:
s.close()
diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py
index 30c3b4a07b..b641015b70 100644
--- a/Lib/test/test_atexit.py
+++ b/Lib/test/test_atexit.py
@@ -2,6 +2,7 @@ import sys
import unittest
import io
import atexit
+import _testcapi
from test import support
### helpers
@@ -23,7 +24,9 @@ def raise1():
def raise2():
raise SystemError
-class TestCase(unittest.TestCase):
+
+class GeneralTest(unittest.TestCase):
+
def setUp(self):
self.save_stdout = sys.stdout
self.save_stderr = sys.stderr
@@ -141,8 +144,43 @@ class TestCase(unittest.TestCase):
self.assertEqual(l, [5])
+class SubinterpreterTest(unittest.TestCase):
+
+ def test_callbacks_leak(self):
+ # This test shows a leak in refleak mode if atexit doesn't
+ # take care to free callbacks in its per-subinterpreter module
+ # state.
+ n = atexit._ncallbacks()
+ code = r"""if 1:
+ import atexit
+ def f():
+ pass
+ atexit.register(f)
+ del atexit
+ """
+ ret = _testcapi.run_in_subinterp(code)
+ self.assertEqual(ret, 0)
+ self.assertEqual(atexit._ncallbacks(), n)
+
+ def test_callbacks_leak_refcycle(self):
+ # Similar to the above, but with a refcycle through the atexit
+ # module.
+ n = atexit._ncallbacks()
+ code = r"""if 1:
+ import atexit
+ def f():
+ pass
+ atexit.register(f)
+ atexit.__atexit = atexit
+ """
+ ret = _testcapi.run_in_subinterp(code)
+ self.assertEqual(ret, 0)
+ self.assertEqual(atexit._ncallbacks(), n)
+
+
def test_main():
- support.run_unittest(TestCase)
+ support.run_unittest(__name__)
+
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_audioop.py b/Lib/test/test_audioop.py
index a92cf874bd..fe96b75dfa 100644
--- a/Lib/test/test_audioop.py
+++ b/Lib/test/test_audioop.py
@@ -5,13 +5,18 @@ import unittest
def pack(width, data):
return b''.join(v.to_bytes(width, sys.byteorder, signed=True) for v in data)
-packs = {w: (lambda *data, width=w: pack(width, data)) for w in (1, 2, 4)}
-maxvalues = {w: (1 << (8 * w - 1)) - 1 for w in (1, 2, 4)}
-minvalues = {w: -1 << (8 * w - 1) for w in (1, 2, 4)}
+def unpack(width, data):
+ return [int.from_bytes(data[i: i + width], sys.byteorder, signed=True)
+ for i in range(0, len(data), width)]
+
+packs = {w: (lambda *data, width=w: pack(width, data)) for w in (1, 2, 3, 4)}
+maxvalues = {w: (1 << (8 * w - 1)) - 1 for w in (1, 2, 3, 4)}
+minvalues = {w: -1 << (8 * w - 1) for w in (1, 2, 3, 4)}
datas = {
1: b'\x00\x12\x45\xbb\x7f\x80\xff',
2: packs[2](0, 0x1234, 0x4567, -0x4567, 0x7fff, -0x8000, -1),
+ 3: packs[3](0, 0x123456, 0x456789, -0x456789, 0x7fffff, -0x800000, -1),
4: packs[4](0, 0x12345678, 0x456789ab, -0x456789ab,
0x7fffffff, -0x80000000, -1),
}
@@ -19,6 +24,7 @@ datas = {
INVALID_DATA = [
(b'abc', 0),
(b'abc', 2),
+ (b'ab', 3),
(b'abc', 4),
]
@@ -26,8 +32,10 @@ INVALID_DATA = [
class TestAudioop(unittest.TestCase):
def test_max(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.max(b'', w), 0)
+ self.assertEqual(audioop.max(bytearray(), w), 0)
+ self.assertEqual(audioop.max(memoryview(b''), w), 0)
p = packs[w]
self.assertEqual(audioop.max(p(5), w), 5)
self.assertEqual(audioop.max(p(5, -8, -1), w), 8)
@@ -36,9 +44,13 @@ class TestAudioop(unittest.TestCase):
self.assertEqual(audioop.max(datas[w], w), -minvalues[w])
def test_minmax(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.minmax(b'', w),
(0x7fffffff, -0x80000000))
+ self.assertEqual(audioop.minmax(bytearray(), w),
+ (0x7fffffff, -0x80000000))
+ self.assertEqual(audioop.minmax(memoryview(b''), w),
+ (0x7fffffff, -0x80000000))
p = packs[w]
self.assertEqual(audioop.minmax(p(5), w), (5, 5))
self.assertEqual(audioop.minmax(p(5, -8, -1), w), (-8, 5))
@@ -50,16 +62,20 @@ class TestAudioop(unittest.TestCase):
(minvalues[w], maxvalues[w]))
def test_maxpp(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.maxpp(b'', w), 0)
+ self.assertEqual(audioop.maxpp(bytearray(), w), 0)
+ self.assertEqual(audioop.maxpp(memoryview(b''), w), 0)
self.assertEqual(audioop.maxpp(packs[w](*range(100)), w), 0)
self.assertEqual(audioop.maxpp(packs[w](9, 10, 5, 5, 0, 1), w), 10)
self.assertEqual(audioop.maxpp(datas[w], w),
maxvalues[w] - minvalues[w])
def test_avg(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.avg(b'', w), 0)
+ self.assertEqual(audioop.avg(bytearray(), w), 0)
+ self.assertEqual(audioop.avg(memoryview(b''), w), 0)
p = packs[w]
self.assertEqual(audioop.avg(p(5), w), 5)
self .assertEqual(audioop.avg(p(5, 8), w), 6)
@@ -74,17 +90,22 @@ class TestAudioop(unittest.TestCase):
-0x60000000)
def test_avgpp(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.avgpp(b'', w), 0)
+ self.assertEqual(audioop.avgpp(bytearray(), w), 0)
+ self.assertEqual(audioop.avgpp(memoryview(b''), w), 0)
self.assertEqual(audioop.avgpp(packs[w](*range(100)), w), 0)
self.assertEqual(audioop.avgpp(packs[w](9, 10, 5, 5, 0, 1), w), 10)
self.assertEqual(audioop.avgpp(datas[1], 1), 196)
self.assertEqual(audioop.avgpp(datas[2], 2), 50534)
+ self.assertEqual(audioop.avgpp(datas[3], 3), 12937096)
self.assertEqual(audioop.avgpp(datas[4], 4), 3311897002)
def test_rms(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.rms(b'', w), 0)
+ self.assertEqual(audioop.rms(bytearray(), w), 0)
+ self.assertEqual(audioop.rms(memoryview(b''), w), 0)
p = packs[w]
self.assertEqual(audioop.rms(p(*range(100)), w), 57)
self.assertAlmostEqual(audioop.rms(p(maxvalues[w]) * 5, w),
@@ -93,11 +114,14 @@ class TestAudioop(unittest.TestCase):
-minvalues[w], delta=1)
self.assertEqual(audioop.rms(datas[1], 1), 77)
self.assertEqual(audioop.rms(datas[2], 2), 20001)
+ self.assertEqual(audioop.rms(datas[3], 3), 5120523)
self.assertEqual(audioop.rms(datas[4], 4), 1310854152)
def test_cross(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.cross(b'', w), -1)
+ self.assertEqual(audioop.cross(bytearray(), w), -1)
+ self.assertEqual(audioop.cross(memoryview(b''), w), -1)
p = packs[w]
self.assertEqual(audioop.cross(p(0, 1, 2), w), 0)
self.assertEqual(audioop.cross(p(1, 2, -3, -4), w), 1)
@@ -106,22 +130,29 @@ class TestAudioop(unittest.TestCase):
self.assertEqual(audioop.cross(p(minvalues[w], maxvalues[w]), w), 1)
def test_add(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.add(b'', b'', w), b'')
+ self.assertEqual(audioop.add(bytearray(), bytearray(), w), b'')
+ self.assertEqual(audioop.add(memoryview(b''), memoryview(b''), w), b'')
self.assertEqual(audioop.add(datas[w], b'\0' * len(datas[w]), w),
datas[w])
self.assertEqual(audioop.add(datas[1], datas[1], 1),
b'\x00\x24\x7f\x80\x7f\x80\xfe')
self.assertEqual(audioop.add(datas[2], datas[2], 2),
packs[2](0, 0x2468, 0x7fff, -0x8000, 0x7fff, -0x8000, -2))
+ self.assertEqual(audioop.add(datas[3], datas[3], 3),
+ packs[3](0, 0x2468ac, 0x7fffff, -0x800000,
+ 0x7fffff, -0x800000, -2))
self.assertEqual(audioop.add(datas[4], datas[4], 4),
packs[4](0, 0x2468acf0, 0x7fffffff, -0x80000000,
0x7fffffff, -0x80000000, -2))
def test_bias(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
for bias in 0, 1, -1, 127, -128, 0x7fffffff, -0x80000000:
self.assertEqual(audioop.bias(b'', w, bias), b'')
+ self.assertEqual(audioop.bias(bytearray(), w, bias), b'')
+ self.assertEqual(audioop.bias(memoryview(b''), w, bias), b'')
self.assertEqual(audioop.bias(datas[1], 1, 1),
b'\x01\x13\x46\xbc\x80\x81\x00')
self.assertEqual(audioop.bias(datas[1], 1, -1),
@@ -138,6 +169,17 @@ class TestAudioop(unittest.TestCase):
packs[2](-1, 0x1233, 0x4566, -0x4568, 0x7ffe, 0x7fff, -2))
self.assertEqual(audioop.bias(datas[2], 2, -0x80000000),
datas[2])
+ self.assertEqual(audioop.bias(datas[3], 3, 1),
+ packs[3](1, 0x123457, 0x45678a, -0x456788,
+ -0x800000, -0x7fffff, 0))
+ self.assertEqual(audioop.bias(datas[3], 3, -1),
+ packs[3](-1, 0x123455, 0x456788, -0x45678a,
+ 0x7ffffe, 0x7fffff, -2))
+ self.assertEqual(audioop.bias(datas[3], 3, 0x7fffffff),
+ packs[3](-1, 0x123455, 0x456788, -0x45678a,
+ 0x7ffffe, 0x7fffff, -2))
+ self.assertEqual(audioop.bias(datas[3], 3, -0x80000000),
+ datas[3])
self.assertEqual(audioop.bias(datas[4], 4, 1),
packs[4](1, 0x12345679, 0x456789ac, -0x456789aa,
-0x80000000, -0x7fffffff, 0))
@@ -152,99 +194,141 @@ class TestAudioop(unittest.TestCase):
-1, 0, 0x7fffffff))
def test_lin2lin(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.lin2lin(datas[w], w, w), datas[w])
+ self.assertEqual(audioop.lin2lin(bytearray(datas[w]), w, w),
+ datas[w])
+ self.assertEqual(audioop.lin2lin(memoryview(datas[w]), w, w),
+ datas[w])
self.assertEqual(audioop.lin2lin(datas[1], 1, 2),
packs[2](0, 0x1200, 0x4500, -0x4500, 0x7f00, -0x8000, -0x100))
+ self.assertEqual(audioop.lin2lin(datas[1], 1, 3),
+ packs[3](0, 0x120000, 0x450000, -0x450000,
+ 0x7f0000, -0x800000, -0x10000))
self.assertEqual(audioop.lin2lin(datas[1], 1, 4),
packs[4](0, 0x12000000, 0x45000000, -0x45000000,
0x7f000000, -0x80000000, -0x1000000))
self.assertEqual(audioop.lin2lin(datas[2], 2, 1),
b'\x00\x12\x45\xba\x7f\x80\xff')
+ self.assertEqual(audioop.lin2lin(datas[2], 2, 3),
+ packs[3](0, 0x123400, 0x456700, -0x456700,
+ 0x7fff00, -0x800000, -0x100))
self.assertEqual(audioop.lin2lin(datas[2], 2, 4),
packs[4](0, 0x12340000, 0x45670000, -0x45670000,
0x7fff0000, -0x80000000, -0x10000))
+ self.assertEqual(audioop.lin2lin(datas[3], 3, 1),
+ b'\x00\x12\x45\xba\x7f\x80\xff')
+ self.assertEqual(audioop.lin2lin(datas[3], 3, 2),
+ packs[2](0, 0x1234, 0x4567, -0x4568, 0x7fff, -0x8000, -1))
+ self.assertEqual(audioop.lin2lin(datas[3], 3, 4),
+ packs[4](0, 0x12345600, 0x45678900, -0x45678900,
+ 0x7fffff00, -0x80000000, -0x100))
self.assertEqual(audioop.lin2lin(datas[4], 4, 1),
b'\x00\x12\x45\xba\x7f\x80\xff')
self.assertEqual(audioop.lin2lin(datas[4], 4, 2),
packs[2](0, 0x1234, 0x4567, -0x4568, 0x7fff, -0x8000, -1))
+ self.assertEqual(audioop.lin2lin(datas[4], 4, 3),
+ packs[3](0, 0x123456, 0x456789, -0x45678a,
+ 0x7fffff, -0x800000, -1))
def test_adpcm2lin(self):
self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 1, None),
(b'\x00\x00\x00\xff\x00\xff', (-179, 40)))
+ self.assertEqual(audioop.adpcm2lin(bytearray(b'\x07\x7f\x7f'), 1, None),
+ (b'\x00\x00\x00\xff\x00\xff', (-179, 40)))
+ self.assertEqual(audioop.adpcm2lin(memoryview(b'\x07\x7f\x7f'), 1, None),
+ (b'\x00\x00\x00\xff\x00\xff', (-179, 40)))
self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 2, None),
(packs[2](0, 0xb, 0x29, -0x16, 0x72, -0xb3), (-179, 40)))
+ self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 3, None),
+ (packs[3](0, 0xb00, 0x2900, -0x1600, 0x7200,
+ -0xb300), (-179, 40)))
self.assertEqual(audioop.adpcm2lin(b'\x07\x7f\x7f', 4, None),
(packs[4](0, 0xb0000, 0x290000, -0x160000, 0x720000,
-0xb30000), (-179, 40)))
# Very cursory test
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.adpcm2lin(b'\0' * 5, w, None),
(b'\0' * w * 10, (0, 0)))
def test_lin2adpcm(self):
self.assertEqual(audioop.lin2adpcm(datas[1], 1, None),
(b'\x07\x7f\x7f', (-221, 39)))
- self.assertEqual(audioop.lin2adpcm(datas[2], 2, None),
- (b'\x07\x7f\x7f', (31, 39)))
- self.assertEqual(audioop.lin2adpcm(datas[4], 4, None),
- (b'\x07\x7f\x7f', (31, 39)))
+ self.assertEqual(audioop.lin2adpcm(bytearray(datas[1]), 1, None),
+ (b'\x07\x7f\x7f', (-221, 39)))
+ self.assertEqual(audioop.lin2adpcm(memoryview(datas[1]), 1, None),
+ (b'\x07\x7f\x7f', (-221, 39)))
+ for w in 2, 3, 4:
+ self.assertEqual(audioop.lin2adpcm(datas[w], w, None),
+ (b'\x07\x7f\x7f', (31, 39)))
# Very cursory test
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.lin2adpcm(b'\0' * w * 10, w, None),
(b'\0' * 5, (0, 0)))
def test_lin2alaw(self):
self.assertEqual(audioop.lin2alaw(datas[1], 1),
b'\xd5\x87\xa4\x24\xaa\x2a\x5a')
- self.assertEqual(audioop.lin2alaw(datas[2], 2),
- b'\xd5\x87\xa4\x24\xaa\x2a\x55')
- self.assertEqual(audioop.lin2alaw(datas[4], 4),
- b'\xd5\x87\xa4\x24\xaa\x2a\x55')
+ self.assertEqual(audioop.lin2alaw(bytearray(datas[1]), 1),
+ b'\xd5\x87\xa4\x24\xaa\x2a\x5a')
+ self.assertEqual(audioop.lin2alaw(memoryview(datas[1]), 1),
+ b'\xd5\x87\xa4\x24\xaa\x2a\x5a')
+ for w in 2, 3, 4:
+ self.assertEqual(audioop.lin2alaw(datas[w], w),
+ b'\xd5\x87\xa4\x24\xaa\x2a\x55')
def test_alaw2lin(self):
encoded = b'\x00\x03\x24\x2a\x51\x54\x55\x58\x6b\x71\x7f'\
b'\x80\x83\xa4\xaa\xd1\xd4\xd5\xd8\xeb\xf1\xff'
src = [-688, -720, -2240, -4032, -9, -3, -1, -27, -244, -82, -106,
688, 720, 2240, 4032, 9, 3, 1, 27, 244, 82, 106]
- for w in 1, 2, 4:
- self.assertEqual(audioop.alaw2lin(encoded, w),
- packs[w](*(x << (w * 8) >> 13 for x in src)))
+ for w in 1, 2, 3, 4:
+ decoded = packs[w](*(x << (w * 8) >> 13 for x in src))
+ self.assertEqual(audioop.alaw2lin(encoded, w), decoded)
+ self.assertEqual(audioop.alaw2lin(bytearray(encoded), w), decoded)
+ self.assertEqual(audioop.alaw2lin(memoryview(encoded), w), decoded)
encoded = bytes(range(256))
- for w in 2, 4:
+ for w in 2, 3, 4:
decoded = audioop.alaw2lin(encoded, w)
self.assertEqual(audioop.lin2alaw(decoded, w), encoded)
def test_lin2ulaw(self):
self.assertEqual(audioop.lin2ulaw(datas[1], 1),
b'\xff\xad\x8e\x0e\x80\x00\x67')
- self.assertEqual(audioop.lin2ulaw(datas[2], 2),
- b'\xff\xad\x8e\x0e\x80\x00\x7e')
- self.assertEqual(audioop.lin2ulaw(datas[4], 4),
- b'\xff\xad\x8e\x0e\x80\x00\x7e')
+ self.assertEqual(audioop.lin2ulaw(bytearray(datas[1]), 1),
+ b'\xff\xad\x8e\x0e\x80\x00\x67')
+ self.assertEqual(audioop.lin2ulaw(memoryview(datas[1]), 1),
+ b'\xff\xad\x8e\x0e\x80\x00\x67')
+ for w in 2, 3, 4:
+ self.assertEqual(audioop.lin2ulaw(datas[w], w),
+ b'\xff\xad\x8e\x0e\x80\x00\x7e')
def test_ulaw2lin(self):
encoded = b'\x00\x0e\x28\x3f\x57\x6a\x76\x7c\x7e\x7f'\
b'\x80\x8e\xa8\xbf\xd7\xea\xf6\xfc\xfe\xff'
src = [-8031, -4447, -1471, -495, -163, -53, -18, -6, -2, 0,
8031, 4447, 1471, 495, 163, 53, 18, 6, 2, 0]
- for w in 1, 2, 4:
- self.assertEqual(audioop.ulaw2lin(encoded, w),
- packs[w](*(x << (w * 8) >> 14 for x in src)))
+ for w in 1, 2, 3, 4:
+ decoded = packs[w](*(x << (w * 8) >> 14 for x in src))
+ self.assertEqual(audioop.ulaw2lin(encoded, w), decoded)
+ self.assertEqual(audioop.ulaw2lin(bytearray(encoded), w), decoded)
+ self.assertEqual(audioop.ulaw2lin(memoryview(encoded), w), decoded)
# Current u-law implementation has two codes fo 0: 0x7f and 0xff.
encoded = bytes(range(127)) + bytes(range(128, 256))
- for w in 2, 4:
+ for w in 2, 3, 4:
decoded = audioop.ulaw2lin(encoded, w)
self.assertEqual(audioop.lin2ulaw(decoded, w), encoded)
def test_mul(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.mul(b'', w, 2), b'')
+ self.assertEqual(audioop.mul(bytearray(), w, 2), b'')
+ self.assertEqual(audioop.mul(memoryview(b''), w, 2), b'')
self.assertEqual(audioop.mul(datas[w], w, 0),
b'\0' * len(datas[w]))
self.assertEqual(audioop.mul(datas[w], w, 1),
@@ -253,14 +337,21 @@ class TestAudioop(unittest.TestCase):
b'\x00\x24\x7f\x80\x7f\x80\xfe')
self.assertEqual(audioop.mul(datas[2], 2, 2),
packs[2](0, 0x2468, 0x7fff, -0x8000, 0x7fff, -0x8000, -2))
+ self.assertEqual(audioop.mul(datas[3], 3, 2),
+ packs[3](0, 0x2468ac, 0x7fffff, -0x800000,
+ 0x7fffff, -0x800000, -2))
self.assertEqual(audioop.mul(datas[4], 4, 2),
packs[4](0, 0x2468acf0, 0x7fffffff, -0x80000000,
0x7fffffff, -0x80000000, -2))
def test_ratecv(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.ratecv(b'', w, 1, 8000, 8000, None),
(b'', (-1, ((0, 0),))))
+ self.assertEqual(audioop.ratecv(bytearray(), w, 1, 8000, 8000, None),
+ (b'', (-1, ((0, 0),))))
+ self.assertEqual(audioop.ratecv(memoryview(b''), w, 1, 8000, 8000, None),
+ (b'', (-1, ((0, 0),))))
self.assertEqual(audioop.ratecv(b'', w, 5, 8000, 8000, None),
(b'', (-1, ((0, 0),) * 5)))
self.assertEqual(audioop.ratecv(b'', w, 1, 8000, 16000, None),
@@ -272,7 +363,7 @@ class TestAudioop(unittest.TestCase):
d2, state = audioop.ratecv(b'\x00\x01\x02', 1, 1, 8000, 16000, state)
self.assertEqual(d1 + d2, b'\000\000\001\001\002\001\000\000\001\001\002')
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
d0, state0 = audioop.ratecv(datas[w], w, 1, 8000, 16000, None)
d, state = b'', None
for i in range(0, len(datas[w]), w):
@@ -283,13 +374,15 @@ class TestAudioop(unittest.TestCase):
self.assertEqual(state, state0)
def test_reverse(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
self.assertEqual(audioop.reverse(b'', w), b'')
+ self.assertEqual(audioop.reverse(bytearray(), w), b'')
+ self.assertEqual(audioop.reverse(memoryview(b''), w), b'')
self.assertEqual(audioop.reverse(packs[w](0, 1, 2), w),
packs[w](2, 1, 0))
def test_tomono(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
data1 = datas[w]
data2 = bytearray(2 * len(data1))
for k in range(w):
@@ -299,9 +392,13 @@ class TestAudioop(unittest.TestCase):
for k in range(w):
data2[k+w::2*w] = data1[k::w]
self.assertEqual(audioop.tomono(data2, w, 0.5, 0.5), data1)
+ self.assertEqual(audioop.tomono(bytearray(data2), w, 0.5, 0.5),
+ data1)
+ self.assertEqual(audioop.tomono(memoryview(data2), w, 0.5, 0.5),
+ data1)
def test_tostereo(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
data1 = datas[w]
data2 = bytearray(2 * len(data1))
for k in range(w):
@@ -311,14 +408,25 @@ class TestAudioop(unittest.TestCase):
for k in range(w):
data2[k+w::2*w] = data1[k::w]
self.assertEqual(audioop.tostereo(data1, w, 1, 1), data2)
+ self.assertEqual(audioop.tostereo(bytearray(data1), w, 1, 1), data2)
+ self.assertEqual(audioop.tostereo(memoryview(data1), w, 1, 1),
+ data2)
def test_findfactor(self):
self.assertEqual(audioop.findfactor(datas[2], datas[2]), 1.0)
+ self.assertEqual(audioop.findfactor(bytearray(datas[2]),
+ bytearray(datas[2])), 1.0)
+ self.assertEqual(audioop.findfactor(memoryview(datas[2]),
+ memoryview(datas[2])), 1.0)
self.assertEqual(audioop.findfactor(b'\0' * len(datas[2]), datas[2]),
0.0)
def test_findfit(self):
self.assertEqual(audioop.findfit(datas[2], datas[2]), (0, 1.0))
+ self.assertEqual(audioop.findfit(bytearray(datas[2]),
+ bytearray(datas[2])), (0, 1.0))
+ self.assertEqual(audioop.findfit(memoryview(datas[2]),
+ memoryview(datas[2])), (0, 1.0))
self.assertEqual(audioop.findfit(datas[2], packs[2](1, 2, 0)),
(1, 8038.8))
self.assertEqual(audioop.findfit(datas[2][:-2] * 5 + datas[2], datas[2]),
@@ -326,11 +434,15 @@ class TestAudioop(unittest.TestCase):
def test_findmax(self):
self.assertEqual(audioop.findmax(datas[2], 1), 5)
+ self.assertEqual(audioop.findmax(bytearray(datas[2]), 1), 5)
+ self.assertEqual(audioop.findmax(memoryview(datas[2]), 1), 5)
def test_getsample(self):
- for w in 1, 2, 4:
+ for w in 1, 2, 3, 4:
data = packs[w](0, 1, -1, maxvalues[w], minvalues[w])
self.assertEqual(audioop.getsample(data, w, 0), 0)
+ self.assertEqual(audioop.getsample(bytearray(data), w, 0), 0)
+ self.assertEqual(audioop.getsample(memoryview(data), w, 0), 0)
self.assertEqual(audioop.getsample(data, w, 1), 1)
self.assertEqual(audioop.getsample(data, w, 2), -1)
self.assertEqual(audioop.getsample(data, w, 3), maxvalues[w])
@@ -365,10 +477,33 @@ class TestAudioop(unittest.TestCase):
self.assertRaises(audioop.error, audioop.lin2alaw, data, size)
self.assertRaises(audioop.error, audioop.lin2adpcm, data, size, state)
+ def test_string(self):
+ data = 'abcd'
+ size = 2
+ self.assertRaises(TypeError, audioop.getsample, data, size, 0)
+ self.assertRaises(TypeError, audioop.max, data, size)
+ self.assertRaises(TypeError, audioop.minmax, data, size)
+ self.assertRaises(TypeError, audioop.avg, data, size)
+ self.assertRaises(TypeError, audioop.rms, data, size)
+ self.assertRaises(TypeError, audioop.avgpp, data, size)
+ self.assertRaises(TypeError, audioop.maxpp, data, size)
+ self.assertRaises(TypeError, audioop.cross, data, size)
+ self.assertRaises(TypeError, audioop.mul, data, size, 1.0)
+ self.assertRaises(TypeError, audioop.tomono, data, size, 0.5, 0.5)
+ self.assertRaises(TypeError, audioop.tostereo, data, size, 0.5, 0.5)
+ self.assertRaises(TypeError, audioop.add, data, data, size)
+ self.assertRaises(TypeError, audioop.bias, data, size, 0)
+ self.assertRaises(TypeError, audioop.reverse, data, size)
+ self.assertRaises(TypeError, audioop.lin2lin, data, size, size)
+ self.assertRaises(TypeError, audioop.ratecv, data, size, 1, 1, 1, None)
+ self.assertRaises(TypeError, audioop.lin2ulaw, data, size)
+ self.assertRaises(TypeError, audioop.lin2alaw, data, size)
+ self.assertRaises(TypeError, audioop.lin2adpcm, data, size, None)
+
def test_wrongsize(self):
data = b'abcdefgh'
state = None
- for size in (-1, 0, 3, 5, 1024):
+ for size in (-1, 0, 5, 1024):
self.assertRaises(audioop.error, audioop.ulaw2lin, data, size)
self.assertRaises(audioop.error, audioop.alaw2lin, data, size)
self.assertRaises(audioop.error, audioop.adpcm2lin, data, size, state)
diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py
index 13695de67e..54f392d4d6 100644
--- a/Lib/test/test_base64.py
+++ b/Lib/test/test_base64.py
@@ -5,10 +5,21 @@ import binascii
import os
import sys
import subprocess
-
+import struct
+from array import array
class LegacyBase64TestCase(unittest.TestCase):
+
+ # Legacy API is not as permissive as the modern API
+ def check_type_errors(self, f):
+ self.assertRaises(TypeError, f, "")
+ self.assertRaises(TypeError, f, [])
+ multidimensional = memoryview(b"1234").cast('B', (2, 2))
+ self.assertRaises(TypeError, f, multidimensional)
+ int_data = memoryview(b"1234").cast('I')
+ self.assertRaises(TypeError, f, int_data)
+
def test_encodebytes(self):
eq = self.assertEqual
eq(base64.encodebytes(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=\n")
@@ -24,7 +35,9 @@ class LegacyBase64TestCase(unittest.TestCase):
b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n")
# Non-bytes
eq(base64.encodebytes(bytearray(b'abc')), b'YWJj\n')
- self.assertRaises(TypeError, base64.encodebytes, "")
+ eq(base64.encodebytes(memoryview(b'abc')), b'YWJj\n')
+ eq(base64.encodebytes(array('B', b'abc')), b'YWJj\n')
+ self.check_type_errors(base64.encodebytes)
def test_decodebytes(self):
eq = self.assertEqual
@@ -41,7 +54,9 @@ class LegacyBase64TestCase(unittest.TestCase):
eq(base64.decodebytes(b''), b'')
# Non-bytes
eq(base64.decodebytes(bytearray(b'YWJj\n')), b'abc')
- self.assertRaises(TypeError, base64.decodebytes, "")
+ eq(base64.decodebytes(memoryview(b'YWJj\n')), b'abc')
+ eq(base64.decodebytes(array('B', b'YWJj\n')), b'abc')
+ self.check_type_errors(base64.decodebytes)
def test_encode(self):
eq = self.assertEqual
@@ -73,6 +88,38 @@ class LegacyBase64TestCase(unittest.TestCase):
class BaseXYTestCase(unittest.TestCase):
+
+ # Modern API completely ignores exported dimension and format data and
+ # treats any buffer as a stream of bytes
+ def check_encode_type_errors(self, f):
+ self.assertRaises(TypeError, f, "")
+ self.assertRaises(TypeError, f, [])
+
+ def check_decode_type_errors(self, f):
+ self.assertRaises(TypeError, f, [])
+
+ def check_other_types(self, f, bytes_data, expected):
+ eq = self.assertEqual
+ eq(f(bytearray(bytes_data)), expected)
+ eq(f(memoryview(bytes_data)), expected)
+ eq(f(array('B', bytes_data)), expected)
+ self.check_nonbyte_element_format(base64.b64encode, bytes_data)
+ self.check_multidimensional(base64.b64encode, bytes_data)
+
+ def check_multidimensional(self, f, data):
+ padding = b"\x00" if len(data) % 2 else b""
+ bytes_data = data + padding # Make sure cast works
+ shape = (len(bytes_data) // 2, 2)
+ multidimensional = memoryview(bytes_data).cast('B', shape)
+ self.assertEqual(f(multidimensional), f(bytes_data))
+
+ def check_nonbyte_element_format(self, f, data):
+ padding = b"\x00" * ((4 - len(data)) % 4)
+ bytes_data = data + padding # Make sure cast works
+ int_data = memoryview(bytes_data).cast('I')
+ self.assertEqual(f(int_data), f(bytes_data))
+
+
def test_b64encode(self):
eq = self.assertEqual
# Test default alphabet
@@ -90,13 +137,16 @@ class BaseXYTestCase(unittest.TestCase):
b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==")
# Test with arbitrary alternative characters
eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=b'*$'), b'01a*b$cd')
- # Non-bytes
- eq(base64.b64encode(bytearray(b'abcd')), b'YWJjZA==')
eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=bytearray(b'*$')),
b'01a*b$cd')
- # Check if passing a str object raises an error
- self.assertRaises(TypeError, base64.b64encode, "")
- self.assertRaises(TypeError, base64.b64encode, b"", altchars="")
+ eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=memoryview(b'*$')),
+ b'01a*b$cd')
+ eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=array('B', b'*$')),
+ b'01a*b$cd')
+ # Non-bytes
+ self.check_other_types(base64.b64encode, b'abcd', b'YWJjZA==')
+ self.check_encode_type_errors(base64.b64encode)
+ self.assertRaises(TypeError, base64.b64encode, b"", altchars="*$")
# Test standard alphabet
eq(base64.standard_b64encode(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=")
eq(base64.standard_b64encode(b"a"), b"YQ==")
@@ -110,15 +160,15 @@ class BaseXYTestCase(unittest.TestCase):
b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==")
# Non-bytes
- eq(base64.standard_b64encode(bytearray(b'abcd')), b'YWJjZA==')
- # Check if passing a str object raises an error
- self.assertRaises(TypeError, base64.standard_b64encode, "")
+ self.check_other_types(base64.standard_b64encode,
+ b'abcd', b'YWJjZA==')
+ self.check_encode_type_errors(base64.standard_b64encode)
# Test with 'URL safe' alternative characters
eq(base64.urlsafe_b64encode(b'\xd3V\xbeo\xf7\x1d'), b'01a-b_cd')
# Non-bytes
- eq(base64.urlsafe_b64encode(bytearray(b'\xd3V\xbeo\xf7\x1d')), b'01a-b_cd')
- # Check if passing a str object raises an error
- self.assertRaises(TypeError, base64.urlsafe_b64encode, "")
+ self.check_other_types(base64.urlsafe_b64encode,
+ b'\xd3V\xbeo\xf7\x1d', b'01a-b_cd')
+ self.check_encode_type_errors(base64.urlsafe_b64encode)
def test_b64decode(self):
eq = self.assertEqual
@@ -141,7 +191,8 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.b64decode(data), res)
eq(base64.b64decode(data.decode('ascii')), res)
# Non-bytes
- eq(base64.b64decode(bytearray(b"YWJj")), b"abc")
+ self.check_other_types(base64.b64decode, b"YWJj", b"abc")
+ self.check_decode_type_errors(base64.b64decode)
# Test with arbitrary alternative characters
tests_altchars = {(b'01a*b$cd', b'*$'): b'\xd3V\xbeo\xf7\x1d',
@@ -160,7 +211,8 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.standard_b64decode(data), res)
eq(base64.standard_b64decode(data.decode('ascii')), res)
# Non-bytes
- eq(base64.standard_b64decode(bytearray(b"YWJj")), b"abc")
+ self.check_other_types(base64.standard_b64decode, b"YWJj", b"abc")
+ self.check_decode_type_errors(base64.standard_b64decode)
# Test with 'URL safe' alternative characters
tests_urlsafe = {b'01a-b_cd': b'\xd3V\xbeo\xf7\x1d',
@@ -170,7 +222,9 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.urlsafe_b64decode(data), res)
eq(base64.urlsafe_b64decode(data.decode('ascii')), res)
# Non-bytes
- eq(base64.urlsafe_b64decode(bytearray(b'01a-b_cd')), b'\xd3V\xbeo\xf7\x1d')
+ self.check_other_types(base64.urlsafe_b64decode, b'01a-b_cd',
+ b'\xd3V\xbeo\xf7\x1d')
+ self.check_decode_type_errors(base64.urlsafe_b64decode)
def test_b64decode_padding_error(self):
self.assertRaises(binascii.Error, base64.b64decode, b'abc')
@@ -205,8 +259,8 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.b32encode(b'abcd'), b'MFRGGZA=')
eq(base64.b32encode(b'abcde'), b'MFRGGZDF')
# Non-bytes
- eq(base64.b32encode(bytearray(b'abcd')), b'MFRGGZA=')
- self.assertRaises(TypeError, base64.b32encode, "")
+ self.check_other_types(base64.b32encode, b'abcd', b'MFRGGZA=')
+ self.check_encode_type_errors(base64.b32encode)
def test_b32decode(self):
eq = self.assertEqual
@@ -222,7 +276,8 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.b32decode(data), res)
eq(base64.b32decode(data.decode('ascii')), res)
# Non-bytes
- eq(base64.b32decode(bytearray(b'MFRGG===')), b'abc')
+ self.check_other_types(base64.b32decode, b'MFRGG===', b"abc")
+ self.check_decode_type_errors(base64.b32decode)
def test_b32decode_casefold(self):
eq = self.assertEqual
@@ -277,8 +332,9 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.b16encode(b'\x01\x02\xab\xcd\xef'), b'0102ABCDEF')
eq(base64.b16encode(b'\x00'), b'00')
# Non-bytes
- eq(base64.b16encode(bytearray(b'\x01\x02\xab\xcd\xef')), b'0102ABCDEF')
- self.assertRaises(TypeError, base64.b16encode, "")
+ self.check_other_types(base64.b16encode, b'\x01\x02\xab\xcd\xef',
+ b'0102ABCDEF')
+ self.check_encode_type_errors(base64.b16encode)
def test_b16decode(self):
eq = self.assertEqual
@@ -293,7 +349,15 @@ class BaseXYTestCase(unittest.TestCase):
eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef')
eq(base64.b16decode('0102abcdef', True), b'\x01\x02\xab\xcd\xef')
# Non-bytes
- eq(base64.b16decode(bytearray(b"0102ABCDEF")), b'\x01\x02\xab\xcd\xef')
+ self.check_other_types(base64.b16decode, b"0102ABCDEF",
+ b'\x01\x02\xab\xcd\xef')
+ self.check_decode_type_errors(base64.b16decode)
+ eq(base64.b16decode(bytearray(b"0102abcdef"), True),
+ b'\x01\x02\xab\xcd\xef')
+ eq(base64.b16decode(memoryview(b"0102abcdef"), True),
+ b'\x01\x02\xab\xcd\xef')
+ eq(base64.b16decode(array('B', b"0102abcdef"), True),
+ b'\x01\x02\xab\xcd\xef')
def test_decode_nonascii_str(self):
decode_funcs = (base64.b64decode,
diff --git a/Lib/test/test_bisect.py b/Lib/test/test_bisect.py
index 7b9bd19fcb..580a963f62 100644
--- a/Lib/test/test_bisect.py
+++ b/Lib/test/test_bisect.py
@@ -7,7 +7,7 @@ py_bisect = support.import_fresh_module('bisect', blocked=['_bisect'])
c_bisect = support.import_fresh_module('bisect', fresh=['_bisect'])
class Range(object):
- """A trivial range()-like object without any integer width limitations."""
+ """A trivial range()-like object that has an insert() method."""
def __init__(self, start, stop):
self.start = start
self.stop = stop
@@ -120,10 +120,10 @@ class TestBisect:
def test_negative_lo(self):
# Issue 3301
mod = self.module
- self.assertRaises(ValueError, mod.bisect_left, [1, 2, 3], 5, -1, 3),
- self.assertRaises(ValueError, mod.bisect_right, [1, 2, 3], 5, -1, 3),
- self.assertRaises(ValueError, mod.insort_left, [1, 2, 3], 5, -1, 3),
- self.assertRaises(ValueError, mod.insort_right, [1, 2, 3], 5, -1, 3),
+ self.assertRaises(ValueError, mod.bisect_left, [1, 2, 3], 5, -1, 3)
+ self.assertRaises(ValueError, mod.bisect_right, [1, 2, 3], 5, -1, 3)
+ self.assertRaises(ValueError, mod.insort_left, [1, 2, 3], 5, -1, 3)
+ self.assertRaises(ValueError, mod.insort_right, [1, 2, 3], 5, -1, 3)
def test_large_range(self):
# Issue 13496
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py
index 04e3f61e89..1667847a9d 100644
--- a/Lib/test/test_buffer.py
+++ b/Lib/test/test_buffer.py
@@ -4290,9 +4290,5 @@ class TestBufferProtocol(unittest.TestCase):
self.assertRaises(BufferError, memoryview, x)
-def test_main():
- support.run_unittest(TestBufferProtocol)
-
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index c342a4329f..2411c9b311 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -464,6 +464,11 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, eval, ())
self.assertRaises(SyntaxError, eval, bom[:2] + b'a')
+ class X:
+ def __getitem__(self, key):
+ raise ValueError
+ self.assertRaises(ValueError, eval, "foo", {}, X())
+
def test_general_eval(self):
# Tests that general mappings can be used for the locals argument
@@ -579,7 +584,10 @@ class BuiltinTest(unittest.TestCase):
raise frozendict_error("frozendict is readonly")
# read-only builtins
- frozen_builtins = frozendict(__builtins__)
+ if isinstance(__builtins__, types.ModuleType):
+ frozen_builtins = frozendict(__builtins__.__dict__)
+ else:
+ frozen_builtins = frozendict(__builtins__)
code = compile("__builtins__['superglobal']=2; print(superglobal)", "test", "exec")
self.assertRaises(frozendict_error,
exec, code, {'__builtins__': frozen_builtins})
@@ -839,8 +847,19 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(max(1, 2.0, 3), 3)
self.assertEqual(max(1.0, 2, 3), 3)
+ self.assertRaises(TypeError, max)
+ self.assertRaises(TypeError, max, 42)
+ self.assertRaises(ValueError, max, ())
+ class BadSeq:
+ def __getitem__(self, index):
+ raise ValueError
+ self.assertRaises(ValueError, max, BadSeq())
+
for stmt in (
"max(key=int)", # no args
+ "max(default=None)",
+ "max(1, 2, default=None)", # require container for default
+ "max(default=None, key=int)",
"max(1, key=int)", # single arg not iterable
"max(1, 2, keystone=int)", # wrong keyword
"max(1, 2, key=int, abc=int)", # two many keywords
@@ -857,6 +876,13 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(max((1,2), key=neg), 1) # two elem iterable
self.assertEqual(max(1, 2, key=neg), 1) # two elems
+ self.assertEqual(max((), default=None), None) # zero elem iterable
+ self.assertEqual(max((1,), default=None), 1) # one elem iterable
+ self.assertEqual(max((1,2), default=None), 2) # two elem iterable
+
+ self.assertEqual(max((), default=1, key=neg), 1)
+ self.assertEqual(max((1, 2), default=3, key=neg), 1)
+
data = [random.randrange(200) for i in range(100)]
keys = dict((elem, random.randrange(50)) for elem in data)
f = keys.__getitem__
@@ -883,6 +909,9 @@ class BuiltinTest(unittest.TestCase):
for stmt in (
"min(key=int)", # no args
+ "min(default=None)",
+ "min(1, 2, default=None)", # require container for default
+ "min(default=None, key=int)",
"min(1, key=int)", # single arg not iterable
"min(1, 2, keystone=int)", # wrong keyword
"min(1, 2, key=int, abc=int)", # two many keywords
@@ -899,6 +928,13 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(min((1,2), key=neg), 2) # two elem iterable
self.assertEqual(min(1, 2, key=neg), 2) # two elems
+ self.assertEqual(min((), default=None), None) # zero elem iterable
+ self.assertEqual(min((1,), default=None), 1) # one elem iterable
+ self.assertEqual(min((1,2), default=None), 1) # two elem iterable
+
+ self.assertEqual(min((), default=1, key=neg), 1)
+ self.assertEqual(min((1, 2), default=1, key=neg), 2)
+
data = [random.randrange(200) for i in range(100)]
keys = dict((elem, random.randrange(50)) for elem in data)
f = keys.__getitem__
@@ -940,29 +976,25 @@ class BuiltinTest(unittest.TestCase):
def write_testfile(self):
# NB the first 4 lines are also used to test input, below
fp = open(TESTFN, 'w')
- try:
+ self.addCleanup(unlink, TESTFN)
+ with fp:
fp.write('1+1\n')
fp.write('The quick brown fox jumps over the lazy dog')
fp.write('.\n')
fp.write('Dear John\n')
fp.write('XXX'*100)
fp.write('YYY'*100)
- finally:
- fp.close()
def test_open(self):
self.write_testfile()
fp = open(TESTFN, 'r')
- try:
+ with fp:
self.assertEqual(fp.readline(4), '1+1\n')
self.assertEqual(fp.readline(), 'The quick brown fox jumps over the lazy dog.\n')
self.assertEqual(fp.readline(4), 'Dear')
self.assertEqual(fp.readline(100), ' John\n')
self.assertEqual(fp.read(300), 'XXX'*100)
self.assertEqual(fp.read(1000), 'YYY'*100)
- finally:
- fp.close()
- unlink(TESTFN)
def test_open_default_encoding(self):
old_environ = dict(os.environ)
@@ -977,15 +1009,17 @@ class BuiltinTest(unittest.TestCase):
self.write_testfile()
current_locale_encoding = locale.getpreferredencoding(False)
fp = open(TESTFN, 'w')
- try:
+ with fp:
self.assertEqual(fp.encoding, current_locale_encoding)
- finally:
- fp.close()
- unlink(TESTFN)
finally:
os.environ.clear()
os.environ.update(old_environ)
+ def test_open_non_inheritable(self):
+ fileobj = open(__file__)
+ with fileobj:
+ self.assertFalse(os.get_inheritable(fileobj.fileno()))
+
def test_ord(self):
self.assertEqual(ord(' '), 32)
self.assertEqual(ord('A'), 65)
@@ -1096,7 +1130,6 @@ class BuiltinTest(unittest.TestCase):
sys.stdin = savestdin
sys.stdout = savestdout
fp.close()
- unlink(TESTFN)
@unittest.skipUnless(pty, "the pty and signal modules must be available")
def check_input_tty(self, prompt, terminal_input, stdio_encoding=None):
@@ -1476,17 +1509,11 @@ class BuiltinTest(unittest.TestCase):
# --------------------------------------------------------------------
# Issue #7994: object.__format__ with a non-empty format string is
# deprecated
- def test_deprecated_format_string(obj, fmt_str, should_raise_warning):
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always", DeprecationWarning)
- format(obj, fmt_str)
- if should_raise_warning:
- self.assertEqual(len(w), 1)
- self.assertIsInstance(w[0].message, DeprecationWarning)
- self.assertIn('object.__format__ with a non-empty format '
- 'string', str(w[0].message))
+ def test_deprecated_format_string(obj, fmt_str, should_raise):
+ if should_raise:
+ self.assertRaises(TypeError, format, obj, fmt_str)
else:
- self.assertEqual(len(w), 0)
+ format(obj, fmt_str)
fmt_strs = ['', 's']
diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py
index 3520e837a1..847c7a613f 100644
--- a/Lib/test/test_bytes.py
+++ b/Lib/test/test_bytes.py
@@ -288,8 +288,22 @@ class BaseBytesTest:
self.assertEqual(self.type2test(b"").join(lst), b"abc")
self.assertEqual(self.type2test(b"").join(tuple(lst)), b"abc")
self.assertEqual(self.type2test(b"").join(iter(lst)), b"abc")
- self.assertEqual(self.type2test(b".").join([b"ab", b"cd"]), b"ab.cd")
- # XXX more...
+ dot_join = self.type2test(b".:").join
+ self.assertEqual(dot_join([b"ab", b"cd"]), b"ab.:cd")
+ self.assertEqual(dot_join([memoryview(b"ab"), b"cd"]), b"ab.:cd")
+ self.assertEqual(dot_join([b"ab", memoryview(b"cd")]), b"ab.:cd")
+ self.assertEqual(dot_join([bytearray(b"ab"), b"cd"]), b"ab.:cd")
+ self.assertEqual(dot_join([b"ab", bytearray(b"cd")]), b"ab.:cd")
+ # Stress it with many items
+ seq = [b"abc"] * 1000
+ expected = b"abc" + b".:abc" * 999
+ self.assertEqual(dot_join(seq), expected)
+ # Error handling and cleanup when some item in the middle of the
+ # sequence has the wrong type.
+ with self.assertRaises(TypeError):
+ dot_join([bytearray(b"ab"), "cd", b"ef"])
+ with self.assertRaises(TypeError):
+ dot_join([memoryview(b"ab"), "cd", b"ef"])
def test_count(self):
b = self.type2test(b'mississippi')
@@ -759,7 +773,7 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase):
finally:
try:
os.remove(tfn)
- except os.error:
+ except OSError:
pass
def test_reverse(self):
@@ -895,6 +909,15 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase):
with self.assertRaises(ValueError):
b[3:4] = elem
+ def test_setslice_extend(self):
+ # Exercise the resizing logic (see issue #19087)
+ b = bytearray(range(100))
+ self.assertEqual(list(b), list(range(100)))
+ del b[:10]
+ self.assertEqual(list(b), list(range(10, 100)))
+ b.extend(range(100, 110))
+ self.assertEqual(list(b), list(range(10, 110)))
+
def test_extended_set_del_slice(self):
indices = (0, None, 1, 3, 19, 300, 1<<333, -1, -2, -31, -300)
for start in indices:
@@ -1274,6 +1297,11 @@ class BytearrayPEP3137Test(unittest.TestCase,
self.assertEqual(val, newval)
self.assertTrue(val is not newval,
expr+' returned val on a mutable object')
+ sep = self.marshal(b'')
+ newval = sep.join([val])
+ self.assertEqual(val, newval)
+ self.assertIsNot(val, newval)
+
class FixedStringTest(test.string_tests.BaseTest):
diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py
index 6764e59553..8d93e2d88f 100644
--- a/Lib/test/test_bz2.py
+++ b/Lib/test/test_bz2.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from test import support
-from test.support import TESTFN, bigmemtest, _4G
+from test.support import bigmemtest, _4G
import unittest
from io import BytesIO
@@ -9,6 +9,7 @@ import pickle
import random
import subprocess
import sys
+from test.support import unlink
try:
import threading
@@ -19,10 +20,10 @@ except ImportError:
bz2 = support.import_module('bz2')
from bz2 import BZ2File, BZ2Compressor, BZ2Decompressor
-has_cmdline_bunzip2 = sys.platform not in ("win32", "os2emx")
class BaseTest(unittest.TestCase):
"Base for other testcases."
+
TEXT_LINES = [
b'root:x:0:0:root:/root:/bin/bash\n',
b'bin:x:1:1:bin:/bin:\n',
@@ -51,13 +52,17 @@ class BaseTest(unittest.TestCase):
EMPTY_DATA = b'BZh9\x17rE8P\x90\x00\x00\x00\x00'
def setUp(self):
- self.filename = TESTFN
+ self.filename = support.TESTFN
def tearDown(self):
if os.path.isfile(self.filename):
os.unlink(self.filename)
- if has_cmdline_bunzip2:
+ if sys.platform == "win32":
+ # bunzip2 isn't available to run on Windows.
+ def decompress(self, data):
+ return bz2.decompress(data)
+ else:
def decompress(self, data):
pop = subprocess.Popen("bunzip2", shell=True,
stdin=subprocess.PIPE,
@@ -71,31 +76,21 @@ class BaseTest(unittest.TestCase):
ret = bz2.decompress(data)
return ret
- else:
- # bunzip2 isn't available to run on Windows.
- def decompress(self, data):
- return bz2.decompress(data)
class BZ2FileTest(BaseTest):
- "Test BZ2File type miscellaneous methods."
+ "Test the BZ2File class."
def createTempFile(self, streams=1):
with open(self.filename, "wb") as f:
f.write(self.DATA * streams)
def testBadArgs(self):
- with self.assertRaises(TypeError):
- BZ2File(123.456)
- with self.assertRaises(ValueError):
- BZ2File("/dev/null", "z")
- with self.assertRaises(ValueError):
- BZ2File("/dev/null", "rx")
- with self.assertRaises(ValueError):
- BZ2File("/dev/null", "rbt")
- with self.assertRaises(ValueError):
- BZ2File("/dev/null", compresslevel=0)
- with self.assertRaises(ValueError):
- BZ2File("/dev/null", compresslevel=10)
+ self.assertRaises(TypeError, BZ2File, 123.456)
+ self.assertRaises(ValueError, BZ2File, "/dev/null", "z")
+ self.assertRaises(ValueError, BZ2File, "/dev/null", "rx")
+ self.assertRaises(ValueError, BZ2File, "/dev/null", "rbt")
+ self.assertRaises(ValueError, BZ2File, "/dev/null", compresslevel=0)
+ self.assertRaises(ValueError, BZ2File, "/dev/null", compresslevel=10)
def testRead(self):
self.createTempFile()
@@ -216,9 +211,8 @@ class BZ2FileTest(BaseTest):
self.createTempFile()
bz2f = BZ2File(self.filename)
bz2f.close()
- self.assertRaises(ValueError, bz2f.__next__)
- # This call will deadlock if the above .__next__ call failed to
- # release the lock.
+ self.assertRaises(ValueError, next, bz2f)
+ # This call will deadlock if the above call failed to release the lock.
self.assertRaises(ValueError, bz2f.readlines)
def testWrite(self):
@@ -262,8 +256,8 @@ class BZ2FileTest(BaseTest):
bz2f.write(b"abc")
with BZ2File(self.filename, "r") as bz2f:
- self.assertRaises(IOError, bz2f.write, b"a")
- self.assertRaises(IOError, bz2f.writelines, [b"a"])
+ self.assertRaises(OSError, bz2f.write, b"a")
+ self.assertRaises(OSError, bz2f.writelines, [b"a"])
def testAppend(self):
with BZ2File(self.filename, "w") as bz2f:
@@ -381,7 +375,7 @@ class BZ2FileTest(BaseTest):
bz2f.close()
self.assertRaises(ValueError, bz2f.seekable)
- bz2f = BZ2File(BytesIO(), mode="w")
+ bz2f = BZ2File(BytesIO(), "w")
try:
self.assertFalse(bz2f.seekable())
finally:
@@ -407,7 +401,7 @@ class BZ2FileTest(BaseTest):
bz2f.close()
self.assertRaises(ValueError, bz2f.readable)
- bz2f = BZ2File(BytesIO(), mode="w")
+ bz2f = BZ2File(BytesIO(), "w")
try:
self.assertFalse(bz2f.readable())
finally:
@@ -424,7 +418,7 @@ class BZ2FileTest(BaseTest):
bz2f.close()
self.assertRaises(ValueError, bz2f.writable)
- bz2f = BZ2File(BytesIO(), mode="w")
+ bz2f = BZ2File(BytesIO(), "w")
try:
self.assertTrue(bz2f.writable())
finally:
@@ -438,7 +432,7 @@ class BZ2FileTest(BaseTest):
del o
def testOpenNonexistent(self):
- self.assertRaises(IOError, BZ2File, "/non/existent")
+ self.assertRaises(OSError, BZ2File, "/non/existent")
def testReadlinesNoNewline(self):
# Issue #1191043: readlines() fails on a file containing no newline.
@@ -478,7 +472,7 @@ class BZ2FileTest(BaseTest):
# Issue #7205: Using a BZ2File from several threads shouldn't deadlock.
data = b"1" * 2**20
nthreads = 10
- with bz2.BZ2File(self.filename, 'wb') as f:
+ with BZ2File(self.filename, 'wb') as f:
def comp():
for i in range(5):
f.write(data)
@@ -489,28 +483,27 @@ class BZ2FileTest(BaseTest):
t.join()
def testWithoutThreading(self):
- bz2 = support.import_fresh_module("bz2", blocked=("threading",))
- with bz2.BZ2File(self.filename, "wb") as f:
+ module = support.import_fresh_module("bz2", blocked=("threading",))
+ with module.BZ2File(self.filename, "wb") as f:
f.write(b"abc")
- with bz2.BZ2File(self.filename, "rb") as f:
+ with module.BZ2File(self.filename, "rb") as f:
self.assertEqual(f.read(), b"abc")
def testMixedIterationAndReads(self):
self.createTempFile()
linelen = len(self.TEXT_LINES[0])
halflen = linelen // 2
- with bz2.BZ2File(self.filename) as bz2f:
+ with BZ2File(self.filename) as bz2f:
bz2f.read(halflen)
self.assertEqual(next(bz2f), self.TEXT_LINES[0][halflen:])
self.assertEqual(bz2f.read(), self.TEXT[linelen:])
- with bz2.BZ2File(self.filename) as bz2f:
+ with BZ2File(self.filename) as bz2f:
bz2f.readline()
self.assertEqual(next(bz2f), self.TEXT_LINES[1])
self.assertEqual(bz2f.readline(), self.TEXT_LINES[2])
- with bz2.BZ2File(self.filename) as bz2f:
+ with BZ2File(self.filename) as bz2f:
bz2f.readlines()
- with self.assertRaises(StopIteration):
- next(bz2f)
+ self.assertRaises(StopIteration, next, bz2f)
self.assertEqual(bz2f.readlines(), [])
def testMultiStreamOrdering(self):
@@ -578,6 +571,20 @@ class BZ2FileTest(BaseTest):
bz2f.seek(-150, 1)
self.assertEqual(bz2f.read(), self.TEXT[500-150:])
+ def test_read_truncated(self):
+ # Drop the eos_magic field (6 bytes) and CRC (4 bytes).
+ truncated = self.DATA[:-10]
+ with BZ2File(BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+ with BZ2File(BytesIO(truncated)) as f:
+ self.assertEqual(f.read(len(self.TEXT)), self.TEXT)
+ self.assertRaises(EOFError, f.read, 1)
+ # Incomplete 4-byte file header, and block header of at least 146 bits.
+ for i in range(22):
+ with BZ2File(BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
+
class BZ2CompressorTest(BaseTest):
def testCompress(self):
bz2c = BZ2Compressor()
@@ -713,97 +720,122 @@ class CompressDecompressTest(BaseTest):
class OpenTest(BaseTest):
+ "Test the open function."
+
+ def open(self, *args, **kwargs):
+ return bz2.open(*args, **kwargs)
+
def test_binary_modes(self):
- with bz2.open(self.filename, "wb") as f:
- f.write(self.TEXT)
- with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read())
- self.assertEqual(file_data, self.TEXT)
- with bz2.open(self.filename, "rb") as f:
- self.assertEqual(f.read(), self.TEXT)
- with bz2.open(self.filename, "ab") as f:
- f.write(self.TEXT)
- with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read())
- self.assertEqual(file_data, self.TEXT * 2)
+ for mode in ("wb", "xb"):
+ if mode == "xb":
+ unlink(self.filename)
+ with self.open(self.filename, mode) as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = self.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT)
+ with self.open(self.filename, "rb") as f:
+ self.assertEqual(f.read(), self.TEXT)
+ with self.open(self.filename, "ab") as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = self.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT * 2)
def test_implicit_binary_modes(self):
# Test implicit binary modes (no "b" or "t" in mode string).
- with bz2.open(self.filename, "w") as f:
- f.write(self.TEXT)
- with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read())
- self.assertEqual(file_data, self.TEXT)
- with bz2.open(self.filename, "r") as f:
- self.assertEqual(f.read(), self.TEXT)
- with bz2.open(self.filename, "a") as f:
- f.write(self.TEXT)
- with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read())
- self.assertEqual(file_data, self.TEXT * 2)
+ for mode in ("w", "x"):
+ if mode == "x":
+ unlink(self.filename)
+ with self.open(self.filename, mode) as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = self.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT)
+ with self.open(self.filename, "r") as f:
+ self.assertEqual(f.read(), self.TEXT)
+ with self.open(self.filename, "a") as f:
+ f.write(self.TEXT)
+ with open(self.filename, "rb") as f:
+ file_data = self.decompress(f.read())
+ self.assertEqual(file_data, self.TEXT * 2)
def test_text_modes(self):
text = self.TEXT.decode("ascii")
text_native_eol = text.replace("\n", os.linesep)
- with bz2.open(self.filename, "wt") as f:
- f.write(text)
- with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read()).decode("ascii")
- self.assertEqual(file_data, text_native_eol)
- with bz2.open(self.filename, "rt") as f:
- self.assertEqual(f.read(), text)
- with bz2.open(self.filename, "at") as f:
- f.write(text)
- with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read()).decode("ascii")
- self.assertEqual(file_data, text_native_eol * 2)
+ for mode in ("wt", "xt"):
+ if mode == "xt":
+ unlink(self.filename)
+ with self.open(self.filename, mode) as f:
+ f.write(text)
+ with open(self.filename, "rb") as f:
+ file_data = self.decompress(f.read()).decode("ascii")
+ self.assertEqual(file_data, text_native_eol)
+ with self.open(self.filename, "rt") as f:
+ self.assertEqual(f.read(), text)
+ with self.open(self.filename, "at") as f:
+ f.write(text)
+ with open(self.filename, "rb") as f:
+ file_data = self.decompress(f.read()).decode("ascii")
+ self.assertEqual(file_data, text_native_eol * 2)
+
+ def test_x_mode(self):
+ for mode in ("x", "xb", "xt"):
+ unlink(self.filename)
+ with self.open(self.filename, mode) as f:
+ pass
+ with self.assertRaises(FileExistsError):
+ with self.open(self.filename, mode) as f:
+ pass
def test_fileobj(self):
- with bz2.open(BytesIO(self.DATA), "r") as f:
+ with self.open(BytesIO(self.DATA), "r") as f:
self.assertEqual(f.read(), self.TEXT)
- with bz2.open(BytesIO(self.DATA), "rb") as f:
+ with self.open(BytesIO(self.DATA), "rb") as f:
self.assertEqual(f.read(), self.TEXT)
text = self.TEXT.decode("ascii")
- with bz2.open(BytesIO(self.DATA), "rt") as f:
+ with self.open(BytesIO(self.DATA), "rt") as f:
self.assertEqual(f.read(), text)
def test_bad_params(self):
# Test invalid parameter combinations.
- with self.assertRaises(ValueError):
- bz2.open(self.filename, "wbt")
- with self.assertRaises(ValueError):
- bz2.open(self.filename, "rb", encoding="utf-8")
- with self.assertRaises(ValueError):
- bz2.open(self.filename, "rb", errors="ignore")
- with self.assertRaises(ValueError):
- bz2.open(self.filename, "rb", newline="\n")
+ self.assertRaises(ValueError,
+ self.open, self.filename, "wbt")
+ self.assertRaises(ValueError,
+ self.open, self.filename, "xbt")
+ self.assertRaises(ValueError,
+ self.open, self.filename, "rb", encoding="utf-8")
+ self.assertRaises(ValueError,
+ self.open, self.filename, "rb", errors="ignore")
+ self.assertRaises(ValueError,
+ self.open, self.filename, "rb", newline="\n")
def test_encoding(self):
# Test non-default encoding.
text = self.TEXT.decode("ascii")
text_native_eol = text.replace("\n", os.linesep)
- with bz2.open(self.filename, "wt", encoding="utf-16-le") as f:
+ with self.open(self.filename, "wt", encoding="utf-16-le") as f:
f.write(text)
with open(self.filename, "rb") as f:
- file_data = bz2.decompress(f.read()).decode("utf-16-le")
+ file_data = self.decompress(f.read()).decode("utf-16-le")
self.assertEqual(file_data, text_native_eol)
- with bz2.open(self.filename, "rt", encoding="utf-16-le") as f:
+ with self.open(self.filename, "rt", encoding="utf-16-le") as f:
self.assertEqual(f.read(), text)
def test_encoding_error_handler(self):
# Test with non-default encoding error handler.
- with bz2.open(self.filename, "wb") as f:
+ with self.open(self.filename, "wb") as f:
f.write(b"foo\xffbar")
- with bz2.open(self.filename, "rt", encoding="ascii", errors="ignore") \
+ with self.open(self.filename, "rt", encoding="ascii", errors="ignore") \
as f:
self.assertEqual(f.read(), "foobar")
def test_newline(self):
# Test with explicit newline (universal newline mode disabled).
text = self.TEXT.decode("ascii")
- with bz2.open(self.filename, "wt", newline="\n") as f:
+ with self.open(self.filename, "wt", newline="\n") as f:
f.write(text)
- with bz2.open(self.filename, "rt", newline="\r") as f:
+ with self.open(self.filename, "rt", newline="\r") as f:
self.assertEqual(f.readlines(), [text])
diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py
index 9013a7b188..d37c0578a3 100644
--- a/Lib/test/test_capi.py
+++ b/Lib/test/test_capi.py
@@ -43,7 +43,7 @@ class CAPITest(unittest.TestCase):
@unittest.skipUnless(threading, 'Threading required for this test.')
def test_no_FatalError_infinite_loop(self):
- with support.suppress_crash_popup():
+ with support.SuppressCrashReport():
p = subprocess.Popen([sys.executable, "-c",
'import _testcapi;'
'_testcapi.crash_no_current_thread()'],
@@ -192,6 +192,9 @@ class TestPendingCalls(unittest.TestCase):
self.pendingcalls_submit(l, n)
self.pendingcalls_wait(l, n)
+
+class SubinterpreterTest(unittest.TestCase):
+
def test_subinterps(self):
import builtins
r, w = os.pipe()
@@ -207,42 +210,104 @@ class TestPendingCalls(unittest.TestCase):
self.assertNotEqual(pickle.load(f), id(sys.modules))
self.assertNotEqual(pickle.load(f), id(builtins))
+
# Bug #6012
class Test6012(unittest.TestCase):
def test(self):
self.assertEqual(_testcapi.argparsing("Hello", "World"), 1)
-class EmbeddingTest(unittest.TestCase):
-
- @unittest.skipIf(
- sys.platform.startswith('win'),
- "test doesn't work under Windows")
- def test_subinterps(self):
- # XXX only tested under Unix checkouts
+class EmbeddingTests(unittest.TestCase):
+ def setUp(self):
basepath = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
- oldcwd = os.getcwd()
+ exename = "_testembed"
+ if sys.platform.startswith("win"):
+ ext = ("_d" if "_d" in sys.executable else "") + ".exe"
+ exename += ext
+ exepath = os.path.dirname(sys.executable)
+ else:
+ exepath = os.path.join(basepath, "Modules")
+ self.test_exe = exe = os.path.join(exepath, exename)
+ if not os.path.exists(exe):
+ self.skipTest("%r doesn't exist" % exe)
# This is needed otherwise we get a fatal error:
# "Py_Initialize: Unable to get the locale encoding
# LookupError: no codec search functions registered: can't find encoding"
+ self.oldcwd = os.getcwd()
os.chdir(basepath)
+
+ def tearDown(self):
+ os.chdir(self.oldcwd)
+
+ def run_embedded_interpreter(self, *args):
+ """Runs a test in the embedded interpreter"""
+ cmd = [self.test_exe]
+ cmd.extend(args)
+ p = subprocess.Popen(cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ (out, err) = p.communicate()
+ self.assertEqual(p.returncode, 0,
+ "bad returncode %d, stderr is %r" %
+ (p.returncode, err))
+ return out.decode("latin1"), err.decode("latin1")
+
+ def test_subinterps(self):
+ # This is just a "don't crash" test
+ out, err = self.run_embedded_interpreter()
+ if support.verbose:
+ print()
+ print(out)
+ print(err)
+
+ @staticmethod
+ def _get_default_pipe_encoding():
+ rp, wp = os.pipe()
try:
- exe = os.path.join(basepath, "Modules", "_testembed")
- if not os.path.exists(exe):
- self.skipTest("%r doesn't exist" % exe)
- p = subprocess.Popen([exe],
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
- (out, err) = p.communicate()
- self.assertEqual(p.returncode, 0,
- "bad returncode %d, stderr is %r" %
- (p.returncode, err))
- if support.verbose:
- print()
- print(out.decode('latin1'))
- print(err.decode('latin1'))
+ with os.fdopen(wp, 'w') as w:
+ default_pipe_encoding = w.encoding
finally:
- os.chdir(oldcwd)
+ os.close(rp)
+ return default_pipe_encoding
+
+ def test_forced_io_encoding(self):
+ # Checks forced configuration of embedded interpreter IO streams
+ out, err = self.run_embedded_interpreter("forced_io_encoding")
+ if support.verbose:
+ print()
+ print(out)
+ print(err)
+ expected_stdin_encoding = sys.__stdin__.encoding
+ expected_pipe_encoding = self._get_default_pipe_encoding()
+ expected_output = os.linesep.join([
+ "--- Use defaults ---",
+ "Expected encoding: default",
+ "Expected errors: default",
+ "stdin: {0}:strict",
+ "stdout: {1}:strict",
+ "stderr: {1}:backslashreplace",
+ "--- Set errors only ---",
+ "Expected encoding: default",
+ "Expected errors: surrogateescape",
+ "stdin: {0}:surrogateescape",
+ "stdout: {1}:surrogateescape",
+ "stderr: {1}:backslashreplace",
+ "--- Set encoding only ---",
+ "Expected encoding: latin-1",
+ "Expected errors: default",
+ "stdin: latin-1:strict",
+ "stdout: latin-1:strict",
+ "stderr: latin-1:backslashreplace",
+ "--- Set encoding and errors ---",
+ "Expected encoding: latin-1",
+ "Expected errors: surrogateescape",
+ "stdin: latin-1:surrogateescape",
+ "stdout: latin-1:surrogateescape",
+ "stderr: latin-1:backslashreplace"]).format(expected_stdin_encoding,
+ expected_pipe_encoding)
+ # This is useful if we ever trip over odd platform behaviour
+ self.maxDiff = None
+ self.assertEqual(out.strip(), expected_output)
class SkipitemTest(unittest.TestCase):
@@ -354,8 +419,9 @@ class Test_testcapi(unittest.TestCase):
def test__testcapi(self):
for name in dir(_testcapi):
if name.startswith('test_'):
- test = getattr(_testcapi, name)
- test()
+ with self.subTest("internal", name=name):
+ test = getattr(_testcapi, name)
+ test()
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index a89d7e4f10..327c1455fc 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -4,6 +4,7 @@
import test.support, unittest
import os
+import shutil
import sys
import subprocess
import tempfile
@@ -41,8 +42,10 @@ class CmdLineTest(unittest.TestCase):
def test_version(self):
version = ('Python %d.%d' % sys.version_info[:2]).encode("ascii")
- rc, out, err = assert_python_ok('-V')
- self.assertTrue(err.startswith(version))
+ for switch in '-V', '--version':
+ rc, out, err = assert_python_ok(switch)
+ self.assertFalse(err.startswith(version))
+ self.assertTrue(out.startswith(version))
def test_verbose(self):
# -v causes imports to write to stderr. If the write to
@@ -54,14 +57,49 @@ class CmdLineTest(unittest.TestCase):
self.assertNotIn(b'stack overflow', err)
def test_xoptions(self):
- rc, out, err = assert_python_ok('-c', 'import sys; print(sys._xoptions)')
- opts = eval(out.splitlines()[0])
+ def get_xoptions(*args):
+ # use subprocess module directly because test.script_helper adds
+ # "-X faulthandler" to the command line
+ args = (sys.executable, '-E') + args
+ args += ('-c', 'import sys; print(sys._xoptions)')
+ out = subprocess.check_output(args)
+ opts = eval(out.splitlines()[0])
+ return opts
+
+ opts = get_xoptions()
self.assertEqual(opts, {})
- rc, out, err = assert_python_ok(
- '-Xa', '-Xb=c,d=e', '-c', 'import sys; print(sys._xoptions)')
- opts = eval(out.splitlines()[0])
+
+ opts = get_xoptions('-Xa', '-Xb=c,d=e')
self.assertEqual(opts, {'a': True, 'b': 'c,d=e'})
+ def test_showrefcount(self):
+ def run_python(*args):
+ # this is similar to assert_python_ok but doesn't strip
+ # the refcount from stderr. It can be replaced once
+ # assert_python_ok stops doing that.
+ cmd = [sys.executable]
+ cmd.extend(args)
+ PIPE = subprocess.PIPE
+ p = subprocess.Popen(cmd, stdout=PIPE, stderr=PIPE)
+ out, err = p.communicate()
+ p.stdout.close()
+ p.stderr.close()
+ rc = p.returncode
+ self.assertEqual(rc, 0)
+ return rc, out, err
+ code = 'import sys; print(sys._xoptions)'
+ # normally the refcount is hidden
+ rc, out, err = run_python('-c', code)
+ self.assertEqual(out.rstrip(), b'{}')
+ self.assertEqual(err, b'')
+ # "-X showrefcount" shows the refcount, but only in debug builds
+ rc, out, err = run_python('-X', 'showrefcount', '-c', code)
+ self.assertEqual(out.rstrip(), b"{'showrefcount': True}")
+ if hasattr(sys, 'gettotalrefcount'): # debug build
+ self.assertRegex(err, br'^\[\d+ refs, \d+ blocks\]')
+ else:
+ self.assertEqual(err, b'')
+
def test_run_module(self):
# Test expected operation of the '-m' switch
# Switch needs an argument
@@ -70,9 +108,9 @@ class CmdLineTest(unittest.TestCase):
assert_python_failure('-m', 'fnord43520xyz')
# Check the runpy module also gives an error for
# a nonexistent module
- assert_python_failure('-m', 'runpy', 'fnord43520xyz'),
+ assert_python_failure('-m', 'runpy', 'fnord43520xyz')
# All good if module is located and run successfully
- assert_python_ok('-m', 'timeit', '-n', '1'),
+ assert_python_ok('-m', 'timeit', '-n', '1')
def test_run_module_bug1764407(self):
# -m and -i need to play well together
@@ -213,6 +251,23 @@ class CmdLineTest(unittest.TestCase):
self.assertIn(path1.encode('ascii'), out)
self.assertIn(path2.encode('ascii'), out)
+ def test_empty_PYTHONPATH_issue16309(self):
+ # On Posix, it is documented that setting PATH to the
+ # empty string is equivalent to not setting PATH at all,
+ # which is an exception to the rule that in a string like
+ # "/bin::/usr/bin" the empty string in the middle gets
+ # interpreted as '.'
+ code = """if 1:
+ import sys
+ path = ":".join(sys.path)
+ path = path.encode("ascii", "backslashreplace")
+ sys.stdout.buffer.write(path)"""
+ rc1, out1, err1 = assert_python_ok('-c', code, PYTHONPATH="")
+ rc2, out2, err2 = assert_python_ok('-c', code, __isolated=False)
+ # regarding to Posix specification, outputs should be equal
+ # for empty and unset PYTHONPATH
+ self.assertEqual(out1, out2)
+
def test_displayhook_unencodable(self):
for encoding in ('ascii', 'latin-1', 'utf-8'):
env = os.environ.copy()
@@ -290,7 +345,7 @@ class CmdLineTest(unittest.TestCase):
rc, out, err = assert_python_ok('-c', code)
self.assertEqual(b'', out)
self.assertRegex(err.decode('ascii', 'ignore'),
- 'Exception OSError: .* ignored')
+ 'Exception ignored in.*\nOSError: .*')
def test_closed_stdout(self):
# Issue #13444: if stdout has been explicitly closed, we should
@@ -385,6 +440,31 @@ class CmdLineTest(unittest.TestCase):
self.assertEqual(b'', out)
+ def test_isolatedmode(self):
+ self.verify_valid_flag('-I')
+ self.verify_valid_flag('-IEs')
+ rc, out, err = assert_python_ok('-I', '-c',
+ 'from sys import flags as f; '
+ 'print(f.no_user_site, f.ignore_environment, f.isolated)',
+ # dummyvar to prevent extranous -E
+ dummyvar="")
+ self.assertEqual(out.strip(), b'1 1 1')
+ with test.support.temp_cwd() as tmpdir:
+ fake = os.path.join(tmpdir, "uuid.py")
+ main = os.path.join(tmpdir, "main.py")
+ with open(fake, "w") as f:
+ f.write("raise RuntimeError('isolated mode test')\n")
+ with open(main, "w") as f:
+ f.write("import uuid\n")
+ f.write("print('ok')\n")
+ self.assertRaises(subprocess.CalledProcessError,
+ subprocess.check_output,
+ [sys.executable, main], cwd=tmpdir,
+ stderr=subprocess.DEVNULL)
+ out = subprocess.check_output([sys.executable, "-I", main],
+ cwd=tmpdir)
+ self.assertEqual(out.strip(), b"ok")
+
def test_main():
test.support.run_unittest(CmdLineTest)
test.support.reap_children()
diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py
index b4269b994e..f804d8645f 100644
--- a/Lib/test/test_cmd_line_script.py
+++ b/Lib/test/test_cmd_line_script.py
@@ -123,7 +123,7 @@ class CmdLineTest(unittest.TestCase):
if not __debug__:
cmd_line_switches += ('-' + 'O' * sys.flags.optimize,)
run_args = cmd_line_switches + (script_name,) + tuple(example_args)
- rc, out, err = assert_python_ok(*run_args)
+ rc, out, err = assert_python_ok(*run_args, __isolated=False)
self._check_output(script_name, rc, out + err, expected_file,
expected_argv0, expected_path0,
expected_package, expected_loader)
@@ -294,7 +294,7 @@ class CmdLineTest(unittest.TestCase):
pkg_dir = os.path.join(script_dir, 'test_pkg')
make_pkg(pkg_dir, "import sys; print('init_argv0==%r' % sys.argv[0])")
script_name = _make_test_script(pkg_dir, 'script')
- rc, out, err = assert_python_ok('-m', 'test_pkg.script', *example_args)
+ rc, out, err = assert_python_ok('-m', 'test_pkg.script', *example_args, __isolated=False)
if verbose > 1:
print(out)
expected = "init_argv0==%r" % '-m'
@@ -311,7 +311,8 @@ class CmdLineTest(unittest.TestCase):
with open("-c", "w") as f:
f.write("data")
rc, out, err = assert_python_ok('-c',
- 'import sys; print("sys.path[0]==%r" % sys.path[0])')
+ 'import sys; print("sys.path[0]==%r" % sys.path[0])',
+ __isolated=False)
if verbose > 1:
print(out)
expected = "sys.path[0]==%r" % ''
@@ -325,7 +326,8 @@ class CmdLineTest(unittest.TestCase):
with support.change_cwd(path=script_dir):
with open("-m", "w") as f:
f.write("data")
- rc, out, err = assert_python_ok('-m', 'other', *example_args)
+ rc, out, err = assert_python_ok('-m', 'other', *example_args,
+ __isolated=False)
self._check_output(script_name, rc, out,
script_name, script_name, '', '',
importlib.machinery.SourceFileLoader)
diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py
index adef1701d4..5fd21dc32c 100644
--- a/Lib/test/test_code_module.py
+++ b/Lib/test/test_code_module.py
@@ -64,6 +64,20 @@ class TestInteractiveConsole(unittest.TestCase):
self.console.interact()
self.assertTrue(hook.called)
+ def test_banner(self):
+ # with banner
+ self.infunc.side_effect = EOFError('Finished')
+ self.console.interact(banner='Foo')
+ self.assertEqual(len(self.stderr.method_calls), 2)
+ banner_call = self.stderr.method_calls[0]
+ self.assertEqual(banner_call, ['write', ('Foo\n',), {}])
+
+ # no banner
+ self.stderr.reset_mock()
+ self.infunc.side_effect = EOFError('Finished')
+ self.console.interact(banner='')
+ self.assertEqual(len(self.stderr.method_calls), 1)
+
def test_main():
support.run_unittest(TestInteractiveConsole)
diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py
index fd88505081..84804bb0da 100644
--- a/Lib/test/test_codeccallbacks.py
+++ b/Lib/test/test_codeccallbacks.py
@@ -875,8 +875,6 @@ class CodecCallbackTest(unittest.TestCase):
with self.assertRaises(TypeError):
data.decode(encoding, "test.replacing")
-def test_main():
- test.support.run_unittest(CodecCallbackTest)
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py
index 35170579a7..55becf4942 100644
--- a/Lib/test/test_codecs.py
+++ b/Lib/test/test_codecs.py
@@ -1,5 +1,6 @@
import _testcapi
import codecs
+import contextlib
import io
import locale
import sys
@@ -840,9 +841,10 @@ class UTF7Test(ReadTest, unittest.TestCase):
(b'a+////,+IKw-b', 'a\uffff\ufffd\u20acb'),
]
for raw, expected in tests:
- self.assertRaises(UnicodeDecodeError, codecs.utf_7_decode,
- raw, 'strict', True)
- self.assertEqual(raw.decode('utf-7', 'replace'), expected)
+ with self.subTest(raw=raw):
+ self.assertRaises(UnicodeDecodeError, codecs.utf_7_decode,
+ raw, 'strict', True)
+ self.assertEqual(raw.decode('utf-7', 'replace'), expected)
def test_nonbmp(self):
self.assertEqual('\U000104A0'.encode(self.encoding), b'+2AHcoA-')
@@ -2291,28 +2293,211 @@ class TransformCodecTest(unittest.TestCase):
def test_basics(self):
binput = bytes(range(256))
for encoding in bytes_transform_encodings:
- # generic codecs interface
- (o, size) = codecs.getencoder(encoding)(binput)
- self.assertEqual(size, len(binput))
- (i, size) = codecs.getdecoder(encoding)(o)
- self.assertEqual(size, len(o))
- self.assertEqual(i, binput)
+ with self.subTest(encoding=encoding):
+ # generic codecs interface
+ (o, size) = codecs.getencoder(encoding)(binput)
+ self.assertEqual(size, len(binput))
+ (i, size) = codecs.getdecoder(encoding)(o)
+ self.assertEqual(size, len(o))
+ self.assertEqual(i, binput)
def test_read(self):
for encoding in bytes_transform_encodings:
- sin = codecs.encode(b"\x80", encoding)
- reader = codecs.getreader(encoding)(io.BytesIO(sin))
- sout = reader.read()
- self.assertEqual(sout, b"\x80")
+ with self.subTest(encoding=encoding):
+ sin = codecs.encode(b"\x80", encoding)
+ reader = codecs.getreader(encoding)(io.BytesIO(sin))
+ sout = reader.read()
+ self.assertEqual(sout, b"\x80")
def test_readline(self):
for encoding in bytes_transform_encodings:
if encoding in ['uu_codec', 'zlib_codec']:
continue
- sin = codecs.encode(b"\x80", encoding)
- reader = codecs.getreader(encoding)(io.BytesIO(sin))
- sout = reader.readline()
- self.assertEqual(sout, b"\x80")
+ with self.subTest(encoding=encoding):
+ sin = codecs.encode(b"\x80", encoding)
+ reader = codecs.getreader(encoding)(io.BytesIO(sin))
+ sout = reader.readline()
+ self.assertEqual(sout, b"\x80")
+
+ def test_buffer_api_usage(self):
+ # We check all the transform codecs accept memoryview input
+ # for encoding and decoding
+ # and also that they roundtrip correctly
+ original = b"12345\x80"
+ for encoding in bytes_transform_encodings:
+ with self.subTest(encoding=encoding):
+ data = original
+ view = memoryview(data)
+ data = codecs.encode(data, encoding)
+ view_encoded = codecs.encode(view, encoding)
+ self.assertEqual(view_encoded, data)
+ view = memoryview(data)
+ data = codecs.decode(data, encoding)
+ self.assertEqual(data, original)
+ view_decoded = codecs.decode(view, encoding)
+ self.assertEqual(view_decoded, data)
+
+ def test_type_error_for_text_input(self):
+ # Check binary -> binary codecs give a good error for str input
+ bad_input = "bad input type"
+ for encoding in bytes_transform_encodings:
+ with self.subTest(encoding=encoding):
+ msg = "^encoding with '{}' codec failed".format(encoding)
+ with self.assertRaisesRegex(TypeError, msg) as failure:
+ bad_input.encode(encoding)
+ self.assertTrue(isinstance(failure.exception.__cause__,
+ TypeError))
+
+ def test_type_error_for_binary_input(self):
+ # Check str -> str codec gives a good error for binary input
+ for bad_input in (b"immutable", bytearray(b"mutable")):
+ with self.subTest(bad_input=bad_input):
+ msg = "^decoding with 'rot_13' codec failed"
+ with self.assertRaisesRegex(AttributeError, msg) as failure:
+ bad_input.decode("rot_13")
+ self.assertTrue(isinstance(failure.exception.__cause__,
+ AttributeError))
+
+ def test_bad_decoding_output_type(self):
+ # Check bytes.decode and bytearray.decode give a good error
+ # message for binary -> binary codecs
+ data = b"encode first to ensure we meet any format restrictions"
+ for encoding in bytes_transform_encodings:
+ with self.subTest(encoding=encoding):
+ encoded_data = codecs.encode(data, encoding)
+ fmt = ("'{}' decoder returned 'bytes' instead of 'str'; "
+ "use codecs.decode\(\) to decode to arbitrary types")
+ msg = fmt.format(encoding)
+ with self.assertRaisesRegex(TypeError, msg):
+ encoded_data.decode(encoding)
+ with self.assertRaisesRegex(TypeError, msg):
+ bytearray(encoded_data).decode(encoding)
+
+ def test_bad_encoding_output_type(self):
+ # Check str.encode gives a good error message for str -> str codecs
+ msg = ("'rot_13' encoder returned 'str' instead of 'bytes'; "
+ "use codecs.encode\(\) to encode to arbitrary types")
+ with self.assertRaisesRegex(TypeError, msg):
+ "just an example message".encode("rot_13")
+
+
+# The codec system tries to wrap exceptions in order to ensure the error
+# mentions the operation being performed and the codec involved. We
+# currently *only* want this to happen for relatively stateless
+# exceptions, where the only significant information they contain is their
+# type and a single str argument.
+
+# Use a local codec registry to avoid appearing to leak objects when
+# registering multiple seach functions
+_TEST_CODECS = {}
+
+def _get_test_codec(codec_name):
+ return _TEST_CODECS.get(codec_name)
+codecs.register(_get_test_codec) # Returns None, not usable as a decorator
+
+class ExceptionChainingTest(unittest.TestCase):
+
+ def setUp(self):
+ # There's no way to unregister a codec search function, so we just
+ # ensure we render this one fairly harmless after the test
+ # case finishes by using the test case repr as the codec name
+ # The codecs module normalizes codec names, although this doesn't
+ # appear to be formally documented...
+ self.codec_name = repr(self).lower().replace(" ", "-")
+
+ def tearDown(self):
+ _TEST_CODECS.pop(self.codec_name, None)
+
+ def set_codec(self, obj_to_raise):
+ def raise_obj(*args, **kwds):
+ raise obj_to_raise
+ codec_info = codecs.CodecInfo(raise_obj, raise_obj,
+ name=self.codec_name)
+ _TEST_CODECS[self.codec_name] = codec_info
+
+ @contextlib.contextmanager
+ def assertWrapped(self, operation, exc_type, msg):
+ full_msg = "{} with '{}' codec failed \({}: {}\)".format(
+ operation, self.codec_name, exc_type.__name__, msg)
+ with self.assertRaisesRegex(exc_type, full_msg) as caught:
+ yield caught
+
+ def check_wrapped(self, obj_to_raise, msg):
+ self.set_codec(obj_to_raise)
+ with self.assertWrapped("encoding", RuntimeError, msg):
+ "str_input".encode(self.codec_name)
+ with self.assertWrapped("encoding", RuntimeError, msg):
+ codecs.encode("str_input", self.codec_name)
+ with self.assertWrapped("decoding", RuntimeError, msg):
+ b"bytes input".decode(self.codec_name)
+ with self.assertWrapped("decoding", RuntimeError, msg):
+ codecs.decode(b"bytes input", self.codec_name)
+
+ def test_raise_by_type(self):
+ self.check_wrapped(RuntimeError, "")
+
+ def test_raise_by_value(self):
+ msg = "This should be wrapped"
+ self.check_wrapped(RuntimeError(msg), msg)
+
+ @contextlib.contextmanager
+ def assertNotWrapped(self, operation, exc_type, msg_re, msg=None):
+ if msg is None:
+ msg = msg_re
+ with self.assertRaisesRegex(exc_type, msg) as caught:
+ yield caught
+ self.assertEqual(str(caught.exception), msg)
+
+ def check_not_wrapped(self, obj_to_raise, msg_re, msg=None):
+ self.set_codec(obj_to_raise)
+ with self.assertNotWrapped("encoding", RuntimeError, msg_re, msg):
+ "str input".encode(self.codec_name)
+ with self.assertNotWrapped("encoding", RuntimeError, msg_re, msg):
+ codecs.encode("str input", self.codec_name)
+ with self.assertNotWrapped("decoding", RuntimeError, msg_re, msg):
+ b"bytes input".decode(self.codec_name)
+ with self.assertNotWrapped("decoding", RuntimeError, msg_re, msg):
+ codecs.decode(b"bytes input", self.codec_name)
+
+ def test_init_override_is_not_wrapped(self):
+ class CustomInit(RuntimeError):
+ def __init__(self):
+ pass
+ self.check_not_wrapped(CustomInit, "")
+
+ def test_new_override_is_not_wrapped(self):
+ class CustomNew(RuntimeError):
+ def __new__(cls):
+ return super().__new__(cls)
+ self.check_not_wrapped(CustomNew, "")
+
+ def test_instance_attribute_is_not_wrapped(self):
+ msg = "This should NOT be wrapped"
+ exc = RuntimeError(msg)
+ exc.attr = 1
+ self.check_not_wrapped(exc, msg)
+
+ def test_non_str_arg_is_not_wrapped(self):
+ self.check_not_wrapped(RuntimeError(1), "1")
+
+ def test_multiple_args_is_not_wrapped(self):
+ msg_re = "\('a', 'b', 'c'\)"
+ msg = "('a', 'b', 'c')"
+ self.check_not_wrapped(RuntimeError('a', 'b', 'c'), msg_re, msg)
+
+ # http://bugs.python.org/issue19609
+ def test_codec_lookup_failure_not_wrapped(self):
+ msg = "unknown encoding: %s" % self.codec_name
+ # The initial codec lookup should not be wrapped
+ with self.assertNotWrapped("encoding", LookupError, msg):
+ "str input".encode(self.codec_name)
+ with self.assertNotWrapped("encoding", LookupError, msg):
+ codecs.encode("str input", self.codec_name)
+ with self.assertNotWrapped("decoding", LookupError, msg):
+ b"bytes input".decode(self.codec_name)
+ with self.assertNotWrapped("decoding", LookupError, msg):
+ codecs.decode(b"bytes input", self.codec_name)
+
@unittest.skipUnless(sys.platform == 'win32',
@@ -2324,8 +2509,8 @@ class CodePageTest(unittest.TestCase):
def test_invalid_code_page(self):
self.assertRaises(ValueError, codecs.code_page_encode, -1, 'a')
self.assertRaises(ValueError, codecs.code_page_decode, -1, b'a')
- self.assertRaises(WindowsError, codecs.code_page_encode, 123, 'a')
- self.assertRaises(WindowsError, codecs.code_page_decode, 123, b'a')
+ self.assertRaises(OSError, codecs.code_page_encode, 123, 'a')
+ self.assertRaises(OSError, codecs.code_page_decode, 123, b'a')
def test_code_page_name(self):
self.assertRaisesRegex(UnicodeEncodeError, 'cp932',
diff --git a/Lib/test/test_coding.py b/Lib/test/test_coding.py
deleted file mode 100644
index 989c7a85d2..0000000000
--- a/Lib/test/test_coding.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import unittest
-from test.support import TESTFN, unlink, unload
-import importlib, os, sys
-
-class CodingTest(unittest.TestCase):
- def test_bad_coding(self):
- module_name = 'bad_coding'
- self.verify_bad_module(module_name)
-
- def test_bad_coding2(self):
- module_name = 'bad_coding2'
- self.verify_bad_module(module_name)
-
- def verify_bad_module(self, module_name):
- self.assertRaises(SyntaxError, __import__, 'test.' + module_name)
-
- path = os.path.dirname(__file__)
- filename = os.path.join(path, module_name + '.py')
- with open(filename, "rb") as fp:
- bytes = fp.read()
- self.assertRaises(SyntaxError, compile, bytes, filename, 'exec')
-
- def test_exec_valid_coding(self):
- d = {}
- exec(b'# coding: cp949\na = "\xaa\xa7"\n', d)
- self.assertEqual(d['a'], '\u3047')
-
- def test_file_parse(self):
- # issue1134: all encodings outside latin-1 and utf-8 fail on
- # multiline strings and long lines (>512 columns)
- unload(TESTFN)
- filename = TESTFN + ".py"
- f = open(filename, "w", encoding="cp1252")
- sys.path.insert(0, os.curdir)
- try:
- with f:
- f.write("# -*- coding: cp1252 -*-\n")
- f.write("'''A short string\n")
- f.write("'''\n")
- f.write("'A very long string %s'\n" % ("X" * 1000))
-
- importlib.invalidate_caches()
- __import__(TESTFN)
- finally:
- del sys.path[0]
- unlink(filename)
- unlink(filename + "c")
- unlink(filename + "o")
- unload(TESTFN)
-
- def test_error_from_string(self):
- # See http://bugs.python.org/issue6289
- input = "# coding: ascii\n\N{SNOWMAN}".encode('utf-8')
- with self.assertRaises(SyntaxError) as c:
- compile(input, "<string>", "exec")
- expected = "'ascii' codec can't decode byte 0xe2 in position 16: " \
- "ordinal not in range(128)"
- self.assertTrue(c.exception.args[0].startswith(expected),
- msg=c.exception.args[0])
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
index ff52755354..56e8120705 100644
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -112,6 +112,38 @@ class TestChainMap(unittest.TestCase):
self.assertEqual(dict(d), dict(a=1, b=2, c=30))
self.assertEqual(dict(d.items()), dict(a=1, b=2, c=30))
+ def test_new_child(self):
+ 'Tests for changes for issue #16613.'
+ c = ChainMap()
+ c['a'] = 1
+ c['b'] = 2
+ m = {'b':20, 'c': 30}
+ d = c.new_child(m)
+ self.assertEqual(d.maps, [{'b':20, 'c':30}, {'a':1, 'b':2}]) # check internal state
+ self.assertIs(m, d.maps[0])
+
+ # Use a different map than a dict
+ class lowerdict(dict):
+ def __getitem__(self, key):
+ if isinstance(key, str):
+ key = key.lower()
+ return dict.__getitem__(self, key)
+ def __contains__(self, key):
+ if isinstance(key, str):
+ key = key.lower()
+ return dict.__contains__(self, key)
+
+ c = ChainMap()
+ c['a'] = 1
+ c['b'] = 2
+ m = lowerdict(b=20, c=30)
+ d = c.new_child(m)
+ self.assertIs(m, d.maps[0])
+ for key in 'abc': # check contains
+ self.assertIn(key, d)
+ for k, v in dict(a=1, B=20, C=30, z=100).items(): # check get
+ self.assertEqual(d.get(k, 100), v)
+
################################################################################
### Named Tuples
@@ -750,6 +782,8 @@ class TestCollectionABCs(ABCTestCase):
self.assertTrue(issubclass(sample, Sequence))
self.assertIsInstance(range(10), Sequence)
self.assertTrue(issubclass(range, Sequence))
+ self.assertIsInstance(memoryview(b""), Sequence)
+ self.assertTrue(issubclass(memoryview, Sequence))
self.assertTrue(issubclass(str, Sequence))
self.validate_abstract_methods(Sequence, '__contains__', '__iter__', '__len__',
'__getitem__')
@@ -904,23 +938,26 @@ class TestCounter(unittest.TestCase):
words = Counter('which witch had which witches wrist watch'.split())
update_test = Counter()
update_test.update(words)
- for i, dup in enumerate([
- words.copy(),
- copy.copy(words),
- copy.deepcopy(words),
- pickle.loads(pickle.dumps(words, 0)),
- pickle.loads(pickle.dumps(words, 1)),
- pickle.loads(pickle.dumps(words, 2)),
- pickle.loads(pickle.dumps(words, -1)),
- eval(repr(words)),
- update_test,
- Counter(words),
- ]):
- msg = (i, dup, words)
- self.assertTrue(dup is not words)
- self.assertEqual(dup, words)
- self.assertEqual(len(dup), len(words))
- self.assertEqual(type(dup), type(words))
+ for label, dup in [
+ ('words.copy()', words.copy()),
+ ('copy.copy(words)', copy.copy(words)),
+ ('copy.deepcopy(words)', copy.deepcopy(words)),
+ ('pickle.loads(pickle.dumps(words, 0))',
+ pickle.loads(pickle.dumps(words, 0))),
+ ('pickle.loads(pickle.dumps(words, 1))',
+ pickle.loads(pickle.dumps(words, 1))),
+ ('pickle.loads(pickle.dumps(words, 2))',
+ pickle.loads(pickle.dumps(words, 2))),
+ ('pickle.loads(pickle.dumps(words, -1))',
+ pickle.loads(pickle.dumps(words, -1))),
+ ('eval(repr(words))', eval(repr(words))),
+ ('update_test', update_test),
+ ('Counter(words)', Counter(words)),
+ ]:
+ with self.subTest(label=label):
+ msg = "\ncopy: %s\nwords: %s" % (dup, words)
+ self.assertIsNot(dup, words, msg)
+ self.assertEqual(dup, words)
def test_copy_subclass(self):
class MyCounter(Counter):
@@ -1043,8 +1080,10 @@ class TestCounter(unittest.TestCase):
# test fidelity to the pure python version
c = CounterSubclassWithSetItem('abracadabra')
self.assertTrue(c.called)
+ self.assertEqual(dict(c), {'a': 5, 'b': 2, 'c': 1, 'd': 1, 'r':2 })
c = CounterSubclassWithGet('abracadabra')
self.assertTrue(c.called)
+ self.assertEqual(dict(c), {'a': 5, 'b': 2, 'c': 1, 'd': 1, 'r':2 })
################################################################################
@@ -1205,24 +1244,28 @@ class TestOrderedDict(unittest.TestCase):
od = OrderedDict(pairs)
update_test = OrderedDict()
update_test.update(od)
- for i, dup in enumerate([
- od.copy(),
- copy.copy(od),
- copy.deepcopy(od),
- pickle.loads(pickle.dumps(od, 0)),
- pickle.loads(pickle.dumps(od, 1)),
- pickle.loads(pickle.dumps(od, 2)),
- pickle.loads(pickle.dumps(od, 3)),
- pickle.loads(pickle.dumps(od, -1)),
- eval(repr(od)),
- update_test,
- OrderedDict(od),
- ]):
- self.assertTrue(dup is not od)
- self.assertEqual(dup, od)
- self.assertEqual(list(dup.items()), list(od.items()))
- self.assertEqual(len(dup), len(od))
- self.assertEqual(type(dup), type(od))
+ for label, dup in [
+ ('od.copy()', od.copy()),
+ ('copy.copy(od)', copy.copy(od)),
+ ('copy.deepcopy(od)', copy.deepcopy(od)),
+ ('pickle.loads(pickle.dumps(od, 0))',
+ pickle.loads(pickle.dumps(od, 0))),
+ ('pickle.loads(pickle.dumps(od, 1))',
+ pickle.loads(pickle.dumps(od, 1))),
+ ('pickle.loads(pickle.dumps(od, 2))',
+ pickle.loads(pickle.dumps(od, 2))),
+ ('pickle.loads(pickle.dumps(od, 3))',
+ pickle.loads(pickle.dumps(od, 3))),
+ ('pickle.loads(pickle.dumps(od, -1))',
+ pickle.loads(pickle.dumps(od, -1))),
+ ('eval(repr(od))', eval(repr(od))),
+ ('update_test', update_test),
+ ('OrderedDict(od)', OrderedDict(od)),
+ ]:
+ with self.subTest(label=label):
+ msg = "\ncopy: %s\nod: %s" % (dup, od)
+ self.assertIsNot(dup, od, msg)
+ self.assertEqual(dup, od)
def test_yaml_linkage(self):
# Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature.
@@ -1237,9 +1280,18 @@ class TestOrderedDict(unittest.TestCase):
# do not save instance dictionary if not needed
pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
od = OrderedDict(pairs)
- self.assertEqual(len(od.__reduce__()), 2)
+ self.assertIsNone(od.__reduce__()[2])
od.x = 10
- self.assertEqual(len(od.__reduce__()), 3)
+ self.assertIsNotNone(od.__reduce__()[2])
+
+ def test_pickle_recursive(self):
+ od = OrderedDict()
+ od[1] = od
+ for proto in range(-1, pickle.HIGHEST_PROTOCOL + 1):
+ dup = pickle.loads(pickle.dumps(od, proto))
+ self.assertIsNot(dup, od)
+ self.assertEqual(list(dup.keys()), [1])
+ self.assertIs(dup[1], dup)
def test_repr(self):
od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)])
diff --git a/Lib/test/test_colorsys.py b/Lib/test/test_colorsys.py
index e405b8a4d0..a24e3adcb4 100644
--- a/Lib/test/test_colorsys.py
+++ b/Lib/test/test_colorsys.py
@@ -1,4 +1,4 @@
-import unittest, test.support
+import unittest
import colorsys
def frange(start, stop, step):
@@ -69,8 +69,32 @@ class ColorsysTest(unittest.TestCase):
self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb))
self.assertTripleEqual(rgb, colorsys.hls_to_rgb(*hls))
-def test_main():
- test.support.run_unittest(ColorsysTest)
+ def test_yiq_roundtrip(self):
+ for r in frange(0.0, 1.0, 0.2):
+ for g in frange(0.0, 1.0, 0.2):
+ for b in frange(0.0, 1.0, 0.2):
+ rgb = (r, g, b)
+ self.assertTripleEqual(
+ rgb,
+ colorsys.yiq_to_rgb(*colorsys.rgb_to_yiq(*rgb))
+ )
+
+ def test_yiq_values(self):
+ values = [
+ # rgb, yiq
+ ((0.0, 0.0, 0.0), (0.0, 0.0, 0.0)), # black
+ ((0.0, 0.0, 1.0), (0.11, -0.3217, 0.3121)), # blue
+ ((0.0, 1.0, 0.0), (0.59, -0.2773, -0.5251)), # green
+ ((0.0, 1.0, 1.0), (0.7, -0.599, -0.213)), # cyan
+ ((1.0, 0.0, 0.0), (0.3, 0.599, 0.213)), # red
+ ((1.0, 0.0, 1.0), (0.41, 0.2773, 0.5251)), # purple
+ ((1.0, 1.0, 0.0), (0.89, 0.3217, -0.3121)), # yellow
+ ((1.0, 1.0, 1.0), (1.0, 0.0, 0.0)), # white
+ ((0.5, 0.5, 0.5), (0.5, 0.0, 0.0)), # grey
+ ]
+ for (rgb, yiq) in values:
+ self.assertTripleEqual(yiq, colorsys.rgb_to_yiq(*rgb))
+ self.assertTripleEqual(rgb, colorsys.yiq_to_rgb(*yiq))
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py
index fddb538efb..ff2515df4c 100644
--- a/Lib/test/test_compileall.py
+++ b/Lib/test/test_compileall.py
@@ -1,11 +1,12 @@
import sys
import compileall
-import imp
+import importlib.util
import os
import py_compile
import shutil
import struct
import subprocess
+import sys
import tempfile
import time
import unittest
@@ -18,11 +19,11 @@ class CompileallTests(unittest.TestCase):
def setUp(self):
self.directory = tempfile.mkdtemp()
self.source_path = os.path.join(self.directory, '_test.py')
- self.bc_path = imp.cache_from_source(self.source_path)
+ self.bc_path = importlib.util.cache_from_source(self.source_path)
with open(self.source_path, 'w') as file:
file.write('x = 123\n')
self.source_path2 = os.path.join(self.directory, '_test2.py')
- self.bc_path2 = imp.cache_from_source(self.source_path2)
+ self.bc_path2 = importlib.util.cache_from_source(self.source_path2)
shutil.copyfile(self.source_path, self.source_path2)
self.subdirectory = os.path.join(self.directory, '_subdir')
os.mkdir(self.subdirectory)
@@ -36,7 +37,7 @@ class CompileallTests(unittest.TestCase):
with open(self.bc_path, 'rb') as file:
data = file.read(8)
mtime = int(os.stat(self.source_path).st_mtime)
- compare = struct.pack('<4sl', imp.get_magic(), mtime)
+ compare = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, mtime)
return data, compare
@unittest.skipUnless(hasattr(os, 'stat'), 'test needs os.stat()')
@@ -56,7 +57,8 @@ class CompileallTests(unittest.TestCase):
def test_mtime(self):
# Test a change in mtime leads to a new .pyc.
- self.recreation_check(struct.pack('<4sl', imp.get_magic(), 1))
+ self.recreation_check(struct.pack('<4sl', importlib.util.MAGIC_NUMBER,
+ 1))
def test_magic_number(self):
# Test a change in mtime leads to a new .pyc.
@@ -96,14 +98,14 @@ class CompileallTests(unittest.TestCase):
# interpreter's creates the correct file names
optimize = 1 if __debug__ else 0
compileall.compile_dir(self.directory, quiet=True, optimize=optimize)
- cached = imp.cache_from_source(self.source_path,
- debug_override=not optimize)
+ cached = importlib.util.cache_from_source(self.source_path,
+ debug_override=not optimize)
self.assertTrue(os.path.isfile(cached))
- cached2 = imp.cache_from_source(self.source_path2,
- debug_override=not optimize)
+ cached2 = importlib.util.cache_from_source(self.source_path2,
+ debug_override=not optimize)
self.assertTrue(os.path.isfile(cached2))
- cached3 = imp.cache_from_source(self.source_path3,
- debug_override=not optimize)
+ cached3 = importlib.util.cache_from_source(self.source_path3,
+ debug_override=not optimize)
self.assertTrue(os.path.isfile(cached3))
@@ -151,10 +153,12 @@ class CommandLineTests(unittest.TestCase):
return rc, out, err
def assertCompiled(self, fn):
- self.assertTrue(os.path.exists(imp.cache_from_source(fn)))
+ path = importlib.util.cache_from_source(fn)
+ self.assertTrue(os.path.exists(path))
def assertNotCompiled(self, fn):
- self.assertFalse(os.path.exists(imp.cache_from_source(fn)))
+ path = importlib.util.cache_from_source(fn)
+ self.assertFalse(os.path.exists(path))
def setUp(self):
self.addCleanup(self._cleanup)
@@ -189,8 +193,8 @@ class CommandLineTests(unittest.TestCase):
['-m', 'compileall', '-q', self.pkgdir]))
# Verify the __pycache__ directory contents.
self.assertTrue(os.path.exists(self.pkgdir_cachedir))
- expected = sorted(base.format(imp.get_tag(), ext) for base in
- ('__init__.{}.{}', 'bar.{}.{}'))
+ expected = sorted(base.format(sys.implementation.cache_tag, ext)
+ for base in ('__init__.{}.{}', 'bar.{}.{}'))
self.assertEqual(sorted(os.listdir(self.pkgdir_cachedir)), expected)
# Make sure there are no .pyc files in the source directory.
self.assertFalse([fn for fn in os.listdir(self.pkgdir)
@@ -223,7 +227,7 @@ class CommandLineTests(unittest.TestCase):
def test_force(self):
self.assertRunOK('-q', self.pkgdir)
- pycpath = imp.cache_from_source(self.barfn)
+ pycpath = importlib.util.cache_from_source(self.barfn)
# set atime/mtime backward to avoid file timestamp resolution issues
os.utime(pycpath, (time.time()-60,)*2)
mtime = os.stat(pycpath).st_mtime
@@ -287,10 +291,10 @@ class CommandLineTests(unittest.TestCase):
bazfn = script_helper.make_script(self.pkgdir, 'baz', 'raise Exception')
self.assertRunOK('-q', '-d', 'dinsdale', self.pkgdir)
fn = script_helper.make_script(self.pkgdir, 'bing', 'import baz')
- pyc = imp.cache_from_source(bazfn)
+ pyc = importlib.util.cache_from_source(bazfn)
os.rename(pyc, os.path.join(self.pkgdir, 'baz.pyc'))
os.remove(bazfn)
- rc, out, err = script_helper.assert_python_failure(fn)
+ rc, out, err = script_helper.assert_python_failure(fn, __isolated=False)
self.assertRegex(err, b'File "dinsdale')
def test_include_bad_file(self):
@@ -298,7 +302,7 @@ class CommandLineTests(unittest.TestCase):
'-i', os.path.join(self.directory, 'nosuchfile'), self.pkgdir)
self.assertRegex(out, b'rror.*nosuchfile')
self.assertNotRegex(err, b'Traceback')
- self.assertFalse(os.path.exists(imp.cache_from_source(
+ self.assertFalse(os.path.exists(importlib.util.cache_from_source(
self.pkgdir_cachedir)))
def test_include_file_with_arg(self):
@@ -355,13 +359,5 @@ class CommandLineTests(unittest.TestCase):
self.assertRegex(out, b"Can't list 'badfilename'")
-def test_main():
- support.run_unittest(
- CommandLineTests,
- CompileallTests,
- EncodingTest,
- )
-
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py
index 1bf409798b..cd55375bdb 100644
--- a/Lib/test/test_complex.py
+++ b/Lib/test/test_complex.py
@@ -220,6 +220,8 @@ class ComplexTest(unittest.TestCase):
self.assertRaises(TypeError, complex, OS(None))
self.assertRaises(TypeError, complex, NS(None))
self.assertRaises(TypeError, complex, {})
+ self.assertRaises(TypeError, complex, NS(1.5))
+ self.assertRaises(TypeError, complex, NS(1))
self.assertAlmostEqual(complex("1+10j"), 1+10j)
self.assertAlmostEqual(complex(10), 10+0j)
@@ -301,6 +303,7 @@ class ComplexTest(unittest.TestCase):
self.assertRaises(TypeError, float, 5+3j)
self.assertRaises(ValueError, complex, "")
self.assertRaises(TypeError, complex, None)
+ self.assertRaisesRegex(TypeError, "not 'NoneType'", complex, None)
self.assertRaises(ValueError, complex, "\0")
self.assertRaises(ValueError, complex, "3\09")
self.assertRaises(TypeError, complex, "1", "2")
diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py
index 39230e1cf0..a7e945cb5e 100644
--- a/Lib/test/test_concurrent_futures.py
+++ b/Lib/test/test_concurrent_futures.py
@@ -15,6 +15,7 @@ import sys
import threading
import time
import unittest
+import weakref
from concurrent import futures
from concurrent.futures._base import (
@@ -52,6 +53,11 @@ def sleep_and_print(t, msg):
sys.stdout.flush()
+class MyObject(object):
+ def my_method(self):
+ pass
+
+
class ExecutorMixin:
worker_count = 5
@@ -396,6 +402,22 @@ class ExecutorTest:
self.executor.map(str, [2] * (self.worker_count + 1))
self.executor.shutdown()
+ @test.support.cpython_only
+ def test_no_stale_references(self):
+ # Issue #16284: check that the executors don't unnecessarily hang onto
+ # references.
+ my_object = MyObject()
+ my_object_collected = threading.Event()
+ my_object_callback = weakref.ref(
+ my_object, lambda obj: my_object_collected.set())
+ # Deliberately discarding the future.
+ self.executor.submit(my_object.my_method)
+ del my_object
+
+ collected = my_object_collected.wait(timeout=5.0)
+ self.assertTrue(collected,
+ "Stale reference not collected within timeout.")
+
class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, unittest.TestCase):
def test_map_submits_without_iteration(self):
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 9e45f70f85..b8770c828f 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -1,5 +1,6 @@
"""Unit tests for contextlib.py, and other context managers."""
+import io
import sys
import tempfile
import unittest
@@ -100,15 +101,25 @@ class ContextManagerTestCase(unittest.TestCase):
self.assertEqual(baz.__name__,'baz')
self.assertEqual(baz.foo, 'bar')
- @unittest.skipIf(sys.flags.optimize >= 2,
- "Docstrings are omitted with -O2 and above")
+ @support.requires_docstrings
def test_contextmanager_doc_attrib(self):
baz = self._create_contextmanager_attribs()
self.assertEqual(baz.__doc__, "Whee!")
+ @support.requires_docstrings
+ def test_instance_docstring_given_cm_docstring(self):
+ baz = self._create_contextmanager_attribs()(None)
+ self.assertEqual(baz.__doc__, "Whee!")
+
+
class ClosingTestCase(unittest.TestCase):
- # XXX This needs more work
+ @support.requires_docstrings
+ def test_instance_docs(self):
+ # Issue 19330: ensure context manager instances have good docstrings
+ cm_docstring = closing.__doc__
+ obj = closing(None)
+ self.assertEqual(obj.__doc__, cm_docstring)
def test_closing(self):
state = []
@@ -204,6 +215,7 @@ class LockContextTestCase(unittest.TestCase):
class mycontext(ContextDecorator):
+ """Example decoration-compatible context manager for testing"""
started = False
exc = None
catch = False
@@ -219,6 +231,13 @@ class mycontext(ContextDecorator):
class TestContextDecorator(unittest.TestCase):
+ @support.requires_docstrings
+ def test_instance_docs(self):
+ # Issue 19330: ensure context manager instances have good docstrings
+ cm_docstring = mycontext.__doc__
+ obj = mycontext()
+ self.assertEqual(obj.__doc__, cm_docstring)
+
def test_contextdecorator(self):
context = mycontext()
with context as result:
@@ -372,6 +391,13 @@ class TestContextDecorator(unittest.TestCase):
class TestExitStack(unittest.TestCase):
+ @support.requires_docstrings
+ def test_instance_docs(self):
+ # Issue 19330: ensure context manager instances have good docstrings
+ cm_docstring = ExitStack.__doc__
+ obj = ExitStack()
+ self.assertEqual(obj.__doc__, cm_docstring)
+
def test_no_resources(self):
with ExitStack():
pass
@@ -631,10 +657,111 @@ class TestExitStack(unittest.TestCase):
stack.push(cm)
self.assertIs(stack._exit_callbacks[-1], cm)
+class TestRedirectStdout(unittest.TestCase):
+
+ @support.requires_docstrings
+ def test_instance_docs(self):
+ # Issue 19330: ensure context manager instances have good docstrings
+ cm_docstring = redirect_stdout.__doc__
+ obj = redirect_stdout(None)
+ self.assertEqual(obj.__doc__, cm_docstring)
+
+ def test_no_redirect_in_init(self):
+ orig_stdout = sys.stdout
+ redirect_stdout(None)
+ self.assertIs(sys.stdout, orig_stdout)
+
+ def test_redirect_to_string_io(self):
+ f = io.StringIO()
+ msg = "Consider an API like help(), which prints directly to stdout"
+ orig_stdout = sys.stdout
+ with redirect_stdout(f):
+ print(msg)
+ self.assertIs(sys.stdout, orig_stdout)
+ s = f.getvalue().strip()
+ self.assertEqual(s, msg)
+
+ def test_enter_result_is_target(self):
+ f = io.StringIO()
+ with redirect_stdout(f) as enter_result:
+ self.assertIs(enter_result, f)
+
+ def test_cm_is_reusable(self):
+ f = io.StringIO()
+ write_to_f = redirect_stdout(f)
+ orig_stdout = sys.stdout
+ with write_to_f:
+ print("Hello", end=" ")
+ with write_to_f:
+ print("World!")
+ self.assertIs(sys.stdout, orig_stdout)
+ s = f.getvalue()
+ self.assertEqual(s, "Hello World!\n")
+
+ def test_cm_is_reentrant(self):
+ f = io.StringIO()
+ write_to_f = redirect_stdout(f)
+ orig_stdout = sys.stdout
+ with write_to_f:
+ print("Hello", end=" ")
+ with write_to_f:
+ print("World!")
+ self.assertIs(sys.stdout, orig_stdout)
+ s = f.getvalue()
+ self.assertEqual(s, "Hello World!\n")
+
+
+class TestSuppress(unittest.TestCase):
+
+ @support.requires_docstrings
+ def test_instance_docs(self):
+ # Issue 19330: ensure context manager instances have good docstrings
+ cm_docstring = suppress.__doc__
+ obj = suppress()
+ self.assertEqual(obj.__doc__, cm_docstring)
+
+ def test_no_result_from_enter(self):
+ with suppress(ValueError) as enter_result:
+ self.assertIsNone(enter_result)
+
+ def test_no_exception(self):
+ with suppress(ValueError):
+ self.assertEqual(pow(2, 5), 32)
+
+ def test_exact_exception(self):
+ with suppress(TypeError):
+ len(5)
+
+ def test_exception_hierarchy(self):
+ with suppress(LookupError):
+ 'Hello'[50]
+
+ def test_other_exception(self):
+ with self.assertRaises(ZeroDivisionError):
+ with suppress(TypeError):
+ 1/0
+
+ def test_no_args(self):
+ with self.assertRaises(ZeroDivisionError):
+ with suppress():
+ 1/0
-# This is needed to make the test actually run under regrtest.py!
-def test_main():
- support.run_unittest(__name__)
+ def test_multiple_exception_args(self):
+ with suppress(ZeroDivisionError, TypeError):
+ 1/0
+ with suppress(ZeroDivisionError, TypeError):
+ len(5)
+
+ def test_cm_is_reentrant(self):
+ ignore_exceptions = suppress(Exception)
+ with ignore_exceptions:
+ pass
+ with ignore_exceptions:
+ len(5)
+ with ignore_exceptions:
+ 1/0
+ with ignore_exceptions: # Check nested usage
+ len(5)
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py
index 56766682b3..c3eb7faf8f 100644
--- a/Lib/test/test_cprofile.py
+++ b/Lib/test/test_cprofile.py
@@ -6,9 +6,11 @@ from test.support import run_unittest, TESTFN, unlink
# rip off all interesting stuff from test_profile
import cProfile
from test.test_profile import ProfileTest, regenerate_expected_output
+from test.profilee import testfunc
class CProfileTest(ProfileTest):
profilerclass = cProfile.Profile
+ profilermodule = cProfile
expected_max_output = "{built-in method max}"
def get_expected_output(self):
diff --git a/Lib/test/test_crypt.py b/Lib/test/test_crypt.py
index cfb7341e20..624d702f99 100644
--- a/Lib/test/test_crypt.py
+++ b/Lib/test/test_crypt.py
@@ -1,11 +1,7 @@
from test import support
import unittest
-def setUpModule():
- # this import will raise unittest.SkipTest if _crypt doesn't exist,
- # so it has to be done in setUpModule for test discovery to work
- global crypt
- crypt = support.import_module('crypt')
+crypt = support.import_module('crypt')
class CryptTestCase(unittest.TestCase):
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 976e6205b7..cf0b09cbe6 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -136,12 +136,12 @@ class Test_Csv(unittest.TestCase):
return 10;
def __getitem__(self, i):
if i > 2:
- raise IOError
- self.assertRaises(IOError, self._write_test, BadList(), '')
+ raise OSError
+ self.assertRaises(OSError, self._write_test, BadList(), '')
class BadItem:
def __str__(self):
- raise IOError
- self.assertRaises(IOError, self._write_test, [BadItem()], '')
+ raise OSError
+ self.assertRaises(OSError, self._write_test, [BadItem()], '')
def test_write_bigfield(self):
# This exercises the buffer realloc functionality
@@ -186,9 +186,9 @@ class Test_Csv(unittest.TestCase):
def test_writerows(self):
class BrokenFile:
def write(self, buf):
- raise IOError
+ raise OSError
writer = csv.writer(BrokenFile())
- self.assertRaises(IOError, writer.writerows, [['a']])
+ self.assertRaises(OSError, writer.writerows, [['a']])
with TemporaryFile("w+", newline='') as fileobj:
writer = csv.writer(fileobj)
@@ -308,6 +308,15 @@ class Test_Csv(unittest.TestCase):
for i, row in enumerate(csv.reader(fileobj)):
self.assertEqual(row, rows[i])
+ def test_roundtrip_escaped_unquoted_newlines(self):
+ with TemporaryFile("w+", newline='') as fileobj:
+ writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")
+ rows = [['a\nb','b'],['c','x\r\nd']]
+ writer.writerows(rows)
+ fileobj.seek(0)
+ for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")):
+ self.assertEqual(row,rows[i])
+
class TestDialectRegistry(unittest.TestCase):
def test_registry_badargs(self):
self.assertRaises(TypeError, csv.list_dialects, None)
diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py
index 243a0af103..1c57770242 100644
--- a/Lib/test/test_dbm.py
+++ b/Lib/test/test_dbm.py
@@ -62,7 +62,7 @@ class AnyDBMTestCase:
return keys
def test_error(self):
- self.assertTrue(issubclass(self.module.error, IOError))
+ self.assertTrue(issubclass(self.module.error, OSError))
def test_anydbm_not_existing(self):
self.assertRaises(dbm.error, dbm.open, _fname)
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 71b93a661b..ac8a7bebc1 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -290,7 +290,6 @@ class IBMTestCases(unittest.TestCase):
global skip_expected
if skip_expected:
raise unittest.SkipTest
- return
with open(file) as f:
for line in f:
line = line.replace('\r\n', '').replace('\n', '')
diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py
index a8487d2d95..7bff1d2798 100644
--- a/Lib/test/test_deque.py
+++ b/Lib/test/test_deque.py
@@ -543,8 +543,8 @@ class TestBasic(unittest.TestCase):
check = self.check_sizeof
check(deque(), basesize + blocksize)
check(deque('a'), basesize + blocksize)
- check(deque('a' * (BLOCKLEN // 2)), basesize + blocksize)
- check(deque('a' * (BLOCKLEN // 2 + 1)), basesize + 2 * blocksize)
+ check(deque('a' * (BLOCKLEN - 1)), basesize + blocksize)
+ check(deque('a' * BLOCKLEN), basesize + 2 * blocksize)
check(deque('a' * (42 * BLOCKLEN)), basesize + 43 * blocksize)
class TestVariousIteratorArgs(unittest.TestCase):
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index 3776389ebb..595d540736 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -1789,6 +1789,7 @@ order (MRO) for bases """
("__trunc__", int, zero, set(), {}),
("__ceil__", math.ceil, zero, set(), {}),
("__dir__", dir, empty_seq, set(), {}),
+ ("__round__", round, zero, set(), {}),
]
class Checker(object):
@@ -2249,7 +2250,9 @@ order (MRO) for bases """
minstance = M("m")
minstance.b = 2
minstance.a = 1
- names = [x for x in dir(minstance) if x not in ["__name__", "__doc__"]]
+ default_attributes = ['__name__', '__doc__', '__package__',
+ '__loader__']
+ names = [x for x in dir(minstance) if x not in default_attributes]
self.assertEqual(names, ['a', 'b'])
class M2(M):
@@ -3733,18 +3736,8 @@ order (MRO) for bases """
# bug).
del c
- # If that didn't blow up, it's also interesting to see whether clearing
- # the last container slot works: that will attempt to delete c again,
- # which will cause c to get appended back to the container again
- # "during" the del. (On non-CPython implementations, however, __del__
- # is typically not called again.)
support.gc_collect()
self.assertEqual(len(C.container), 1)
- del C.container[-1]
- if support.check_impl_detail():
- support.gc_collect()
- self.assertEqual(len(C.container), 1)
- self.assertEqual(C.container[-1].attr, 42)
# Make c mortal again, so that the test framework with -l doesn't report
# it as a leak.
@@ -4524,6 +4517,13 @@ order (MRO) for bases """
self.assertRaises(TypeError, type.__dict__['__qualname__'].__set__,
str, 'Oink')
+ global Y
+ class Y:
+ class Inside:
+ pass
+ self.assertEqual(Y.__qualname__, 'Y')
+ self.assertEqual(Y.Inside.__qualname__, 'Y.Inside')
+
def test_qualname_dict(self):
ns = {'__qualname__': 'some.name'}
tp = type('Foo', (), ns)
diff --git a/Lib/test/test_devpoll.py b/Lib/test/test_devpoll.py
index bef4e1846b..40ebeee597 100644
--- a/Lib/test/test_devpoll.py
+++ b/Lib/test/test_devpoll.py
@@ -2,13 +2,17 @@
# Initial tests are copied as is from "test_poll.py"
-import os, select, random, unittest, sys
+import os
+import random
+import select
+import sys
+import unittest
from test.support import TESTFN, run_unittest
try:
select.devpoll
except AttributeError:
- raise unittest.SkipTest("select.devpoll not defined -- skipping test_devpoll")
+ raise unittest.SkipTest("select.devpoll not defined")
def find_ready_matching(ready, flag):
@@ -87,6 +91,36 @@ class DevPollTests(unittest.TestCase):
self.assertRaises(OverflowError, pollster.poll, 1 << 63)
self.assertRaises(OverflowError, pollster.poll, 1 << 64)
+ def test_close(self):
+ open_file = open(__file__, "rb")
+ self.addCleanup(open_file.close)
+ fd = open_file.fileno()
+ devpoll = select.devpoll()
+
+ # test fileno() method and closed attribute
+ self.assertIsInstance(devpoll.fileno(), int)
+ self.assertFalse(devpoll.closed)
+
+ # test close()
+ devpoll.close()
+ self.assertTrue(devpoll.closed)
+ self.assertRaises(ValueError, devpoll.fileno)
+
+ # close() can be called more than once
+ devpoll.close()
+
+ # operations must fail with ValueError("I/O operation on closed ...")
+ self.assertRaises(ValueError, devpoll.modify, fd, select.POLLIN)
+ self.assertRaises(ValueError, devpoll.poll)
+ self.assertRaises(ValueError, devpoll.register, fd, fd, select.POLLIN)
+ self.assertRaises(ValueError, devpoll.unregister, fd)
+
+ def test_fd_non_inheritable(self):
+ devpoll = select.devpoll()
+ self.addCleanup(devpoll.close)
+ self.assertEqual(os.get_inheritable(devpoll.fileno()), False)
+
+
def test_main():
run_unittest(DevPollTests)
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index d1355bbb96..2b546785a1 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -1,11 +1,14 @@
# Minimal tests for dis module
from test.support import run_unittest, captured_stdout
+from test.bytecode_helper import BytecodeTestCase
import difflib
import unittest
import sys
import dis
import io
+import types
+import contextlib
class _C:
def __init__(self, x):
@@ -22,12 +25,12 @@ dis_c_instance_method = """\
""" % (_C.__init__.__code__.co_firstlineno + 1,)
dis_c_instance_method_bytes = """\
- 0 LOAD_FAST 1 (1)
- 3 LOAD_CONST 1 (1)
- 6 COMPARE_OP 2 (==)
- 9 LOAD_FAST 0 (0)
- 12 STORE_ATTR 0 (0)
- 15 LOAD_CONST 0 (0)
+ 0 LOAD_FAST 1 (1)
+ 3 LOAD_CONST 1 (1)
+ 6 COMPARE_OP 2 (==)
+ 9 LOAD_FAST 0 (0)
+ 12 STORE_ATTR 0 (0)
+ 15 LOAD_CONST 0 (0)
18 RETURN_VALUE
"""
@@ -48,11 +51,11 @@ dis_f = """\
dis_f_co_code = """\
- 0 LOAD_GLOBAL 0 (0)
- 3 LOAD_FAST 0 (0)
- 6 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
+ 0 LOAD_GLOBAL 0 (0)
+ 3 LOAD_FAST 0 (0)
+ 6 CALL_FUNCTION 1 (1 positional, 0 keyword pair)
9 POP_TOP
- 10 LOAD_CONST 1 (1)
+ 10 LOAD_CONST 1 (1)
13 RETURN_VALUE
"""
@@ -174,30 +177,20 @@ dis_compound_stmt_str = """\
class DisTests(unittest.TestCase):
def get_disassembly(self, func, lasti=-1, wrapper=True):
- s = io.StringIO()
- save_stdout = sys.stdout
- sys.stdout = s
- try:
+ # We want to test the default printing behaviour, not the file arg
+ output = io.StringIO()
+ with contextlib.redirect_stdout(output):
if wrapper:
dis.dis(func)
else:
dis.disassemble(func, lasti)
- finally:
- sys.stdout = save_stdout
- # Trim trailing blanks (if any).
- return [line.rstrip() for line in s.getvalue().splitlines()]
+ return output.getvalue()
def get_disassemble_as_string(self, func, lasti=-1):
- return '\n'.join(self.get_disassembly(func, lasti, False))
+ return self.get_disassembly(func, lasti, False)
def do_disassembly_test(self, func, expected):
- lines = self.get_disassembly(func)
- expected = expected.splitlines()
- if expected != lines:
- self.fail(
- "events did not match expectation:\n" +
- "\n".join(difflib.ndiff(expected,
- lines)))
+ self.assertEqual(self.get_disassembly(func), expected)
def test_opmap(self):
self.assertEqual(dis.opmap["NOP"], 9)
@@ -288,35 +281,35 @@ class DisTests(unittest.TestCase):
def test_dis_object(self):
self.assertRaises(TypeError, dis.dis, object())
+class DisWithFileTests(DisTests):
+
+ # Run the tests again, using the file arg instead of print
+ def get_disassembly(self, func, lasti=-1, wrapper=True):
+ output = io.StringIO()
+ if wrapper:
+ dis.dis(func, file=output)
+ else:
+ dis.disassemble(func, lasti, file=output)
+ return output.getvalue()
+
+
+
code_info_code_info = """\
Name: code_info
Filename: (.*)
Argument count: 1
Kw-only arguments: 0
Number of locals: 1
-Stack size: 4
+Stack size: 3
Flags: OPTIMIZED, NEWLOCALS, NOFREE
Constants:
0: %r
- 1: '__func__'
- 2: '__code__'
- 3: '<code_info>'
- 4: 'co_code'
- 5: "don't know how to disassemble %%s objects"
-%sNames:
- 0: hasattr
- 1: __func__
- 2: __code__
- 3: isinstance
- 4: str
- 5: _try_compile
- 6: _format_code_info
- 7: TypeError
- 8: type
- 9: __name__
+Names:
+ 0: _format_code_info
+ 1: _get_code_object
Variable names:
- 0: x""" % (('Formatted details of methods, functions, or code.', ' 6: None\n')
- if sys.flags.optimize < 2 else (None, ''))
+ 0: x""" % (('Formatted details of methods, functions, or code.',)
+ if sys.flags.optimize < 2 else (None,))
@staticmethod
def tricky(x, y, z=True, *args, c, d, e=[], **kwds):
@@ -380,7 +373,7 @@ Free variables:
code_info_expr_str = """\
Name: <module>
-Filename: <code_info>
+Filename: <disassembly>
Argument count: 0
Kw-only arguments: 0
Number of locals: 0
@@ -393,7 +386,7 @@ Names:
code_info_simple_stmt_str = """\
Name: <module>
-Filename: <code_info>
+Filename: <disassembly>
Argument count: 0
Kw-only arguments: 0
Number of locals: 0
@@ -407,7 +400,7 @@ Names:
code_info_compound_stmt_str = """\
Name: <module>
-Filename: <code_info>
+Filename: <disassembly>
Argument count: 0
Kw-only arguments: 0
Number of locals: 0
@@ -441,6 +434,9 @@ class CodeInfoTests(unittest.TestCase):
with captured_stdout() as output:
dis.show_code(x)
self.assertRegex(output.getvalue(), expected+"\n")
+ output = io.StringIO()
+ dis.show_code(x, file=output)
+ self.assertRegex(output.getvalue(), expected)
def test_code_info_object(self):
self.assertRaises(TypeError, dis.code_info, object())
@@ -449,8 +445,323 @@ class CodeInfoTests(unittest.TestCase):
self.assertEqual(dis.pretty_flags(0), '0x0')
+# Fodder for instruction introspection tests
+# Editing any of these may require recalculating the expected output
+def outer(a=1, b=2):
+ def f(c=3, d=4):
+ def inner(e=5, f=6):
+ print(a, b, c, d, e, f)
+ print(a, b, c, d)
+ return inner
+ print(a, b, '', 1, [], {}, "Hello world!")
+ return f
+
+def jumpy():
+ # This won't actually run (but that's OK, we only disassemble it)
+ for i in range(10):
+ print(i)
+ if i < 4:
+ continue
+ if i > 6:
+ break
+ else:
+ print("I can haz else clause?")
+ while i:
+ print(i)
+ i -= 1
+ if i > 6:
+ continue
+ if i < 4:
+ break
+ else:
+ print("Who let lolcatz into this test suite?")
+ try:
+ 1 / 0
+ except ZeroDivisionError:
+ print("Here we go, here we go, here we go...")
+ else:
+ with i as dodgy:
+ print("Never reach this")
+ finally:
+ print("OK, now we're done")
+
+# End fodder for opinfo generation tests
+expected_outer_line = 1
+_line_offset = outer.__code__.co_firstlineno - 1
+code_object_f = outer.__code__.co_consts[3]
+expected_f_line = code_object_f.co_firstlineno - _line_offset
+code_object_inner = code_object_f.co_consts[3]
+expected_inner_line = code_object_inner.co_firstlineno - _line_offset
+expected_jumpy_line = 1
+
+# The following lines are useful to regenerate the expected results after
+# either the fodder is modified or the bytecode generation changes
+# After regeneration, update the references to code_object_f and
+# code_object_inner before rerunning the tests
+
+#_instructions = dis.get_instructions(outer, first_line=expected_outer_line)
+#print('expected_opinfo_outer = [\n ',
+ #',\n '.join(map(str, _instructions)), ',\n]', sep='')
+#_instructions = dis.get_instructions(outer(), first_line=expected_outer_line)
+#print('expected_opinfo_f = [\n ',
+ #',\n '.join(map(str, _instructions)), ',\n]', sep='')
+#_instructions = dis.get_instructions(outer()(), first_line=expected_outer_line)
+#print('expected_opinfo_inner = [\n ',
+ #',\n '.join(map(str, _instructions)), ',\n]', sep='')
+#_instructions = dis.get_instructions(jumpy, first_line=expected_jumpy_line)
+#print('expected_opinfo_jumpy = [\n ',
+ #',\n '.join(map(str, _instructions)), ',\n]', sep='')
+
+
+Instruction = dis.Instruction
+expected_opinfo_outer = [
+ Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=3, argrepr='3', offset=0, starts_line=2, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=3, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CLOSURE', opcode=135, arg=0, argval='a', argrepr='a', offset=6, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CLOSURE', opcode=135, arg=1, argval='b', argrepr='b', offset=9, starts_line=None, is_jump_target=False),
+ Instruction(opname='BUILD_TUPLE', opcode=102, arg=2, argval=2, argrepr='', offset=12, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=code_object_f, argrepr=repr(code_object_f), offset=15, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='outer.<locals>.f', argrepr="'outer.<locals>.f'", offset=18, starts_line=None, is_jump_target=False),
+ Instruction(opname='MAKE_CLOSURE', opcode=134, arg=2, argval=2, argrepr='', offset=21, starts_line=None, is_jump_target=False),
+ Instruction(opname='STORE_FAST', opcode=125, arg=2, argval='f', argrepr='f', offset=24, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='print', argrepr='print', offset=27, starts_line=7, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=0, argval='a', argrepr='a', offset=30, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=1, argval='b', argrepr='b', offset=33, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval='', argrepr="''", offset=36, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=6, argval=1, argrepr='1', offset=39, starts_line=None, is_jump_target=False),
+ Instruction(opname='BUILD_LIST', opcode=103, arg=0, argval=0, argrepr='', offset=42, starts_line=None, is_jump_target=False),
+ Instruction(opname='BUILD_MAP', opcode=105, arg=0, argval=0, argrepr='', offset=45, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=7, argval='Hello world!', argrepr="'Hello world!'", offset=48, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=7, argval=7, argrepr='7 positional, 0 keyword pair', offset=51, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=54, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=2, argval='f', argrepr='f', offset=55, starts_line=8, is_jump_target=False),
+ Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=58, starts_line=None, is_jump_target=False),
+]
+
+expected_opinfo_f = [
+ Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=5, argrepr='5', offset=0, starts_line=3, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=6, argrepr='6', offset=3, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CLOSURE', opcode=135, arg=2, argval='a', argrepr='a', offset=6, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CLOSURE', opcode=135, arg=3, argval='b', argrepr='b', offset=9, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CLOSURE', opcode=135, arg=0, argval='c', argrepr='c', offset=12, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CLOSURE', opcode=135, arg=1, argval='d', argrepr='d', offset=15, starts_line=None, is_jump_target=False),
+ Instruction(opname='BUILD_TUPLE', opcode=102, arg=4, argval=4, argrepr='', offset=18, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=code_object_inner, argrepr=repr(code_object_inner), offset=21, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='outer.<locals>.f.<locals>.inner', argrepr="'outer.<locals>.f.<locals>.inner'", offset=24, starts_line=None, is_jump_target=False),
+ Instruction(opname='MAKE_CLOSURE', opcode=134, arg=2, argval=2, argrepr='', offset=27, starts_line=None, is_jump_target=False),
+ Instruction(opname='STORE_FAST', opcode=125, arg=2, argval='inner', argrepr='inner', offset=30, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='print', argrepr='print', offset=33, starts_line=5, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=2, argval='a', argrepr='a', offset=36, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=3, argval='b', argrepr='b', offset=39, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=0, argval='c', argrepr='c', offset=42, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=1, argval='d', argrepr='d', offset=45, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=4, argval=4, argrepr='4 positional, 0 keyword pair', offset=48, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=51, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=2, argval='inner', argrepr='inner', offset=52, starts_line=6, is_jump_target=False),
+ Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=55, starts_line=None, is_jump_target=False),
+]
+
+expected_opinfo_inner = [
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='print', argrepr='print', offset=0, starts_line=4, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=0, argval='a', argrepr='a', offset=3, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=1, argval='b', argrepr='b', offset=6, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=2, argval='c', argrepr='c', offset=9, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_DEREF', opcode=136, arg=3, argval='d', argrepr='d', offset=12, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='e', argrepr='e', offset=15, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=1, argval='f', argrepr='f', offset=18, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=6, argval=6, argrepr='6 positional, 0 keyword pair', offset=21, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=24, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=25, starts_line=None, is_jump_target=False),
+ Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=28, starts_line=None, is_jump_target=False),
+]
+
+expected_opinfo_jumpy = [
+ Instruction(opname='SETUP_LOOP', opcode=120, arg=74, argval=77, argrepr='to 77', offset=0, starts_line=3, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='range', argrepr='range', offset=3, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=10, argrepr='10', offset=6, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=9, starts_line=None, is_jump_target=False),
+ Instruction(opname='GET_ITER', opcode=68, arg=None, argval=None, argrepr='', offset=12, starts_line=None, is_jump_target=False),
+ Instruction(opname='FOR_ITER', opcode=93, arg=50, argval=66, argrepr='to 66', offset=13, starts_line=None, is_jump_target=True),
+ Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=16, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=19, starts_line=4, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=22, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=25, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=28, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=29, starts_line=5, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=32, starts_line=None, is_jump_target=False),
+ Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=35, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=47, argval=47, argrepr='', offset=38, starts_line=None, is_jump_target=False),
+ Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=41, starts_line=6, is_jump_target=False),
+ Instruction(opname='JUMP_FORWARD', opcode=110, arg=0, argval=47, argrepr='to 47', offset=44, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=47, starts_line=7, is_jump_target=True),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=50, starts_line=None, is_jump_target=False),
+ Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=53, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=13, argval=13, argrepr='', offset=56, starts_line=None, is_jump_target=False),
+ Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=59, starts_line=8, is_jump_target=False),
+ Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=60, starts_line=None, is_jump_target=False),
+ Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=63, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=66, starts_line=None, is_jump_target=True),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=67, starts_line=10, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=70, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=73, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=76, starts_line=None, is_jump_target=False),
+ Instruction(opname='SETUP_LOOP', opcode=120, arg=74, argval=154, argrepr='to 154', offset=77, starts_line=11, is_jump_target=True),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=80, starts_line=None, is_jump_target=True),
+ Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=143, argval=143, argrepr='', offset=83, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=86, starts_line=12, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=89, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=92, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=95, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=96, starts_line=13, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=99, starts_line=None, is_jump_target=False),
+ Instruction(opname='INPLACE_SUBTRACT', opcode=56, arg=None, argval=None, argrepr='', offset=102, starts_line=None, is_jump_target=False),
+ Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=103, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=106, starts_line=14, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=109, starts_line=None, is_jump_target=False),
+ Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=112, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=124, argval=124, argrepr='', offset=115, starts_line=None, is_jump_target=False),
+ Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=118, starts_line=15, is_jump_target=False),
+ Instruction(opname='JUMP_FORWARD', opcode=110, arg=0, argval=124, argrepr='to 124', offset=121, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=124, starts_line=16, is_jump_target=True),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=127, starts_line=None, is_jump_target=False),
+ Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=130, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=80, argval=80, argrepr='', offset=133, starts_line=None, is_jump_target=False),
+ Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=136, starts_line=17, is_jump_target=False),
+ Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=137, starts_line=None, is_jump_target=False),
+ Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=140, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=143, starts_line=None, is_jump_target=True),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=144, starts_line=19, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=6, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=147, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=150, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=153, starts_line=None, is_jump_target=False),
+ Instruction(opname='SETUP_FINALLY', opcode=122, arg=72, argval=229, argrepr='to 229', offset=154, starts_line=20, is_jump_target=True),
+ Instruction(opname='SETUP_EXCEPT', opcode=121, arg=12, argval=172, argrepr='to 172', offset=157, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=160, starts_line=21, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=7, argval=0, argrepr='0', offset=163, starts_line=None, is_jump_target=False),
+ Instruction(opname='BINARY_TRUE_DIVIDE', opcode=27, arg=None, argval=None, argrepr='', offset=166, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=167, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=168, starts_line=None, is_jump_target=False),
+ Instruction(opname='JUMP_FORWARD', opcode=110, arg=28, argval=200, argrepr='to 200', offset=169, starts_line=None, is_jump_target=False),
+ Instruction(opname='DUP_TOP', opcode=4, arg=None, argval=None, argrepr='', offset=172, starts_line=22, is_jump_target=True),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=2, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=173, starts_line=None, is_jump_target=False),
+ Instruction(opname='COMPARE_OP', opcode=107, arg=10, argval='exception match', argrepr='exception match', offset=176, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=199, argval=199, argrepr='', offset=179, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=182, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=183, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=184, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=185, starts_line=23, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=8, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=188, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=191, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=194, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_EXCEPT', opcode=89, arg=None, argval=None, argrepr='', offset=195, starts_line=None, is_jump_target=False),
+ Instruction(opname='JUMP_FORWARD', opcode=110, arg=26, argval=225, argrepr='to 225', offset=196, starts_line=None, is_jump_target=False),
+ Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=199, starts_line=None, is_jump_target=True),
+ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=200, starts_line=25, is_jump_target=True),
+ Instruction(opname='SETUP_WITH', opcode=143, arg=17, argval=223, argrepr='to 223', offset=203, starts_line=None, is_jump_target=False),
+ Instruction(opname='STORE_FAST', opcode=125, arg=1, argval='dodgy', argrepr='dodgy', offset=206, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=209, starts_line=26, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=9, argval='Never reach this', argrepr="'Never reach this'", offset=212, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=215, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=218, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=219, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=220, starts_line=None, is_jump_target=False),
+ Instruction(opname='WITH_CLEANUP', opcode=81, arg=None, argval=None, argrepr='', offset=223, starts_line=None, is_jump_target=True),
+ Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=224, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=225, starts_line=None, is_jump_target=True),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=226, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=229, starts_line=28, is_jump_target=True),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=10, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=232, starts_line=None, is_jump_target=False),
+ Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=235, starts_line=None, is_jump_target=False),
+ Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=238, starts_line=None, is_jump_target=False),
+ Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=239, starts_line=None, is_jump_target=False),
+ Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=240, starts_line=None, is_jump_target=False),
+ Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=243, starts_line=None, is_jump_target=False),
+]
+
+# One last piece of inspect fodder to check the default line number handling
+def simple(): pass
+expected_opinfo_simple = [
+ Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=0, starts_line=simple.__code__.co_firstlineno, is_jump_target=False),
+ Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=3, starts_line=None, is_jump_target=False)
+]
+
+
+class InstructionTests(BytecodeTestCase):
+
+ def test_default_first_line(self):
+ actual = dis.get_instructions(simple)
+ self.assertEqual(list(actual), expected_opinfo_simple)
+
+ def test_first_line_set_to_None(self):
+ actual = dis.get_instructions(simple, first_line=None)
+ self.assertEqual(list(actual), expected_opinfo_simple)
+
+ def test_outer(self):
+ actual = dis.get_instructions(outer, first_line=expected_outer_line)
+ self.assertEqual(list(actual), expected_opinfo_outer)
+
+ def test_nested(self):
+ with captured_stdout():
+ f = outer()
+ actual = dis.get_instructions(f, first_line=expected_f_line)
+ self.assertEqual(list(actual), expected_opinfo_f)
+
+ def test_doubly_nested(self):
+ with captured_stdout():
+ inner = outer()()
+ actual = dis.get_instructions(inner, first_line=expected_inner_line)
+ self.assertEqual(list(actual), expected_opinfo_inner)
+
+ def test_jumpy(self):
+ actual = dis.get_instructions(jumpy, first_line=expected_jumpy_line)
+ self.assertEqual(list(actual), expected_opinfo_jumpy)
+
+# get_instructions has its own tests above, so can rely on it to validate
+# the object oriented API
+class BytecodeTests(unittest.TestCase):
+ def test_instantiation(self):
+ # Test with function, method, code string and code object
+ for obj in [_f, _C(1).__init__, "a=1", _f.__code__]:
+ with self.subTest(obj=obj):
+ b = dis.Bytecode(obj)
+ self.assertIsInstance(b.codeobj, types.CodeType)
+
+ self.assertRaises(TypeError, dis.Bytecode, object())
+
+ def test_iteration(self):
+ for obj in [_f, _C(1).__init__, "a=1", _f.__code__]:
+ with self.subTest(obj=obj):
+ via_object = list(dis.Bytecode(obj))
+ via_generator = list(dis.get_instructions(obj))
+ self.assertEqual(via_object, via_generator)
+
+ def test_explicit_first_line(self):
+ actual = dis.Bytecode(outer, first_line=expected_outer_line)
+ self.assertEqual(list(actual), expected_opinfo_outer)
+
+ def test_source_line_in_disassembly(self):
+ # Use the line in the source code
+ actual = dis.Bytecode(simple).dis()[:3]
+ expected = "{:>3}".format(simple.__code__.co_firstlineno)
+ self.assertEqual(actual, expected)
+ # Use an explicit first line number
+ actual = dis.Bytecode(simple, first_line=350).dis()[:3]
+ self.assertEqual(actual, "350")
+
+ def test_info(self):
+ self.maxDiff = 1000
+ for x, expected in CodeInfoTests.test_pairs:
+ b = dis.Bytecode(x)
+ self.assertRegex(b.info(), expected)
+
+ def test_disassembled(self):
+ actual = dis.Bytecode(_f).dis()
+ self.assertEqual(actual, dis_f)
+
+
def test_main():
- run_unittest(DisTests, CodeInfoTests)
+ run_unittest(DisTests, DisWithFileTests, CodeInfoTests,
+ InstructionTests, BytecodeTests)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_doctest.py b/Lib/test/test_doctest.py
index 8f8c7c7e82..d4ff049fca 100644
--- a/Lib/test/test_doctest.py
+++ b/Lib/test/test_doctest.py
@@ -1409,8 +1409,40 @@ However, output from `report_start` is not suppressed:
2
TestResults(failed=3, attempted=5)
-For the purposes of REPORT_ONLY_FIRST_FAILURE, unexpected exceptions
-count as failures:
+The FAIL_FAST flag causes the runner to exit after the first failing example,
+so subsequent examples are not even attempted:
+
+ >>> flags = doctest.FAIL_FAST
+ >>> doctest.DocTestRunner(verbose=False, optionflags=flags).run(test)
+ ... # doctest: +ELLIPSIS
+ **********************************************************************
+ File ..., line 5, in f
+ Failed example:
+ print(2) # first failure
+ Expected:
+ 200
+ Got:
+ 2
+ TestResults(failed=1, attempted=2)
+
+Specifying both FAIL_FAST and REPORT_ONLY_FIRST_FAILURE is equivalent to
+FAIL_FAST only:
+
+ >>> flags = doctest.FAIL_FAST | doctest.REPORT_ONLY_FIRST_FAILURE
+ >>> doctest.DocTestRunner(verbose=False, optionflags=flags).run(test)
+ ... # doctest: +ELLIPSIS
+ **********************************************************************
+ File ..., line 5, in f
+ Failed example:
+ print(2) # first failure
+ Expected:
+ 200
+ Got:
+ 2
+ TestResults(failed=1, attempted=2)
+
+For the purposes of both REPORT_ONLY_FIRST_FAILURE and FAIL_FAST, unexpected
+exceptions count as failures:
>>> def f(x):
... r'''
@@ -1437,6 +1469,17 @@ count as failures:
...
ValueError: 2
TestResults(failed=3, attempted=5)
+ >>> flags = doctest.FAIL_FAST
+ >>> doctest.DocTestRunner(verbose=False, optionflags=flags).run(test)
+ ... # doctest: +ELLIPSIS
+ **********************************************************************
+ File ..., line 5, in f
+ Failed example:
+ raise ValueError(2) # first failure
+ Exception raised:
+ ...
+ ValueError: 2
+ TestResults(failed=1, attempted=2)
New option flags can also be registered, via register_optionflag(). Here
we reach into doctest's internals a bit.
@@ -2553,6 +2596,240 @@ Check doctest with a non-ascii filename:
TestResults(failed=1, attempted=1)
"""
+def test_CLI(): r"""
+The doctest module can be used to run doctests against an arbitrary file.
+These tests test this CLI functionality.
+
+We'll use the support module's script_helpers for this, and write a test files
+to a temp dir to run the command against. Due to a current limitation in
+script_helpers, though, we need a little utility function to turn the returned
+output into something we can doctest against:
+
+ >>> def normalize(s):
+ ... return '\n'.join(s.decode().splitlines())
+
+Note: we also pass TERM='' to all the assert_python calls to avoid a bug
+in the readline library that is triggered in these tests because we are
+running them in a new python process. See:
+
+ http://lists.gnu.org/archive/html/bug-readline/2013-06/msg00000.html
+
+With those preliminaries out of the way, we'll start with a file with two
+simple tests and no errors. We'll run both the unadorned doctest command, and
+the verbose version, and then check the output:
+
+ >>> from test import script_helper
+ >>> with script_helper.temp_dir() as tmpdir:
+ ... fn = os.path.join(tmpdir, 'myfile.doc')
+ ... with open(fn, 'w') as f:
+ ... _ = f.write('This is a very simple test file.\n')
+ ... _ = f.write(' >>> 1 + 1\n')
+ ... _ = f.write(' 2\n')
+ ... _ = f.write(' >>> "a"\n')
+ ... _ = f.write(" 'a'\n")
+ ... _ = f.write('\n')
+ ... _ = f.write('And that is it.\n')
+ ... rc1, out1, err1 = script_helper.assert_python_ok(
+ ... '-m', 'doctest', fn, TERM='')
+ ... rc2, out2, err2 = script_helper.assert_python_ok(
+ ... '-m', 'doctest', '-v', fn, TERM='')
+
+With no arguments and passing tests, we should get no output:
+
+ >>> rc1, out1, err1
+ (0, b'', b'')
+
+With the verbose flag, we should see the test output, but no error output:
+
+ >>> rc2, err2
+ (0, b'')
+ >>> print(normalize(out2))
+ Trying:
+ 1 + 1
+ Expecting:
+ 2
+ ok
+ Trying:
+ "a"
+ Expecting:
+ 'a'
+ ok
+ 1 items passed all tests:
+ 2 tests in myfile.doc
+ 2 tests in 1 items.
+ 2 passed and 0 failed.
+ Test passed.
+
+Now we'll write a couple files, one with three tests, the other a python module
+with two tests, both of the files having "errors" in the tests that can be made
+non-errors by applying the appropriate doctest options to the run (ELLIPSIS in
+the first file, NORMALIZE_WHITESPACE in the second). This combination will
+allow to thoroughly test the -f and -o flags, as well as the doctest command's
+ability to process more than one file on the command line and, since the second
+file ends in '.py', its handling of python module files (as opposed to straight
+text files).
+
+ >>> from test import script_helper
+ >>> with script_helper.temp_dir() as tmpdir:
+ ... fn = os.path.join(tmpdir, 'myfile.doc')
+ ... with open(fn, 'w') as f:
+ ... _ = f.write('This is another simple test file.\n')
+ ... _ = f.write(' >>> 1 + 1\n')
+ ... _ = f.write(' 2\n')
+ ... _ = f.write(' >>> "abcdef"\n')
+ ... _ = f.write(" 'a...f'\n")
+ ... _ = f.write(' >>> "ajkml"\n')
+ ... _ = f.write(" 'a...l'\n")
+ ... _ = f.write('\n')
+ ... _ = f.write('And that is it.\n')
+ ... fn2 = os.path.join(tmpdir, 'myfile2.py')
+ ... with open(fn2, 'w') as f:
+ ... _ = f.write('def test_func():\n')
+ ... _ = f.write(' \"\"\"\n')
+ ... _ = f.write(' This is simple python test function.\n')
+ ... _ = f.write(' >>> 1 + 1\n')
+ ... _ = f.write(' 2\n')
+ ... _ = f.write(' >>> "abc def"\n')
+ ... _ = f.write(" 'abc def'\n")
+ ... _ = f.write("\n")
+ ... _ = f.write(' \"\"\"\n')
+ ... import shutil
+ ... rc1, out1, err1 = script_helper.assert_python_failure(
+ ... '-m', 'doctest', fn, fn2, TERM='')
+ ... rc2, out2, err2 = script_helper.assert_python_ok(
+ ... '-m', 'doctest', '-o', 'ELLIPSIS', fn, TERM='')
+ ... rc3, out3, err3 = script_helper.assert_python_ok(
+ ... '-m', 'doctest', '-o', 'ELLIPSIS',
+ ... '-o', 'NORMALIZE_WHITESPACE', fn, fn2, TERM='')
+ ... rc4, out4, err4 = script_helper.assert_python_failure(
+ ... '-m', 'doctest', '-f', fn, fn2, TERM='')
+ ... rc5, out5, err5 = script_helper.assert_python_ok(
+ ... '-m', 'doctest', '-v', '-o', 'ELLIPSIS',
+ ... '-o', 'NORMALIZE_WHITESPACE', fn, fn2, TERM='')
+
+Our first test run will show the errors from the first file (doctest stops if a
+file has errors). Note that doctest test-run error output appears on stdout,
+not stderr:
+
+ >>> rc1, err1
+ (1, b'')
+ >>> print(normalize(out1)) # doctest: +ELLIPSIS
+ **********************************************************************
+ File "...myfile.doc", line 4, in myfile.doc
+ Failed example:
+ "abcdef"
+ Expected:
+ 'a...f'
+ Got:
+ 'abcdef'
+ **********************************************************************
+ File "...myfile.doc", line 6, in myfile.doc
+ Failed example:
+ "ajkml"
+ Expected:
+ 'a...l'
+ Got:
+ 'ajkml'
+ **********************************************************************
+ 1 items had failures:
+ 2 of 3 in myfile.doc
+ ***Test Failed*** 2 failures.
+
+With -o ELLIPSIS specified, the second run, against just the first file, should
+produce no errors, and with -o NORMALIZE_WHITESPACE also specified, neither
+should the third, which ran against both files:
+
+ >>> rc2, out2, err2
+ (0, b'', b'')
+ >>> rc3, out3, err3
+ (0, b'', b'')
+
+The fourth run uses FAIL_FAST, so we should see only one error:
+
+ >>> rc4, err4
+ (1, b'')
+ >>> print(normalize(out4)) # doctest: +ELLIPSIS
+ **********************************************************************
+ File "...myfile.doc", line 4, in myfile.doc
+ Failed example:
+ "abcdef"
+ Expected:
+ 'a...f'
+ Got:
+ 'abcdef'
+ **********************************************************************
+ 1 items had failures:
+ 1 of 2 in myfile.doc
+ ***Test Failed*** 1 failures.
+
+The fifth test uses verbose with the two options, so we should get verbose
+success output for the tests in both files:
+
+ >>> rc5, err5
+ (0, b'')
+ >>> print(normalize(out5))
+ Trying:
+ 1 + 1
+ Expecting:
+ 2
+ ok
+ Trying:
+ "abcdef"
+ Expecting:
+ 'a...f'
+ ok
+ Trying:
+ "ajkml"
+ Expecting:
+ 'a...l'
+ ok
+ 1 items passed all tests:
+ 3 tests in myfile.doc
+ 3 tests in 1 items.
+ 3 passed and 0 failed.
+ Test passed.
+ Trying:
+ 1 + 1
+ Expecting:
+ 2
+ ok
+ Trying:
+ "abc def"
+ Expecting:
+ 'abc def'
+ ok
+ 1 items had no tests:
+ myfile2
+ 1 items passed all tests:
+ 2 tests in myfile2.test_func
+ 2 tests in 2 items.
+ 2 passed and 0 failed.
+ Test passed.
+
+We should also check some typical error cases.
+
+Invalid file name:
+
+ >>> rc, out, err = script_helper.assert_python_failure(
+ ... '-m', 'doctest', 'nosuchfile', TERM='')
+ >>> rc, out
+ (1, b'')
+ >>> print(normalize(err)) # doctest: +ELLIPSIS
+ Traceback (most recent call last):
+ ...
+ FileNotFoundError: [Errno ...] No such file or directory: 'nosuchfile'
+
+Invalid doctest option:
+
+ >>> rc, out, err = script_helper.assert_python_failure(
+ ... '-m', 'doctest', '-o', 'nosuchoption', TERM='')
+ >>> rc, out
+ (2, b'')
+ >>> print(normalize(err)) # doctest: +ELLIPSIS
+ usage...invalid...nosuchoption...
+
+"""
+
######################################################################
## Main
######################################################################
diff --git a/Lib/test/test_dynamicclassattribute.py b/Lib/test/test_dynamicclassattribute.py
new file mode 100644
index 0000000000..079d6c3224
--- /dev/null
+++ b/Lib/test/test_dynamicclassattribute.py
@@ -0,0 +1,304 @@
+# Test case for DynamicClassAttribute
+# more tests are in test_descr
+
+import abc
+import sys
+import unittest
+from test.support import run_unittest
+from types import DynamicClassAttribute
+
+class PropertyBase(Exception):
+ pass
+
+class PropertyGet(PropertyBase):
+ pass
+
+class PropertySet(PropertyBase):
+ pass
+
+class PropertyDel(PropertyBase):
+ pass
+
+class BaseClass(object):
+ def __init__(self):
+ self._spam = 5
+
+ @DynamicClassAttribute
+ def spam(self):
+ """BaseClass.getter"""
+ return self._spam
+
+ @spam.setter
+ def spam(self, value):
+ self._spam = value
+
+ @spam.deleter
+ def spam(self):
+ del self._spam
+
+class SubClass(BaseClass):
+
+ spam = BaseClass.__dict__['spam']
+
+ @spam.getter
+ def spam(self):
+ """SubClass.getter"""
+ raise PropertyGet(self._spam)
+
+ @spam.setter
+ def spam(self, value):
+ raise PropertySet(self._spam)
+
+ @spam.deleter
+ def spam(self):
+ raise PropertyDel(self._spam)
+
+class PropertyDocBase(object):
+ _spam = 1
+ def _get_spam(self):
+ return self._spam
+ spam = DynamicClassAttribute(_get_spam, doc="spam spam spam")
+
+class PropertyDocSub(PropertyDocBase):
+ spam = PropertyDocBase.__dict__['spam']
+ @spam.getter
+ def spam(self):
+ """The decorator does not use this doc string"""
+ return self._spam
+
+class PropertySubNewGetter(BaseClass):
+ spam = BaseClass.__dict__['spam']
+ @spam.getter
+ def spam(self):
+ """new docstring"""
+ return 5
+
+class PropertyNewGetter(object):
+ @DynamicClassAttribute
+ def spam(self):
+ """original docstring"""
+ return 1
+ @spam.getter
+ def spam(self):
+ """new docstring"""
+ return 8
+
+class ClassWithAbstractVirtualProperty(metaclass=abc.ABCMeta):
+ @DynamicClassAttribute
+ @abc.abstractmethod
+ def color():
+ pass
+
+class ClassWithPropertyAbstractVirtual(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ @DynamicClassAttribute
+ def color():
+ pass
+
+class PropertyTests(unittest.TestCase):
+ def test_property_decorator_baseclass(self):
+ # see #1620
+ base = BaseClass()
+ self.assertEqual(base.spam, 5)
+ self.assertEqual(base._spam, 5)
+ base.spam = 10
+ self.assertEqual(base.spam, 10)
+ self.assertEqual(base._spam, 10)
+ delattr(base, "spam")
+ self.assertTrue(not hasattr(base, "spam"))
+ self.assertTrue(not hasattr(base, "_spam"))
+ base.spam = 20
+ self.assertEqual(base.spam, 20)
+ self.assertEqual(base._spam, 20)
+
+ def test_property_decorator_subclass(self):
+ # see #1620
+ sub = SubClass()
+ self.assertRaises(PropertyGet, getattr, sub, "spam")
+ self.assertRaises(PropertySet, setattr, sub, "spam", None)
+ self.assertRaises(PropertyDel, delattr, sub, "spam")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_property_decorator_subclass_doc(self):
+ sub = SubClass()
+ self.assertEqual(sub.__class__.__dict__['spam'].__doc__, "SubClass.getter")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_property_decorator_baseclass_doc(self):
+ base = BaseClass()
+ self.assertEqual(base.__class__.__dict__['spam'].__doc__, "BaseClass.getter")
+
+ def test_property_decorator_doc(self):
+ base = PropertyDocBase()
+ sub = PropertyDocSub()
+ self.assertEqual(base.__class__.__dict__['spam'].__doc__, "spam spam spam")
+ self.assertEqual(sub.__class__.__dict__['spam'].__doc__, "spam spam spam")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_property_getter_doc_override(self):
+ newgettersub = PropertySubNewGetter()
+ self.assertEqual(newgettersub.spam, 5)
+ self.assertEqual(newgettersub.__class__.__dict__['spam'].__doc__, "new docstring")
+ newgetter = PropertyNewGetter()
+ self.assertEqual(newgetter.spam, 8)
+ self.assertEqual(newgetter.__class__.__dict__['spam'].__doc__, "new docstring")
+
+ def test_property___isabstractmethod__descriptor(self):
+ for val in (True, False, [], [1], '', '1'):
+ class C(object):
+ def foo(self):
+ pass
+ foo.__isabstractmethod__ = val
+ foo = DynamicClassAttribute(foo)
+ self.assertIs(C.__dict__['foo'].__isabstractmethod__, bool(val))
+
+ # check that the DynamicClassAttribute's __isabstractmethod__ descriptor does the
+ # right thing when presented with a value that fails truth testing:
+ class NotBool(object):
+ def __nonzero__(self):
+ raise ValueError()
+ __len__ = __nonzero__
+ with self.assertRaises(ValueError):
+ class C(object):
+ def foo(self):
+ pass
+ foo.__isabstractmethod__ = NotBool()
+ foo = DynamicClassAttribute(foo)
+
+ def test_abstract_virtual(self):
+ self.assertRaises(TypeError, ClassWithAbstractVirtualProperty)
+ self.assertRaises(TypeError, ClassWithPropertyAbstractVirtual)
+ class APV(ClassWithPropertyAbstractVirtual):
+ pass
+ self.assertRaises(TypeError, APV)
+ class AVP(ClassWithAbstractVirtualProperty):
+ pass
+ self.assertRaises(TypeError, AVP)
+ class Okay1(ClassWithAbstractVirtualProperty):
+ @DynamicClassAttribute
+ def color(self):
+ return self._color
+ def __init__(self):
+ self._color = 'cyan'
+ with self.assertRaises(AttributeError):
+ Okay1.color
+ self.assertEqual(Okay1().color, 'cyan')
+ class Okay2(ClassWithAbstractVirtualProperty):
+ @DynamicClassAttribute
+ def color(self):
+ return self._color
+ def __init__(self):
+ self._color = 'magenta'
+ with self.assertRaises(AttributeError):
+ Okay2.color
+ self.assertEqual(Okay2().color, 'magenta')
+
+
+# Issue 5890: subclasses of DynamicClassAttribute do not preserve method __doc__ strings
+class PropertySub(DynamicClassAttribute):
+ """This is a subclass of DynamicClassAttribute"""
+
+class PropertySubSlots(DynamicClassAttribute):
+ """This is a subclass of DynamicClassAttribute that defines __slots__"""
+ __slots__ = ()
+
+class PropertySubclassTests(unittest.TestCase):
+
+ @unittest.skipIf(hasattr(PropertySubSlots, '__doc__'),
+ "__doc__ is already present, __slots__ will have no effect")
+ def test_slots_docstring_copy_exception(self):
+ try:
+ class Foo(object):
+ @PropertySubSlots
+ def spam(self):
+ """Trying to copy this docstring will raise an exception"""
+ return 1
+ print('\n',spam.__doc__)
+ except AttributeError:
+ pass
+ else:
+ raise Exception("AttributeError not raised")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_docstring_copy(self):
+ class Foo(object):
+ @PropertySub
+ def spam(self):
+ """spam wrapped in DynamicClassAttribute subclass"""
+ return 1
+ self.assertEqual(
+ Foo.__dict__['spam'].__doc__,
+ "spam wrapped in DynamicClassAttribute subclass")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_property_setter_copies_getter_docstring(self):
+ class Foo(object):
+ def __init__(self): self._spam = 1
+ @PropertySub
+ def spam(self):
+ """spam wrapped in DynamicClassAttribute subclass"""
+ return self._spam
+ @spam.setter
+ def spam(self, value):
+ """this docstring is ignored"""
+ self._spam = value
+ foo = Foo()
+ self.assertEqual(foo.spam, 1)
+ foo.spam = 2
+ self.assertEqual(foo.spam, 2)
+ self.assertEqual(
+ Foo.__dict__['spam'].__doc__,
+ "spam wrapped in DynamicClassAttribute subclass")
+ class FooSub(Foo):
+ spam = Foo.__dict__['spam']
+ @spam.setter
+ def spam(self, value):
+ """another ignored docstring"""
+ self._spam = 'eggs'
+ foosub = FooSub()
+ self.assertEqual(foosub.spam, 1)
+ foosub.spam = 7
+ self.assertEqual(foosub.spam, 'eggs')
+ self.assertEqual(
+ FooSub.__dict__['spam'].__doc__,
+ "spam wrapped in DynamicClassAttribute subclass")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_property_new_getter_new_docstring(self):
+
+ class Foo(object):
+ @PropertySub
+ def spam(self):
+ """a docstring"""
+ return 1
+ @spam.getter
+ def spam(self):
+ """a new docstring"""
+ return 2
+ self.assertEqual(Foo.__dict__['spam'].__doc__, "a new docstring")
+ class FooBase(object):
+ @PropertySub
+ def spam(self):
+ """a docstring"""
+ return 1
+ class Foo2(FooBase):
+ spam = FooBase.__dict__['spam']
+ @spam.getter
+ def spam(self):
+ """a new docstring"""
+ return 2
+ self.assertEqual(Foo.__dict__['spam'].__doc__, "a new docstring")
+
+
+
+def test_main():
+ run_unittest(PropertyTests, PropertySubclassTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_email/__init__.py b/Lib/test/test_email/__init__.py
index f206ace10c..d8896eed46 100644
--- a/Lib/test/test_email/__init__.py
+++ b/Lib/test/test_email/__init__.py
@@ -2,6 +2,7 @@ import os
import sys
import unittest
import test.support
+import collections
import email
from email.message import Message
from email._policybase import compat32
@@ -42,6 +43,8 @@ class TestEmailBase(unittest.TestCase):
# here we make minimal changes in the test_email tests compared to their
# pre-3.3 state.
policy = compat32
+ # Likewise, the default message object is Message.
+ message = Message
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
@@ -54,11 +57,23 @@ class TestEmailBase(unittest.TestCase):
with openfile(filename) as fp:
return email.message_from_file(fp, policy=self.policy)
- def _str_msg(self, string, message=Message, policy=None):
+ def _str_msg(self, string, message=None, policy=None):
if policy is None:
policy = self.policy
+ if message is None:
+ message = self.message
return email.message_from_string(string, message, policy=policy)
+ def _bytes_msg(self, bytestring, message=None, policy=None):
+ if policy is None:
+ policy = self.policy
+ if message is None:
+ message = self.message
+ return email.message_from_bytes(bytestring, message, policy=policy)
+
+ def _make_message(self):
+ return self.message(policy=self.policy)
+
def _bytes_repr(self, b):
return [repr(x) for x in b.splitlines(keepends=True)]
@@ -123,6 +138,7 @@ def parameterize(cls):
"""
paramdicts = {}
+ testers = collections.defaultdict(list)
for name, attr in cls.__dict__.items():
if name.endswith('_params'):
if not hasattr(attr, 'keys'):
@@ -134,7 +150,15 @@ def parameterize(cls):
d[n] = x
attr = d
paramdicts[name[:-7] + '_as_'] = attr
+ if '_as_' in name:
+ testers[name.split('_as_')[0] + '_as_'].append(name)
testfuncs = {}
+ for name in paramdicts:
+ if name not in testers:
+ raise ValueError("No tester found for {}".format(name))
+ for name in testers:
+ if name not in paramdicts:
+ raise ValueError("No params found for {}".format(name))
for name, attr in cls.__dict__.items():
for paramsname, paramsdict in paramdicts.items():
if name.startswith(paramsname):
diff --git a/Lib/test/test_email/test_contentmanager.py b/Lib/test/test_email/test_contentmanager.py
new file mode 100644
index 0000000000..1629e2d8fd
--- /dev/null
+++ b/Lib/test/test_email/test_contentmanager.py
@@ -0,0 +1,796 @@
+import unittest
+from test.test_email import TestEmailBase, parameterize
+import textwrap
+from email import policy
+from email.message import EmailMessage
+from email.contentmanager import ContentManager, raw_data_manager
+
+
+@parameterize
+class TestContentManager(TestEmailBase):
+
+ policy = policy.default
+ message = EmailMessage
+
+ get_key_params = {
+ 'full_type': (1, 'text/plain',),
+ 'maintype_only': (2, 'text',),
+ 'null_key': (3, '',),
+ }
+
+ def get_key_as_get_content_key(self, order, key):
+ def foo_getter(msg, foo=None):
+ bar = msg['X-Bar-Header']
+ return foo, bar
+ cm = ContentManager()
+ cm.add_get_handler(key, foo_getter)
+ m = self._make_message()
+ m['Content-Type'] = 'text/plain'
+ m['X-Bar-Header'] = 'foo'
+ self.assertEqual(cm.get_content(m, foo='bar'), ('bar', 'foo'))
+
+ def get_key_as_get_content_key_order(self, order, key):
+ def bar_getter(msg):
+ return msg['X-Bar-Header']
+ def foo_getter(msg):
+ return msg['X-Foo-Header']
+ cm = ContentManager()
+ cm.add_get_handler(key, foo_getter)
+ for precedence, key in self.get_key_params.values():
+ if precedence > order:
+ cm.add_get_handler(key, bar_getter)
+ m = self._make_message()
+ m['Content-Type'] = 'text/plain'
+ m['X-Bar-Header'] = 'bar'
+ m['X-Foo-Header'] = 'foo'
+ self.assertEqual(cm.get_content(m), ('foo'))
+
+ def test_get_content_raises_if_unknown_mimetype_and_no_default(self):
+ cm = ContentManager()
+ m = self._make_message()
+ m['Content-Type'] = 'text/plain'
+ with self.assertRaisesRegex(KeyError, 'text/plain'):
+ cm.get_content(m)
+
+ class BaseThing(str):
+ pass
+ baseobject_full_path = __name__ + '.' + 'TestContentManager.BaseThing'
+ class Thing(BaseThing):
+ pass
+ testobject_full_path = __name__ + '.' + 'TestContentManager.Thing'
+
+ set_key_params = {
+ 'type': (0, Thing,),
+ 'full_path': (1, testobject_full_path,),
+ 'qualname': (2, 'TestContentManager.Thing',),
+ 'name': (3, 'Thing',),
+ 'base_type': (4, BaseThing,),
+ 'base_full_path': (5, baseobject_full_path,),
+ 'base_qualname': (6, 'TestContentManager.BaseThing',),
+ 'base_name': (7, 'BaseThing',),
+ 'str_type': (8, str,),
+ 'str_full_path': (9, 'builtins.str',),
+ 'str_name': (10, 'str',), # str name and qualname are the same
+ 'null_key': (11, None,),
+ }
+
+ def set_key_as_set_content_key(self, order, key):
+ def foo_setter(msg, obj, foo=None):
+ msg['X-Foo-Header'] = foo
+ msg.set_payload(obj)
+ cm = ContentManager()
+ cm.add_set_handler(key, foo_setter)
+ m = self._make_message()
+ msg_obj = self.Thing()
+ cm.set_content(m, msg_obj, foo='bar')
+ self.assertEqual(m['X-Foo-Header'], 'bar')
+ self.assertEqual(m.get_payload(), msg_obj)
+
+ def set_key_as_set_content_key_order(self, order, key):
+ def foo_setter(msg, obj):
+ msg['X-FooBar-Header'] = 'foo'
+ msg.set_payload(obj)
+ def bar_setter(msg, obj):
+ msg['X-FooBar-Header'] = 'bar'
+ cm = ContentManager()
+ cm.add_set_handler(key, foo_setter)
+ for precedence, key in self.get_key_params.values():
+ if precedence > order:
+ cm.add_set_handler(key, bar_setter)
+ m = self._make_message()
+ msg_obj = self.Thing()
+ cm.set_content(m, msg_obj)
+ self.assertEqual(m['X-FooBar-Header'], 'foo')
+ self.assertEqual(m.get_payload(), msg_obj)
+
+ def test_set_content_raises_if_unknown_type_and_no_default(self):
+ cm = ContentManager()
+ m = self._make_message()
+ msg_obj = self.Thing()
+ with self.assertRaisesRegex(KeyError, self.testobject_full_path):
+ cm.set_content(m, msg_obj)
+
+ def test_set_content_raises_if_called_on_multipart(self):
+ cm = ContentManager()
+ m = self._make_message()
+ m['Content-Type'] = 'multipart/foo'
+ with self.assertRaises(TypeError):
+ cm.set_content(m, 'test')
+
+ def test_set_content_calls_clear_content(self):
+ m = self._make_message()
+ m['Content-Foo'] = 'bar'
+ m['Content-Type'] = 'text/html'
+ m['To'] = 'test'
+ m.set_payload('abc')
+ cm = ContentManager()
+ cm.add_set_handler(str, lambda *args, **kw: None)
+ m.set_content('xyz', content_manager=cm)
+ self.assertIsNone(m['Content-Foo'])
+ self.assertIsNone(m['Content-Type'])
+ self.assertEqual(m['To'], 'test')
+ self.assertIsNone(m.get_payload())
+
+
+@parameterize
+class TestRawDataManager(TestEmailBase):
+ # Note: these tests are dependent on the order in which headers are added
+ # to the message objects by the code. There's no defined ordering in
+ # RFC5322/MIME, so this makes the tests more fragile than the standards
+ # require. However, if the header order changes it is best to understand
+ # *why*, and make sure it isn't a subtle bug in whatever change was
+ # applied.
+
+ policy = policy.default.clone(max_line_length=60,
+ content_manager=raw_data_manager)
+ message = EmailMessage
+
+ def test_get_text_plain(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain
+
+ Basic text.
+ """))
+ self.assertEqual(raw_data_manager.get_content(m), "Basic text.\n")
+
+ def test_get_text_html(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/html
+
+ <p>Basic text.</p>
+ """))
+ self.assertEqual(raw_data_manager.get_content(m),
+ "<p>Basic text.</p>\n")
+
+ def test_get_text_plain_latin1(self):
+ m = self._bytes_msg(textwrap.dedent("""\
+ Content-Type: text/plain; charset=latin1
+
+ Basìc tëxt.
+ """).encode('latin1'))
+ self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt.\n")
+
+ def test_get_text_plain_latin1_quoted_printable(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain; charset="latin-1"
+ Content-Transfer-Encoding: quoted-printable
+
+ Bas=ECc t=EBxt.
+ """))
+ self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt.\n")
+
+ def test_get_text_plain_utf8_base64(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf8"
+ Content-Transfer-Encoding: base64
+
+ QmFzw6xjIHTDq3h0Lgo=
+ """))
+ self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt.\n")
+
+ def test_get_text_plain_bad_utf8_quoted_printable(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf8"
+ Content-Transfer-Encoding: quoted-printable
+
+ Bas=c3=acc t=c3=abxt=fd.
+ """))
+ self.assertEqual(raw_data_manager.get_content(m), "Basìc tëxt�.\n")
+
+ def test_get_text_plain_bad_utf8_quoted_printable_ignore_errors(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf8"
+ Content-Transfer-Encoding: quoted-printable
+
+ Bas=c3=acc t=c3=abxt=fd.
+ """))
+ self.assertEqual(raw_data_manager.get_content(m, errors='ignore'),
+ "Basìc tëxt.\n")
+
+ def test_get_text_plain_utf8_base64_recoverable_bad_CTE_data(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf8"
+ Content-Transfer-Encoding: base64
+
+ QmFzw6xjIHTDq3h0Lgo\xFF=
+ """))
+ self.assertEqual(raw_data_manager.get_content(m, errors='ignore'),
+ "Basìc tëxt.\n")
+
+ def test_get_text_invalid_keyword(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: text/plain
+
+ Basic text.
+ """))
+ with self.assertRaises(TypeError):
+ raw_data_manager.get_content(m, foo='ignore')
+
+ def test_get_non_text(self):
+ template = textwrap.dedent("""\
+ Content-Type: {}
+ Content-Transfer-Encoding: base64
+
+ Ym9ndXMgZGF0YQ==
+ """)
+ for maintype in 'audio image video application'.split():
+ with self.subTest(maintype=maintype):
+ m = self._str_msg(template.format(maintype+'/foo'))
+ self.assertEqual(raw_data_manager.get_content(m), b"bogus data")
+
+ def test_get_non_text_invalid_keyword(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: image/jpg
+ Content-Transfer-Encoding: base64
+
+ Ym9ndXMgZGF0YQ==
+ """))
+ with self.assertRaises(TypeError):
+ raw_data_manager.get_content(m, errors='ignore')
+
+ def test_get_raises_on_multipart(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ --===--
+ """))
+ with self.assertRaises(KeyError):
+ raw_data_manager.get_content(m)
+
+ def test_get_message_rfc822_and_external_body(self):
+ template = textwrap.dedent("""\
+ Content-Type: message/{}
+
+ To: foo@example.com
+ From: bar@example.com
+ Subject: example
+
+ an example message
+ """)
+ for subtype in 'rfc822 external-body'.split():
+ with self.subTest(subtype=subtype):
+ m = self._str_msg(template.format(subtype))
+ sub_msg = raw_data_manager.get_content(m)
+ self.assertIsInstance(sub_msg, self.message)
+ self.assertEqual(raw_data_manager.get_content(sub_msg),
+ "an example message\n")
+ self.assertEqual(sub_msg['to'], 'foo@example.com')
+ self.assertEqual(sub_msg['from'].addresses[0].username, 'bar')
+
+ def test_get_message_non_rfc822_or_external_body_yields_bytes(self):
+ m = self._str_msg(textwrap.dedent("""\
+ Content-Type: message/partial
+
+ To: foo@example.com
+ From: bar@example.com
+ Subject: example
+
+ The real body is in another message.
+ """))
+ self.assertEqual(raw_data_manager.get_content(m)[:10], b'To: foo@ex')
+
+ def test_set_text_plain(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 7bit
+
+ Simple message.
+ """))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_html(self):
+ m = self._make_message()
+ content = "<p>Simple message.</p>\n"
+ raw_data_manager.set_content(m, content, subtype='html')
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Content-Type: text/html; charset="utf-8"
+ Content-Transfer-Encoding: 7bit
+
+ <p>Simple message.</p>
+ """))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_charset_latin_1(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ raw_data_manager.set_content(m, content, charset='latin-1')
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="iso-8859-1"
+ Content-Transfer-Encoding: 7bit
+
+ Simple message.
+ """))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_short_line_minimal_non_ascii_heuristics(self):
+ m = self._make_message()
+ content = "et là il est monté sur moi et il commence à m'éto.\n"
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+
+ et là il est monté sur moi et il commence à m'éto.
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_long_line_minimal_non_ascii_heuristics(self):
+ m = self._make_message()
+ content = ("j'ai un problème de python. il est sorti de son"
+ " vivarium. et là il est monté sur moi et il commence"
+ " à m'éto.\n")
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: quoted-printable
+
+ j'ai un probl=C3=A8me de python. il est sorti de son vivari=
+ um. et l=C3=A0 il est mont=C3=A9 sur moi et il commence =
+ =C3=A0 m'=C3=A9to.
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_11_lines_long_line_minimal_non_ascii_heuristics(self):
+ m = self._make_message()
+ content = '\n'*10 + (
+ "j'ai un problème de python. il est sorti de son"
+ " vivarium. et là il est monté sur moi et il commence"
+ " à m'éto.\n")
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: quoted-printable
+ """ + '\n'*10 + """
+ j'ai un probl=C3=A8me de python. il est sorti de son vivari=
+ um. et l=C3=A0 il est mont=C3=A9 sur moi et il commence =
+ =C3=A0 m'=C3=A9to.
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_maximal_non_ascii_heuristics(self):
+ m = self._make_message()
+ content = "áàäéèęöő.\n"
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+
+ áàäéèęöő.
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_11_lines_maximal_non_ascii_heuristics(self):
+ m = self._make_message()
+ content = '\n'*10 + "áàäéèęöő.\n"
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+ """ + '\n'*10 + """
+ áàäéèęöő.
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_long_line_maximal_non_ascii_heuristics(self):
+ m = self._make_message()
+ content = ("áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő"
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő"
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő.\n")
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: base64
+
+ w6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOgw6TDqcOoxJnD
+ tsWRw6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOgw6TDqcOo
+ xJnDtsWRw6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOgw6TD
+ qcOoxJnDtsWRw6HDoMOkw6nDqMSZw7bFkcOhw6DDpMOpw6jEmcO2xZHDocOg
+ w6TDqcOoxJnDtsWRLgo=
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_11_lines_long_line_maximal_non_ascii_heuristics(self):
+ # Yes, it chooses "wrong" here. It's a heuristic. So this result
+ # could change if we come up with a better heuristic.
+ m = self._make_message()
+ content = ('\n'*10 +
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő"
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő"
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő.\n")
+ raw_data_manager.set_content(m, "\n"*10 +
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő"
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő"
+ "áàäéèęöőáàäéèęöőáàäéèęöőáàäéèęöő.\n")
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: quoted-printable
+ """ + '\n'*10 + """
+ =C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=
+ =A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=
+ =C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=
+ =A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=
+ =C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=
+ =91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=
+ =C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=
+ =A4=C3=A9=C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=
+ =C3=A8=C4=99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=
+ =99=C3=B6=C5=91=C3=A1=C3=A0=C3=A4=C3=A9=C3=A8=C4=99=C3=B6=
+ =C5=91.
+ """).encode('utf-8'))
+ self.assertEqual(m.get_payload(decode=True).decode('utf-8'), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_text_non_ascii_with_cte_7bit_raises(self):
+ m = self._make_message()
+ with self.assertRaises(UnicodeError):
+ raw_data_manager.set_content(m,"áàäéèęöő.\n", cte='7bit')
+
+ def test_set_text_non_ascii_with_charset_ascii_raises(self):
+ m = self._make_message()
+ with self.assertRaises(UnicodeError):
+ raw_data_manager.set_content(m,"áàäéèęöő.\n", charset='ascii')
+
+ def test_set_text_non_ascii_with_cte_7bit_and_charset_ascii_raises(self):
+ m = self._make_message()
+ with self.assertRaises(UnicodeError):
+ raw_data_manager.set_content(m,"áàäéèęöő.\n", cte='7bit', charset='ascii')
+
+ def test_set_message(self):
+ m = self._make_message()
+ m['Subject'] = "Forwarded message"
+ content = self._make_message()
+ content['To'] = 'python@vivarium.org'
+ content['From'] = 'police@monty.org'
+ content['Subject'] = "get back in your box"
+ content.set_content("Or face the comfy chair.")
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Subject: Forwarded message
+ Content-Type: message/rfc822
+ Content-Transfer-Encoding: 8bit
+
+ To: python@vivarium.org
+ From: police@monty.org
+ Subject: get back in your box
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 7bit
+ MIME-Version: 1.0
+
+ Or face the comfy chair.
+ """))
+ payload = m.get_payload(0)
+ self.assertIsInstance(payload, self.message)
+ self.assertEqual(str(payload), str(content))
+ self.assertIsInstance(m.get_content(), self.message)
+ self.assertEqual(str(m.get_content()), str(content))
+
+ def test_set_message_with_non_ascii_and_coercion_to_7bit(self):
+ m = self._make_message()
+ m['Subject'] = "Escape report"
+ content = self._make_message()
+ content['To'] = 'police@monty.org'
+ content['From'] = 'victim@monty.org'
+ content['Subject'] = "Help"
+ content.set_content("j'ai un problème de python. il est sorti de son"
+ " vivarium.")
+ raw_data_manager.set_content(m, content)
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Subject: Escape report
+ Content-Type: message/rfc822
+ Content-Transfer-Encoding: 8bit
+
+ To: police@monty.org
+ From: victim@monty.org
+ Subject: Help
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 8bit
+ MIME-Version: 1.0
+
+ j'ai un problème de python. il est sorti de son vivarium.
+ """).encode('utf-8'))
+ # The choice of base64 for the body encoding is because generator
+ # doesn't bother with heuristics and uses it unconditionally for utf-8
+ # text.
+ # XXX: the first cte should be 7bit, too...that's a generator bug.
+ # XXX: the line length in the body also looks like a generator bug.
+ self.assertEqual(m.as_string(maxheaderlen=self.policy.max_line_length),
+ textwrap.dedent("""\
+ Subject: Escape report
+ Content-Type: message/rfc822
+ Content-Transfer-Encoding: 8bit
+
+ To: police@monty.org
+ From: victim@monty.org
+ Subject: Help
+ Content-Type: text/plain; charset="utf-8"
+ MIME-Version: 1.0
+ Content-Transfer-Encoding: base64
+
+ aidhaSB1biBwcm9ibMOobWUgZGUgcHl0aG9uLiBpbCBlc3Qgc29ydGkgZGUgc29uIHZpdmFyaXVt
+ Lgo=
+ """))
+ self.assertIsInstance(m.get_content(), self.message)
+ self.assertEqual(str(m.get_content()), str(content))
+
+ def test_set_message_invalid_cte_raises(self):
+ m = self._make_message()
+ content = self._make_message()
+ for cte in 'quoted-printable base64'.split():
+ for subtype in 'rfc822 external-body'.split():
+ with self.subTest(cte=cte, subtype=subtype):
+ with self.assertRaises(ValueError) as ar:
+ m.set_content(content, subtype, cte=cte)
+ exc = str(ar.exception)
+ self.assertIn(cte, exc)
+ self.assertIn(subtype, exc)
+ subtype = 'external-body'
+ for cte in '8bit binary'.split():
+ with self.subTest(cte=cte, subtype=subtype):
+ with self.assertRaises(ValueError) as ar:
+ m.set_content(content, subtype, cte=cte)
+ exc = str(ar.exception)
+ self.assertIn(cte, exc)
+ self.assertIn(subtype, exc)
+
+ def test_set_image_jpg(self):
+ for content in (b"bogus content",
+ bytearray(b"bogus content"),
+ memoryview(b"bogus content")):
+ with self.subTest(content=content):
+ m = self._make_message()
+ raw_data_manager.set_content(m, content, 'image', 'jpeg')
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Content-Type: image/jpeg
+ Content-Transfer-Encoding: base64
+
+ Ym9ndXMgY29udGVudA==
+ """))
+ self.assertEqual(m.get_payload(decode=True), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_audio_aif_with_quoted_printable_cte(self):
+ # Why you would use qp, I don't know, but it is technically supported.
+ # XXX: the incorrect line length is because binascii.b2a_qp doesn't
+ # support a line length parameter, but we must use it to get newline
+ # encoding.
+ # XXX: what about that lack of tailing newline? Do we actually handle
+ # that correctly in all cases? That is, if the *source* has an
+ # unencoded newline, do we add an extra newline to the returned payload
+ # or not? And can that actually be disambiguated based on the RFC?
+ m = self._make_message()
+ content = b'b\xFFgus\tcon\nt\rent ' + b'z'*100
+ m.set_content(content, 'audio', 'aif', cte='quoted-printable')
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: audio/aif
+ Content-Transfer-Encoding: quoted-printable
+ MIME-Version: 1.0
+
+ b=FFgus=09con=0At=0Dent=20zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz=
+ zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz""").encode('latin-1'))
+ self.assertEqual(m.get_payload(decode=True), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_video_mpeg_with_binary_cte(self):
+ m = self._make_message()
+ content = b'b\xFFgus\tcon\nt\rent ' + b'z'*100
+ m.set_content(content, 'video', 'mpeg', cte='binary')
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: video/mpeg
+ Content-Transfer-Encoding: binary
+ MIME-Version: 1.0
+
+ """).encode('ascii') +
+ # XXX: the second \n ought to be a \r, but generator gets it wrong.
+ # THIS MEANS WE DON'T ACTUALLY SUPPORT THE 'binary' CTE.
+ b'b\xFFgus\tcon\nt\nent zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz' +
+ b'zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz')
+ self.assertEqual(m.get_payload(decode=True), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_application_octet_stream_with_8bit_cte(self):
+ # In 8bit mode, univeral line end logic applies. It is up to the
+ # application to make sure the lines are short enough; we don't check.
+ m = self._make_message()
+ content = b'b\xFFgus\tcon\nt\rent\n' + b'z'*60 + b'\n'
+ m.set_content(content, 'application', 'octet-stream', cte='8bit')
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: application/octet-stream
+ Content-Transfer-Encoding: 8bit
+ MIME-Version: 1.0
+
+ """).encode('ascii') +
+ b'b\xFFgus\tcon\nt\nent\n' +
+ b'zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz\n')
+ self.assertEqual(m.get_payload(decode=True), content)
+ self.assertEqual(m.get_content(), content)
+
+ def test_set_headers_from_header_objects(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ header_factory = self.policy.header_factory
+ raw_data_manager.set_content(m, content, headers=(
+ header_factory("To", "foo@example.com"),
+ header_factory("From", "foo@example.com"),
+ header_factory("Subject", "I'm talking to myself.")))
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ To: foo@example.com
+ From: foo@example.com
+ Subject: I'm talking to myself.
+ Content-Transfer-Encoding: 7bit
+
+ Simple message.
+ """))
+
+ def test_set_headers_from_strings(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ raw_data_manager.set_content(m, content, headers=(
+ "X-Foo-Header: foo",
+ "X-Bar-Header: bar",))
+ self.assertEqual(str(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ X-Foo-Header: foo
+ X-Bar-Header: bar
+ Content-Transfer-Encoding: 7bit
+
+ Simple message.
+ """))
+
+ def test_set_headers_with_invalid_duplicate_string_header_raises(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ with self.assertRaisesRegex(ValueError, 'Content-Type'):
+ raw_data_manager.set_content(m, content, headers=(
+ "Content-Type: foo/bar",)
+ )
+
+ def test_set_headers_with_invalid_duplicate_header_header_raises(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ header_factory = self.policy.header_factory
+ with self.assertRaisesRegex(ValueError, 'Content-Type'):
+ raw_data_manager.set_content(m, content, headers=(
+ header_factory("Content-Type", " foo/bar"),)
+ )
+
+ def test_set_headers_with_defective_string_header_raises(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ with self.assertRaisesRegex(ValueError, 'a@fairly@@invalid@address'):
+ raw_data_manager.set_content(m, content, headers=(
+ 'To: a@fairly@@invalid@address',)
+ )
+ print(m['To'].defects)
+
+ def test_set_headers_with_defective_header_header_raises(self):
+ m = self._make_message()
+ content = "Simple message.\n"
+ header_factory = self.policy.header_factory
+ with self.assertRaisesRegex(ValueError, 'a@fairly@@invalid@address'):
+ raw_data_manager.set_content(m, content, headers=(
+ header_factory('To', 'a@fairly@@invalid@address'),)
+ )
+ print(m['To'].defects)
+
+ def test_set_disposition_inline(self):
+ m = self._make_message()
+ m.set_content('foo', disposition='inline')
+ self.assertEqual(m['Content-Disposition'], 'inline')
+
+ def test_set_disposition_attachment(self):
+ m = self._make_message()
+ m.set_content('foo', disposition='attachment')
+ self.assertEqual(m['Content-Disposition'], 'attachment')
+
+ def test_set_disposition_foo(self):
+ m = self._make_message()
+ m.set_content('foo', disposition='foo')
+ self.assertEqual(m['Content-Disposition'], 'foo')
+
+ # XXX: we should have a 'strict' policy mode (beyond raise_on_defect) that
+ # would cause 'foo' above to raise.
+
+ def test_set_filename(self):
+ m = self._make_message()
+ m.set_content('foo', filename='bar.txt')
+ self.assertEqual(m['Content-Disposition'],
+ 'attachment; filename="bar.txt"')
+
+ def test_set_filename_and_disposition_inline(self):
+ m = self._make_message()
+ m.set_content('foo', disposition='inline', filename='bar.txt')
+ self.assertEqual(m['Content-Disposition'], 'inline; filename="bar.txt"')
+
+ def test_set_non_ascii_filename(self):
+ m = self._make_message()
+ m.set_content('foo', filename='ábárî.txt')
+ self.assertEqual(bytes(m), textwrap.dedent("""\
+ Content-Type: text/plain; charset="utf-8"
+ Content-Transfer-Encoding: 7bit
+ Content-Disposition: attachment;
+ filename*=utf-8''%C3%A1b%C3%A1r%C3%AE.txt
+ MIME-Version: 1.0
+
+ foo
+ """).encode('ascii'))
+
+ content_object_params = {
+ 'text_plain': ('content', ()),
+ 'text_html': ('content', ('html',)),
+ 'application_octet_stream': (b'content',
+ ('application', 'octet_stream')),
+ 'image_jpeg': (b'content', ('image', 'jpeg')),
+ 'message_rfc822': (message(), ()),
+ 'message_external_body': (message(), ('external-body',)),
+ }
+
+ def content_object_as_header_receiver(self, obj, mimetype):
+ m = self._make_message()
+ m.set_content(obj, *mimetype, headers=(
+ 'To: foo@example.com',
+ 'From: bar@simple.net'))
+ self.assertEqual(m['to'], 'foo@example.com')
+ self.assertEqual(m['from'], 'bar@simple.net')
+
+ def content_object_as_disposition_inline_receiver(self, obj, mimetype):
+ m = self._make_message()
+ m.set_content(obj, *mimetype, disposition='inline')
+ self.assertEqual(m['Content-Disposition'], 'inline')
+
+ def content_object_as_non_ascii_filename_receiver(self, obj, mimetype):
+ m = self._make_message()
+ m.set_content(obj, *mimetype, disposition='inline', filename='bár.txt')
+ self.assertEqual(m['Content-Disposition'], 'inline; filename="bár.txt"')
+ self.assertEqual(m.get_filename(), "bár.txt")
+ self.assertEqual(m['Content-Disposition'].params['filename'], "bár.txt")
+
+ def content_object_as_cid_receiver(self, obj, mimetype):
+ m = self._make_message()
+ m.set_content(obj, *mimetype, cid='some_random_stuff')
+ self.assertEqual(m['Content-ID'], 'some_random_stuff')
+
+ def content_object_as_params_receiver(self, obj, mimetype):
+ m = self._make_message()
+ params = {'foo': 'bár', 'abc': 'xyz'}
+ m.set_content(obj, *mimetype, params=params)
+ if isinstance(obj, str):
+ params['charset'] = 'utf-8'
+ self.assertEqual(m['Content-Type'].params, params)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py
index cd4f757c1f..57f12ae466 100644
--- a/Lib/test/test_email/test_email.py
+++ b/Lib/test/test_email/test_email.py
@@ -249,15 +249,42 @@ class TestMessageAPI(TestEmailBase):
self.assertIn('TO', msg)
def test_as_string(self):
- eq = self.ndiffAssertEqual
msg = self._msgobj('msg_01.txt')
with openfile('msg_01.txt') as fp:
text = fp.read()
- eq(text, str(msg))
+ self.assertEqual(text, str(msg))
fullrepr = msg.as_string(unixfrom=True)
lines = fullrepr.split('\n')
self.assertTrue(lines[0].startswith('From '))
- eq(text, NL.join(lines[1:]))
+ self.assertEqual(text, NL.join(lines[1:]))
+
+ def test_as_string_policy(self):
+ msg = self._msgobj('msg_01.txt')
+ newpolicy = msg.policy.clone(linesep='\r\n')
+ fullrepr = msg.as_string(policy=newpolicy)
+ s = StringIO()
+ g = Generator(s, policy=newpolicy)
+ g.flatten(msg)
+ self.assertEqual(fullrepr, s.getvalue())
+
+ def test_as_bytes(self):
+ msg = self._msgobj('msg_01.txt')
+ with openfile('msg_01.txt') as fp:
+ data = fp.read().encode('ascii')
+ self.assertEqual(data, bytes(msg))
+ fullrepr = msg.as_bytes(unixfrom=True)
+ lines = fullrepr.split(b'\n')
+ self.assertTrue(lines[0].startswith(b'From '))
+ self.assertEqual(data, b'\n'.join(lines[1:]))
+
+ def test_as_bytes_policy(self):
+ msg = self._msgobj('msg_01.txt')
+ newpolicy = msg.policy.clone(linesep='\r\n')
+ fullrepr = msg.as_bytes(policy=newpolicy)
+ s = BytesIO()
+ g = BytesGenerator(s,policy=newpolicy)
+ g.flatten(msg)
+ self.assertEqual(fullrepr, s.getvalue())
# test_headerregistry.TestContentTypeHeader.bad_params
def test_bad_param(self):
diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py
index f829f83e32..f2df3720eb 100644
--- a/Lib/test/test_email/test_headerregistry.py
+++ b/Lib/test/test_email/test_headerregistry.py
@@ -661,7 +661,7 @@ class TestContentTypeHeader(TestHeaderBase):
'text/plain; name="ascii_is_the_default"'),
'rfc2231_bad_character_in_charset_parameter_value': (
- "text/plain; charset*=ascii''utf-8%E2%80%9D",
+ "text/plain; charset*=ascii''utf-8%F1%F2%F3",
'text/plain',
'text',
'plain',
@@ -669,6 +669,18 @@ class TestContentTypeHeader(TestHeaderBase):
[errors.UndecodableBytesDefect],
'text/plain; charset="utf-8\uFFFD\uFFFD\uFFFD"'),
+ 'rfc2231_utf_8_in_supposedly_ascii_charset_parameter_value': (
+ "text/plain; charset*=ascii''utf-8%E2%80%9D",
+ 'text/plain',
+ 'text',
+ 'plain',
+ {'charset': 'utf-8”'},
+ [errors.UndecodableBytesDefect],
+ 'text/plain; charset="utf-8”"',
+ ),
+ # XXX: if the above were *re*folded, it would get tagged as utf-8
+ # instead of ascii in the param, since it now contains non-ASCII.
+
'rfc2231_encoded_then_unencoded_segments': (
('application/x-foo;'
'\tname*0*="us-ascii\'en-us\'My";'
diff --git a/Lib/test/test_email/test_message.py b/Lib/test/test_email/test_message.py
index 8cc3f80e46..e0ebb8ad6c 100644
--- a/Lib/test/test_email/test_message.py
+++ b/Lib/test/test_email/test_message.py
@@ -1,6 +1,13 @@
import unittest
+import textwrap
from email import policy
-from test.test_email import TestEmailBase
+from email.message import EmailMessage, MIMEPart
+from test.test_email import TestEmailBase, parameterize
+
+
+# Helper.
+def first(iterable):
+ return next(filter(lambda x: x is not None, iterable), None)
class Test(TestEmailBase):
@@ -14,5 +21,738 @@ class Test(TestEmailBase):
m['To'] = 'xyz@abc'
+@parameterize
+class TestEmailMessageBase:
+
+ policy = policy.default
+
+ # The first argument is a triple (related, html, plain) of indices into the
+ # list returned by 'walk' called on a Message constructed from the third.
+ # The indices indicate which part should match the corresponding part-type
+ # when passed to get_body (ie: the "first" part of that type in the
+ # message). The second argument is a list of indices into the 'walk' list
+ # of the attachments that should be returned by a call to
+ # 'iter_attachments'. The third argument is a list of indices into 'walk'
+ # that should be returned by a call to 'iter_parts'. Note that the first
+ # item returned by 'walk' is the Message itself.
+
+ message_params = {
+
+ 'empty_message': (
+ (None, None, 0),
+ (),
+ (),
+ ""),
+
+ 'non_mime_plain': (
+ (None, None, 0),
+ (),
+ (),
+ textwrap.dedent("""\
+ To: foo@example.com
+
+ simple text body
+ """)),
+
+ 'mime_non_text': (
+ (None, None, None),
+ (),
+ (),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: image/jpg
+
+ bogus body.
+ """)),
+
+ 'plain_html_alternative': (
+ (None, 2, 1),
+ (),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/alternative; boundary="==="
+
+ preamble
+
+ --===
+ Content-Type: text/plain
+
+ simple body
+
+ --===
+ Content-Type: text/html
+
+ <p>simple body</p>
+ --===--
+ """)),
+
+ 'plain_html_mixed': (
+ (None, 2, 1),
+ (),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ preamble
+
+ --===
+ Content-Type: text/plain
+
+ simple body
+
+ --===
+ Content-Type: text/html
+
+ <p>simple body</p>
+
+ --===--
+ """)),
+
+ 'plain_html_attachment_mixed': (
+ (None, None, 1),
+ (2,),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: text/plain
+
+ simple body
+
+ --===
+ Content-Type: text/html
+ Content-Disposition: attachment
+
+ <p>simple body</p>
+
+ --===--
+ """)),
+
+ 'html_text_attachment_mixed': (
+ (None, 2, None),
+ (1,),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: text/plain
+ Content-Disposition: AtTaChment
+
+ simple body
+
+ --===
+ Content-Type: text/html
+
+ <p>simple body</p>
+
+ --===--
+ """)),
+
+ 'html_text_attachment_inline_mixed': (
+ (None, 2, 1),
+ (),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: text/plain
+ Content-Disposition: InLine
+
+ simple body
+
+ --===
+ Content-Type: text/html
+ Content-Disposition: inline
+
+ <p>simple body</p>
+
+ --===--
+ """)),
+
+ # RFC 2387
+ 'related': (
+ (0, 1, None),
+ (2,),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/related; boundary="==="; type=text/html
+
+ --===
+ Content-Type: text/html
+
+ <p>simple body</p>
+
+ --===
+ Content-Type: image/jpg
+ Content-ID: <image1>
+
+ bogus data
+
+ --===--
+ """)),
+
+ # This message structure will probably never be seen in the wild, but
+ # it proves we distinguish between text parts based on 'start'. The
+ # content would not, of course, actually work :)
+ 'related_with_start': (
+ (0, 2, None),
+ (1,),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/related; boundary="==="; type=text/html;
+ start="<body>"
+
+ --===
+ Content-Type: text/html
+ Content-ID: <include>
+
+ useless text
+
+ --===
+ Content-Type: text/html
+ Content-ID: <body>
+
+ <p>simple body</p>
+ <!--#include file="<include>"-->
+
+ --===--
+ """)),
+
+
+ 'mixed_alternative_plain_related': (
+ (3, 4, 2),
+ (6, 7),
+ (1, 6, 7),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: multipart/alternative; boundary="+++"
+
+ --+++
+ Content-Type: text/plain
+
+ simple body
+
+ --+++
+ Content-Type: multipart/related; boundary="___"
+
+ --___
+ Content-Type: text/html
+
+ <p>simple body</p>
+
+ --___
+ Content-Type: image/jpg
+ Content-ID: <image1@cid>
+
+ bogus jpg body
+
+ --___--
+
+ --+++--
+
+ --===
+ Content-Type: image/jpg
+ Content-Disposition: attachment
+
+ bogus jpg body
+
+ --===
+ Content-Type: image/jpg
+ Content-Disposition: AttacHmenT
+
+ another bogus jpg body
+
+ --===--
+ """)),
+
+ # This structure suggested by Stephen J. Turnbull...may not exist/be
+ # supported in the wild, but we want to support it.
+ 'mixed_related_alternative_plain_html': (
+ (1, 4, 3),
+ (6, 7),
+ (1, 6, 7),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: multipart/related; boundary="+++"
+
+ --+++
+ Content-Type: multipart/alternative; boundary="___"
+
+ --___
+ Content-Type: text/plain
+
+ simple body
+
+ --___
+ Content-Type: text/html
+
+ <p>simple body</p>
+
+ --___--
+
+ --+++
+ Content-Type: image/jpg
+ Content-ID: <image1@cid>
+
+ bogus jpg body
+
+ --+++--
+
+ --===
+ Content-Type: image/jpg
+ Content-Disposition: attachment
+
+ bogus jpg body
+
+ --===
+ Content-Type: image/jpg
+ Content-Disposition: attachment
+
+ another bogus jpg body
+
+ --===--
+ """)),
+
+ # Same thing, but proving we only look at the root part, which is the
+ # first one if there isn't any start parameter. That is, this is a
+ # broken related.
+ 'mixed_related_alternative_plain_html_wrong_order': (
+ (1, None, None),
+ (6, 7),
+ (1, 6, 7),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: multipart/related; boundary="+++"
+
+ --+++
+ Content-Type: image/jpg
+ Content-ID: <image1@cid>
+
+ bogus jpg body
+
+ --+++
+ Content-Type: multipart/alternative; boundary="___"
+
+ --___
+ Content-Type: text/plain
+
+ simple body
+
+ --___
+ Content-Type: text/html
+
+ <p>simple body</p>
+
+ --___--
+
+ --+++--
+
+ --===
+ Content-Type: image/jpg
+ Content-Disposition: attachment
+
+ bogus jpg body
+
+ --===
+ Content-Type: image/jpg
+ Content-Disposition: attachment
+
+ another bogus jpg body
+
+ --===--
+ """)),
+
+ 'message_rfc822': (
+ (None, None, None),
+ (),
+ (),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: message/rfc822
+
+ To: bar@example.com
+ From: robot@examp.com
+
+ this is a message body.
+ """)),
+
+ 'mixed_text_message_rfc822': (
+ (None, None, 1),
+ (2,),
+ (1, 2),
+ textwrap.dedent("""\
+ To: foo@example.com
+ MIME-Version: 1.0
+ Content-Type: multipart/mixed; boundary="==="
+
+ --===
+ Content-Type: text/plain
+
+ Your message has bounced, ser.
+
+ --===
+ Content-Type: message/rfc822
+
+ To: bar@example.com
+ From: robot@examp.com
+
+ this is a message body.
+
+ --===--
+ """)),
+
+ }
+
+ def message_as_get_body(self, body_parts, attachments, parts, msg):
+ m = self._str_msg(msg)
+ allparts = list(m.walk())
+ expected = [None if n is None else allparts[n] for n in body_parts]
+ related = 0; html = 1; plain = 2
+ self.assertEqual(m.get_body(), first(expected))
+ self.assertEqual(m.get_body(preferencelist=(
+ 'related', 'html', 'plain')),
+ first(expected))
+ self.assertEqual(m.get_body(preferencelist=('related', 'html')),
+ first(expected[related:html+1]))
+ self.assertEqual(m.get_body(preferencelist=('related', 'plain')),
+ first([expected[related], expected[plain]]))
+ self.assertEqual(m.get_body(preferencelist=('html', 'plain')),
+ first(expected[html:plain+1]))
+ self.assertEqual(m.get_body(preferencelist=['related']),
+ expected[related])
+ self.assertEqual(m.get_body(preferencelist=['html']), expected[html])
+ self.assertEqual(m.get_body(preferencelist=['plain']), expected[plain])
+ self.assertEqual(m.get_body(preferencelist=('plain', 'html')),
+ first(expected[plain:html-1:-1]))
+ self.assertEqual(m.get_body(preferencelist=('plain', 'related')),
+ first([expected[plain], expected[related]]))
+ self.assertEqual(m.get_body(preferencelist=('html', 'related')),
+ first(expected[html::-1]))
+ self.assertEqual(m.get_body(preferencelist=('plain', 'html', 'related')),
+ first(expected[::-1]))
+ self.assertEqual(m.get_body(preferencelist=('html', 'plain', 'related')),
+ first([expected[html],
+ expected[plain],
+ expected[related]]))
+
+ def message_as_iter_attachment(self, body_parts, attachments, parts, msg):
+ m = self._str_msg(msg)
+ allparts = list(m.walk())
+ attachments = [allparts[n] for n in attachments]
+ self.assertEqual(list(m.iter_attachments()), attachments)
+
+ def message_as_iter_parts(self, body_parts, attachments, parts, msg):
+ m = self._str_msg(msg)
+ allparts = list(m.walk())
+ parts = [allparts[n] for n in parts]
+ self.assertEqual(list(m.iter_parts()), parts)
+
+ class _TestContentManager:
+ def get_content(self, msg, *args, **kw):
+ return msg, args, kw
+ def set_content(self, msg, *args, **kw):
+ self.msg = msg
+ self.args = args
+ self.kw = kw
+
+ def test_get_content_with_cm(self):
+ m = self._str_msg('')
+ cm = self._TestContentManager()
+ self.assertEqual(m.get_content(content_manager=cm), (m, (), {}))
+ msg, args, kw = m.get_content('foo', content_manager=cm, bar=1, k=2)
+ self.assertEqual(msg, m)
+ self.assertEqual(args, ('foo',))
+ self.assertEqual(kw, dict(bar=1, k=2))
+
+ def test_get_content_default_cm_comes_from_policy(self):
+ p = policy.default.clone(content_manager=self._TestContentManager())
+ m = self._str_msg('', policy=p)
+ self.assertEqual(m.get_content(), (m, (), {}))
+ msg, args, kw = m.get_content('foo', bar=1, k=2)
+ self.assertEqual(msg, m)
+ self.assertEqual(args, ('foo',))
+ self.assertEqual(kw, dict(bar=1, k=2))
+
+ def test_set_content_with_cm(self):
+ m = self._str_msg('')
+ cm = self._TestContentManager()
+ m.set_content(content_manager=cm)
+ self.assertEqual(cm.msg, m)
+ self.assertEqual(cm.args, ())
+ self.assertEqual(cm.kw, {})
+ m.set_content('foo', content_manager=cm, bar=1, k=2)
+ self.assertEqual(cm.msg, m)
+ self.assertEqual(cm.args, ('foo',))
+ self.assertEqual(cm.kw, dict(bar=1, k=2))
+
+ def test_set_content_default_cm_comes_from_policy(self):
+ cm = self._TestContentManager()
+ p = policy.default.clone(content_manager=cm)
+ m = self._str_msg('', policy=p)
+ m.set_content()
+ self.assertEqual(cm.msg, m)
+ self.assertEqual(cm.args, ())
+ self.assertEqual(cm.kw, {})
+ m.set_content('foo', bar=1, k=2)
+ self.assertEqual(cm.msg, m)
+ self.assertEqual(cm.args, ('foo',))
+ self.assertEqual(cm.kw, dict(bar=1, k=2))
+
+ # outcome is whether xxx_method should raise ValueError error when called
+ # on multipart/subtype. Blank outcome means it depends on xxx (add
+ # succeeds, make raises). Note: 'none' means there are content-type
+ # headers but payload is None...this happening in practice would be very
+ # unusual, so treating it as if there were content seems reasonable.
+ # method subtype outcome
+ subtype_params = (
+ ('related', 'no_content', 'succeeds'),
+ ('related', 'none', 'succeeds'),
+ ('related', 'plain', 'succeeds'),
+ ('related', 'related', ''),
+ ('related', 'alternative', 'raises'),
+ ('related', 'mixed', 'raises'),
+ ('alternative', 'no_content', 'succeeds'),
+ ('alternative', 'none', 'succeeds'),
+ ('alternative', 'plain', 'succeeds'),
+ ('alternative', 'related', 'succeeds'),
+ ('alternative', 'alternative', ''),
+ ('alternative', 'mixed', 'raises'),
+ ('mixed', 'no_content', 'succeeds'),
+ ('mixed', 'none', 'succeeds'),
+ ('mixed', 'plain', 'succeeds'),
+ ('mixed', 'related', 'succeeds'),
+ ('mixed', 'alternative', 'succeeds'),
+ ('mixed', 'mixed', ''),
+ )
+
+ def _make_subtype_test_message(self, subtype):
+ m = self.message()
+ payload = None
+ msg_headers = [
+ ('To', 'foo@bar.com'),
+ ('From', 'bar@foo.com'),
+ ]
+ if subtype != 'no_content':
+ ('content-shadow', 'Logrus'),
+ msg_headers.append(('X-Random-Header', 'Corwin'))
+ if subtype == 'text':
+ payload = ''
+ msg_headers.append(('Content-Type', 'text/plain'))
+ m.set_payload('')
+ elif subtype != 'no_content':
+ payload = []
+ msg_headers.append(('Content-Type', 'multipart/' + subtype))
+ msg_headers.append(('X-Trump', 'Random'))
+ m.set_payload(payload)
+ for name, value in msg_headers:
+ m[name] = value
+ return m, msg_headers, payload
+
+ def _check_disallowed_subtype_raises(self, m, method_name, subtype, method):
+ with self.assertRaises(ValueError) as ar:
+ getattr(m, method)()
+ exc_text = str(ar.exception)
+ self.assertIn(subtype, exc_text)
+ self.assertIn(method_name, exc_text)
+
+ def _check_make_multipart(self, m, msg_headers, payload):
+ count = 0
+ for name, value in msg_headers:
+ if not name.lower().startswith('content-'):
+ self.assertEqual(m[name], value)
+ count += 1
+ self.assertEqual(len(m), count+1) # +1 for new Content-Type
+ part = next(m.iter_parts())
+ count = 0
+ for name, value in msg_headers:
+ if name.lower().startswith('content-'):
+ self.assertEqual(part[name], value)
+ count += 1
+ self.assertEqual(len(part), count)
+ self.assertEqual(part.get_payload(), payload)
+
+ def subtype_as_make(self, method, subtype, outcome):
+ m, msg_headers, payload = self._make_subtype_test_message(subtype)
+ make_method = 'make_' + method
+ if outcome in ('', 'raises'):
+ self._check_disallowed_subtype_raises(m, method, subtype, make_method)
+ return
+ getattr(m, make_method)()
+ self.assertEqual(m.get_content_maintype(), 'multipart')
+ self.assertEqual(m.get_content_subtype(), method)
+ if subtype == 'no_content':
+ self.assertEqual(len(m.get_payload()), 0)
+ self.assertEqual(m.items(),
+ msg_headers + [('Content-Type',
+ 'multipart/'+method)])
+ else:
+ self.assertEqual(len(m.get_payload()), 1)
+ self._check_make_multipart(m, msg_headers, payload)
+
+ def subtype_as_make_with_boundary(self, method, subtype, outcome):
+ # Doing all variation is a bit of overkill...
+ m = self.message()
+ if outcome in ('', 'raises'):
+ m['Content-Type'] = 'multipart/' + subtype
+ with self.assertRaises(ValueError) as cm:
+ getattr(m, 'make_' + method)()
+ return
+ if subtype == 'plain':
+ m['Content-Type'] = 'text/plain'
+ elif subtype != 'no_content':
+ m['Content-Type'] = 'multipart/' + subtype
+ getattr(m, 'make_' + method)(boundary="abc")
+ self.assertTrue(m.is_multipart())
+ self.assertEqual(m.get_boundary(), 'abc')
+
+ def test_policy_on_part_made_by_make_comes_from_message(self):
+ for method in ('make_related', 'make_alternative', 'make_mixed'):
+ m = self.message(policy=self.policy.clone(content_manager='foo'))
+ m['Content-Type'] = 'text/plain'
+ getattr(m, method)()
+ self.assertEqual(m.get_payload(0).policy.content_manager, 'foo')
+
+ class _TestSetContentManager:
+ def set_content(self, msg, content, *args, **kw):
+ msg['Content-Type'] = 'text/plain'
+ msg.set_payload(content)
+
+ def subtype_as_add(self, method, subtype, outcome):
+ m, msg_headers, payload = self._make_subtype_test_message(subtype)
+ cm = self._TestSetContentManager()
+ add_method = 'add_attachment' if method=='mixed' else 'add_' + method
+ if outcome == 'raises':
+ self._check_disallowed_subtype_raises(m, method, subtype, add_method)
+ return
+ getattr(m, add_method)('test', content_manager=cm)
+ self.assertEqual(m.get_content_maintype(), 'multipart')
+ self.assertEqual(m.get_content_subtype(), method)
+ if method == subtype or subtype == 'no_content':
+ self.assertEqual(len(m.get_payload()), 1)
+ for name, value in msg_headers:
+ self.assertEqual(m[name], value)
+ part = m.get_payload()[0]
+ else:
+ self.assertEqual(len(m.get_payload()), 2)
+ self._check_make_multipart(m, msg_headers, payload)
+ part = m.get_payload()[1]
+ self.assertEqual(part.get_content_type(), 'text/plain')
+ self.assertEqual(part.get_payload(), 'test')
+ if method=='mixed':
+ self.assertEqual(part['Content-Disposition'], 'attachment')
+ elif method=='related':
+ self.assertEqual(part['Content-Disposition'], 'inline')
+ else:
+ # Otherwise we don't guess.
+ self.assertIsNone(part['Content-Disposition'])
+
+ class _TestSetRaisingContentManager:
+ def set_content(self, msg, content, *args, **kw):
+ raise Exception('test')
+
+ def test_default_content_manager_for_add_comes_from_policy(self):
+ cm = self._TestSetRaisingContentManager()
+ m = self.message(policy=self.policy.clone(content_manager=cm))
+ for method in ('add_related', 'add_alternative', 'add_attachment'):
+ with self.assertRaises(Exception) as ar:
+ getattr(m, method)('')
+ self.assertEqual(str(ar.exception), 'test')
+
+ def message_as_clear(self, body_parts, attachments, parts, msg):
+ m = self._str_msg(msg)
+ m.clear()
+ self.assertEqual(len(m), 0)
+ self.assertEqual(list(m.items()), [])
+ self.assertIsNone(m.get_payload())
+ self.assertEqual(list(m.iter_parts()), [])
+
+ def message_as_clear_content(self, body_parts, attachments, parts, msg):
+ m = self._str_msg(msg)
+ expected_headers = [h for h in m.keys()
+ if not h.lower().startswith('content-')]
+ m.clear_content()
+ self.assertEqual(list(m.keys()), expected_headers)
+ self.assertIsNone(m.get_payload())
+ self.assertEqual(list(m.iter_parts()), [])
+
+ def test_is_attachment(self):
+ m = self._make_message()
+ self.assertFalse(m.is_attachment)
+ m['Content-Disposition'] = 'inline'
+ self.assertFalse(m.is_attachment)
+ m.replace_header('Content-Disposition', 'attachment')
+ self.assertTrue(m.is_attachment)
+ m.replace_header('Content-Disposition', 'AtTachMent')
+ self.assertTrue(m.is_attachment)
+
+
+
+class TestEmailMessage(TestEmailMessageBase, TestEmailBase):
+ message = EmailMessage
+
+ def test_set_content_adds_MIME_Version(self):
+ m = self._str_msg('')
+ cm = self._TestContentManager()
+ self.assertNotIn('MIME-Version', m)
+ m.set_content(content_manager=cm)
+ self.assertEqual(m['MIME-Version'], '1.0')
+
+ class _MIME_Version_adding_CM:
+ def set_content(self, msg, *args, **kw):
+ msg['MIME-Version'] = '1.0'
+
+ def test_set_content_does_not_duplicate_MIME_Version(self):
+ m = self._str_msg('')
+ cm = self._MIME_Version_adding_CM()
+ self.assertNotIn('MIME-Version', m)
+ m.set_content(content_manager=cm)
+ self.assertEqual(m['MIME-Version'], '1.0')
+
+
+class TestMIMEPart(TestEmailMessageBase, TestEmailBase):
+ # Doing the full test run here may seem a bit redundant, since the two
+ # classes are almost identical. But what if they drift apart? So we do
+ # the full tests so that any future drift doesn't introduce bugs.
+ message = MIMEPart
+
+ def test_set_content_does_not_add_MIME_Version(self):
+ m = self._str_msg('')
+ cm = self._TestContentManager()
+ self.assertNotIn('MIME-Version', m)
+ m.set_content(content_manager=cm)
+ self.assertNotIn('MIME-Version', m)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_email/test_policy.py b/Lib/test/test_email/test_policy.py
index 983bd49a11..06ad5f24b8 100644
--- a/Lib/test/test_email/test_policy.py
+++ b/Lib/test/test_email/test_policy.py
@@ -30,6 +30,7 @@ class PolicyAPITests(unittest.TestCase):
'raise_on_defect': False,
'header_factory': email.policy.EmailPolicy.header_factory,
'refold_source': 'long',
+ 'content_manager': email.policy.EmailPolicy.content_manager,
})
# For each policy under test, we give here what we expect the defaults to
diff --git a/Lib/test/test_email/torture_test.py b/Lib/test/test_email/torture_test.py
index 544b1bbb39..19cf64f0c7 100644
--- a/Lib/test/test_email/torture_test.py
+++ b/Lib/test/test_email/torture_test.py
@@ -27,7 +27,7 @@ def openfile(filename):
# Prevent this test from running in the Python distro
try:
openfile('crispin-torture.txt')
-except IOError:
+except OSError:
raise TestSkipped
diff --git a/Lib/test/test_ensurepip.py b/Lib/test/test_ensurepip.py
new file mode 100644
index 0000000000..abf00fd188
--- /dev/null
+++ b/Lib/test/test_ensurepip.py
@@ -0,0 +1,128 @@
+import unittest
+import unittest.mock
+import ensurepip
+import test.support
+import os
+import os.path
+
+
+class TestEnsurePipVersion(unittest.TestCase):
+
+ def test_returns_version(self):
+ self.assertEqual(ensurepip._PIP_VERSION, ensurepip.version())
+
+
+class TestBootstrap(unittest.TestCase):
+
+ def setUp(self):
+ run_pip_patch = unittest.mock.patch("ensurepip._run_pip")
+ self.run_pip = run_pip_patch.start()
+ self.addCleanup(run_pip_patch.stop)
+
+ # Avoid side effects on the actual os module
+ os_patch = unittest.mock.patch("ensurepip.os")
+ patched_os = os_patch.start()
+ self.addCleanup(os_patch.stop)
+ patched_os.path = os.path
+ self.os_environ = patched_os.environ = os.environ.copy()
+
+ def test_basic_bootstrapping(self):
+ ensurepip.bootstrap()
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ additional_paths = self.run_pip.call_args[0][1]
+ self.assertEqual(len(additional_paths), 2)
+
+ def test_bootstrapping_with_root(self):
+ ensurepip.bootstrap(root="/foo/bar/")
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "--root", "/foo/bar/",
+ "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ def test_bootstrapping_with_user(self):
+ ensurepip.bootstrap(user=True)
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "--user", "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ def test_bootstrapping_with_upgrade(self):
+ ensurepip.bootstrap(upgrade=True)
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "--upgrade", "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ def test_bootstrapping_with_verbosity_1(self):
+ ensurepip.bootstrap(verbosity=1)
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "-v", "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ def test_bootstrapping_with_verbosity_2(self):
+ ensurepip.bootstrap(verbosity=2)
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "-vv", "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ def test_bootstrapping_with_verbosity_3(self):
+ ensurepip.bootstrap(verbosity=3)
+
+ self.run_pip.assert_called_once_with(
+ [
+ "install", "--no-index", "--find-links",
+ unittest.mock.ANY, "--pre", "-vvv", "setuptools", "pip",
+ ],
+ unittest.mock.ANY,
+ )
+
+ def test_bootstrapping_with_regular_install(self):
+ ensurepip.bootstrap()
+ self.assertEqual(self.os_environ["ENSUREPIP_OPTIONS"], "install")
+
+ def test_bootstrapping_with_alt_install(self):
+ ensurepip.bootstrap(altinstall=True)
+ self.assertEqual(self.os_environ["ENSUREPIP_OPTIONS"], "altinstall")
+
+ def test_bootstrapping_with_default_pip(self):
+ ensurepip.bootstrap(default_pip=True)
+ self.assertNotIn("ENSUREPIP_OPTIONS", self.os_environ)
+
+ def test_altinstall_default_pip_conflict(self):
+ with self.assertRaises(ValueError):
+ ensurepip.bootstrap(altinstall=True, default_pip=True)
+
+
+if __name__ == "__main__":
+ test.support.run_unittest(__name__)
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
new file mode 100644
index 0000000000..03f0e5d5cd
--- /dev/null
+++ b/Lib/test/test_enum.py
@@ -0,0 +1,1325 @@
+import enum
+import inspect
+import pydoc
+import unittest
+from collections import OrderedDict
+from enum import Enum, IntEnum, EnumMeta, unique
+from io import StringIO
+from pickle import dumps, loads, PicklingError
+
+# for pickle tests
+try:
+ class Stooges(Enum):
+ LARRY = 1
+ CURLY = 2
+ MOE = 3
+except Exception as exc:
+ Stooges = exc
+
+try:
+ class IntStooges(int, Enum):
+ LARRY = 1
+ CURLY = 2
+ MOE = 3
+except Exception as exc:
+ IntStooges = exc
+
+try:
+ class FloatStooges(float, Enum):
+ LARRY = 1.39
+ CURLY = 2.72
+ MOE = 3.142596
+except Exception as exc:
+ FloatStooges = exc
+
+# for pickle test and subclass tests
+try:
+ class StrEnum(str, Enum):
+ 'accepts only string values'
+ class Name(StrEnum):
+ BDFL = 'Guido van Rossum'
+ FLUFL = 'Barry Warsaw'
+except Exception as exc:
+ Name = exc
+
+try:
+ Question = Enum('Question', 'who what when where why', module=__name__)
+except Exception as exc:
+ Question = exc
+
+try:
+ Answer = Enum('Answer', 'him this then there because')
+except Exception as exc:
+ Answer = exc
+
+# for doctests
+try:
+ class Fruit(Enum):
+ tomato = 1
+ banana = 2
+ cherry = 3
+except Exception:
+ pass
+
+
+class TestHelpers(unittest.TestCase):
+ # _is_descriptor, _is_sunder, _is_dunder
+
+ def test_is_descriptor(self):
+ class foo:
+ pass
+ for attr in ('__get__','__set__','__delete__'):
+ obj = foo()
+ self.assertFalse(enum._is_descriptor(obj))
+ setattr(obj, attr, 1)
+ self.assertTrue(enum._is_descriptor(obj))
+
+ def test_is_sunder(self):
+ for s in ('_a_', '_aa_'):
+ self.assertTrue(enum._is_sunder(s))
+
+ for s in ('a', 'a_', '_a', '__a', 'a__', '__a__', '_a__', '__a_', '_',
+ '__', '___', '____', '_____',):
+ self.assertFalse(enum._is_sunder(s))
+
+ def test_is_dunder(self):
+ for s in ('__a__', '__aa__'):
+ self.assertTrue(enum._is_dunder(s))
+ for s in ('a', 'a_', '_a', '__a', 'a__', '_a_', '_a__', '__a_', '_',
+ '__', '___', '____', '_____',):
+ self.assertFalse(enum._is_dunder(s))
+
+
+class TestEnum(unittest.TestCase):
+ def setUp(self):
+ class Season(Enum):
+ SPRING = 1
+ SUMMER = 2
+ AUTUMN = 3
+ WINTER = 4
+ self.Season = Season
+
+ class Konstants(float, Enum):
+ E = 2.7182818
+ PI = 3.1415926
+ TAU = 2 * PI
+ self.Konstants = Konstants
+
+ class Grades(IntEnum):
+ A = 5
+ B = 4
+ C = 3
+ D = 2
+ F = 0
+ self.Grades = Grades
+
+ class Directional(str, Enum):
+ EAST = 'east'
+ WEST = 'west'
+ NORTH = 'north'
+ SOUTH = 'south'
+ self.Directional = Directional
+
+ from datetime import date
+ class Holiday(date, Enum):
+ NEW_YEAR = 2013, 1, 1
+ IDES_OF_MARCH = 2013, 3, 15
+ self.Holiday = Holiday
+
+ def test_dir_on_class(self):
+ Season = self.Season
+ self.assertEqual(
+ set(dir(Season)),
+ set(['__class__', '__doc__', '__members__', '__module__',
+ 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']),
+ )
+
+ def test_dir_on_item(self):
+ Season = self.Season
+ self.assertEqual(
+ set(dir(Season.WINTER)),
+ set(['__class__', '__doc__', '__module__', 'name', 'value']),
+ )
+
+ def test_dir_with_added_behavior(self):
+ class Test(Enum):
+ this = 'that'
+ these = 'those'
+ def wowser(self):
+ return ("Wowser! I'm %s!" % self.name)
+ self.assertEqual(
+ set(dir(Test)),
+ set(['__class__', '__doc__', '__members__', '__module__', 'this', 'these']),
+ )
+ self.assertEqual(
+ set(dir(Test.this)),
+ set(['__class__', '__doc__', '__module__', 'name', 'value', 'wowser']),
+ )
+
+ def test_enum_in_enum_out(self):
+ Season = self.Season
+ self.assertIs(Season(Season.WINTER), Season.WINTER)
+
+ def test_enum_value(self):
+ Season = self.Season
+ self.assertEqual(Season.SPRING.value, 1)
+
+ def test_intenum_value(self):
+ self.assertEqual(IntStooges.CURLY.value, 2)
+
+ def test_enum(self):
+ Season = self.Season
+ lst = list(Season)
+ self.assertEqual(len(lst), len(Season))
+ self.assertEqual(len(Season), 4, Season)
+ self.assertEqual(
+ [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst)
+
+ for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split(), 1):
+ e = Season(i)
+ self.assertEqual(e, getattr(Season, season))
+ self.assertEqual(e.value, i)
+ self.assertNotEqual(e, i)
+ self.assertEqual(e.name, season)
+ self.assertIn(e, Season)
+ self.assertIs(type(e), Season)
+ self.assertIsInstance(e, Season)
+ self.assertEqual(str(e), 'Season.' + season)
+ self.assertEqual(
+ repr(e),
+ '<Season.{0}: {1}>'.format(season, i),
+ )
+
+ def test_value_name(self):
+ Season = self.Season
+ self.assertEqual(Season.SPRING.name, 'SPRING')
+ self.assertEqual(Season.SPRING.value, 1)
+ with self.assertRaises(AttributeError):
+ Season.SPRING.name = 'invierno'
+ with self.assertRaises(AttributeError):
+ Season.SPRING.value = 2
+
+ def test_changing_member(self):
+ Season = self.Season
+ with self.assertRaises(AttributeError):
+ Season.WINTER = 'really cold'
+
+ def test_attribute_deletion(self):
+ class Season(Enum):
+ SPRING = 1
+ SUMMER = 2
+ AUTUMN = 3
+ WINTER = 4
+
+ def spam(cls):
+ pass
+
+ self.assertTrue(hasattr(Season, 'spam'))
+ del Season.spam
+ self.assertFalse(hasattr(Season, 'spam'))
+
+ with self.assertRaises(AttributeError):
+ del Season.SPRING
+ with self.assertRaises(AttributeError):
+ del Season.DRY
+ with self.assertRaises(AttributeError):
+ del Season.SPRING.name
+
+ def test_invalid_names(self):
+ with self.assertRaises(ValueError):
+ class Wrong(Enum):
+ mro = 9
+ with self.assertRaises(ValueError):
+ class Wrong(Enum):
+ _create_= 11
+ with self.assertRaises(ValueError):
+ class Wrong(Enum):
+ _get_mixins_ = 9
+ with self.assertRaises(ValueError):
+ class Wrong(Enum):
+ _find_new_ = 1
+ with self.assertRaises(ValueError):
+ class Wrong(Enum):
+ _any_name_ = 9
+
+ def test_contains(self):
+ Season = self.Season
+ self.assertIn(Season.AUTUMN, Season)
+ self.assertNotIn(3, Season)
+
+ val = Season(3)
+ self.assertIn(val, Season)
+
+ class OtherEnum(Enum):
+ one = 1; two = 2
+ self.assertNotIn(OtherEnum.two, Season)
+
+ def test_comparisons(self):
+ Season = self.Season
+ with self.assertRaises(TypeError):
+ Season.SPRING < Season.WINTER
+ with self.assertRaises(TypeError):
+ Season.SPRING > 4
+
+ self.assertNotEqual(Season.SPRING, 1)
+
+ class Part(Enum):
+ SPRING = 1
+ CLIP = 2
+ BARREL = 3
+
+ self.assertNotEqual(Season.SPRING, Part.SPRING)
+ with self.assertRaises(TypeError):
+ Season.SPRING < Part.CLIP
+
+ def test_enum_duplicates(self):
+ class Season(Enum):
+ SPRING = 1
+ SUMMER = 2
+ AUTUMN = FALL = 3
+ WINTER = 4
+ ANOTHER_SPRING = 1
+ lst = list(Season)
+ self.assertEqual(
+ lst,
+ [Season.SPRING, Season.SUMMER,
+ Season.AUTUMN, Season.WINTER,
+ ])
+ self.assertIs(Season.FALL, Season.AUTUMN)
+ self.assertEqual(Season.FALL.value, 3)
+ self.assertEqual(Season.AUTUMN.value, 3)
+ self.assertIs(Season(3), Season.AUTUMN)
+ self.assertIs(Season(1), Season.SPRING)
+ self.assertEqual(Season.FALL.name, 'AUTUMN')
+ self.assertEqual(
+ [k for k,v in Season.__members__.items() if v.name != k],
+ ['FALL', 'ANOTHER_SPRING'],
+ )
+
+ def test_duplicate_name(self):
+ with self.assertRaises(TypeError):
+ class Color(Enum):
+ red = 1
+ green = 2
+ blue = 3
+ red = 4
+
+ with self.assertRaises(TypeError):
+ class Color(Enum):
+ red = 1
+ green = 2
+ blue = 3
+ def red(self):
+ return 'red'
+
+ with self.assertRaises(TypeError):
+ class Color(Enum):
+ @property
+ def red(self):
+ return 'redder'
+ red = 1
+ green = 2
+ blue = 3
+
+
+ def test_enum_with_value_name(self):
+ class Huh(Enum):
+ name = 1
+ value = 2
+ self.assertEqual(
+ list(Huh),
+ [Huh.name, Huh.value],
+ )
+ self.assertIs(type(Huh.name), Huh)
+ self.assertEqual(Huh.name.name, 'name')
+ self.assertEqual(Huh.name.value, 1)
+
+ def test_format_enum(self):
+ Season = self.Season
+ self.assertEqual('{}'.format(Season.SPRING),
+ '{}'.format(str(Season.SPRING)))
+ self.assertEqual( '{:}'.format(Season.SPRING),
+ '{:}'.format(str(Season.SPRING)))
+ self.assertEqual('{:20}'.format(Season.SPRING),
+ '{:20}'.format(str(Season.SPRING)))
+ self.assertEqual('{:^20}'.format(Season.SPRING),
+ '{:^20}'.format(str(Season.SPRING)))
+ self.assertEqual('{:>20}'.format(Season.SPRING),
+ '{:>20}'.format(str(Season.SPRING)))
+ self.assertEqual('{:<20}'.format(Season.SPRING),
+ '{:<20}'.format(str(Season.SPRING)))
+
+ def test_format_enum_custom(self):
+ class TestFloat(float, Enum):
+ one = 1.0
+ two = 2.0
+ def __format__(self, spec):
+ return 'TestFloat success!'
+ self.assertEqual('{}'.format(TestFloat.one), 'TestFloat success!')
+
+ def assertFormatIsValue(self, spec, member):
+ self.assertEqual(spec.format(member), spec.format(member.value))
+
+ def test_format_enum_date(self):
+ Holiday = self.Holiday
+ self.assertFormatIsValue('{}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:20}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:^20}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:>20}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:<20}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:%Y %m}', Holiday.IDES_OF_MARCH)
+ self.assertFormatIsValue('{:%Y %m %M:00}', Holiday.IDES_OF_MARCH)
+
+ def test_format_enum_float(self):
+ Konstants = self.Konstants
+ self.assertFormatIsValue('{}', Konstants.TAU)
+ self.assertFormatIsValue('{:}', Konstants.TAU)
+ self.assertFormatIsValue('{:20}', Konstants.TAU)
+ self.assertFormatIsValue('{:^20}', Konstants.TAU)
+ self.assertFormatIsValue('{:>20}', Konstants.TAU)
+ self.assertFormatIsValue('{:<20}', Konstants.TAU)
+ self.assertFormatIsValue('{:n}', Konstants.TAU)
+ self.assertFormatIsValue('{:5.2}', Konstants.TAU)
+ self.assertFormatIsValue('{:f}', Konstants.TAU)
+
+ def test_format_enum_int(self):
+ Grades = self.Grades
+ self.assertFormatIsValue('{}', Grades.C)
+ self.assertFormatIsValue('{:}', Grades.C)
+ self.assertFormatIsValue('{:20}', Grades.C)
+ self.assertFormatIsValue('{:^20}', Grades.C)
+ self.assertFormatIsValue('{:>20}', Grades.C)
+ self.assertFormatIsValue('{:<20}', Grades.C)
+ self.assertFormatIsValue('{:+}', Grades.C)
+ self.assertFormatIsValue('{:08X}', Grades.C)
+ self.assertFormatIsValue('{:b}', Grades.C)
+
+ def test_format_enum_str(self):
+ Directional = self.Directional
+ self.assertFormatIsValue('{}', Directional.WEST)
+ self.assertFormatIsValue('{:}', Directional.WEST)
+ self.assertFormatIsValue('{:20}', Directional.WEST)
+ self.assertFormatIsValue('{:^20}', Directional.WEST)
+ self.assertFormatIsValue('{:>20}', Directional.WEST)
+ self.assertFormatIsValue('{:<20}', Directional.WEST)
+
+ def test_hash(self):
+ Season = self.Season
+ dates = {}
+ dates[Season.WINTER] = '1225'
+ dates[Season.SPRING] = '0315'
+ dates[Season.SUMMER] = '0704'
+ dates[Season.AUTUMN] = '1031'
+ self.assertEqual(dates[Season.AUTUMN], '1031')
+
+ def test_intenum_from_scratch(self):
+ class phy(int, Enum):
+ pi = 3
+ tau = 2 * pi
+ self.assertTrue(phy.pi < phy.tau)
+
+ def test_intenum_inherited(self):
+ class IntEnum(int, Enum):
+ pass
+ class phy(IntEnum):
+ pi = 3
+ tau = 2 * pi
+ self.assertTrue(phy.pi < phy.tau)
+
+ def test_floatenum_from_scratch(self):
+ class phy(float, Enum):
+ pi = 3.1415926
+ tau = 2 * pi
+ self.assertTrue(phy.pi < phy.tau)
+
+ def test_floatenum_inherited(self):
+ class FloatEnum(float, Enum):
+ pass
+ class phy(FloatEnum):
+ pi = 3.1415926
+ tau = 2 * pi
+ self.assertTrue(phy.pi < phy.tau)
+
+ def test_strenum_from_scratch(self):
+ class phy(str, Enum):
+ pi = 'Pi'
+ tau = 'Tau'
+ self.assertTrue(phy.pi < phy.tau)
+
+ def test_strenum_inherited(self):
+ class StrEnum(str, Enum):
+ pass
+ class phy(StrEnum):
+ pi = 'Pi'
+ tau = 'Tau'
+ self.assertTrue(phy.pi < phy.tau)
+
+
+ def test_intenum(self):
+ class WeekDay(IntEnum):
+ SUNDAY = 1
+ MONDAY = 2
+ TUESDAY = 3
+ WEDNESDAY = 4
+ THURSDAY = 5
+ FRIDAY = 6
+ SATURDAY = 7
+
+ self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c')
+ self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2])
+
+ lst = list(WeekDay)
+ self.assertEqual(len(lst), len(WeekDay))
+ self.assertEqual(len(WeekDay), 7)
+ target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY'
+ target = target.split()
+ for i, weekday in enumerate(target, 1):
+ e = WeekDay(i)
+ self.assertEqual(e, i)
+ self.assertEqual(int(e), i)
+ self.assertEqual(e.name, weekday)
+ self.assertIn(e, WeekDay)
+ self.assertEqual(lst.index(e)+1, i)
+ self.assertTrue(0 < e < 8)
+ self.assertIs(type(e), WeekDay)
+ self.assertIsInstance(e, int)
+ self.assertIsInstance(e, Enum)
+
+ def test_intenum_duplicates(self):
+ class WeekDay(IntEnum):
+ SUNDAY = 1
+ MONDAY = 2
+ TUESDAY = TEUSDAY = 3
+ WEDNESDAY = 4
+ THURSDAY = 5
+ FRIDAY = 6
+ SATURDAY = 7
+ self.assertIs(WeekDay.TEUSDAY, WeekDay.TUESDAY)
+ self.assertEqual(WeekDay(3).name, 'TUESDAY')
+ self.assertEqual([k for k,v in WeekDay.__members__.items()
+ if v.name != k], ['TEUSDAY', ])
+
+ def test_pickle_enum(self):
+ if isinstance(Stooges, Exception):
+ raise Stooges
+ self.assertIs(Stooges.CURLY, loads(dumps(Stooges.CURLY)))
+ self.assertIs(Stooges, loads(dumps(Stooges)))
+
+ def test_pickle_int(self):
+ if isinstance(IntStooges, Exception):
+ raise IntStooges
+ self.assertIs(IntStooges.CURLY, loads(dumps(IntStooges.CURLY)))
+ self.assertIs(IntStooges, loads(dumps(IntStooges)))
+
+ def test_pickle_float(self):
+ if isinstance(FloatStooges, Exception):
+ raise FloatStooges
+ self.assertIs(FloatStooges.CURLY, loads(dumps(FloatStooges.CURLY)))
+ self.assertIs(FloatStooges, loads(dumps(FloatStooges)))
+
+ def test_pickle_enum_function(self):
+ if isinstance(Answer, Exception):
+ raise Answer
+ self.assertIs(Answer.him, loads(dumps(Answer.him)))
+ self.assertIs(Answer, loads(dumps(Answer)))
+
+ def test_pickle_enum_function_with_module(self):
+ if isinstance(Question, Exception):
+ raise Question
+ self.assertIs(Question.who, loads(dumps(Question.who)))
+ self.assertIs(Question, loads(dumps(Question)))
+
+ def test_exploding_pickle(self):
+ BadPickle = Enum('BadPickle', 'dill sweet bread-n-butter')
+ enum._make_class_unpicklable(BadPickle)
+ globals()['BadPickle'] = BadPickle
+ with self.assertRaises(TypeError):
+ dumps(BadPickle.dill)
+ with self.assertRaises(PicklingError):
+ dumps(BadPickle)
+
+ def test_string_enum(self):
+ class SkillLevel(str, Enum):
+ master = 'what is the sound of one hand clapping?'
+ journeyman = 'why did the chicken cross the road?'
+ apprentice = 'knock, knock!'
+ self.assertEqual(SkillLevel.apprentice, 'knock, knock!')
+
+ def test_getattr_getitem(self):
+ class Period(Enum):
+ morning = 1
+ noon = 2
+ evening = 3
+ night = 4
+ self.assertIs(Period(2), Period.noon)
+ self.assertIs(getattr(Period, 'night'), Period.night)
+ self.assertIs(Period['morning'], Period.morning)
+
+ def test_getattr_dunder(self):
+ Season = self.Season
+ self.assertTrue(getattr(Season, '__eq__'))
+
+ def test_iteration_order(self):
+ class Season(Enum):
+ SUMMER = 2
+ WINTER = 4
+ AUTUMN = 3
+ SPRING = 1
+ self.assertEqual(
+ list(Season),
+ [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING],
+ )
+
+ def test_reversed_iteration_order(self):
+ self.assertEqual(
+ list(reversed(self.Season)),
+ [self.Season.WINTER, self.Season.AUTUMN, self.Season.SUMMER,
+ self.Season.SPRING]
+ )
+
+ def test_programatic_function_string(self):
+ SummerMonth = Enum('SummerMonth', 'june july august')
+ lst = list(SummerMonth)
+ self.assertEqual(len(lst), len(SummerMonth))
+ self.assertEqual(len(SummerMonth), 3, SummerMonth)
+ self.assertEqual(
+ [SummerMonth.june, SummerMonth.july, SummerMonth.august],
+ lst,
+ )
+ for i, month in enumerate('june july august'.split(), 1):
+ e = SummerMonth(i)
+ self.assertEqual(int(e.value), i)
+ self.assertNotEqual(e, i)
+ self.assertEqual(e.name, month)
+ self.assertIn(e, SummerMonth)
+ self.assertIs(type(e), SummerMonth)
+
+ def test_programatic_function_string_list(self):
+ SummerMonth = Enum('SummerMonth', ['june', 'july', 'august'])
+ lst = list(SummerMonth)
+ self.assertEqual(len(lst), len(SummerMonth))
+ self.assertEqual(len(SummerMonth), 3, SummerMonth)
+ self.assertEqual(
+ [SummerMonth.june, SummerMonth.july, SummerMonth.august],
+ lst,
+ )
+ for i, month in enumerate('june july august'.split(), 1):
+ e = SummerMonth(i)
+ self.assertEqual(int(e.value), i)
+ self.assertNotEqual(e, i)
+ self.assertEqual(e.name, month)
+ self.assertIn(e, SummerMonth)
+ self.assertIs(type(e), SummerMonth)
+
+ def test_programatic_function_iterable(self):
+ SummerMonth = Enum(
+ 'SummerMonth',
+ (('june', 1), ('july', 2), ('august', 3))
+ )
+ lst = list(SummerMonth)
+ self.assertEqual(len(lst), len(SummerMonth))
+ self.assertEqual(len(SummerMonth), 3, SummerMonth)
+ self.assertEqual(
+ [SummerMonth.june, SummerMonth.july, SummerMonth.august],
+ lst,
+ )
+ for i, month in enumerate('june july august'.split(), 1):
+ e = SummerMonth(i)
+ self.assertEqual(int(e.value), i)
+ self.assertNotEqual(e, i)
+ self.assertEqual(e.name, month)
+ self.assertIn(e, SummerMonth)
+ self.assertIs(type(e), SummerMonth)
+
+ def test_programatic_function_from_dict(self):
+ SummerMonth = Enum(
+ 'SummerMonth',
+ OrderedDict((('june', 1), ('july', 2), ('august', 3)))
+ )
+ lst = list(SummerMonth)
+ self.assertEqual(len(lst), len(SummerMonth))
+ self.assertEqual(len(SummerMonth), 3, SummerMonth)
+ self.assertEqual(
+ [SummerMonth.june, SummerMonth.july, SummerMonth.august],
+ lst,
+ )
+ for i, month in enumerate('june july august'.split(), 1):
+ e = SummerMonth(i)
+ self.assertEqual(int(e.value), i)
+ self.assertNotEqual(e, i)
+ self.assertEqual(e.name, month)
+ self.assertIn(e, SummerMonth)
+ self.assertIs(type(e), SummerMonth)
+
+ def test_programatic_function_type(self):
+ SummerMonth = Enum('SummerMonth', 'june july august', type=int)
+ lst = list(SummerMonth)
+ self.assertEqual(len(lst), len(SummerMonth))
+ self.assertEqual(len(SummerMonth), 3, SummerMonth)
+ self.assertEqual(
+ [SummerMonth.june, SummerMonth.july, SummerMonth.august],
+ lst,
+ )
+ for i, month in enumerate('june july august'.split(), 1):
+ e = SummerMonth(i)
+ self.assertEqual(e, i)
+ self.assertEqual(e.name, month)
+ self.assertIn(e, SummerMonth)
+ self.assertIs(type(e), SummerMonth)
+
+ def test_programatic_function_type_from_subclass(self):
+ SummerMonth = IntEnum('SummerMonth', 'june july august')
+ lst = list(SummerMonth)
+ self.assertEqual(len(lst), len(SummerMonth))
+ self.assertEqual(len(SummerMonth), 3, SummerMonth)
+ self.assertEqual(
+ [SummerMonth.june, SummerMonth.july, SummerMonth.august],
+ lst,
+ )
+ for i, month in enumerate('june july august'.split(), 1):
+ e = SummerMonth(i)
+ self.assertEqual(e, i)
+ self.assertEqual(e.name, month)
+ self.assertIn(e, SummerMonth)
+ self.assertIs(type(e), SummerMonth)
+
+ def test_subclassing(self):
+ if isinstance(Name, Exception):
+ raise Name
+ self.assertEqual(Name.BDFL, 'Guido van Rossum')
+ self.assertTrue(Name.BDFL, Name('Guido van Rossum'))
+ self.assertIs(Name.BDFL, getattr(Name, 'BDFL'))
+ self.assertIs(Name.BDFL, loads(dumps(Name.BDFL)))
+
+ def test_extending(self):
+ class Color(Enum):
+ red = 1
+ green = 2
+ blue = 3
+ with self.assertRaises(TypeError):
+ class MoreColor(Color):
+ cyan = 4
+ magenta = 5
+ yellow = 6
+
+ def test_exclude_methods(self):
+ class whatever(Enum):
+ this = 'that'
+ these = 'those'
+ def really(self):
+ return 'no, not %s' % self.value
+ self.assertIsNot(type(whatever.really), whatever)
+ self.assertEqual(whatever.this.really(), 'no, not that')
+
+ def test_wrong_inheritance_order(self):
+ with self.assertRaises(TypeError):
+ class Wrong(Enum, str):
+ NotHere = 'error before this point'
+
+ def test_intenum_transitivity(self):
+ class number(IntEnum):
+ one = 1
+ two = 2
+ three = 3
+ class numero(IntEnum):
+ uno = 1
+ dos = 2
+ tres = 3
+ self.assertEqual(number.one, numero.uno)
+ self.assertEqual(number.two, numero.dos)
+ self.assertEqual(number.three, numero.tres)
+
+ def test_wrong_enum_in_call(self):
+ class Monochrome(Enum):
+ black = 0
+ white = 1
+ class Gender(Enum):
+ male = 0
+ female = 1
+ self.assertRaises(ValueError, Monochrome, Gender.male)
+
+ def test_wrong_enum_in_mixed_call(self):
+ class Monochrome(IntEnum):
+ black = 0
+ white = 1
+ class Gender(Enum):
+ male = 0
+ female = 1
+ self.assertRaises(ValueError, Monochrome, Gender.male)
+
+ def test_mixed_enum_in_call_1(self):
+ class Monochrome(IntEnum):
+ black = 0
+ white = 1
+ class Gender(IntEnum):
+ male = 0
+ female = 1
+ self.assertIs(Monochrome(Gender.female), Monochrome.white)
+
+ def test_mixed_enum_in_call_2(self):
+ class Monochrome(Enum):
+ black = 0
+ white = 1
+ class Gender(IntEnum):
+ male = 0
+ female = 1
+ self.assertIs(Monochrome(Gender.male), Monochrome.black)
+
+ def test_flufl_enum(self):
+ class Fluflnum(Enum):
+ def __int__(self):
+ return int(self.value)
+ class MailManOptions(Fluflnum):
+ option1 = 1
+ option2 = 2
+ option3 = 3
+ self.assertEqual(int(MailManOptions.option1), 1)
+
+ def test_introspection(self):
+ class Number(IntEnum):
+ one = 100
+ two = 200
+ self.assertIs(Number.one._member_type_, int)
+ self.assertIs(Number._member_type_, int)
+ class String(str, Enum):
+ yarn = 'soft'
+ rope = 'rough'
+ wire = 'hard'
+ self.assertIs(String.yarn._member_type_, str)
+ self.assertIs(String._member_type_, str)
+ class Plain(Enum):
+ vanilla = 'white'
+ one = 1
+ self.assertIs(Plain.vanilla._member_type_, object)
+ self.assertIs(Plain._member_type_, object)
+
+ def test_no_such_enum_member(self):
+ class Color(Enum):
+ red = 1
+ green = 2
+ blue = 3
+ with self.assertRaises(ValueError):
+ Color(4)
+ with self.assertRaises(KeyError):
+ Color['chartreuse']
+
+ def test_new_repr(self):
+ class Color(Enum):
+ red = 1
+ green = 2
+ blue = 3
+ def __repr__(self):
+ return "don't you just love shades of %s?" % self.name
+ self.assertEqual(
+ repr(Color.blue),
+ "don't you just love shades of blue?",
+ )
+
+ def test_inherited_repr(self):
+ class MyEnum(Enum):
+ def __repr__(self):
+ return "My name is %s." % self.name
+ class MyIntEnum(int, MyEnum):
+ this = 1
+ that = 2
+ theother = 3
+ self.assertEqual(repr(MyIntEnum.that), "My name is that.")
+
+ def test_multiple_mixin_mro(self):
+ class auto_enum(type(Enum)):
+ def __new__(metacls, cls, bases, classdict):
+ temp = type(classdict)()
+ names = set(classdict._member_names)
+ i = 0
+ for k in classdict._member_names:
+ v = classdict[k]
+ if v is Ellipsis:
+ v = i
+ else:
+ i = v
+ i += 1
+ temp[k] = v
+ for k, v in classdict.items():
+ if k not in names:
+ temp[k] = v
+ return super(auto_enum, metacls).__new__(
+ metacls, cls, bases, temp)
+
+ class AutoNumberedEnum(Enum, metaclass=auto_enum):
+ pass
+
+ class AutoIntEnum(IntEnum, metaclass=auto_enum):
+ pass
+
+ class TestAutoNumber(AutoNumberedEnum):
+ a = ...
+ b = 3
+ c = ...
+
+ class TestAutoInt(AutoIntEnum):
+ a = ...
+ b = 3
+ c = ...
+
+ def test_subclasses_with_getnewargs(self):
+ class NamedInt(int):
+ def __new__(cls, *args):
+ _args = args
+ name, *args = args
+ if len(args) == 0:
+ raise TypeError("name and value must be specified")
+ self = int.__new__(cls, *args)
+ self._intname = name
+ self._args = _args
+ return self
+ def __getnewargs__(self):
+ return self._args
+ @property
+ def __name__(self):
+ return self._intname
+ def __repr__(self):
+ # repr() is updated to include the name and type info
+ return "{}({!r}, {})".format(type(self).__name__,
+ self.__name__,
+ int.__repr__(self))
+ def __str__(self):
+ # str() is unchanged, even if it relies on the repr() fallback
+ base = int
+ base_str = base.__str__
+ if base_str.__objclass__ is object:
+ return base.__repr__(self)
+ return base_str(self)
+ # for simplicity, we only define one operator that
+ # propagates expressions
+ def __add__(self, other):
+ temp = int(self) + int( other)
+ if isinstance(self, NamedInt) and isinstance(other, NamedInt):
+ return NamedInt(
+ '({0} + {1})'.format(self.__name__, other.__name__),
+ temp )
+ else:
+ return temp
+
+ class NEI(NamedInt, Enum):
+ x = ('the-x', 1)
+ y = ('the-y', 2)
+
+
+ self.assertIs(NEI.__new__, Enum.__new__)
+ self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)")
+ globals()['NamedInt'] = NamedInt
+ globals()['NEI'] = NEI
+ NI5 = NamedInt('test', 5)
+ self.assertEqual(NI5, 5)
+ self.assertEqual(loads(dumps(NI5)), 5)
+ self.assertEqual(NEI.y.value, 2)
+ self.assertIs(loads(dumps(NEI.y)), NEI.y)
+
+ def test_subclasses_without_getnewargs(self):
+ class NamedInt(int):
+ def __new__(cls, *args):
+ _args = args
+ name, *args = args
+ if len(args) == 0:
+ raise TypeError("name and value must be specified")
+ self = int.__new__(cls, *args)
+ self._intname = name
+ self._args = _args
+ return self
+ @property
+ def __name__(self):
+ return self._intname
+ def __repr__(self):
+ # repr() is updated to include the name and type info
+ return "{}({!r}, {})".format(type(self).__name__,
+ self.__name__,
+ int.__repr__(self))
+ def __str__(self):
+ # str() is unchanged, even if it relies on the repr() fallback
+ base = int
+ base_str = base.__str__
+ if base_str.__objclass__ is object:
+ return base.__repr__(self)
+ return base_str(self)
+ # for simplicity, we only define one operator that
+ # propagates expressions
+ def __add__(self, other):
+ temp = int(self) + int( other)
+ if isinstance(self, NamedInt) and isinstance(other, NamedInt):
+ return NamedInt(
+ '({0} + {1})'.format(self.__name__, other.__name__),
+ temp )
+ else:
+ return temp
+
+ class NEI(NamedInt, Enum):
+ x = ('the-x', 1)
+ y = ('the-y', 2)
+
+ self.assertIs(NEI.__new__, Enum.__new__)
+ self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)")
+ globals()['NamedInt'] = NamedInt
+ globals()['NEI'] = NEI
+ NI5 = NamedInt('test', 5)
+ self.assertEqual(NI5, 5)
+ self.assertEqual(NEI.y.value, 2)
+ with self.assertRaises(TypeError):
+ dumps(NEI.x)
+ with self.assertRaises(PicklingError):
+ dumps(NEI)
+
+ def test_tuple_subclass(self):
+ class SomeTuple(tuple, Enum):
+ first = (1, 'for the money')
+ second = (2, 'for the show')
+ third = (3, 'for the music')
+ self.assertIs(type(SomeTuple.first), SomeTuple)
+ self.assertIsInstance(SomeTuple.second, tuple)
+ self.assertEqual(SomeTuple.third, (3, 'for the music'))
+ globals()['SomeTuple'] = SomeTuple
+ self.assertIs(loads(dumps(SomeTuple.first)), SomeTuple.first)
+
+ def test_duplicate_values_give_unique_enum_items(self):
+ class AutoNumber(Enum):
+ first = ()
+ second = ()
+ third = ()
+ def __new__(cls):
+ value = len(cls.__members__) + 1
+ obj = object.__new__(cls)
+ obj._value_ = value
+ return obj
+ def __int__(self):
+ return int(self._value_)
+ self.assertEqual(
+ list(AutoNumber),
+ [AutoNumber.first, AutoNumber.second, AutoNumber.third],
+ )
+ self.assertEqual(int(AutoNumber.second), 2)
+ self.assertEqual(AutoNumber.third.value, 3)
+ self.assertIs(AutoNumber(1), AutoNumber.first)
+
+ def test_inherited_new_from_enhanced_enum(self):
+ class AutoNumber(Enum):
+ def __new__(cls):
+ value = len(cls.__members__) + 1
+ obj = object.__new__(cls)
+ obj._value_ = value
+ return obj
+ def __int__(self):
+ return int(self._value_)
+ class Color(AutoNumber):
+ red = ()
+ green = ()
+ blue = ()
+ self.assertEqual(list(Color), [Color.red, Color.green, Color.blue])
+ self.assertEqual(list(map(int, Color)), [1, 2, 3])
+
+ def test_inherited_new_from_mixed_enum(self):
+ class AutoNumber(IntEnum):
+ def __new__(cls):
+ value = len(cls.__members__) + 1
+ obj = int.__new__(cls, value)
+ obj._value_ = value
+ return obj
+ class Color(AutoNumber):
+ red = ()
+ green = ()
+ blue = ()
+ self.assertEqual(list(Color), [Color.red, Color.green, Color.blue])
+ self.assertEqual(list(map(int, Color)), [1, 2, 3])
+
+ def test_equality(self):
+ class AlwaysEqual:
+ def __eq__(self, other):
+ return True
+ class OrdinaryEnum(Enum):
+ a = 1
+ self.assertEqual(AlwaysEqual(), OrdinaryEnum.a)
+ self.assertEqual(OrdinaryEnum.a, AlwaysEqual())
+
+ def test_ordered_mixin(self):
+ class OrderedEnum(Enum):
+ def __ge__(self, other):
+ if self.__class__ is other.__class__:
+ return self._value_ >= other._value_
+ return NotImplemented
+ def __gt__(self, other):
+ if self.__class__ is other.__class__:
+ return self._value_ > other._value_
+ return NotImplemented
+ def __le__(self, other):
+ if self.__class__ is other.__class__:
+ return self._value_ <= other._value_
+ return NotImplemented
+ def __lt__(self, other):
+ if self.__class__ is other.__class__:
+ return self._value_ < other._value_
+ return NotImplemented
+ class Grade(OrderedEnum):
+ A = 5
+ B = 4
+ C = 3
+ D = 2
+ F = 1
+ self.assertGreater(Grade.A, Grade.B)
+ self.assertLessEqual(Grade.F, Grade.C)
+ self.assertLess(Grade.D, Grade.A)
+ self.assertGreaterEqual(Grade.B, Grade.B)
+ self.assertEqual(Grade.B, Grade.B)
+ self.assertNotEqual(Grade.C, Grade.D)
+
+ def test_extending2(self):
+ class Shade(Enum):
+ def shade(self):
+ print(self.name)
+ class Color(Shade):
+ red = 1
+ green = 2
+ blue = 3
+ with self.assertRaises(TypeError):
+ class MoreColor(Color):
+ cyan = 4
+ magenta = 5
+ yellow = 6
+
+ def test_extending3(self):
+ class Shade(Enum):
+ def shade(self):
+ return self.name
+ class Color(Shade):
+ def hex(self):
+ return '%s hexlified!' % self.value
+ class MoreColor(Color):
+ cyan = 4
+ magenta = 5
+ yellow = 6
+ self.assertEqual(MoreColor.magenta.hex(), '5 hexlified!')
+
+
+ def test_no_duplicates(self):
+ class UniqueEnum(Enum):
+ def __init__(self, *args):
+ cls = self.__class__
+ if any(self.value == e.value for e in cls):
+ a = self.name
+ e = cls(self.value).name
+ raise ValueError(
+ "aliases not allowed in UniqueEnum: %r --> %r"
+ % (a, e)
+ )
+ class Color(UniqueEnum):
+ red = 1
+ green = 2
+ blue = 3
+ with self.assertRaises(ValueError):
+ class Color(UniqueEnum):
+ red = 1
+ green = 2
+ blue = 3
+ grene = 2
+
+ def test_init(self):
+ class Planet(Enum):
+ MERCURY = (3.303e+23, 2.4397e6)
+ VENUS = (4.869e+24, 6.0518e6)
+ EARTH = (5.976e+24, 6.37814e6)
+ MARS = (6.421e+23, 3.3972e6)
+ JUPITER = (1.9e+27, 7.1492e7)
+ SATURN = (5.688e+26, 6.0268e7)
+ URANUS = (8.686e+25, 2.5559e7)
+ NEPTUNE = (1.024e+26, 2.4746e7)
+ def __init__(self, mass, radius):
+ self.mass = mass # in kilograms
+ self.radius = radius # in meters
+ @property
+ def surface_gravity(self):
+ # universal gravitational constant (m3 kg-1 s-2)
+ G = 6.67300E-11
+ return G * self.mass / (self.radius * self.radius)
+ self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80)
+ self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6))
+
+ def test_nonhash_value(self):
+ class AutoNumberInAList(Enum):
+ def __new__(cls):
+ value = [len(cls.__members__) + 1]
+ obj = object.__new__(cls)
+ obj._value_ = value
+ return obj
+ class ColorInAList(AutoNumberInAList):
+ red = ()
+ green = ()
+ blue = ()
+ self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue])
+ for enum, value in zip(ColorInAList, range(3)):
+ value += 1
+ self.assertEqual(enum.value, [value])
+ self.assertIs(ColorInAList([value]), enum)
+
+ def test_conflicting_types_resolved_in_new(self):
+ class LabelledIntEnum(int, Enum):
+ def __new__(cls, *args):
+ value, label = args
+ obj = int.__new__(cls, value)
+ obj.label = label
+ obj._value_ = value
+ return obj
+
+ class LabelledList(LabelledIntEnum):
+ unprocessed = (1, "Unprocessed")
+ payment_complete = (2, "Payment Complete")
+
+ self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete])
+ self.assertEqual(LabelledList.unprocessed, 1)
+ self.assertEqual(LabelledList(1), LabelledList.unprocessed)
+
+
+class TestUnique(unittest.TestCase):
+
+ def test_unique_clean(self):
+ @unique
+ class Clean(Enum):
+ one = 1
+ two = 'dos'
+ tres = 4.0
+ @unique
+ class Cleaner(IntEnum):
+ single = 1
+ double = 2
+ triple = 3
+
+ def test_unique_dirty(self):
+ with self.assertRaisesRegex(ValueError, 'tres.*one'):
+ @unique
+ class Dirty(Enum):
+ one = 1
+ two = 'dos'
+ tres = 1
+ with self.assertRaisesRegex(
+ ValueError,
+ 'double.*single.*turkey.*triple',
+ ):
+ @unique
+ class Dirtier(IntEnum):
+ single = 1
+ double = 1
+ triple = 3
+ turkey = 3
+
+
+expected_help_output = """
+Help on class Color in module %s:
+
+class Color(enum.Enum)
+ | Method resolution order:
+ | Color
+ | enum.Enum
+ | builtins.object
+ |\x20\x20
+ | Data and other attributes defined here:
+ |\x20\x20
+ | blue = <Color.blue: 3>
+ |\x20\x20
+ | green = <Color.green: 2>
+ |\x20\x20
+ | red = <Color.red: 1>
+ |\x20\x20
+ | ----------------------------------------------------------------------
+ | Data descriptors inherited from enum.Enum:
+ |\x20\x20
+ | name
+ | The name of the Enum member.
+ |\x20\x20
+ | value
+ | The value of the Enum member.
+ |\x20\x20
+ | ----------------------------------------------------------------------
+ | Data descriptors inherited from enum.EnumMeta:
+ |\x20\x20
+ | __members__
+ | Returns a mapping of member name->value.
+ |\x20\x20\x20\x20\x20\x20
+ | This mapping lists all enum members, including aliases. Note that this
+ | is a read-only view of the internal mapping.
+""".strip()
+
+class TestStdLib(unittest.TestCase):
+
+ class Color(Enum):
+ red = 1
+ green = 2
+ blue = 3
+
+ def test_pydoc(self):
+ # indirectly test __objclass__
+ expected_text = expected_help_output % __name__
+ output = StringIO()
+ helper = pydoc.Helper(output=output)
+ helper(self.Color)
+ result = output.getvalue().strip()
+ if result != expected_text:
+ print_diffs(expected_text, result)
+ self.fail("outputs are not equal, see diff above")
+
+ def test_inspect_getmembers(self):
+ values = dict((
+ ('__class__', EnumMeta),
+ ('__doc__', None),
+ ('__members__', self.Color.__members__),
+ ('__module__', __name__),
+ ('blue', self.Color.blue),
+ ('green', self.Color.green),
+ ('name', Enum.__dict__['name']),
+ ('red', self.Color.red),
+ ('value', Enum.__dict__['value']),
+ ))
+ result = dict(inspect.getmembers(self.Color))
+ self.assertEqual(values.keys(), result.keys())
+ failed = False
+ for k in values.keys():
+ if result[k] != values[k]:
+ print()
+ print('\n%s\n key: %s\n result: %s\nexpected: %s\n%s\n' %
+ ('=' * 75, k, result[k], values[k], '=' * 75), sep='')
+ failed = True
+ if failed:
+ self.fail("result does not equal expected, see print above")
+
+ def test_inspect_classify_class_attrs(self):
+ # indirectly test __objclass__
+ from inspect import Attribute
+ values = [
+ Attribute(name='__class__', kind='data',
+ defining_class=object, object=EnumMeta),
+ Attribute(name='__doc__', kind='data',
+ defining_class=self.Color, object=None),
+ Attribute(name='__members__', kind='property',
+ defining_class=EnumMeta, object=EnumMeta.__members__),
+ Attribute(name='__module__', kind='data',
+ defining_class=self.Color, object=__name__),
+ Attribute(name='blue', kind='data',
+ defining_class=self.Color, object=self.Color.blue),
+ Attribute(name='green', kind='data',
+ defining_class=self.Color, object=self.Color.green),
+ Attribute(name='red', kind='data',
+ defining_class=self.Color, object=self.Color.red),
+ Attribute(name='name', kind='data',
+ defining_class=Enum, object=Enum.__dict__['name']),
+ Attribute(name='value', kind='data',
+ defining_class=Enum, object=Enum.__dict__['value']),
+ ]
+ values.sort(key=lambda item: item.name)
+ result = list(inspect.classify_class_attrs(self.Color))
+ result.sort(key=lambda item: item.name)
+ failed = False
+ for v, r in zip(values, result):
+ if r != v:
+ print('\n%s\n%s\n%s\n%s\n' % ('=' * 75, r, v, '=' * 75), sep='')
+ failed = True
+ if failed:
+ self.fail("result does not equal expected, see print above")
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_enumerate.py b/Lib/test/test_enumerate.py
index 4af217b733..8742afcea8 100644
--- a/Lib/test/test_enumerate.py
+++ b/Lib/test/test_enumerate.py
@@ -1,4 +1,5 @@
import unittest
+import operator
import sys
import pickle
@@ -168,15 +169,12 @@ class TestReversed(unittest.TestCase, PickleTest):
x = range(1)
self.assertEqual(type(reversed(x)), type(iter(x)))
- @support.cpython_only
def test_len(self):
- # This is an implementation detail, not an interface requirement
- from test.test_iterlen import len
for s in ('hello', tuple('hello'), list('hello'), range(5)):
- self.assertEqual(len(reversed(s)), len(s))
+ self.assertEqual(operator.length_hint(reversed(s)), len(s))
r = reversed(s)
list(r)
- self.assertEqual(len(r), 0)
+ self.assertEqual(operator.length_hint(r), 0)
class SeqWithWeirdLen:
called = False
def __len__(self):
@@ -187,7 +185,7 @@ class TestReversed(unittest.TestCase, PickleTest):
def __getitem__(self, index):
return index
r = reversed(SeqWithWeirdLen())
- self.assertRaises(ZeroDivisionError, len, r)
+ self.assertRaises(ZeroDivisionError, operator.length_hint, r)
def test_gc(self):
diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py
index 7f9547ff95..6459fbafc8 100644
--- a/Lib/test/test_epoll.py
+++ b/Lib/test/test_epoll.py
@@ -21,10 +21,11 @@
"""
Tests for epoll wrapper.
"""
-import socket
import errno
-import time
+import os
import select
+import socket
+import time
import unittest
from test import support
@@ -33,7 +34,7 @@ if not hasattr(select, "epoll"):
try:
select.epoll()
-except IOError as e:
+except OSError as e:
if e.errno == errno.ENOSYS:
raise unittest.SkipTest("kernel doesn't support epoll()")
raise
@@ -56,7 +57,7 @@ class TestEPoll(unittest.TestCase):
client.setblocking(False)
try:
client.connect(('127.0.0.1', self.serverSocket.getsockname()[1]))
- except socket.error as e:
+ except OSError as e:
self.assertEqual(e.args[0], errno.EINPROGRESS)
else:
raise AssertionError("Connect should have raised EINPROGRESS")
@@ -87,6 +88,13 @@ class TestEPoll(unittest.TestCase):
self.assertRaises(TypeError, select.epoll, ['foo'])
self.assertRaises(TypeError, select.epoll, {})
+ def test_context_manager(self):
+ with select.epoll(16) as ep:
+ self.assertGreater(ep.fileno(), 0)
+ self.assertFalse(ep.closed)
+ self.assertTrue(ep.closed)
+ self.assertRaises(ValueError, ep.fileno)
+
def test_add(self):
server, client = self._connected_pair()
@@ -115,12 +123,12 @@ class TestEPoll(unittest.TestCase):
# ValueError: file descriptor cannot be a negative integer (-1)
self.assertRaises(ValueError, ep.register, -1,
select.EPOLLIN | select.EPOLLOUT)
- # IOError: [Errno 9] Bad file descriptor
- self.assertRaises(IOError, ep.register, 10000,
+ # OSError: [Errno 9] Bad file descriptor
+ self.assertRaises(OSError, ep.register, 10000,
select.EPOLLIN | select.EPOLLOUT)
# registering twice also raises an exception
ep.register(server, select.EPOLLIN | select.EPOLLOUT)
- self.assertRaises(IOError, ep.register, server,
+ self.assertRaises(OSError, ep.register, server,
select.EPOLLIN | select.EPOLLOUT)
finally:
ep.close()
@@ -142,7 +150,7 @@ class TestEPoll(unittest.TestCase):
ep.close()
try:
ep2.poll(1, 4)
- except IOError as e:
+ except OSError as e:
self.assertEqual(e.args[0], errno.EBADF, e)
else:
self.fail("epoll on closed fd didn't raise EBADF")
@@ -218,6 +226,36 @@ class TestEPoll(unittest.TestCase):
server.close()
ep.unregister(fd)
+ def test_close(self):
+ open_file = open(__file__, "rb")
+ self.addCleanup(open_file.close)
+ fd = open_file.fileno()
+ epoll = select.epoll()
+
+ # test fileno() method and closed attribute
+ self.assertIsInstance(epoll.fileno(), int)
+ self.assertFalse(epoll.closed)
+
+ # test close()
+ epoll.close()
+ self.assertTrue(epoll.closed)
+ self.assertRaises(ValueError, epoll.fileno)
+
+ # close() can be called more than once
+ epoll.close()
+
+ # operations must fail with ValueError("I/O operation on closed ...")
+ self.assertRaises(ValueError, epoll.modify, fd, select.EPOLLIN)
+ self.assertRaises(ValueError, epoll.poll, 1.0)
+ self.assertRaises(ValueError, epoll.register, fd, select.EPOLLIN)
+ self.assertRaises(ValueError, epoll.unregister, fd)
+
+ def test_fd_non_inheritable(self):
+ epoll = select.epoll()
+ self.addCleanup(epoll.close)
+ self.assertEqual(os.get_inheritable(epoll.fileno()), False)
+
+
def test_main():
support.run_unittest(TestEPoll)
diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py
index 1ad7f97b74..f0851bdbe2 100644
--- a/Lib/test/test_exceptions.py
+++ b/Lib/test/test_exceptions.py
@@ -8,8 +8,8 @@ import weakref
import errno
from test.support import (TESTFN, captured_output, check_impl_detail,
- cpython_only, gc_collect, run_unittest, no_tracing,
- unlink)
+ check_warnings, cpython_only, gc_collect, run_unittest,
+ no_tracing, unlink)
class NaiveException(Exception):
def __init__(self, x):
@@ -244,22 +244,22 @@ class ExceptionTests(unittest.TestCase):
{'args' : ('foo', 1)}),
(SystemExit, ('foo',),
{'args' : ('foo',), 'code' : 'foo'}),
- (IOError, ('foo',),
+ (OSError, ('foo',),
{'args' : ('foo',), 'filename' : None,
'errno' : None, 'strerror' : None}),
- (IOError, ('foo', 'bar'),
+ (OSError, ('foo', 'bar'),
{'args' : ('foo', 'bar'), 'filename' : None,
'errno' : 'foo', 'strerror' : 'bar'}),
- (IOError, ('foo', 'bar', 'baz'),
+ (OSError, ('foo', 'bar', 'baz'),
{'args' : ('foo', 'bar'), 'filename' : 'baz',
'errno' : 'foo', 'strerror' : 'bar'}),
- (IOError, ('foo', 'bar', 'baz', 'quux'),
+ (OSError, ('foo', 'bar', 'baz', 'quux'),
{'args' : ('foo', 'bar', 'baz', 'quux')}),
- (EnvironmentError, ('errnoStr', 'strErrorStr', 'filenameStr'),
+ (OSError, ('errnoStr', 'strErrorStr', 'filenameStr'),
{'args' : ('errnoStr', 'strErrorStr'),
'strerror' : 'strErrorStr', 'errno' : 'errnoStr',
'filename' : 'filenameStr'}),
- (EnvironmentError, (1, 'strErrorStr', 'filenameStr'),
+ (OSError, (1, 'strErrorStr', 'filenameStr'),
{'args' : (1, 'strErrorStr'), 'errno' : 1,
'strerror' : 'strErrorStr', 'filename' : 'filenameStr'}),
(SyntaxError, (), {'msg' : None, 'text' : None,
@@ -409,7 +409,7 @@ class ExceptionTests(unittest.TestCase):
self.assertIsNone(e.__context__)
self.assertIsNone(e.__cause__)
- class MyException(EnvironmentError):
+ class MyException(OSError):
pass
e = MyException()
@@ -947,13 +947,11 @@ class ImportErrorTests(unittest.TestCase):
def test_non_str_argument(self):
# Issue #15778
- arg = b'abc'
- exc = ImportError(arg)
- self.assertEqual(str(arg), str(exc))
+ with check_warnings(('', BytesWarning), quiet=True):
+ arg = b'abc'
+ exc = ImportError(arg)
+ self.assertEqual(str(arg), str(exc))
-def test_main():
- run_unittest(ExceptionTests, ImportErrorTests)
-
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py
index 770e70c77d..ebd99edc8f 100644
--- a/Lib/test/test_faulthandler.py
+++ b/Lib/test/test_faulthandler.py
@@ -19,18 +19,6 @@ except ImportError:
TIMEOUT = 0.5
-try:
- from resource import setrlimit, RLIMIT_CORE, error as resource_error
-except ImportError:
- prepare_subprocess = None
-else:
- def prepare_subprocess():
- # don't create core file
- try:
- setrlimit(RLIMIT_CORE, (0, 0))
- except (ValueError, resource_error):
- pass
-
def expected_traceback(lineno1, lineno2, header, min_count=1):
regex = header
regex += ' File "<string>", line %s in func\n' % lineno1
@@ -59,10 +47,8 @@ class FaultHandlerTests(unittest.TestCase):
build, and replace "Current thread 0x00007f8d8fbd9700" by "Current
thread XXX".
"""
- options = {}
- if prepare_subprocess:
- options['preexec_fn'] = prepare_subprocess
- process = script_helper.spawn_python('-c', code, **options)
+ with support.SuppressCrashReport():
+ process = script_helper.spawn_python('-c', code)
stdout, stderr = process.communicate()
exitcode = process.wait()
output = support.strip_python_stderr(stdout)
@@ -86,9 +72,9 @@ class FaultHandlerTests(unittest.TestCase):
Raise an error if the output doesn't match the expected format.
"""
if all_threads:
- header = 'Current thread XXX'
+ header = 'Current thread XXX (most recent call first)'
else:
- header = 'Traceback (most recent call first)'
+ header = 'Stack (most recent call first)'
regex = """
^Fatal Python error: {name}
@@ -101,8 +87,7 @@ class FaultHandlerTests(unittest.TestCase):
header=re.escape(header))
if other_regex:
regex += '|' + other_regex
- with support.suppress_crash_popup():
- output, exitcode = self.get_output(code, filename)
+ output, exitcode = self.get_output(code, filename)
output = '\n'.join(output)
self.assertRegex(output, regex)
self.assertNotEqual(exitcode, 0)
@@ -232,8 +217,7 @@ faulthandler.disable()
faulthandler._sigsegv()
""".strip()
not_expected = 'Fatal Python error'
- with support.suppress_crash_popup():
- stderr, exitcode = self.get_output(code)
+ stderr, exitcode = self.get_output(code)
stder = '\n'.join(stderr)
self.assertTrue(not_expected not in stderr,
"%r is present in %r" % (not_expected, stderr))
@@ -264,16 +248,34 @@ faulthandler._sigsegv()
def test_disabled_by_default(self):
# By default, the module should be disabled
code = "import faulthandler; print(faulthandler.is_enabled())"
- rc, stdout, stderr = assert_python_ok("-c", code)
- stdout = (stdout + stderr).strip()
- self.assertEqual(stdout, b"False")
+ args = (sys.executable, '-E', '-c', code)
+ # don't use assert_python_ok() because it always enable faulthandler
+ output = subprocess.check_output(args)
+ self.assertEqual(output.rstrip(), b"False")
def test_sys_xoptions(self):
# Test python -X faulthandler
code = "import faulthandler; print(faulthandler.is_enabled())"
- rc, stdout, stderr = assert_python_ok("-X", "faulthandler", "-c", code)
- stdout = (stdout + stderr).strip()
- self.assertEqual(stdout, b"True")
+ args = (sys.executable, "-E", "-X", "faulthandler", "-c", code)
+ # don't use assert_python_ok() because it always enable faulthandler
+ output = subprocess.check_output(args)
+ self.assertEqual(output.rstrip(), b"True")
+
+ def test_env_var(self):
+ # empty env var
+ code = "import faulthandler; print(faulthandler.is_enabled())"
+ args = (sys.executable, "-c", code)
+ env = os.environ.copy()
+ env['PYTHONFAULTHANDLER'] = ''
+ # don't use assert_python_ok() because it always enable faulthandler
+ output = subprocess.check_output(args, env=env)
+ self.assertEqual(output.rstrip(), b"False")
+
+ # non-empty env var
+ env = os.environ.copy()
+ env['PYTHONFAULTHANDLER'] = '1'
+ output = subprocess.check_output(args, env=env)
+ self.assertEqual(output.rstrip(), b"True")
def check_dump_traceback(self, filename):
"""
@@ -304,7 +306,7 @@ funcA()
else:
lineno = 8
expected = [
- 'Traceback (most recent call first):',
+ 'Stack (most recent call first):',
' File "<string>", line %s in funcB' % lineno,
' File "<string>", line 11 in funcA',
' File "<string>", line 13 in <module>'
@@ -336,7 +338,7 @@ def {func_name}():
func_name=func_name,
)
expected = [
- 'Traceback (most recent call first):',
+ 'Stack (most recent call first):',
' File "<string>", line 4 in %s' % truncated,
' File "<string>", line 6 in <module>'
]
@@ -390,13 +392,13 @@ waiter.join()
else:
lineno = 10
regex = """
-^Thread 0x[0-9a-f]+:
+^Thread 0x[0-9a-f]+ \(most recent call first\):
(?: File ".*threading.py", line [0-9]+ in [_a-z]+
){{1,3}} File "<string>", line 23 in run
File ".*threading.py", line [0-9]+ in _bootstrap_inner
File ".*threading.py", line [0-9]+ in _bootstrap
-Current thread XXX:
+Current thread XXX \(most recent call first\):
File "<string>", line {lineno} in dump
File "<string>", line 28 in <module>$
""".strip()
@@ -459,7 +461,7 @@ if file is not None:
count = loops
if repeat:
count *= 2
- header = r'Timeout \(%s\)!\nThread 0x[0-9a-f]+:\n' % timeout_str
+ header = r'Timeout \(%s\)!\nThread 0x[0-9a-f]+ \(most recent call first\):\n' % timeout_str
regex = expected_traceback(9, 20, header, min_count=count)
self.assertRegex(trace, regex)
else:
@@ -561,9 +563,9 @@ sys.exit(exitcode)
trace = '\n'.join(trace)
if not unregister:
if all_threads:
- regex = 'Current thread XXX:\n'
+ regex = 'Current thread XXX \(most recent call first\):\n'
else:
- regex = 'Traceback \(most recent call first\):\n'
+ regex = 'Stack \(most recent call first\):\n'
regex = expected_traceback(7, 28, regex)
self.assertRegex(trace, regex)
else:
@@ -590,8 +592,5 @@ sys.exit(exitcode)
self.check_register(chain=True)
-def test_main():
- support.run_unittest(FaultHandlerTests)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_fcntl.py b/Lib/test/test_fcntl.py
index b8cda2f108..c816d970d7 100644
--- a/Lib/test/test_fcntl.py
+++ b/Lib/test/test_fcntl.py
@@ -1,7 +1,4 @@
"""Test program for the fcntl C module.
-
-OS/2+EMX doesn't support the file locking operations.
-
"""
import platform
import os
@@ -39,8 +36,6 @@ def get_lockdata():
lockdata = struct.pack('qqihhi', 0, 0, 0, fcntl.F_WRLCK, 0, 0)
elif sys.platform in ['aix3', 'aix4', 'hp-uxB', 'unixware7']:
lockdata = struct.pack('hhlllii', fcntl.F_WRLCK, 0, 0, 0, 0, 0, 0)
- elif sys.platform in ['os2emx']:
- lockdata = None
else:
lockdata = struct.pack('hh'+start_len+'hh', fcntl.F_WRLCK, 0, 0, 0, 0, 0)
if lockdata:
@@ -66,18 +61,20 @@ class TestFcntl(unittest.TestCase):
rv = fcntl.fcntl(self.f.fileno(), fcntl.F_SETFL, os.O_NONBLOCK)
if verbose:
print('Status from fcntl with O_NONBLOCK: ', rv)
- if sys.platform not in ['os2emx']:
- rv = fcntl.fcntl(self.f.fileno(), fcntl.F_SETLKW, lockdata)
- if verbose:
- print('String from fcntl with F_SETLKW: ', repr(rv))
+ rv = fcntl.fcntl(self.f.fileno(), fcntl.F_SETLKW, lockdata)
+ if verbose:
+ print('String from fcntl with F_SETLKW: ', repr(rv))
self.f.close()
def test_fcntl_file_descriptor(self):
# again, but pass the file rather than numeric descriptor
self.f = open(TESTFN, 'wb')
rv = fcntl.fcntl(self.f, fcntl.F_SETFL, os.O_NONBLOCK)
- if sys.platform not in ['os2emx']:
- rv = fcntl.fcntl(self.f, fcntl.F_SETLKW, lockdata)
+ if verbose:
+ print('Status from fcntl with O_NONBLOCK: ', rv)
+ rv = fcntl.fcntl(self.f, fcntl.F_SETLKW, lockdata)
+ if verbose:
+ print('String from fcntl with F_SETLKW: ', repr(rv))
self.f.close()
def test_fcntl_bad_file(self):
diff --git a/Lib/test/test_file.py b/Lib/test/test_file.py
index 76b1694e98..d54e976143 100644
--- a/Lib/test/test_file.py
+++ b/Lib/test/test_file.py
@@ -87,7 +87,7 @@ class AutoFileTests:
self.assertTrue(not f.closed)
if hasattr(f, "readinto"):
- self.assertRaises((IOError, TypeError), f.readinto, "")
+ self.assertRaises((OSError, TypeError), f.readinto, "")
f.close()
self.assertTrue(f.closed)
@@ -126,7 +126,7 @@ class AutoFileTests:
self.assertEqual(self.f.__exit__(*sys.exc_info()), None)
def testReadWhenWriting(self):
- self.assertRaises(IOError, self.f.read)
+ self.assertRaises(OSError, self.f.read)
class CAutoFileTests(AutoFileTests, unittest.TestCase):
open = io.open
@@ -177,7 +177,7 @@ class OtherFileTests:
d = int(f.read().decode("ascii"))
f.close()
f.close()
- except IOError as msg:
+ except OSError as msg:
self.fail('error setting buffer size %d: %s' % (s, str(msg)))
self.assertEqual(d, s)
diff --git a/Lib/test/test_filecmp.py b/Lib/test/test_filecmp.py
index 09599794ad..c7cb3fcf36 100644
--- a/Lib/test/test_filecmp.py
+++ b/Lib/test/test_filecmp.py
@@ -1,4 +1,3 @@
-
import os, filecmp, shutil, tempfile
import unittest
from test import support
@@ -40,15 +39,27 @@ class FileCompareTestCase(unittest.TestCase):
self.assertFalse(filecmp.cmp(self.name, self.dir),
"File and directory compare as equal")
+ def test_cache_clear(self):
+ first_compare = filecmp.cmp(self.name, self.name_same, shallow=False)
+ second_compare = filecmp.cmp(self.name, self.name_diff, shallow=False)
+ filecmp.clear_cache()
+ self.assertTrue(len(filecmp._cache) == 0,
+ "Cache not cleared after calling clear_cache")
+
class DirCompareTestCase(unittest.TestCase):
def setUp(self):
tmpdir = tempfile.gettempdir()
self.dir = os.path.join(tmpdir, 'dir')
self.dir_same = os.path.join(tmpdir, 'dir-same')
self.dir_diff = os.path.join(tmpdir, 'dir-diff')
+
+ # Another dir is created under dir_same, but it has a name from the
+ # ignored list so it should not affect testing results.
+ self.dir_ignored = os.path.join(self.dir_same, '.hg')
+
self.caseinsensitive = os.path.normcase('A') == os.path.normcase('a')
data = 'Contents of file go here.\n'
- for dir in [self.dir, self.dir_same, self.dir_diff]:
+ for dir in (self.dir, self.dir_same, self.dir_diff, self.dir_ignored):
shutil.rmtree(dir, True)
os.mkdir(dir)
if self.caseinsensitive and dir is self.dir_same:
@@ -64,9 +75,11 @@ class DirCompareTestCase(unittest.TestCase):
output.close()
def tearDown(self):
- shutil.rmtree(self.dir)
- shutil.rmtree(self.dir_same)
- shutil.rmtree(self.dir_diff)
+ for dir in (self.dir, self.dir_same, self.dir_diff):
+ shutil.rmtree(dir)
+
+ def test_default_ignores(self):
+ self.assertIn('.hg', filecmp.DEFAULT_IGNORES)
def test_cmpfiles(self):
self.assertTrue(filecmp.cmpfiles(self.dir, self.dir, ['file']) ==
diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py
index 1e70641150..c5e57d4715 100644
--- a/Lib/test/test_fileinput.py
+++ b/Lib/test/test_fileinput.py
@@ -275,8 +275,8 @@ class FileInputTests(unittest.TestCase):
try:
t1 = writeTmp(1, [""])
with FileInput(files=t1) as fi:
- raise IOError
- except IOError:
+ raise OSError
+ except OSError:
self.assertEqual(fi._files, ())
finally:
remove_tempfiles(t1)
@@ -835,22 +835,6 @@ class Test_hook_encoded(unittest.TestCase):
self.assertIs(kwargs.pop('encoding'), encoding)
self.assertFalse(kwargs)
-def test_main():
- run_unittest(
- BufferSizesTests,
- FileInputTests,
- Test_fileinput_input,
- Test_fileinput_close,
- Test_fileinput_nextfile,
- Test_fileinput_filename,
- Test_fileinput_lineno,
- Test_fileinput_filelineno,
- Test_fileinput_fileno,
- Test_fileinput_isfirstline,
- Test_fileinput_isstdin,
- Test_hook_compressed,
- Test_hook_encoded,
- )
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py
index f91ccaa427..e67876985e 100644
--- a/Lib/test/test_fileio.py
+++ b/Lib/test/test_fileio.py
@@ -145,16 +145,16 @@ class AutoFileTests(unittest.TestCase):
# Unix calls dircheck() and returns "[Errno 21]: Is a directory"
try:
_FileIO('.', 'r')
- except IOError as e:
+ except OSError as e:
self.assertNotEqual(e.errno, 0)
self.assertEqual(e.filename, ".")
else:
- self.fail("Should have raised IOError")
+ self.fail("Should have raised OSError")
@unittest.skipIf(os.name == 'nt', "test only works on a POSIX-like system")
def testOpenDirFD(self):
fd = os.open('.', os.O_RDONLY)
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(OSError) as cm:
_FileIO(fd, 'r')
os.close(fd)
self.assertEqual(cm.exception.errno, errno.EISDIR)
@@ -172,7 +172,7 @@ class AutoFileTests(unittest.TestCase):
finally:
try:
self.f.close()
- except IOError:
+ except OSError:
pass
return wrapper
@@ -184,14 +184,14 @@ class AutoFileTests(unittest.TestCase):
os.close(f.fileno())
try:
func(self, f)
- except IOError as e:
+ except OSError as e:
self.assertEqual(e.errno, errno.EBADF)
else:
- self.fail("Should have raised IOError")
+ self.fail("Should have raised OSError")
finally:
try:
self.f.close()
- except IOError:
+ except OSError:
pass
return wrapper
@@ -238,7 +238,7 @@ class AutoFileTests(unittest.TestCase):
def ReopenForRead(self):
try:
self.f.close()
- except IOError:
+ except OSError:
pass
self.f = _FileIO(TESTFN, 'r')
os.close(self.f.fileno())
@@ -286,7 +286,7 @@ class OtherFileTests(unittest.TestCase):
if sys.platform != "win32":
try:
f = _FileIO("/dev/tty", "a")
- except EnvironmentError:
+ except OSError:
# When run in a cron job there just aren't any
# ttys, so skip the test. This also handles other
# OS'es that don't support /dev/tty.
@@ -362,7 +362,7 @@ class OtherFileTests(unittest.TestCase):
self.assertRaises(OSError, _FileIO, make_bad_fd())
if sys.platform == 'win32':
import msvcrt
- self.assertRaises(IOError, msvcrt.get_osfhandle, make_bad_fd())
+ self.assertRaises(OSError, msvcrt.get_osfhandle, make_bad_fd())
# Issue 15989
self.assertRaises(TypeError, _FileIO, _testcapi.INT_MAX + 1)
self.assertRaises(TypeError, _FileIO, _testcapi.INT_MIN - 1)
diff --git a/Lib/test/test_finalization.py b/Lib/test/test_finalization.py
new file mode 100644
index 0000000000..80c9b87988
--- /dev/null
+++ b/Lib/test/test_finalization.py
@@ -0,0 +1,513 @@
+"""
+Tests for object finalization semantics, as outlined in PEP 442.
+"""
+
+import contextlib
+import gc
+import unittest
+import weakref
+
+import _testcapi
+from test import support
+
+
+class NonGCSimpleBase:
+ """
+ The base class for all the objects under test, equipped with various
+ testing features.
+ """
+
+ survivors = []
+ del_calls = []
+ tp_del_calls = []
+ errors = []
+
+ _cleaning = False
+
+ __slots__ = ()
+
+ @classmethod
+ def _cleanup(cls):
+ cls.survivors.clear()
+ cls.errors.clear()
+ gc.garbage.clear()
+ gc.collect()
+ cls.del_calls.clear()
+ cls.tp_del_calls.clear()
+
+ @classmethod
+ @contextlib.contextmanager
+ def test(cls):
+ """
+ A context manager to use around all finalization tests.
+ """
+ with support.disable_gc():
+ cls.del_calls.clear()
+ cls.tp_del_calls.clear()
+ NonGCSimpleBase._cleaning = False
+ try:
+ yield
+ if cls.errors:
+ raise cls.errors[0]
+ finally:
+ NonGCSimpleBase._cleaning = True
+ cls._cleanup()
+
+ def check_sanity(self):
+ """
+ Check the object is sane (non-broken).
+ """
+
+ def __del__(self):
+ """
+ PEP 442 finalizer. Record that this was called, check the
+ object is in a sane state, and invoke a side effect.
+ """
+ try:
+ if not self._cleaning:
+ self.del_calls.append(id(self))
+ self.check_sanity()
+ self.side_effect()
+ except Exception as e:
+ self.errors.append(e)
+
+ def side_effect(self):
+ """
+ A side effect called on destruction.
+ """
+
+
+class SimpleBase(NonGCSimpleBase):
+
+ def __init__(self):
+ self.id_ = id(self)
+
+ def check_sanity(self):
+ assert self.id_ == id(self)
+
+
+class NonGC(NonGCSimpleBase):
+ __slots__ = ()
+
+class NonGCResurrector(NonGCSimpleBase):
+ __slots__ = ()
+
+ def side_effect(self):
+ """
+ Resurrect self by storing self in a class-wide list.
+ """
+ self.survivors.append(self)
+
+class Simple(SimpleBase):
+ pass
+
+class SimpleResurrector(NonGCResurrector, SimpleBase):
+ pass
+
+
+class TestBase:
+
+ def setUp(self):
+ self.old_garbage = gc.garbage[:]
+ gc.garbage[:] = []
+
+ def tearDown(self):
+ # None of the tests here should put anything in gc.garbage
+ try:
+ self.assertEqual(gc.garbage, [])
+ finally:
+ del self.old_garbage
+ gc.collect()
+
+ def assert_del_calls(self, ids):
+ self.assertEqual(sorted(SimpleBase.del_calls), sorted(ids))
+
+ def assert_tp_del_calls(self, ids):
+ self.assertEqual(sorted(SimpleBase.tp_del_calls), sorted(ids))
+
+ def assert_survivors(self, ids):
+ self.assertEqual(sorted(id(x) for x in SimpleBase.survivors), sorted(ids))
+
+ def assert_garbage(self, ids):
+ self.assertEqual(sorted(id(x) for x in gc.garbage), sorted(ids))
+
+ def clear_survivors(self):
+ SimpleBase.survivors.clear()
+
+
+class SimpleFinalizationTest(TestBase, unittest.TestCase):
+ """
+ Test finalization without refcycles.
+ """
+
+ def test_simple(self):
+ with SimpleBase.test():
+ s = Simple()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+
+ def test_simple_resurrect(self):
+ with SimpleBase.test():
+ s = SimpleResurrector()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors(ids)
+ self.assertIsNot(wr(), None)
+ self.clear_survivors()
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+
+ def test_non_gc(self):
+ with SimpleBase.test():
+ s = NonGC()
+ self.assertFalse(gc.is_tracked(s))
+ ids = [id(s)]
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+
+ def test_non_gc_resurrect(self):
+ with SimpleBase.test():
+ s = NonGCResurrector()
+ self.assertFalse(gc.is_tracked(s))
+ ids = [id(s)]
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors(ids)
+ self.clear_survivors()
+ gc.collect()
+ self.assert_del_calls(ids * 2)
+ self.assert_survivors(ids)
+
+
+class SelfCycleBase:
+
+ def __init__(self):
+ super().__init__()
+ self.ref = self
+
+ def check_sanity(self):
+ super().check_sanity()
+ assert self.ref is self
+
+class SimpleSelfCycle(SelfCycleBase, Simple):
+ pass
+
+class SelfCycleResurrector(SelfCycleBase, SimpleResurrector):
+ pass
+
+class SuicidalSelfCycle(SelfCycleBase, Simple):
+
+ def side_effect(self):
+ """
+ Explicitly break the reference cycle.
+ """
+ self.ref = None
+
+
+class SelfCycleFinalizationTest(TestBase, unittest.TestCase):
+ """
+ Test finalization of an object having a single cyclic reference to
+ itself.
+ """
+
+ def test_simple(self):
+ with SimpleBase.test():
+ s = SimpleSelfCycle()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+
+ def test_simple_resurrect(self):
+ # Test that __del__ can resurrect the object being finalized.
+ with SimpleBase.test():
+ s = SelfCycleResurrector()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors(ids)
+ # XXX is this desirable?
+ self.assertIs(wr(), None)
+ # When trying to destroy the object a second time, __del__
+ # isn't called anymore (and the object isn't resurrected).
+ self.clear_survivors()
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+
+ def test_simple_suicide(self):
+ # Test the GC is able to deal with an object that kills its last
+ # reference during __del__.
+ with SimpleBase.test():
+ s = SuicidalSelfCycle()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+
+
+class ChainedBase:
+
+ def chain(self, left):
+ self.suicided = False
+ self.left = left
+ left.right = self
+
+ def check_sanity(self):
+ super().check_sanity()
+ if self.suicided:
+ assert self.left is None
+ assert self.right is None
+ else:
+ left = self.left
+ if left.suicided:
+ assert left.right is None
+ else:
+ assert left.right is self
+ right = self.right
+ if right.suicided:
+ assert right.left is None
+ else:
+ assert right.left is self
+
+class SimpleChained(ChainedBase, Simple):
+ pass
+
+class ChainedResurrector(ChainedBase, SimpleResurrector):
+ pass
+
+class SuicidalChained(ChainedBase, Simple):
+
+ def side_effect(self):
+ """
+ Explicitly break the reference cycle.
+ """
+ self.suicided = True
+ self.left = None
+ self.right = None
+
+
+class CycleChainFinalizationTest(TestBase, unittest.TestCase):
+ """
+ Test finalization of a cyclic chain. These tests are similar in
+ spirit to the self-cycle tests above, but the collectable object
+ graph isn't trivial anymore.
+ """
+
+ def build_chain(self, classes):
+ nodes = [cls() for cls in classes]
+ for i in range(len(nodes)):
+ nodes[i].chain(nodes[i-1])
+ return nodes
+
+ def check_non_resurrecting_chain(self, classes):
+ N = len(classes)
+ with SimpleBase.test():
+ nodes = self.build_chain(classes)
+ ids = [id(s) for s in nodes]
+ wrs = [weakref.ref(s) for s in nodes]
+ del nodes
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+ self.assertEqual([wr() for wr in wrs], [None] * N)
+ gc.collect()
+ self.assert_del_calls(ids)
+
+ def check_resurrecting_chain(self, classes):
+ N = len(classes)
+ with SimpleBase.test():
+ nodes = self.build_chain(classes)
+ N = len(nodes)
+ ids = [id(s) for s in nodes]
+ survivor_ids = [id(s) for s in nodes if isinstance(s, SimpleResurrector)]
+ wrs = [weakref.ref(s) for s in nodes]
+ del nodes
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors(survivor_ids)
+ # XXX desirable?
+ self.assertEqual([wr() for wr in wrs], [None] * N)
+ self.clear_survivors()
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_survivors([])
+
+ def test_homogenous(self):
+ self.check_non_resurrecting_chain([SimpleChained] * 3)
+
+ def test_homogenous_resurrect(self):
+ self.check_resurrecting_chain([ChainedResurrector] * 3)
+
+ def test_homogenous_suicidal(self):
+ self.check_non_resurrecting_chain([SuicidalChained] * 3)
+
+ def test_heterogenous_suicidal_one(self):
+ self.check_non_resurrecting_chain([SuicidalChained, SimpleChained] * 2)
+
+ def test_heterogenous_suicidal_two(self):
+ self.check_non_resurrecting_chain(
+ [SuicidalChained] * 2 + [SimpleChained] * 2)
+
+ def test_heterogenous_resurrect_one(self):
+ self.check_resurrecting_chain([ChainedResurrector, SimpleChained] * 2)
+
+ def test_heterogenous_resurrect_two(self):
+ self.check_resurrecting_chain(
+ [ChainedResurrector, SimpleChained, SuicidalChained] * 2)
+
+ def test_heterogenous_resurrect_three(self):
+ self.check_resurrecting_chain(
+ [ChainedResurrector] * 2 + [SimpleChained] * 2 + [SuicidalChained] * 2)
+
+
+# NOTE: the tp_del slot isn't automatically inherited, so we have to call
+# with_tp_del() for each instantiated class.
+
+class LegacyBase(SimpleBase):
+
+ def __del__(self):
+ try:
+ # Do not invoke side_effect here, since we are now exercising
+ # the tp_del slot.
+ if not self._cleaning:
+ self.del_calls.append(id(self))
+ self.check_sanity()
+ except Exception as e:
+ self.errors.append(e)
+
+ def __tp_del__(self):
+ """
+ Legacy (pre-PEP 442) finalizer, mapped to a tp_del slot.
+ """
+ try:
+ if not self._cleaning:
+ self.tp_del_calls.append(id(self))
+ self.check_sanity()
+ self.side_effect()
+ except Exception as e:
+ self.errors.append(e)
+
+@_testcapi.with_tp_del
+class Legacy(LegacyBase):
+ pass
+
+@_testcapi.with_tp_del
+class LegacyResurrector(LegacyBase):
+
+ def side_effect(self):
+ """
+ Resurrect self by storing self in a class-wide list.
+ """
+ self.survivors.append(self)
+
+@_testcapi.with_tp_del
+class LegacySelfCycle(SelfCycleBase, LegacyBase):
+ pass
+
+
+class LegacyFinalizationTest(TestBase, unittest.TestCase):
+ """
+ Test finalization of objects with a tp_del.
+ """
+
+ def tearDown(self):
+ # These tests need to clean up a bit more, since they create
+ # uncollectable objects.
+ gc.garbage.clear()
+ gc.collect()
+ super().tearDown()
+
+ def test_legacy(self):
+ with SimpleBase.test():
+ s = Legacy()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_tp_del_calls(ids)
+ self.assert_survivors([])
+ self.assertIs(wr(), None)
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_tp_del_calls(ids)
+
+ def test_legacy_resurrect(self):
+ with SimpleBase.test():
+ s = LegacyResurrector()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_tp_del_calls(ids)
+ self.assert_survivors(ids)
+ # weakrefs are cleared before tp_del is called.
+ self.assertIs(wr(), None)
+ self.clear_survivors()
+ gc.collect()
+ self.assert_del_calls(ids)
+ self.assert_tp_del_calls(ids * 2)
+ self.assert_survivors(ids)
+ self.assertIs(wr(), None)
+
+ def test_legacy_self_cycle(self):
+ # Self-cycles with legacy finalizers end up in gc.garbage.
+ with SimpleBase.test():
+ s = LegacySelfCycle()
+ ids = [id(s)]
+ wr = weakref.ref(s)
+ del s
+ gc.collect()
+ self.assert_del_calls([])
+ self.assert_tp_del_calls([])
+ self.assert_survivors([])
+ self.assert_garbage(ids)
+ self.assertIsNot(wr(), None)
+ # Break the cycle to allow collection
+ gc.garbage[0].ref = None
+ self.assert_garbage([])
+ self.assertIs(wr(), None)
+
+
+def test_main():
+ support.run_unittest(__name__)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py
index 502292f615..5c2dc3ff3b 100644
--- a/Lib/test/test_float.py
+++ b/Lib/test/test_float.py
@@ -41,6 +41,7 @@ class GeneralFloatCases(unittest.TestCase):
self.assertRaises(ValueError, float, "-.")
self.assertRaises(ValueError, float, b"-")
self.assertRaises(TypeError, float, {})
+ self.assertRaisesRegex(TypeError, "not 'dict'", float, {})
# Lone surrogate
self.assertRaises(UnicodeEncodeError, float, '\uD8F0')
# check that we don't accept alternate exponent markers
diff --git a/Lib/test/test_fork1.py b/Lib/test/test_fork1.py
index 8192c38a44..e0626dffdc 100644
--- a/Lib/test/test_fork1.py
+++ b/Lib/test/test_fork1.py
@@ -1,7 +1,7 @@
"""This test checks for correct fork() behavior.
"""
-import imp
+import _imp as imp
import os
import signal
import sys
diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py
index bd159f5689..29330f9171 100644
--- a/Lib/test/test_format.py
+++ b/Lib/test/test_format.py
@@ -307,27 +307,41 @@ class FormatTest(unittest.TestCase):
finally:
locale.setlocale(locale.LC_ALL, oldloc)
+ @support.cpython_only
+ def test_optimisations(self):
+ text = "abcde" # 5 characters
+ self.assertIs("%s" % text, text)
+ self.assertIs("%.5s" % text, text)
+ self.assertIs("%.10s" % text, text)
+ self.assertIs("%1s" % text, text)
+ self.assertIs("%5s" % text, text)
-def test_main():
- support.run_unittest(FormatTest)
+ self.assertIs("{0}".format(text), text)
+ self.assertIs("{0:s}".format(text), text)
+ self.assertIs("{0:.5s}".format(text), text)
+ self.assertIs("{0:.10s}".format(text), text)
+ self.assertIs("{0:1s}".format(text), text)
+ self.assertIs("{0:5s}".format(text), text)
+ self.assertIs(text % (), text)
+ self.assertIs(text.format(), text)
+
+ @support.cpython_only
def test_precision(self):
- INT_MAX = 2147483647
+ from _testcapi import INT_MAX
f = 1.2
self.assertEqual(format(f, ".0f"), "1")
self.assertEqual(format(f, ".3f"), "1.200")
with self.assertRaises(ValueError) as cm:
format(f, ".%sf" % (INT_MAX + 1))
- self.assertEqual(str(cm.exception), "precision too big")
c = complex(f)
- self.assertEqual(format(f, ".0f"), "1")
- self.assertEqual(format(f, ".3f"), "1.200")
+ self.assertEqual(format(c, ".0f"), "1+0j")
+ self.assertEqual(format(c, ".3f"), "1.200+0.000j")
with self.assertRaises(ValueError) as cm:
- format(f, ".%sf" % (INT_MAX + 1))
- self.assertEqual(str(cm.exception), "precision too big")
+ format(c, ".%sf" % (INT_MAX + 1))
if __name__ == "__main__":
diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py
index 1fad9215c0..33365322b9 100644
--- a/Lib/test/test_fractions.py
+++ b/Lib/test/test_fractions.py
@@ -146,9 +146,10 @@ class FractionTest(unittest.TestCase):
self.assertEqual((0, 1), _components(F(-0.0)))
self.assertEqual((3602879701896397, 36028797018963968),
_components(F(0.1)))
- self.assertRaises(TypeError, F, float('nan'))
- self.assertRaises(TypeError, F, float('inf'))
- self.assertRaises(TypeError, F, float('-inf'))
+ # bug 16469: error types should be consistent with float -> int
+ self.assertRaises(ValueError, F, float('nan'))
+ self.assertRaises(OverflowError, F, float('inf'))
+ self.assertRaises(OverflowError, F, float('-inf'))
def testInitFromDecimal(self):
self.assertEqual((11, 10),
@@ -157,10 +158,11 @@ class FractionTest(unittest.TestCase):
_components(F(Decimal('3.5e-2'))))
self.assertEqual((0, 1),
_components(F(Decimal('.000e20'))))
- self.assertRaises(TypeError, F, Decimal('nan'))
- self.assertRaises(TypeError, F, Decimal('snan'))
- self.assertRaises(TypeError, F, Decimal('inf'))
- self.assertRaises(TypeError, F, Decimal('-inf'))
+ # bug 16469: error types should be consistent with decimal -> int
+ self.assertRaises(ValueError, F, Decimal('nan'))
+ self.assertRaises(ValueError, F, Decimal('snan'))
+ self.assertRaises(OverflowError, F, Decimal('inf'))
+ self.assertRaises(OverflowError, F, Decimal('-inf'))
def testFromString(self):
self.assertEqual((5, 1), _components(F("5")))
@@ -248,14 +250,15 @@ class FractionTest(unittest.TestCase):
inf = 1e1000
nan = inf - inf
+ # bug 16469: error types should be consistent with float -> int
self.assertRaisesMessage(
- TypeError, "Cannot convert inf to Fraction.",
+ OverflowError, "Cannot convert inf to Fraction.",
F.from_float, inf)
self.assertRaisesMessage(
- TypeError, "Cannot convert -inf to Fraction.",
+ OverflowError, "Cannot convert -inf to Fraction.",
F.from_float, -inf)
self.assertRaisesMessage(
- TypeError, "Cannot convert nan to Fraction.",
+ ValueError, "Cannot convert nan to Fraction.",
F.from_float, nan)
def testFromDecimal(self):
@@ -268,17 +271,18 @@ class FractionTest(unittest.TestCase):
self.assertEqual(1 - F(1, 10**30),
F.from_decimal(Decimal("0." + "9" * 30)))
+ # bug 16469: error types should be consistent with decimal -> int
self.assertRaisesMessage(
- TypeError, "Cannot convert Infinity to Fraction.",
+ OverflowError, "Cannot convert Infinity to Fraction.",
F.from_decimal, Decimal("inf"))
self.assertRaisesMessage(
- TypeError, "Cannot convert -Infinity to Fraction.",
+ OverflowError, "Cannot convert -Infinity to Fraction.",
F.from_decimal, Decimal("-inf"))
self.assertRaisesMessage(
- TypeError, "Cannot convert NaN to Fraction.",
+ ValueError, "Cannot convert NaN to Fraction.",
F.from_decimal, Decimal("nan"))
self.assertRaisesMessage(
- TypeError, "Cannot convert sNaN to Fraction.",
+ ValueError, "Cannot convert sNaN to Fraction.",
F.from_decimal, Decimal("snan"))
def testLimitDenominator(self):
diff --git a/Lib/test/test_frame.py b/Lib/test/test_frame.py
new file mode 100644
index 0000000000..2dd5780c88
--- /dev/null
+++ b/Lib/test/test_frame.py
@@ -0,0 +1,116 @@
+import gc
+import sys
+import unittest
+import weakref
+
+from test import support
+
+
+class ClearTest(unittest.TestCase):
+ """
+ Tests for frame.clear().
+ """
+
+ def inner(self, x=5, **kwargs):
+ 1/0
+
+ def outer(self, **kwargs):
+ try:
+ self.inner(**kwargs)
+ except ZeroDivisionError as e:
+ exc = e
+ return exc
+
+ def clear_traceback_frames(self, tb):
+ """
+ Clear all frames in a traceback.
+ """
+ while tb is not None:
+ tb.tb_frame.clear()
+ tb = tb.tb_next
+
+ def test_clear_locals(self):
+ class C:
+ pass
+ c = C()
+ wr = weakref.ref(c)
+ exc = self.outer(c=c)
+ del c
+ support.gc_collect()
+ # A reference to c is held through the frames
+ self.assertIsNot(None, wr())
+ self.clear_traceback_frames(exc.__traceback__)
+ support.gc_collect()
+ # The reference was released by .clear()
+ self.assertIs(None, wr())
+
+ def test_clear_generator(self):
+ endly = False
+ def g():
+ nonlocal endly
+ try:
+ yield
+ inner()
+ finally:
+ endly = True
+ gen = g()
+ next(gen)
+ self.assertFalse(endly)
+ # Clearing the frame closes the generator
+ gen.gi_frame.clear()
+ self.assertTrue(endly)
+
+ def test_clear_executing(self):
+ # Attempting to clear an executing frame is forbidden.
+ try:
+ 1/0
+ except ZeroDivisionError as e:
+ f = e.__traceback__.tb_frame
+ with self.assertRaises(RuntimeError):
+ f.clear()
+ with self.assertRaises(RuntimeError):
+ f.f_back.clear()
+
+ def test_clear_executing_generator(self):
+ # Attempting to clear an executing generator frame is forbidden.
+ endly = False
+ def g():
+ nonlocal endly
+ try:
+ 1/0
+ except ZeroDivisionError as e:
+ f = e.__traceback__.tb_frame
+ with self.assertRaises(RuntimeError):
+ f.clear()
+ with self.assertRaises(RuntimeError):
+ f.f_back.clear()
+ yield f
+ finally:
+ endly = True
+ gen = g()
+ f = next(gen)
+ self.assertFalse(endly)
+ # Clearing the frame closes the generator
+ f.clear()
+ self.assertTrue(endly)
+
+ @support.cpython_only
+ def test_clear_refcycles(self):
+ # .clear() doesn't leave any refcycle behind
+ with support.disable_gc():
+ class C:
+ pass
+ c = C()
+ wr = weakref.ref(c)
+ exc = self.outer(c=c)
+ del c
+ self.assertIsNot(None, wr())
+ self.clear_traceback_frames(exc.__traceback__)
+ self.assertIs(None, wr())
+
+
+def test_main():
+ support.run_unittest(__name__)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_frozen.py b/Lib/test/test_frozen.py
index fd6761c651..4c50cb7510 100644
--- a/Lib/test/test_frozen.py
+++ b/Lib/test/test_frozen.py
@@ -36,7 +36,7 @@ class FrozenTests(unittest.TestCase):
else:
expect.add('spam')
self.assertEqual(set(dir(__phello__)), expect)
- self.assertEqual(__phello__.__path__, [__phello__.__name__])
+ self.assertEqual(__phello__.__path__, [])
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
with captured_stdout() as stdout:
@@ -72,8 +72,6 @@ class FrozenTests(unittest.TestCase):
del sys.modules['__phello__']
del sys.modules['__phello__.spam']
-def test_main():
- run_unittest(FrozenTests)
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
index 64cd80bc14..41463e2a5b 100644
--- a/Lib/test/test_ftplib.py
+++ b/Lib/test/test_ftplib.py
@@ -21,6 +21,7 @@ from test import support
from test.support import HOST, HOSTv6
threading = support.import_module('threading')
+TIMEOUT = 3
# the dummy data returned by server over the data channel when
# RETR, LIST, NLST, MLSD commands are issued
RETR_DATA = 'abcde12345\r\n' * 1000
@@ -126,7 +127,7 @@ class DummyFTPHandler(asynchat.async_chat):
addr = list(map(int, arg.split(',')))
ip = '%d.%d.%d.%d' %tuple(addr[:4])
port = (addr[4] * 256) + addr[5]
- s = socket.create_connection((ip, port), timeout=2)
+ s = socket.create_connection((ip, port), timeout=TIMEOUT)
self.dtp = self.dtp_handler(s, baseclass=self)
self.push('200 active data connection established')
@@ -134,7 +135,7 @@ class DummyFTPHandler(asynchat.async_chat):
with socket.socket() as sock:
sock.bind((self.socket.getsockname()[0], 0))
sock.listen(5)
- sock.settimeout(10)
+ sock.settimeout(TIMEOUT)
ip, port = sock.getsockname()[:2]
ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256
self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2))
@@ -144,7 +145,7 @@ class DummyFTPHandler(asynchat.async_chat):
def cmd_eprt(self, arg):
af, ip, port = arg.split(arg[0])[1:-1]
port = int(port)
- s = socket.create_connection((ip, port), timeout=2)
+ s = socket.create_connection((ip, port), timeout=TIMEOUT)
self.dtp = self.dtp_handler(s, baseclass=self)
self.push('200 active data connection established')
@@ -152,7 +153,7 @@ class DummyFTPHandler(asynchat.async_chat):
with socket.socket(socket.AF_INET6) as sock:
sock.bind((self.socket.getsockname()[0], 0))
sock.listen(5)
- sock.settimeout(10)
+ sock.settimeout(TIMEOUT)
port = sock.getsockname()[1]
self.push('229 entering extended passive mode (|||%d|)' %port)
conn, addr = sock.accept()
@@ -327,7 +328,7 @@ if ssl is not None:
elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close()
raise
- except socket.error as err:
+ except OSError as err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
@@ -341,7 +342,7 @@ if ssl is not None:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
- except socket.error as err:
+ except OSError as err:
# Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
# from OpenSSL's SSL_shutdown(), corresponding to a
# closed socket condition. See also:
@@ -460,7 +461,7 @@ class TestFTPClass(TestCase):
def setUp(self):
self.server = DummyFTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP(timeout=2)
+ self.client = ftplib.FTP(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
@@ -488,7 +489,7 @@ class TestFTPClass(TestCase):
def test_all_errors(self):
exceptions = (ftplib.error_reply, ftplib.error_temp, ftplib.error_perm,
- ftplib.error_proto, ftplib.Error, IOError, EOFError)
+ ftplib.error_proto, ftplib.Error, OSError, EOFError)
for x in exceptions:
try:
raise x('exception not included in all_errors set')
@@ -678,7 +679,7 @@ class TestFTPClass(TestCase):
def test_makepasv(self):
host, port = self.client.makepasv()
- conn = socket.create_connection((host, port), 10)
+ conn = socket.create_connection((host, port), timeout=TIMEOUT)
conn.close()
# IPv4 is in use, just make sure send_epsv has not been used
self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv')
@@ -691,12 +692,12 @@ class TestFTPClass(TestCase):
return False
try:
self.client.sendcmd('noop')
- except (socket.error, EOFError):
+ except (OSError, EOFError):
return False
return True
# base test
- with ftplib.FTP(timeout=2) as self.client:
+ with ftplib.FTP(timeout=TIMEOUT) as self.client:
self.client.connect(self.server.host, self.server.port)
self.client.sendcmd('noop')
self.assertTrue(is_client_connected())
@@ -704,7 +705,7 @@ class TestFTPClass(TestCase):
self.assertFalse(is_client_connected())
# QUIT sent inside the with block
- with ftplib.FTP(timeout=2) as self.client:
+ with ftplib.FTP(timeout=TIMEOUT) as self.client:
self.client.connect(self.server.host, self.server.port)
self.client.sendcmd('noop')
self.client.quit()
@@ -714,7 +715,7 @@ class TestFTPClass(TestCase):
# force a wrong response code to be sent on QUIT: error_perm
# is expected and the connection is supposed to be closed
try:
- with ftplib.FTP(timeout=2) as self.client:
+ with ftplib.FTP(timeout=TIMEOUT) as self.client:
self.client.connect(self.server.host, self.server.port)
self.client.sendcmd('noop')
self.server.handler_instance.next_response = '550 error on quit'
@@ -736,7 +737,7 @@ class TestFTPClass(TestCase):
source_address=(HOST, port))
self.assertEqual(self.client.sock.getsockname()[1], port)
self.client.quit()
- except IOError as e:
+ except OSError as e:
if e.errno == errno.EADDRINUSE:
self.skipTest("couldn't bind to port %d" % port)
raise
@@ -747,7 +748,7 @@ class TestFTPClass(TestCase):
try:
with self.client.transfercmd('list') as sock:
self.assertEqual(sock.getsockname()[1], port)
- except IOError as e:
+ except OSError as e:
if e.errno == errno.EADDRINUSE:
self.skipTest("couldn't bind to port %d" % port)
raise
@@ -785,7 +786,7 @@ class TestIPv6Environment(TestCase):
def setUp(self):
self.server = DummyFTPServer((HOSTv6, 0), af=socket.AF_INET6)
self.server.start()
- self.client = ftplib.FTP()
+ self.client = ftplib.FTP(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
@@ -802,7 +803,7 @@ class TestIPv6Environment(TestCase):
def test_makepasv(self):
host, port = self.client.makepasv()
- conn = socket.create_connection((host, port), 10)
+ conn = socket.create_connection((host, port), timeout=TIMEOUT)
conn.close()
self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv')
@@ -829,7 +830,7 @@ class TestTLS_FTPClassMixin(TestFTPClass):
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP_TLS(timeout=2)
+ self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
# enable TLS
self.client.auth()
@@ -843,7 +844,7 @@ class TestTLS_FTPClass(TestCase):
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP_TLS(timeout=2)
+ self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
@@ -903,7 +904,7 @@ class TestTLS_FTPClass(TestCase):
self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
keyfile=CERTFILE, context=ctx)
- self.client = ftplib.FTP_TLS(context=ctx, timeout=2)
+ self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
self.client.auth()
@@ -1017,8 +1018,19 @@ class TestTimeouts(TestCase):
ftp.close()
+class TestNetrcDeprecation(TestCase):
+
+ def test_deprecation(self):
+ with support.temp_cwd(), support.EnvironmentVarGuard() as env:
+ env['HOME'] = os.getcwd()
+ open('.netrc', 'w').close()
+ with self.assertWarns(DeprecationWarning):
+ ftplib.Netrc()
+
+
+
def test_main():
- tests = [TestFTPClass, TestTimeouts,
+ tests = [TestFTPClass, TestTimeouts, TestNetrcDeprecation,
TestIPv6Environment,
TestTLS_FTPClassMixin, TestTLS_FTPClass]
diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py
index c8ed83020e..5094f7ba1f 100644
--- a/Lib/test/test_funcattrs.py
+++ b/Lib/test/test_funcattrs.py
@@ -7,6 +7,11 @@ def global_function():
def inner_function():
class LocalClass:
pass
+ global inner_global_function
+ def inner_global_function():
+ def inner_function2():
+ pass
+ return inner_function2
return LocalClass
return lambda: inner_function
@@ -116,6 +121,8 @@ class FunctionPropertiesTest(FuncAttrsTest):
'global_function.<locals>.inner_function')
self.assertEqual(global_function()()().__qualname__,
'global_function.<locals>.inner_function.<locals>.LocalClass')
+ self.assertEqual(inner_global_function.__qualname__, 'inner_global_function')
+ self.assertEqual(inner_global_function().__qualname__, 'inner_global_function.<locals>.inner_function2')
self.b.__qualname__ = 'c'
self.assertEqual(self.b.__qualname__, 'c')
self.b.__qualname__ = 'd'
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index db1e9348dd..31c093b1f0 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -1,57 +1,55 @@
-import functools
+import abc
import collections
+from itertools import permutations
+import pickle
+from random import choice
import sys
-import unittest
from test import support
+import unittest
from weakref import proxy
-import pickle
-from random import choice
-@staticmethod
-def PythonPartial(func, *args, **keywords):
- 'Pure Python approximation of partial()'
- def newfunc(*fargs, **fkeywords):
- newkeywords = keywords.copy()
- newkeywords.update(fkeywords)
- return func(*(args + fargs), **newkeywords)
- newfunc.func = func
- newfunc.args = args
- newfunc.keywords = keywords
- return newfunc
+import functools
+
+py_functools = support.import_fresh_module('functools', blocked=['_functools'])
+c_functools = support.import_fresh_module('functools', fresh=['_functools'])
+
+decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
+
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
return args, kw
+
def signature(part):
""" return the signature of a partial object """
return (part.func, part.args, part.keywords, part.__dict__)
-class TestPartial(unittest.TestCase):
- thetype = functools.partial
+class TestPartial:
def test_basic_examples(self):
- p = self.thetype(capture, 1, 2, a=10, b=20)
+ p = self.partial(capture, 1, 2, a=10, b=20)
+ self.assertTrue(callable(p))
self.assertEqual(p(3, 4, b=30, c=40),
((1, 2, 3, 4), dict(a=10, b=30, c=40)))
- p = self.thetype(map, lambda x: x*10)
+ p = self.partial(map, lambda x: x*10)
self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
def test_attributes(self):
- p = self.thetype(capture, 1, 2, a=10, b=20)
+ p = self.partial(capture, 1, 2, a=10, b=20)
# attributes should be readable
self.assertEqual(p.func, capture)
self.assertEqual(p.args, (1, 2))
self.assertEqual(p.keywords, dict(a=10, b=20))
# attributes should not be writable
- if not isinstance(self.thetype, type):
+ if not isinstance(self.partial, type):
return
self.assertRaises(AttributeError, setattr, p, 'func', map)
self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
- p = self.thetype(hex)
+ p = self.partial(hex)
try:
del p.__dict__
except TypeError:
@@ -60,9 +58,9 @@ class TestPartial(unittest.TestCase):
self.fail('partial object allowed __dict__ to be deleted')
def test_argument_checking(self):
- self.assertRaises(TypeError, self.thetype) # need at least a func arg
+ self.assertRaises(TypeError, self.partial) # need at least a func arg
try:
- self.thetype(2)()
+ self.partial(2)()
except TypeError:
pass
else:
@@ -73,7 +71,7 @@ class TestPartial(unittest.TestCase):
def func(a=10, b=20):
return a
d = {'a':3}
- p = self.thetype(func, a=5)
+ p = self.partial(func, a=5)
self.assertEqual(p(**d), 3)
self.assertEqual(d, {'a':3})
p(b=7)
@@ -82,20 +80,20 @@ class TestPartial(unittest.TestCase):
def test_arg_combinations(self):
# exercise special code paths for zero args in either partial
# object or the caller
- p = self.thetype(capture)
+ p = self.partial(capture)
self.assertEqual(p(), ((), {}))
self.assertEqual(p(1,2), ((1,2), {}))
- p = self.thetype(capture, 1, 2)
+ p = self.partial(capture, 1, 2)
self.assertEqual(p(), ((1,2), {}))
self.assertEqual(p(3,4), ((1,2,3,4), {}))
def test_kw_combinations(self):
# exercise special code paths for no keyword args in
# either the partial object or the caller
- p = self.thetype(capture)
+ p = self.partial(capture)
self.assertEqual(p(), ((), {}))
self.assertEqual(p(a=1), ((), {'a':1}))
- p = self.thetype(capture, a=1)
+ p = self.partial(capture, a=1)
self.assertEqual(p(), ((), {'a':1}))
self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
# keyword args in the call override those in the partial object
@@ -104,7 +102,7 @@ class TestPartial(unittest.TestCase):
def test_positional(self):
# make sure positional arguments are captured correctly
for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
- p = self.thetype(capture, *args)
+ p = self.partial(capture, *args)
expected = args + ('x',)
got, empty = p('x')
self.assertTrue(expected == got and empty == {})
@@ -112,14 +110,14 @@ class TestPartial(unittest.TestCase):
def test_keyword(self):
# make sure keyword arguments are captured correctly
for a in ['a', 0, None, 3.5]:
- p = self.thetype(capture, a=a)
+ p = self.partial(capture, a=a)
expected = {'a':a,'x':None}
empty, got = p(x=None)
self.assertTrue(expected == got and empty == ())
def test_no_side_effects(self):
# make sure there are no side effects that affect subsequent calls
- p = self.thetype(capture, 0, a=1)
+ p = self.partial(capture, 0, a=1)
args1, kw1 = p(1, b=2)
self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
args2, kw2 = p()
@@ -128,13 +126,13 @@ class TestPartial(unittest.TestCase):
def test_error_propagation(self):
def f(x, y):
x / y
- self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
- self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
- self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
- self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
+ self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
+ self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
+ self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
+ self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
def test_weakref(self):
- f = self.thetype(int, base=16)
+ f = self.partial(int, base=16)
p = proxy(f)
self.assertEqual(f.func, p.func)
f = None
@@ -142,39 +140,45 @@ class TestPartial(unittest.TestCase):
def test_with_bound_and_unbound_methods(self):
data = list(map(str, range(10)))
- join = self.thetype(str.join, '')
+ join = self.partial(str.join, '')
self.assertEqual(join(data), '0123456789')
- join = self.thetype(''.join)
+ join = self.partial(''.join)
self.assertEqual(join(data), '0123456789')
+
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialC(TestPartial, unittest.TestCase):
+ if c_functools:
+ partial = c_functools.partial
+
def test_repr(self):
args = (object(), object())
args_repr = ', '.join(repr(a) for a in args)
kwargs = {'a': object(), 'b': object()}
kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
- if self.thetype is functools.partial:
+ if self.partial is c_functools.partial:
name = 'functools.partial'
else:
- name = self.thetype.__name__
+ name = self.partial.__name__
- f = self.thetype(capture)
+ f = self.partial(capture)
self.assertEqual('{}({!r})'.format(name, capture),
repr(f))
- f = self.thetype(capture, *args)
+ f = self.partial(capture, *args)
self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
repr(f))
- f = self.thetype(capture, **kwargs)
+ f = self.partial(capture, **kwargs)
self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
repr(f))
- f = self.thetype(capture, *args, **kwargs)
+ f = self.partial(capture, *args, **kwargs)
self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
repr(f))
def test_pickle(self):
- f = self.thetype(signature, 'asdf', bar=True)
+ f = self.partial(signature, 'asdf', bar=True)
f.add_something_to__dict__ = True
f_copy = pickle.loads(pickle.dumps(f))
self.assertEqual(signature(f), signature(f_copy))
@@ -193,28 +197,140 @@ class TestPartial(unittest.TestCase):
return {}
raise IndexError
- f = self.thetype(object)
+ f = self.partial(object)
self.assertRaisesRegex(SystemError,
"new style getargs format but argument is not a tuple",
f.__setstate__, BadSequence())
-class PartialSubclass(functools.partial):
- pass
-class TestPartialSubclass(TestPartial):
+class TestPartialPy(TestPartial, unittest.TestCase):
+ partial = staticmethod(py_functools.partial)
+
+
+if c_functools:
+ class PartialSubclass(c_functools.partial):
+ pass
+
- thetype = PartialSubclass
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialCSubclass(TestPartialC):
+ if c_functools:
+ partial = PartialSubclass
-class TestPythonPartial(TestPartial):
- thetype = PythonPartial
+class TestPartialMethod(unittest.TestCase):
- # the python version hasn't a nice repr
- def test_repr(self): pass
+ class A(object):
+ nothing = functools.partialmethod(capture)
+ positional = functools.partialmethod(capture, 1)
+ keywords = functools.partialmethod(capture, a=2)
+ both = functools.partialmethod(capture, 3, b=4)
+
+ nested = functools.partialmethod(positional, 5)
+
+ over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
+
+ static = functools.partialmethod(staticmethod(capture), 8)
+ cls = functools.partialmethod(classmethod(capture), d=9)
+
+ a = A()
+
+ def test_arg_combinations(self):
+ self.assertEqual(self.a.nothing(), ((self.a,), {}))
+ self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
+ self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
+ self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
+
+ self.assertEqual(self.a.positional(), ((self.a, 1), {}))
+ self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
+ self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
+ self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
+
+ self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
+ self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
+ self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
+ self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
+
+ self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
+ self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
+ self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
+ self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
+
+ self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
+
+ def test_nested(self):
+ self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
+ self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
+ self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
+ self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
+
+ self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
+
+ def test_over_partial(self):
+ self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
+ self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
+ self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
+ self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
+
+ self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
+
+ def test_bound_method_introspection(self):
+ obj = self.a
+ self.assertIs(obj.both.__self__, obj)
+ self.assertIs(obj.nested.__self__, obj)
+ self.assertIs(obj.over_partial.__self__, obj)
+ self.assertIs(obj.cls.__self__, self.A)
+ self.assertIs(self.A.cls.__self__, self.A)
+
+ def test_unbound_method_retrieval(self):
+ obj = self.A
+ self.assertFalse(hasattr(obj.both, "__self__"))
+ self.assertFalse(hasattr(obj.nested, "__self__"))
+ self.assertFalse(hasattr(obj.over_partial, "__self__"))
+ self.assertFalse(hasattr(obj.static, "__self__"))
+ self.assertFalse(hasattr(self.a.static, "__self__"))
+
+ def test_descriptors(self):
+ for obj in [self.A, self.a]:
+ with self.subTest(obj=obj):
+ self.assertEqual(obj.static(), ((8,), {}))
+ self.assertEqual(obj.static(5), ((8, 5), {}))
+ self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
+ self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
+
+ self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
+ self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
+ self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
+ self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
+
+ def test_overriding_keywords(self):
+ self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
+ self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
+
+ def test_invalid_args(self):
+ with self.assertRaises(TypeError):
+ class B(object):
+ method = functools.partialmethod(None, 1)
+
+ def test_repr(self):
+ self.assertEqual(repr(vars(self.A)['both']),
+ 'functools.partialmethod({}, 3, b=4)'.format(capture))
+
+ def test_abstract(self):
+ class Abstract(abc.ABCMeta):
+
+ @abc.abstractmethod
+ def add(self, x, y):
+ pass
+
+ add5 = functools.partialmethod(add, 5)
+
+ self.assertTrue(Abstract.add.__isabstractmethod__)
+ self.assertTrue(Abstract.add5.__isabstractmethod__)
+
+ for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
+ self.assertFalse(getattr(func, '__isabstractmethod__', False))
- # the python version isn't picklable
- def test_pickle(self): pass
- def test_setstate_refcount(self): pass
class TestUpdateWrapper(unittest.TestCase):
@@ -223,19 +339,26 @@ class TestUpdateWrapper(unittest.TestCase):
updated=functools.WRAPPER_UPDATES):
# Check attributes were assigned
for name in assigned:
- self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
+ self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
# Check attributes were updated
for name in updated:
wrapper_attr = getattr(wrapper, name)
wrapped_attr = getattr(wrapped, name)
for key in wrapped_attr:
- self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
+ if name == "__dict__" and key == "__wrapped__":
+ # __wrapped__ is overwritten by the update code
+ continue
+ self.assertIs(wrapped_attr[key], wrapper_attr[key])
+ # Check __wrapped__
+ self.assertIs(wrapper.__wrapped__, wrapped)
+
def _default_update(self):
def f(a:'This is a new annotation'):
"""This is a test"""
pass
f.attr = 'This is also a test'
+ f.__wrapped__ = "This is a bald faced lie"
def wrapper(b:'This is the prior annotation'):
pass
functools.update_wrapper(wrapper, f)
@@ -322,6 +445,7 @@ class TestUpdateWrapper(unittest.TestCase):
self.assertTrue(wrapper.__doc__.startswith('max('))
self.assertEqual(wrapper.__annotations__, {})
+
class TestWraps(TestUpdateWrapper):
def _default_update(self):
@@ -329,14 +453,15 @@ class TestWraps(TestUpdateWrapper):
"""This is a test"""
pass
f.attr = 'This is also a test'
+ f.__wrapped__ = "This is still a bald faced lie"
@functools.wraps(f)
def wrapper():
pass
- self.check_wrapper(wrapper, f)
return wrapper, f
def test_default_update(self):
wrapper, f = self._default_update()
+ self.check_wrapper(wrapper, f)
self.assertEqual(wrapper.__name__, 'f')
self.assertEqual(wrapper.__qualname__, f.__qualname__)
self.assertEqual(wrapper.attr, 'This is also a test')
@@ -382,6 +507,7 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
+
class TestReduce(unittest.TestCase):
func = functools.reduce
@@ -462,24 +588,29 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
-class TestCmpToKey(unittest.TestCase):
+
+class TestCmpToKey:
def test_cmp_to_key(self):
def cmp1(x, y):
return (x > y) - (x < y)
- key = functools.cmp_to_key(cmp1)
+ key = self.cmp_to_key(cmp1)
self.assertEqual(key(3), key(3))
self.assertGreater(key(3), key(1))
+ self.assertGreaterEqual(key(3), key(3))
+
def cmp2(x, y):
return int(x) - int(y)
- key = functools.cmp_to_key(cmp2)
+ key = self.cmp_to_key(cmp2)
self.assertEqual(key(4.0), key('4'))
self.assertLess(key(2), key('35'))
+ self.assertLessEqual(key(2), key('35'))
+ self.assertNotEqual(key(2), key('35'))
def test_cmp_to_key_arguments(self):
def cmp1(x, y):
return (x > y) - (x < y)
- key = functools.cmp_to_key(mycmp=cmp1)
+ key = self.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(obj=3), key(obj=3))
self.assertGreater(key(obj=3), key(obj=1))
with self.assertRaises((TypeError, AttributeError)):
@@ -487,10 +618,10 @@ class TestCmpToKey(unittest.TestCase):
with self.assertRaises((TypeError, AttributeError)):
1 < key(3) # lhs is not a K object
with self.assertRaises(TypeError):
- key = functools.cmp_to_key() # too few args
+ key = self.cmp_to_key() # too few args
with self.assertRaises(TypeError):
- key = functools.cmp_to_key(cmp1, None) # too many args
- key = functools.cmp_to_key(cmp1)
+ key = self.cmp_to_key(cmp1, None) # too many args
+ key = self.cmp_to_key(cmp1)
with self.assertRaises(TypeError):
key() # too few args
with self.assertRaises(TypeError):
@@ -499,7 +630,7 @@ class TestCmpToKey(unittest.TestCase):
def test_bad_cmp(self):
def cmp1(x, y):
raise ZeroDivisionError
- key = functools.cmp_to_key(cmp1)
+ key = self.cmp_to_key(cmp1)
with self.assertRaises(ZeroDivisionError):
key(3) > key(1)
@@ -514,13 +645,13 @@ class TestCmpToKey(unittest.TestCase):
def test_obj_field(self):
def cmp1(x, y):
return (x > y) - (x < y)
- key = functools.cmp_to_key(mycmp=cmp1)
+ key = self.cmp_to_key(mycmp=cmp1)
self.assertEqual(key(50).obj, 50)
def test_sort_int(self):
def mycmp(x, y):
return y - x
- self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
+ self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
def test_sort_int_str(self):
@@ -528,18 +659,29 @@ class TestCmpToKey(unittest.TestCase):
x, y = int(x), int(y)
return (x > y) - (x < y)
values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
- values = sorted(values, key=functools.cmp_to_key(mycmp))
+ values = sorted(values, key=self.cmp_to_key(mycmp))
self.assertEqual([int(value) for value in values],
[0, 1, 1, 2, 3, 4, 5, 7, 10])
def test_hash(self):
def mycmp(x, y):
return y - x
- key = functools.cmp_to_key(mycmp)
+ key = self.cmp_to_key(mycmp)
k = key(10)
self.assertRaises(TypeError, hash, k)
self.assertNotIsInstance(k, collections.Hashable)
+
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
+ if c_functools:
+ cmp_to_key = c_functools.cmp_to_key
+
+
+class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
+ cmp_to_key = staticmethod(py_functools.cmp_to_key)
+
+
class TestTotalOrdering(unittest.TestCase):
def test_total_ordering_lt(self):
@@ -557,6 +699,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2))
+ self.assertFalse(A(1) > A(2))
def test_total_ordering_le(self):
@functools.total_ordering
@@ -573,6 +716,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2))
+ self.assertFalse(A(1) >= A(2))
def test_total_ordering_gt(self):
@functools.total_ordering
@@ -589,6 +733,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2))
+ self.assertFalse(A(2) < A(1))
def test_total_ordering_ge(self):
@functools.total_ordering
@@ -605,6 +750,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2))
+ self.assertFalse(A(2) <= A(1))
def test_total_ordering_no_overwrite(self):
# new methods should not overwrite existing
@@ -624,27 +770,118 @@ class TestTotalOrdering(unittest.TestCase):
class A:
pass
- def test_bug_10042(self):
+ def test_type_error_when_not_implemented(self):
+ # bug 10042; ensure stack overflow does not occur
+ # when decorated types return NotImplemented
@functools.total_ordering
- class TestTO:
+ class ImplementsLessThan:
def __init__(self, value):
self.value = value
def __eq__(self, other):
- if isinstance(other, TestTO):
+ if isinstance(other, ImplementsLessThan):
return self.value == other.value
return False
def __lt__(self, other):
- if isinstance(other, TestTO):
+ if isinstance(other, ImplementsLessThan):
return self.value < other.value
- raise TypeError
- with self.assertRaises(TypeError):
- TestTO(8) <= ()
+ return NotImplemented
+
+ @functools.total_ordering
+ class ImplementsGreaterThan:
+ def __init__(self, value):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, ImplementsGreaterThan):
+ return self.value == other.value
+ return False
+ def __gt__(self, other):
+ if isinstance(other, ImplementsGreaterThan):
+ return self.value > other.value
+ return NotImplemented
+
+ @functools.total_ordering
+ class ImplementsLessThanEqualTo:
+ def __init__(self, value):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, ImplementsLessThanEqualTo):
+ return self.value == other.value
+ return False
+ def __le__(self, other):
+ if isinstance(other, ImplementsLessThanEqualTo):
+ return self.value <= other.value
+ return NotImplemented
+
+ @functools.total_ordering
+ class ImplementsGreaterThanEqualTo:
+ def __init__(self, value):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, ImplementsGreaterThanEqualTo):
+ return self.value == other.value
+ return False
+ def __ge__(self, other):
+ if isinstance(other, ImplementsGreaterThanEqualTo):
+ return self.value >= other.value
+ return NotImplemented
+
+ @functools.total_ordering
+ class ComparatorNotImplemented:
+ def __init__(self, value):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, ComparatorNotImplemented):
+ return self.value == other.value
+ return False
+ def __lt__(self, other):
+ return NotImplemented
+
+ with self.subTest("LT < 1"), self.assertRaises(TypeError):
+ ImplementsLessThan(-1) < 1
+
+ with self.subTest("LT < LE"), self.assertRaises(TypeError):
+ ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
+
+ with self.subTest("LT < GT"), self.assertRaises(TypeError):
+ ImplementsLessThan(1) < ImplementsGreaterThan(1)
+
+ with self.subTest("LE <= LT"), self.assertRaises(TypeError):
+ ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
+
+ with self.subTest("LE <= GE"), self.assertRaises(TypeError):
+ ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
+
+ with self.subTest("GT > GE"), self.assertRaises(TypeError):
+ ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
+
+ with self.subTest("GT > LT"), self.assertRaises(TypeError):
+ ImplementsGreaterThan(5) > ImplementsLessThan(5)
+
+ with self.subTest("GE >= GT"), self.assertRaises(TypeError):
+ ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
+
+ with self.subTest("GE >= LE"), self.assertRaises(TypeError):
+ ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
+
+ with self.subTest("GE when equal"):
+ a = ComparatorNotImplemented(8)
+ b = ComparatorNotImplemented(8)
+ self.assertEqual(a, b)
+ with self.assertRaises(TypeError):
+ a >= b
+
+ with self.subTest("LE when equal"):
+ a = ComparatorNotImplemented(9)
+ b = ComparatorNotImplemented(9)
+ self.assertEqual(a, b)
+ with self.assertRaises(TypeError):
+ a <= b
class TestLRU(unittest.TestCase):
def test_lru(self):
def orig(x, y):
- return 3*x+y
+ return 3 * x + y
f = functools.lru_cache(maxsize=20)(orig)
hits, misses, maxsize, currsize = f.cache_info()
self.assertEqual(maxsize, 20)
@@ -749,7 +986,7 @@ class TestLRU(unittest.TestCase):
# Verify that user_function exceptions get passed through without
# creating a hard-to-read chained exception.
# http://bugs.python.org/issue13177
- for maxsize in (None, 100):
+ for maxsize in (None, 128):
@functools.lru_cache(maxsize)
def func(i):
return 'abc'[i]
@@ -762,7 +999,7 @@ class TestLRU(unittest.TestCase):
func(15)
def test_lru_with_types(self):
- for maxsize in (None, 100):
+ for maxsize in (None, 128):
@functools.lru_cache(maxsize=maxsize, typed=True)
def square(x):
return x * x
@@ -777,6 +1014,36 @@ class TestLRU(unittest.TestCase):
self.assertEqual(square.cache_info().hits, 4)
self.assertEqual(square.cache_info().misses, 4)
+ def test_lru_with_keyword_args(self):
+ @functools.lru_cache()
+ def fib(n):
+ if n < 2:
+ return n
+ return fib(n=n-1) + fib(n=n-2)
+ self.assertEqual(
+ [fib(n=number) for number in range(16)],
+ [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
+ )
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
+ fib.cache_clear()
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
+
+ def test_lru_with_keyword_args_maxsize_none(self):
+ @functools.lru_cache(maxsize=None)
+ def fib(n):
+ if n < 2:
+ return n
+ return fib(n=n-1) + fib(n=n-2)
+ self.assertEqual([fib(n=number) for number in range(16)],
+ [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
+ fib.cache_clear()
+ self.assertEqual(fib.cache_info(),
+ functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
+
def test_need_for_rlock(self):
# This will deadlock on an LRU cache that uses a regular lock
@@ -802,17 +1069,494 @@ class TestLRU(unittest.TestCase):
DoubleEq(2)) # Verify the correct return value
+class TestSingleDispatch(unittest.TestCase):
+ def test_simple_overloads(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ def g_int(i):
+ return "integer"
+ g.register(int, g_int)
+ self.assertEqual(g("str"), "base")
+ self.assertEqual(g(1), "integer")
+ self.assertEqual(g([1,2,3]), "base")
+
+ def test_mro(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ class A:
+ pass
+ class C(A):
+ pass
+ class B(A):
+ pass
+ class D(C, B):
+ pass
+ def g_A(a):
+ return "A"
+ def g_B(b):
+ return "B"
+ g.register(A, g_A)
+ g.register(B, g_B)
+ self.assertEqual(g(A()), "A")
+ self.assertEqual(g(B()), "B")
+ self.assertEqual(g(C()), "A")
+ self.assertEqual(g(D()), "B")
+
+ def test_register_decorator(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ @g.register(int)
+ def g_int(i):
+ return "int %s" % (i,)
+ self.assertEqual(g(""), "base")
+ self.assertEqual(g(12), "int 12")
+ self.assertIs(g.dispatch(int), g_int)
+ self.assertIs(g.dispatch(object), g.dispatch(str))
+ # Note: in the assert above this is not g.
+ # @singledispatch returns the wrapper.
+
+ def test_wrapping_attributes(self):
+ @functools.singledispatch
+ def g(obj):
+ "Simple test"
+ return "Test"
+ self.assertEqual(g.__name__, "g")
+ self.assertEqual(g.__doc__, "Simple test")
+
+ @unittest.skipUnless(decimal, 'requires _decimal')
+ @support.cpython_only
+ def test_c_classes(self):
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ @g.register(decimal.DecimalException)
+ def _(obj):
+ return obj.args
+ subn = decimal.Subnormal("Exponent < Emin")
+ rnd = decimal.Rounded("Number got rounded")
+ self.assertEqual(g(subn), ("Exponent < Emin",))
+ self.assertEqual(g(rnd), ("Number got rounded",))
+ @g.register(decimal.Subnormal)
+ def _(obj):
+ return "Too small to care."
+ self.assertEqual(g(subn), "Too small to care.")
+ self.assertEqual(g(rnd), ("Number got rounded",))
+
+ def test_compose_mro(self):
+ # None of the examples in this test depend on haystack ordering.
+ c = collections
+ mro = functools._compose_mro
+ bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
+ for haystack in permutations(bases):
+ m = mro(dict, haystack)
+ self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
+ c.Iterable, c.Container, object])
+ bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
+ for haystack in permutations(bases):
+ m = mro(c.ChainMap, haystack)
+ self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
+ c.Sized, c.Iterable, c.Container, object])
+
+ # If there's a generic function with implementations registered for
+ # both Sized and Container, passing a defaultdict to it results in an
+ # ambiguous dispatch which will cause a RuntimeError (see
+ # test_mro_conflicts).
+ bases = [c.Container, c.Sized, str]
+ for haystack in permutations(bases):
+ m = mro(c.defaultdict, [c.Sized, c.Container, str])
+ self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
+ object])
+
+ # MutableSequence below is registered directly on D. In other words, it
+ # preceeds MutableMapping which means single dispatch will always
+ # choose MutableSequence here.
+ class D(c.defaultdict):
+ pass
+ c.MutableSequence.register(D)
+ bases = [c.MutableSequence, c.MutableMapping]
+ for haystack in permutations(bases):
+ m = mro(D, bases)
+ self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
+ c.defaultdict, dict, c.MutableMapping,
+ c.Mapping, c.Sized, c.Iterable, c.Container,
+ object])
+
+ # Container and Callable are registered on different base classes and
+ # a generic function supporting both should always pick the Callable
+ # implementation if a C instance is passed.
+ class C(c.defaultdict):
+ def __call__(self):
+ pass
+ bases = [c.Sized, c.Callable, c.Container, c.Mapping]
+ for haystack in permutations(bases):
+ m = mro(C, haystack)
+ self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
+ c.Sized, c.Iterable, c.Container, object])
+
+ def test_register_abc(self):
+ c = collections
+ d = {"a": "b"}
+ l = [1, 2, 3]
+ s = {object(), None}
+ f = frozenset(s)
+ t = (1, 2, 3)
+ @functools.singledispatch
+ def g(obj):
+ return "base"
+ self.assertEqual(g(d), "base")
+ self.assertEqual(g(l), "base")
+ self.assertEqual(g(s), "base")
+ self.assertEqual(g(f), "base")
+ self.assertEqual(g(t), "base")
+ g.register(c.Sized, lambda obj: "sized")
+ self.assertEqual(g(d), "sized")
+ self.assertEqual(g(l), "sized")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.MutableMapping, lambda obj: "mutablemapping")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "sized")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.ChainMap, lambda obj: "chainmap")
+ self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
+ self.assertEqual(g(l), "sized")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.MutableSequence, lambda obj: "mutablesequence")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "sized")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.MutableSet, lambda obj: "mutableset")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.Mapping, lambda obj: "mapping")
+ self.assertEqual(g(d), "mutablemapping") # not specific enough
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sized")
+ g.register(c.Sequence, lambda obj: "sequence")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "sized")
+ self.assertEqual(g(t), "sequence")
+ g.register(c.Set, lambda obj: "set")
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(dict, lambda obj: "dict")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "mutablesequence")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(list, lambda obj: "list")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "mutableset")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(set, lambda obj: "concrete-set")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "concrete-set")
+ self.assertEqual(g(f), "set")
+ self.assertEqual(g(t), "sequence")
+ g.register(frozenset, lambda obj: "frozen-set")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "concrete-set")
+ self.assertEqual(g(f), "frozen-set")
+ self.assertEqual(g(t), "sequence")
+ g.register(tuple, lambda obj: "tuple")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(s), "concrete-set")
+ self.assertEqual(g(f), "frozen-set")
+ self.assertEqual(g(t), "tuple")
+
+ def test_c3_abc(self):
+ c = collections
+ mro = functools._c3_mro
+ class A(object):
+ pass
+ class B(A):
+ def __len__(self):
+ return 0 # implies Sized
+ @c.Container.register
+ class C(object):
+ pass
+ class D(object):
+ pass # unrelated
+ class X(D, C, B):
+ def __call__(self):
+ pass # implies Callable
+ expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
+ for abcs in permutations([c.Sized, c.Callable, c.Container]):
+ self.assertEqual(mro(X, abcs=abcs), expected)
+ # unrelated ABCs don't appear in the resulting MRO
+ many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
+ self.assertEqual(mro(X, abcs=many_abcs), expected)
+
+ def test_mro_conflicts(self):
+ c = collections
+ @functools.singledispatch
+ def g(arg):
+ return "base"
+ class O(c.Sized):
+ def __len__(self):
+ return 0
+ o = O()
+ self.assertEqual(g(o), "base")
+ g.register(c.Iterable, lambda arg: "iterable")
+ g.register(c.Container, lambda arg: "container")
+ g.register(c.Sized, lambda arg: "sized")
+ g.register(c.Set, lambda arg: "set")
+ self.assertEqual(g(o), "sized")
+ c.Iterable.register(O)
+ self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
+ c.Container.register(O)
+ self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
+ c.Set.register(O)
+ self.assertEqual(g(o), "set") # because c.Set is a subclass of
+ # c.Sized and c.Container
+ class P:
+ pass
+ p = P()
+ self.assertEqual(g(p), "base")
+ c.Iterable.register(P)
+ self.assertEqual(g(p), "iterable")
+ c.Container.register(P)
+ with self.assertRaises(RuntimeError) as re_one:
+ g(p)
+ self.assertIn(
+ str(re_one.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Iterable'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class Q(c.Sized):
+ def __len__(self):
+ return 0
+ q = Q()
+ self.assertEqual(g(q), "sized")
+ c.Iterable.register(Q)
+ self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
+ c.Set.register(Q)
+ self.assertEqual(g(q), "set") # because c.Set is a subclass of
+ # c.Sized and c.Iterable
+ @functools.singledispatch
+ def h(arg):
+ return "base"
+ @h.register(c.Sized)
+ def _(arg):
+ return "sized"
+ @h.register(c.Container)
+ def _(arg):
+ return "container"
+ # Even though Sized and Container are explicit bases of MutableMapping,
+ # this ABC is implicitly registered on defaultdict which makes all of
+ # MutableMapping's bases implicit as well from defaultdict's
+ # perspective.
+ with self.assertRaises(RuntimeError) as re_two:
+ h(c.defaultdict(lambda: 0))
+ self.assertIn(
+ str(re_two.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Sized'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class R(c.defaultdict):
+ pass
+ c.MutableSequence.register(R)
+ @functools.singledispatch
+ def i(arg):
+ return "base"
+ @i.register(c.MutableMapping)
+ def _(arg):
+ return "mapping"
+ @i.register(c.MutableSequence)
+ def _(arg):
+ return "sequence"
+ r = R()
+ self.assertEqual(i(r), "sequence")
+ class S:
+ pass
+ class T(S, c.Sized):
+ def __len__(self):
+ return 0
+ t = T()
+ self.assertEqual(h(t), "sized")
+ c.Container.register(T)
+ self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
+ class U:
+ def __len__(self):
+ return 0
+ u = U()
+ self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
+ # from the existence of __len__()
+ c.Container.register(U)
+ # There is no preference for registered versus inferred ABCs.
+ with self.assertRaises(RuntimeError) as re_three:
+ h(u)
+ self.assertIn(
+ str(re_three.exception),
+ (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+ "or <class 'collections.abc.Sized'>"),
+ ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+ "or <class 'collections.abc.Container'>")),
+ )
+ class V(c.Sized, S):
+ def __len__(self):
+ return 0
+ @functools.singledispatch
+ def j(arg):
+ return "base"
+ @j.register(S)
+ def _(arg):
+ return "s"
+ @j.register(c.Container)
+ def _(arg):
+ return "container"
+ v = V()
+ self.assertEqual(j(v), "s")
+ c.Container.register(V)
+ self.assertEqual(j(v), "container") # because it ends up right after
+ # Sized in the MRO
+
+ def test_cache_invalidation(self):
+ from collections import UserDict
+ class TracingDict(UserDict):
+ def __init__(self, *args, **kwargs):
+ super(TracingDict, self).__init__(*args, **kwargs)
+ self.set_ops = []
+ self.get_ops = []
+ def __getitem__(self, key):
+ result = self.data[key]
+ self.get_ops.append(key)
+ return result
+ def __setitem__(self, key, value):
+ self.set_ops.append(key)
+ self.data[key] = value
+ def clear(self):
+ self.data.clear()
+ _orig_wkd = functools.WeakKeyDictionary
+ td = TracingDict()
+ functools.WeakKeyDictionary = lambda: td
+ c = collections
+ @functools.singledispatch
+ def g(arg):
+ return "base"
+ d = {}
+ l = []
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "base")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(td.get_ops, [])
+ self.assertEqual(td.set_ops, [dict])
+ self.assertEqual(td.data[dict], g.registry[object])
+ self.assertEqual(g(l), "base")
+ self.assertEqual(len(td), 2)
+ self.assertEqual(td.get_ops, [])
+ self.assertEqual(td.set_ops, [dict, list])
+ self.assertEqual(td.data[dict], g.registry[object])
+ self.assertEqual(td.data[list], g.registry[object])
+ self.assertEqual(td.data[dict], td.data[list])
+ self.assertEqual(g(l), "base")
+ self.assertEqual(g(d), "base")
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(td.set_ops, [dict, list])
+ g.register(list, lambda arg: "list")
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "base")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict])
+ self.assertEqual(td.data[dict],
+ functools._find_impl(dict, g.registry))
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 2)
+ self.assertEqual(td.get_ops, [list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list])
+ self.assertEqual(td.data[list],
+ functools._find_impl(list, g.registry))
+ class X:
+ pass
+ c.MutableMapping.register(X) # Will not invalidate the cache,
+ # not using ABCs yet.
+ self.assertEqual(g(d), "base")
+ self.assertEqual(g(l), "list")
+ self.assertEqual(td.get_ops, [list, dict, dict, list])
+ self.assertEqual(td.set_ops, [dict, list, dict, list])
+ g.register(c.Sized, lambda arg: "sized")
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "sized")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(td.get_ops, [list, dict, dict, list])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 2)
+ self.assertEqual(td.get_ops, [list, dict, dict, list])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ self.assertEqual(g(l), "list")
+ self.assertEqual(g(d), "sized")
+ self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ g.dispatch(list)
+ g.dispatch(dict)
+ self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
+ list, dict])
+ self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
+ c.MutableSet.register(X) # Will invalidate the cache.
+ self.assertEqual(len(td), 2) # Stale cache.
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 1)
+ g.register(c.MutableMapping, lambda arg: "mutablemapping")
+ self.assertEqual(len(td), 0)
+ self.assertEqual(g(d), "mutablemapping")
+ self.assertEqual(len(td), 1)
+ self.assertEqual(g(l), "list")
+ self.assertEqual(len(td), 2)
+ g.register(dict, lambda arg: "dict")
+ self.assertEqual(g(d), "dict")
+ self.assertEqual(g(l), "list")
+ g._clear_cache()
+ self.assertEqual(len(td), 0)
+ functools.WeakKeyDictionary = _orig_wkd
+
+
def test_main(verbose=None):
test_classes = (
- TestPartial,
- TestPartialSubclass,
- TestPythonPartial,
+ TestPartialC,
+ TestPartialPy,
+ TestPartialCSubclass,
+ TestPartialMethod,
TestUpdateWrapper,
TestTotalOrdering,
- TestCmpToKey,
+ TestCmpToKeyC,
+ TestCmpToKeyPy,
TestWraps,
TestReduce,
TestLRU,
+ TestSingleDispatch,
)
support.run_unittest(*test_classes)
diff --git a/Lib/test/test_future.py b/Lib/test/test_future.py
index a0c156f5f7..beac993e4d 100644
--- a/Lib/test/test_future.py
+++ b/Lib/test/test_future.py
@@ -82,6 +82,14 @@ class FutureTest(unittest.TestCase):
else:
self.fail("expected exception didn't occur")
+ def test_badfuture10(self):
+ try:
+ from test import badsyntax_future10
+ except SyntaxError as msg:
+ self.assertEqual(get_error_location(msg), ("badsyntax_future10", '3'))
+ else:
+ self.fail("expected exception didn't occur")
+
def test_parserhack(self):
# test that the parser.c::future_hack function works as expected
# Note: although this test must pass, it's not testing the original
diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py
index c59b72eacf..e8f52a508f 100644
--- a/Lib/test/test_gc.py
+++ b/Lib/test/test_gc.py
@@ -1,6 +1,9 @@
+import _testcapi
import unittest
from test.support import (verbose, refcount_test, run_unittest,
strip_python_stderr)
+from test.script_helper import assert_python_ok, make_script, temp_dir
+
import sys
import time
import gc
@@ -38,6 +41,7 @@ class GC_Detector(object):
# gc collects it.
self.wr = weakref.ref(C1055820(666), it_happened)
+@_testcapi.with_tp_del
class Uncollectable(object):
"""Create a reference cycle with multiple __del__ methods.
@@ -50,7 +54,7 @@ class Uncollectable(object):
self.partner = Uncollectable(partner=self)
else:
self.partner = partner
- def __del__(self):
+ def __tp_del__(self):
pass
### Tests
@@ -139,11 +143,12 @@ class GCTests(unittest.TestCase):
del a
self.assertNotEqual(gc.collect(), 0)
- def test_finalizer(self):
+ def test_legacy_finalizer(self):
# A() is uncollectable if it is part of a cycle, make sure it shows up
# in gc.garbage.
+ @_testcapi.with_tp_del
class A:
- def __del__(self): pass
+ def __tp_del__(self): pass
class B:
pass
a = A()
@@ -163,11 +168,12 @@ class GCTests(unittest.TestCase):
self.fail("didn't find obj in garbage (finalizer)")
gc.garbage.remove(obj)
- def test_finalizer_newclass(self):
+ def test_legacy_finalizer_newclass(self):
# A() is uncollectable if it is part of a cycle, make sure it shows up
# in gc.garbage.
+ @_testcapi.with_tp_del
class A(object):
- def __del__(self): pass
+ def __tp_del__(self): pass
class B(object):
pass
a = A()
@@ -568,12 +574,14 @@ class GCTests(unittest.TestCase):
import subprocess
code = """if 1:
import gc
+ import _testcapi
+ @_testcapi.with_tp_del
class X:
def __init__(self, name):
self.name = name
def __repr__(self):
return "<X %%r>" %% self.name
- def __del__(self):
+ def __tp_del__(self):
pass
x = X('first')
@@ -610,6 +618,66 @@ class GCTests(unittest.TestCase):
stderr = run_command(code % "gc.DEBUG_SAVEALL")
self.assertNotIn(b"uncollectable objects at shutdown", stderr)
+ def test_gc_main_module_at_shutdown(self):
+ # Create a reference cycle through the __main__ module and check
+ # it gets collected at interpreter shutdown.
+ code = """if 1:
+ import weakref
+ class C:
+ def __del__(self):
+ print('__del__ called')
+ l = [C()]
+ l.append(l)
+ """
+ rc, out, err = assert_python_ok('-c', code)
+ self.assertEqual(out.strip(), b'__del__ called')
+
+ def test_gc_ordinary_module_at_shutdown(self):
+ # Same as above, but with a non-__main__ module.
+ with temp_dir() as script_dir:
+ module = """if 1:
+ import weakref
+ class C:
+ def __del__(self):
+ print('__del__ called')
+ l = [C()]
+ l.append(l)
+ """
+ code = """if 1:
+ import sys
+ sys.path.insert(0, %r)
+ import gctest
+ """ % (script_dir,)
+ make_script(script_dir, 'gctest', module)
+ rc, out, err = assert_python_ok('-c', code)
+ self.assertEqual(out.strip(), b'__del__ called')
+
+ def test_get_stats(self):
+ stats = gc.get_stats()
+ self.assertEqual(len(stats), 3)
+ for st in stats:
+ self.assertIsInstance(st, dict)
+ self.assertEqual(set(st),
+ {"collected", "collections", "uncollectable"})
+ self.assertGreaterEqual(st["collected"], 0)
+ self.assertGreaterEqual(st["collections"], 0)
+ self.assertGreaterEqual(st["uncollectable"], 0)
+ # Check that collection counts are incremented correctly
+ if gc.isenabled():
+ self.addCleanup(gc.enable)
+ gc.disable()
+ old = gc.get_stats()
+ gc.collect(0)
+ new = gc.get_stats()
+ self.assertEqual(new[0]["collections"], old[0]["collections"] + 1)
+ self.assertEqual(new[1]["collections"], old[1]["collections"])
+ self.assertEqual(new[2]["collections"], old[2]["collections"])
+ gc.collect(2)
+ new = gc.get_stats()
+ self.assertEqual(new[0]["collections"], old[0]["collections"] + 1)
+ self.assertEqual(new[1]["collections"], old[1]["collections"])
+ self.assertEqual(new[2]["collections"], old[2]["collections"] + 1)
+
class GCCallbackTests(unittest.TestCase):
def setUp(self):
diff --git a/Lib/test/test_gdb.py b/Lib/test/test_gdb.py
index 284c8d8fd9..624e3d3db0 100644
--- a/Lib/test/test_gdb.py
+++ b/Lib/test/test_gdb.py
@@ -5,6 +5,7 @@
import os
import re
+import pprint
import subprocess
import sys
import sysconfig
@@ -17,6 +18,7 @@ try:
except ImportError:
_thread = None
+from test import support
from test.support import run_unittest, findfile, python_is_optimized
try:
@@ -165,6 +167,7 @@ class DebuggerTests(unittest.TestCase):
'linux-gate.so',
'Do you need "set solib-search-path" or '
'"set sysroot"?',
+ 'warning: Source file is more recent than executable.',
)
for line in errlines:
if not line.startswith(ignore_patterns):
@@ -307,6 +310,8 @@ class PrettyPrintTests(DebuggerTests):
def test_sets(self):
'Verify the pretty-printing of sets'
+ if (gdb_major_version, gdb_minor_version) < (7, 3):
+ self.skipTest("pretty-printing of sets needs gdb 7.3 or later")
self.assertGdbRepr(set())
self.assertGdbRepr(set(['a', 'b']), "{'a', 'b'}")
self.assertGdbRepr(set([4, 5, 6]), "{4, 5, 6}")
@@ -320,6 +325,8 @@ id(s)''')
def test_frozensets(self):
'Verify the pretty-printing of frozensets'
+ if (gdb_major_version, gdb_minor_version) < (7, 3):
+ self.skipTest("pretty-printing of frozensets needs gdb 7.3 or later")
self.assertGdbRepr(frozenset())
self.assertGdbRepr(frozenset(['a', 'b']), "frozenset({'a', 'b'})")
self.assertGdbRepr(frozenset([4, 5, 6]), "frozenset({4, 5, 6})")
@@ -837,6 +844,10 @@ class PyLocalsTests(DebuggerTests):
r".*\na = 1\nb = 2\nc = 3\n.*")
def test_main():
+ if support.verbose:
+ print("GDB version:")
+ for line in os.fsdecode(gdb_version).splitlines():
+ print(" " * 4 + line)
run_unittest(PrettyPrintTests,
PyListTests,
StackNavigationTests,
diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py
index 958054aef5..4e921177a5 100644
--- a/Lib/test/test_generators.py
+++ b/Lib/test/test_generators.py
@@ -1,3 +1,55 @@
+import gc
+import sys
+import unittest
+import weakref
+
+from test import support
+
+
+class FinalizationTest(unittest.TestCase):
+
+ def test_frame_resurrect(self):
+ # A generator frame can be resurrected by a generator's finalization.
+ def gen():
+ nonlocal frame
+ try:
+ yield
+ finally:
+ frame = sys._getframe()
+
+ g = gen()
+ wr = weakref.ref(g)
+ next(g)
+ del g
+ support.gc_collect()
+ self.assertIs(wr(), None)
+ self.assertTrue(frame)
+ del frame
+ support.gc_collect()
+
+ def test_refcycle(self):
+ # A generator caught in a refcycle gets finalized anyway.
+ old_garbage = gc.garbage[:]
+ finalized = False
+ def gen():
+ nonlocal finalized
+ try:
+ g = yield
+ yield 1
+ finally:
+ finalized = True
+
+ g = gen()
+ next(g)
+ g.send(g)
+ self.assertGreater(sys.getrefcount(g), 2)
+ self.assertFalse(finalized)
+ del g
+ support.gc_collect()
+ self.assertTrue(finalized)
+ self.assertEqual(gc.garbage, old_garbage)
+
+
tutorial_tests = """
Let's try a simple generator:
@@ -1729,9 +1781,7 @@ Our ill-behaved code should be invoked during GC:
>>> g = f()
>>> next(g)
>>> del g
->>> sys.stderr.getvalue().startswith(
-... "Exception RuntimeError: 'generator ignored GeneratorExit' in "
-... )
+>>> "RuntimeError: generator ignored GeneratorExit" in sys.stderr.getvalue()
True
>>> sys.stderr = old
@@ -1841,22 +1891,23 @@ to test.
... sys.stderr = io.StringIO()
... class Leaker:
... def __del__(self):
-... raise RuntimeError
+... def invoke(message):
+... raise RuntimeError(message)
+... invoke("test")
...
... l = Leaker()
... del l
... err = sys.stderr.getvalue().strip()
-... err.startswith(
-... "Exception RuntimeError: RuntimeError() in <"
-... )
-... err.endswith("> ignored")
-... len(err.splitlines())
+... "Exception ignored in" in err
+... "RuntimeError: test" in err
+... "Traceback" in err
+... "in invoke" in err
... finally:
... sys.stderr = old
True
True
-1
-
+True
+True
These refleak tests should perhaps be in a testfile of their own,
@@ -1881,6 +1932,7 @@ __test__ = {"tut": tutorial_tests,
# so this works as expected in both ways of running regrtest.
def test_main(verbose=None):
from test import support, test_generators
+ support.run_unittest(__name__)
support.run_doctest(test_generators, verbose)
# This part isn't needed for regrtest, but for running the test directly.
diff --git a/Lib/test/test_genericpath.py b/Lib/test/test_genericpath.py
index fd8bc577ca..e967897d75 100644
--- a/Lib/test/test_genericpath.py
+++ b/Lib/test/test_genericpath.py
@@ -188,12 +188,93 @@ class GenericTest:
support.unlink(support.TESTFN)
safe_rmdir(support.TESTFN)
+ @staticmethod
+ def _create_file(filename):
+ with open(filename, 'wb') as f:
+ f.write(b'foo')
+
+ def test_samefile(self):
+ try:
+ test_fn = support.TESTFN + "1"
+ self._create_file(test_fn)
+ self.assertTrue(self.pathmodule.samefile(test_fn, test_fn))
+ self.assertRaises(TypeError, self.pathmodule.samefile)
+ finally:
+ os.remove(test_fn)
+
+ @support.skip_unless_symlink
+ def test_samefile_on_symlink(self):
+ self._test_samefile_on_link_func(os.symlink)
+
+ def test_samefile_on_link(self):
+ self._test_samefile_on_link_func(os.link)
+
+ def _test_samefile_on_link_func(self, func):
+ try:
+ test_fn1 = support.TESTFN + "1"
+ test_fn2 = support.TESTFN + "2"
+ self._create_file(test_fn1)
+
+ func(test_fn1, test_fn2)
+ self.assertTrue(self.pathmodule.samefile(test_fn1, test_fn2))
+ os.remove(test_fn2)
+
+ self._create_file(test_fn2)
+ self.assertFalse(self.pathmodule.samefile(test_fn1, test_fn2))
+ finally:
+ os.remove(test_fn1)
+ os.remove(test_fn2)
+
+ def test_samestat(self):
+ try:
+ test_fn = support.TESTFN + "1"
+ self._create_file(test_fn)
+ test_fns = [test_fn]*2
+ stats = map(os.stat, test_fns)
+ self.assertTrue(self.pathmodule.samestat(*stats))
+ finally:
+ os.remove(test_fn)
+
+ @support.skip_unless_symlink
+ def test_samestat_on_symlink(self):
+ self._test_samestat_on_link_func(os.symlink)
+
+ def test_samestat_on_link(self):
+ self._test_samestat_on_link_func(os.link)
+
+ def _test_samestat_on_link_func(self, func):
+ try:
+ test_fn1 = support.TESTFN + "1"
+ test_fn2 = support.TESTFN + "2"
+ self._create_file(test_fn1)
+ test_fns = (test_fn1, test_fn2)
+ func(*test_fns)
+ stats = map(os.stat, test_fns)
+ self.assertTrue(self.pathmodule.samestat(*stats))
+ os.remove(test_fn2)
+
+ self._create_file(test_fn2)
+ stats = map(os.stat, test_fns)
+ self.assertFalse(self.pathmodule.samestat(*stats))
+
+ self.assertRaises(TypeError, self.pathmodule.samestat)
+ finally:
+ os.remove(test_fn1)
+ os.remove(test_fn2)
+
+ def test_sameopenfile(self):
+ fname = support.TESTFN + "1"
+ with open(fname, "wb") as a, open(fname, "wb") as b:
+ self.assertTrue(self.pathmodule.sameopenfile(
+ a.fileno(), b.fileno()))
+
class TestGenericTest(GenericTest, unittest.TestCase):
# Issue 16852: GenericTest can't inherit from unittest.TestCase
# for test discovery purposes; CommonTest inherits from GenericTest
# and is only meant to be inherited by others.
pathmodule = genericpath
+
# Following TestCase is not supposed to be run from test_genericpath.
# It is inherited by other test modules (macpath, ntpath, posixpath).
@@ -322,7 +403,6 @@ class CommonTest(GenericTest):
else:
self.skipTest("need support.TESTFN_NONASCII")
- # Test non-ASCII, non-UTF8 bytes in the path.
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
with support.temp_cwd(name):
diff --git a/Lib/test/test_getargs2.py b/Lib/test/test_getargs2.py
index 48ca94ee38..beea76a920 100644
--- a/Lib/test/test_getargs2.py
+++ b/Lib/test/test_getargs2.py
@@ -1,38 +1,40 @@
import unittest
from test import support
from _testcapi import getargs_keywords, getargs_keyword_only
-
-"""
-> How about the following counterproposal. This also changes some of
-> the other format codes to be a little more regular.
->
-> Code C type Range check
->
-> b unsigned char 0..UCHAR_MAX
-> h signed short SHRT_MIN..SHRT_MAX
-> B unsigned char none **
-> H unsigned short none **
-> k * unsigned long none
-> I * unsigned int 0..UINT_MAX
-
-
-> i int INT_MIN..INT_MAX
-> l long LONG_MIN..LONG_MAX
-
-> K * unsigned long long none
-> L long long LLONG_MIN..LLONG_MAX
-
-> Notes:
->
-> * New format codes.
->
-> ** Changed from previous "range-and-a-half" to "none"; the
-> range-and-a-half checking wasn't particularly useful.
-
-Plus a C API or two, e.g. PyInt_AsLongMask() ->
-unsigned long and PyInt_AsLongLongMask() -> unsigned
-long long (if that exists).
-"""
+try:
+ from _testcapi import getargs_L, getargs_K
+except ImportError:
+ getargs_L = None # PY_LONG_LONG not available
+
+# > How about the following counterproposal. This also changes some of
+# > the other format codes to be a little more regular.
+# >
+# > Code C type Range check
+# >
+# > b unsigned char 0..UCHAR_MAX
+# > h signed short SHRT_MIN..SHRT_MAX
+# > B unsigned char none **
+# > H unsigned short none **
+# > k * unsigned long none
+# > I * unsigned int 0..UINT_MAX
+#
+#
+# > i int INT_MIN..INT_MAX
+# > l long LONG_MIN..LONG_MAX
+#
+# > K * unsigned long long none
+# > L long long LLONG_MIN..LLONG_MAX
+#
+# > Notes:
+# >
+# > * New format codes.
+# >
+# > ** Changed from previous "range-and-a-half" to "none"; the
+# > range-and-a-half checking wasn't particularly useful.
+#
+# Plus a C API or two, e.g. PyInt_AsLongMask() ->
+# unsigned long and PyInt_AsLongLongMask() -> unsigned
+# long long (if that exists).
LARGE = 0x7FFFFFFF
VERY_LARGE = 0xFF0000121212121212121242
@@ -184,6 +186,7 @@ class Signed_TestCase(unittest.TestCase):
self.assertRaises(OverflowError, getargs_n, VERY_LARGE)
+@unittest.skipIf(getargs_L is None, 'PY_LONG_LONG is not available')
class LongLong_TestCase(unittest.TestCase):
def test_L(self):
from _testcapi import getargs_L
@@ -536,24 +539,5 @@ class Unicode_TestCase(unittest.TestCase):
self.assertIsNone(getargs_Z_hash(None))
-def test_main():
- tests = [
- Signed_TestCase,
- Unsigned_TestCase,
- Boolean_TestCase,
- Tuple_TestCase,
- Keywords_TestCase,
- KeywordOnly_TestCase,
- Bytes_TestCase,
- Unicode_TestCase,
- ]
- try:
- from _testcapi import getargs_L, getargs_K
- except ImportError:
- pass # PY_LONG_LONG not available
- else:
- tests.append(LongLong_TestCase)
- support.run_unittest(*tests)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_getpass.py b/Lib/test/test_getpass.py
new file mode 100644
index 0000000000..1731bd44a6
--- /dev/null
+++ b/Lib/test/test_getpass.py
@@ -0,0 +1,155 @@
+import getpass
+import os
+import unittest
+from io import BytesIO, StringIO
+from unittest import mock
+from test import support
+
+try:
+ import termios
+except ImportError:
+ termios = None
+try:
+ import pwd
+except ImportError:
+ pwd = None
+
+@mock.patch('os.environ')
+class GetpassGetuserTest(unittest.TestCase):
+
+ def test_username_takes_username_from_env(self, environ):
+ expected_name = 'some_name'
+ environ.get.return_value = expected_name
+ self.assertEqual(expected_name, getpass.getuser())
+
+ def test_username_priorities_of_env_values(self, environ):
+ environ.get.return_value = None
+ try:
+ getpass.getuser()
+ except ImportError: # in case there's no pwd module
+ pass
+ self.assertEqual(
+ environ.get.call_args_list,
+ [mock.call(x) for x in ('LOGNAME', 'USER', 'LNAME', 'USERNAME')])
+
+ def test_username_falls_back_to_pwd(self, environ):
+ expected_name = 'some_name'
+ environ.get.return_value = None
+ if pwd:
+ with mock.patch('os.getuid') as uid, \
+ mock.patch('pwd.getpwuid') as getpw:
+ uid.return_value = 42
+ getpw.return_value = [expected_name]
+ self.assertEqual(expected_name,
+ getpass.getuser())
+ getpw.assert_called_once_with(42)
+ else:
+ self.assertRaises(ImportError, getpass.getuser)
+
+
+class GetpassRawinputTest(unittest.TestCase):
+
+ def test_flushes_stream_after_prompt(self):
+ # see issue 1703
+ stream = mock.Mock(spec=StringIO)
+ input = StringIO('input_string')
+ getpass._raw_input('some_prompt', stream, input=input)
+ stream.flush.assert_called_once_with()
+
+ def test_uses_stderr_as_default(self):
+ input = StringIO('input_string')
+ prompt = 'some_prompt'
+ with mock.patch('sys.stderr') as stderr:
+ getpass._raw_input(prompt, input=input)
+ stderr.write.assert_called_once_with(prompt)
+
+ @mock.patch('sys.stdin')
+ def test_uses_stdin_as_default_input(self, mock_input):
+ mock_input.readline.return_value = 'input_string'
+ getpass._raw_input(stream=StringIO())
+ mock_input.readline.assert_called_once_with()
+
+ def test_raises_on_empty_input(self):
+ input = StringIO('')
+ self.assertRaises(EOFError, getpass._raw_input, input=input)
+
+ def test_trims_trailing_newline(self):
+ input = StringIO('test\n')
+ self.assertEqual('test', getpass._raw_input(input=input))
+
+
+# Some of these tests are a bit white-box. The functional requirement is that
+# the password input be taken directly from the tty, and that it not be echoed
+# on the screen, unless we are falling back to stderr/stdin.
+
+# Some of these might run on platforms without termios, but play it safe.
+@unittest.skipUnless(termios, 'tests require system with termios')
+class UnixGetpassTest(unittest.TestCase):
+
+ def test_uses_tty_directly(self):
+ with mock.patch('os.open') as open, \
+ mock.patch('io.FileIO') as fileio, \
+ mock.patch('io.TextIOWrapper') as textio:
+ # By setting open's return value to None the implementation will
+ # skip code we don't care about in this test. We can mock this out
+ # fully if an alternate implementation works differently.
+ open.return_value = None
+ getpass.unix_getpass()
+ open.assert_called_once_with('/dev/tty',
+ os.O_RDWR | os.O_NOCTTY)
+ fileio.assert_called_once_with(open.return_value, 'w+')
+ textio.assert_called_once_with(fileio.return_value)
+
+ def test_resets_termios(self):
+ with mock.patch('os.open') as open, \
+ mock.patch('io.FileIO'), \
+ mock.patch('io.TextIOWrapper'), \
+ mock.patch('termios.tcgetattr') as tcgetattr, \
+ mock.patch('termios.tcsetattr') as tcsetattr:
+ open.return_value = 3
+ fake_attrs = [255, 255, 255, 255, 255]
+ tcgetattr.return_value = list(fake_attrs)
+ getpass.unix_getpass()
+ tcsetattr.assert_called_with(3, mock.ANY, fake_attrs)
+
+ def test_falls_back_to_fallback_if_termios_raises(self):
+ with mock.patch('os.open') as open, \
+ mock.patch('io.FileIO') as fileio, \
+ mock.patch('io.TextIOWrapper') as textio, \
+ mock.patch('termios.tcgetattr'), \
+ mock.patch('termios.tcsetattr') as tcsetattr, \
+ mock.patch('getpass.fallback_getpass') as fallback:
+ open.return_value = 3
+ fileio.return_value = BytesIO()
+ tcsetattr.side_effect = termios.error
+ getpass.unix_getpass()
+ fallback.assert_called_once_with('Password: ',
+ textio.return_value)
+
+ def test_flushes_stream_after_input(self):
+ # issue 7208
+ with mock.patch('os.open') as open, \
+ mock.patch('io.FileIO'), \
+ mock.patch('io.TextIOWrapper'), \
+ mock.patch('termios.tcgetattr'), \
+ mock.patch('termios.tcsetattr'):
+ open.return_value = 3
+ mock_stream = mock.Mock(spec=StringIO)
+ getpass.unix_getpass(stream=mock_stream)
+ mock_stream.flush.assert_called_with()
+
+ def test_falls_back_to_stdin(self):
+ with mock.patch('os.open') as os_open, \
+ mock.patch('sys.stdin', spec=StringIO) as stdin:
+ os_open.side_effect = IOError
+ stdin.fileno.side_effect = AttributeError
+ with support.captured_stderr() as stderr:
+ with self.assertWarns(getpass.GetPassWarning):
+ getpass.unix_getpass()
+ stdin.readline.assert_called_once_with()
+ self.assertIn('Warning', stderr.getvalue())
+ self.assertIn('Password:', stderr.getvalue())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py
index 5eac9217b2..aeafc6536d 100755..100644
--- a/Lib/test/test_gzip.py
+++ b/Lib/test/test_gzip.py
@@ -131,6 +131,14 @@ class TestGzip(BaseTest):
if not ztxt: break
self.assertEqual(contents, b'a'*201)
+ def test_exclusive_write(self):
+ with gzip.GzipFile(self.filename, 'xb') as f:
+ f.write(data1 * 50)
+ with gzip.GzipFile(self.filename, 'rb') as f:
+ self.assertEqual(f.read(), data1 * 50)
+ with self.assertRaises(FileExistsError):
+ gzip.GzipFile(self.filename, 'xb')
+
def test_buffered_reader(self):
# Issue #7471: a GzipFile can be wrapped in a BufferedReader for
# performance.
@@ -206,6 +214,9 @@ class TestGzip(BaseTest):
self.test_write()
with gzip.GzipFile(self.filename, 'r') as f:
self.assertEqual(f.myfileobj.mode, 'rb')
+ support.unlink(self.filename)
+ with gzip.GzipFile(self.filename, 'x') as f:
+ self.assertEqual(f.myfileobj.mode, 'xb')
def test_1647484(self):
for mode in ('wb', 'rb'):
@@ -389,6 +400,20 @@ class TestGzip(BaseTest):
datac = gzip.compress(data)
self.assertEqual(gzip.decompress(datac), data)
+ def test_read_truncated(self):
+ data = data1*50
+ # Drop the CRC (4 bytes) and file size (4 bytes).
+ truncated = gzip.compress(data)[:-8]
+ with gzip.GzipFile(fileobj=io.BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+ with gzip.GzipFile(fileobj=io.BytesIO(truncated)) as f:
+ self.assertEqual(f.read(len(data)), data)
+ self.assertRaises(EOFError, f.read, 1)
+ # Incomplete 10-byte header.
+ for i in range(2, 10):
+ with gzip.GzipFile(fileobj=io.BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
def test_read_with_extra(self):
# Gzip data with an extra field
gzdata = (b'\x1f\x8b\x08\x04\xb2\x17cQ\x02\xff'
@@ -400,35 +425,59 @@ class TestGzip(BaseTest):
class TestOpen(BaseTest):
def test_binary_modes(self):
uncompressed = data1 * 50
+
with gzip.open(self.filename, "wb") as f:
f.write(uncompressed)
with open(self.filename, "rb") as f:
file_data = gzip.decompress(f.read())
self.assertEqual(file_data, uncompressed)
+
with gzip.open(self.filename, "rb") as f:
self.assertEqual(f.read(), uncompressed)
+
with gzip.open(self.filename, "ab") as f:
f.write(uncompressed)
with open(self.filename, "rb") as f:
file_data = gzip.decompress(f.read())
self.assertEqual(file_data, uncompressed * 2)
+ with self.assertRaises(FileExistsError):
+ gzip.open(self.filename, "xb")
+ support.unlink(self.filename)
+ with gzip.open(self.filename, "xb") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read())
+ self.assertEqual(file_data, uncompressed)
+
def test_implicit_binary_modes(self):
# Test implicit binary modes (no "b" or "t" in mode string).
uncompressed = data1 * 50
+
with gzip.open(self.filename, "w") as f:
f.write(uncompressed)
with open(self.filename, "rb") as f:
file_data = gzip.decompress(f.read())
self.assertEqual(file_data, uncompressed)
+
with gzip.open(self.filename, "r") as f:
self.assertEqual(f.read(), uncompressed)
+
with gzip.open(self.filename, "a") as f:
f.write(uncompressed)
with open(self.filename, "rb") as f:
file_data = gzip.decompress(f.read())
self.assertEqual(file_data, uncompressed * 2)
+ with self.assertRaises(FileExistsError):
+ gzip.open(self.filename, "x")
+ support.unlink(self.filename)
+ with gzip.open(self.filename, "x") as f:
+ f.write(uncompressed)
+ with open(self.filename, "rb") as f:
+ file_data = gzip.decompress(f.read())
+ self.assertEqual(file_data, uncompressed)
+
def test_text_modes(self):
uncompressed = data1.decode("ascii") * 50
uncompressed_raw = uncompressed.replace("\n", os.linesep)
@@ -463,6 +512,8 @@ class TestOpen(BaseTest):
with self.assertRaises(ValueError):
gzip.open(self.filename, "wbt")
with self.assertRaises(ValueError):
+ gzip.open(self.filename, "xbt")
+ with self.assertRaises(ValueError):
gzip.open(self.filename, "rb", encoding="utf-8")
with self.assertRaises(ValueError):
gzip.open(self.filename, "rb", errors="ignore")
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
index f3385f6d6f..4a5ea7f1a9 100644
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -18,11 +18,13 @@ except ImportError:
import unittest
import warnings
from test import support
-from test.support import _4G, bigmemtest
+from test.support import _4G, bigmemtest, import_fresh_module
# Were we compiled --with-pydebug or with #define Py_DEBUG?
COMPILED_WITH_PYDEBUG = hasattr(sys, 'gettotalrefcount')
+c_hashlib = import_fresh_module('hashlib', fresh=['_hashlib'])
+py_hashlib = import_fresh_module('hashlib', blocked=['_hashlib'])
def hexstr(s):
assert isinstance(s, bytes), repr(s)
@@ -36,7 +38,10 @@ def hexstr(s):
class HashLibTestCase(unittest.TestCase):
supported_hash_names = ( 'md5', 'MD5', 'sha1', 'SHA1',
'sha224', 'SHA224', 'sha256', 'SHA256',
- 'sha384', 'SHA384', 'sha512', 'SHA512' )
+ 'sha384', 'SHA384', 'sha512', 'SHA512',
+ 'sha3_224', 'sha3_256', 'sha3_384',
+ 'sha3_512', 'SHA3_224', 'SHA3_256',
+ 'SHA3_384', 'SHA3_512' )
# Issue #14693: fallback modules are always compiled under POSIX
_warn_on_extension_import = os.name == 'posix' or COMPILED_WITH_PYDEBUG
@@ -72,27 +77,37 @@ class HashLibTestCase(unittest.TestCase):
if _hashlib:
# These two algorithms should always be present when this module
# is compiled. If not, something was compiled wrong.
- assert hasattr(_hashlib, 'openssl_md5')
- assert hasattr(_hashlib, 'openssl_sha1')
+ self.assertTrue(hasattr(_hashlib, 'openssl_md5'))
+ self.assertTrue(hasattr(_hashlib, 'openssl_sha1'))
for algorithm, constructors in self.constructors_to_test.items():
constructor = getattr(_hashlib, 'openssl_'+algorithm, None)
if constructor:
constructors.add(constructor)
+ def add_builtin_constructor(name):
+ constructor = getattr(hashlib, "__get_builtin_constructor")(name)
+ self.constructors_to_test[name].add(constructor)
+
_md5 = self._conditional_import_module('_md5')
if _md5:
- self.constructors_to_test['md5'].add(_md5.md5)
+ add_builtin_constructor('md5')
_sha1 = self._conditional_import_module('_sha1')
if _sha1:
- self.constructors_to_test['sha1'].add(_sha1.sha1)
+ add_builtin_constructor('sha1')
_sha256 = self._conditional_import_module('_sha256')
if _sha256:
- self.constructors_to_test['sha224'].add(_sha256.sha224)
- self.constructors_to_test['sha256'].add(_sha256.sha256)
+ add_builtin_constructor('sha224')
+ add_builtin_constructor('sha256')
_sha512 = self._conditional_import_module('_sha512')
if _sha512:
- self.constructors_to_test['sha384'].add(_sha512.sha384)
- self.constructors_to_test['sha512'].add(_sha512.sha512)
+ add_builtin_constructor('sha384')
+ add_builtin_constructor('sha512')
+ _sha3 = self._conditional_import_module('_sha3')
+ if _sha3:
+ add_builtin_constructor('sha3_224')
+ add_builtin_constructor('sha3_256')
+ add_builtin_constructor('sha3_384')
+ add_builtin_constructor('sha3_512')
super(HashLibTestCase, self).__init__(*args, **kwargs)
@@ -121,8 +136,10 @@ class HashLibTestCase(unittest.TestCase):
self.assertRaises(TypeError, hashlib.new, 1)
def test_get_builtin_constructor(self):
- get_builtin_constructor = hashlib.__dict__[
- '__get_builtin_constructor']
+ get_builtin_constructor = getattr(hashlib,
+ '__get_builtin_constructor')
+ builtin_constructor_cache = getattr(hashlib,
+ '__builtin_constructor_cache')
self.assertRaises(ValueError, get_builtin_constructor, 'test')
try:
import _md5
@@ -130,6 +147,8 @@ class HashLibTestCase(unittest.TestCase):
pass
# This forces an ImportError for "import _md5" statements
sys.modules['_md5'] = None
+ # clear the cache
+ builtin_constructor_cache.clear()
try:
self.assertRaises(ValueError, get_builtin_constructor, 'md5')
finally:
@@ -138,13 +157,23 @@ class HashLibTestCase(unittest.TestCase):
else:
del sys.modules['_md5']
self.assertRaises(TypeError, get_builtin_constructor, 3)
+ constructor = get_builtin_constructor('md5')
+ self.assertIs(constructor, _md5.md5)
+ self.assertEqual(sorted(builtin_constructor_cache), ['MD5', 'md5'])
def test_hexdigest(self):
for cons in self.hash_constructors:
h = cons()
- assert isinstance(h.digest(), bytes), name
+ self.assertIsInstance(h.digest(), bytes)
self.assertEqual(hexstr(h.digest()), h.hexdigest())
+ def test_name_attribute(self):
+ for cons in self.hash_constructors:
+ h = cons()
+ self.assertIsInstance(h.name, str)
+ self.assertIn(h.name, self.supported_hash_names)
+ self.assertEqual(h.name, hashlib.new(h.name).name)
+
def test_large_update(self):
aas = b'a' * 128
bees = b'b' * 127
@@ -205,6 +234,10 @@ class HashLibTestCase(unittest.TestCase):
self.check_no_unicode('sha256')
self.check_no_unicode('sha384')
self.check_no_unicode('sha512')
+ self.check_no_unicode('sha3_224')
+ self.check_no_unicode('sha3_256')
+ self.check_no_unicode('sha3_384')
+ self.check_no_unicode('sha3_512')
def check_blocksize_name(self, name, block_size=0, digest_size=0):
constructors = self.constructors_to_test[name]
@@ -213,8 +246,9 @@ class HashLibTestCase(unittest.TestCase):
self.assertEqual(m.block_size, block_size)
self.assertEqual(m.digest_size, digest_size)
self.assertEqual(len(m.digest()), digest_size)
- self.assertEqual(m.name.lower(), name.lower())
- self.assertIn(name.split("_")[0], repr(m).lower())
+ self.assertEqual(m.name, name)
+ # split for sha3_512 / _sha3.sha3 object
+ self.assertIn(name.split("_")[0], repr(m))
def test_blocksize_name(self):
self.check_blocksize_name('md5', 64, 16)
@@ -223,6 +257,10 @@ class HashLibTestCase(unittest.TestCase):
self.check_blocksize_name('sha256', 64, 32)
self.check_blocksize_name('sha384', 128, 48)
self.check_blocksize_name('sha512', 128, 64)
+ self.check_blocksize_name('sha3_224', NotImplemented, 28)
+ self.check_blocksize_name('sha3_256', NotImplemented, 32)
+ self.check_blocksize_name('sha3_384', NotImplemented, 48)
+ self.check_blocksize_name('sha3_512', NotImplemented, 64)
def test_case_md5_0(self):
self.check('md5', b'', 'd41d8cd98f00b204e9800998ecf8427e')
@@ -358,6 +396,108 @@ class HashLibTestCase(unittest.TestCase):
"e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973eb"+
"de0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b")
+ # SHA-3 family
+ def test_case_sha3_224_0(self):
+ self.check('sha3_224', b"",
+ "F71837502BA8E10837BDD8D365ADB85591895602FC552B48B7390ABD")
+
+ def test_case_sha3_224_1(self):
+ self.check('sha3_224', bytes.fromhex("CC"),
+ "A9CAB59EB40A10B246290F2D6086E32E3689FAF1D26B470C899F2802")
+
+ def test_case_sha3_224_2(self):
+ self.check('sha3_224', bytes.fromhex("41FB"),
+ "615BA367AFDC35AAC397BC7EB5D58D106A734B24986D5D978FEFD62C")
+
+ def test_case_sha3_224_3(self):
+ self.check('sha3_224', bytes.fromhex(
+ "433C5303131624C0021D868A30825475E8D0BD3052A022180398F4CA4423B9"+
+ "8214B6BEAAC21C8807A2C33F8C93BD42B092CC1B06CEDF3224D5ED1EC29784"+
+ "444F22E08A55AA58542B524B02CD3D5D5F6907AFE71C5D7462224A3F9D9E53"+
+ "E7E0846DCBB4CE"),
+ "62B10F1B6236EBC2DA72957742A8D4E48E213B5F8934604BFD4D2C3A")
+
+ @bigmemtest(size=_4G + 5, memuse=1)
+ def test_case_sha3_224_huge(self, size):
+ if size == _4G + 5:
+ try:
+ self.check('sha3_224', b'A'*size,
+ '58ef60057c9dddb6a87477e9ace5a26f0d9db01881cf9b10a9f8c224')
+ except OverflowError:
+ pass # 32-bit arch
+
+
+ def test_case_sha3_256_0(self):
+ self.check('sha3_256', b"",
+ "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470")
+
+ def test_case_sha3_256_1(self):
+ self.check('sha3_256', bytes.fromhex("CC"),
+ "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A")
+
+ def test_case_sha3_256_2(self):
+ self.check('sha3_256', bytes.fromhex("41FB"),
+ "A8EACEDA4D47B3281A795AD9E1EA2122B407BAF9AABCB9E18B5717B7873537D2")
+
+ def test_case_sha3_256_3(self):
+ self.check('sha3_256', bytes.fromhex(
+ "433C5303131624C0021D868A30825475E8D0BD3052A022180398F4CA4423B9"+
+ "8214B6BEAAC21C8807A2C33F8C93BD42B092CC1B06CEDF3224D5ED1EC29784"+
+ "444F22E08A55AA58542B524B02CD3D5D5F6907AFE71C5D7462224A3F9D9E53"+
+ "E7E0846DCBB4CE"),
+ "CE87A5173BFFD92399221658F801D45C294D9006EE9F3F9D419C8D427748DC41")
+
+
+ def test_case_sha3_384_0(self):
+ self.check('sha3_384', b"",
+ "2C23146A63A29ACF99E73B88F8C24EAA7DC60AA771780CCC006AFBFA8FE2479B"+
+ "2DD2B21362337441AC12B515911957FF")
+
+ def test_case_sha3_384_1(self):
+ self.check('sha3_384', bytes.fromhex("CC"),
+ "1B84E62A46E5A201861754AF5DC95C4A1A69CAF4A796AE405680161E29572641"+
+ "F5FA1E8641D7958336EE7B11C58F73E9")
+
+ def test_case_sha3_384_2(self):
+ self.check('sha3_384', bytes.fromhex("41FB"),
+ "495CCE2714CD72C8C53C3363D22C58B55960FE26BE0BF3BBC7A3316DD563AD1D"+
+ "B8410E75EEFEA655E39D4670EC0B1792")
+
+ def test_case_sha3_384_3(self):
+ self.check('sha3_384', bytes.fromhex(
+ "433C5303131624C0021D868A30825475E8D0BD3052A022180398F4CA4423B9"+
+ "8214B6BEAAC21C8807A2C33F8C93BD42B092CC1B06CEDF3224D5ED1EC29784"+
+ "444F22E08A55AA58542B524B02CD3D5D5F6907AFE71C5D7462224A3F9D9E53"+
+ "E7E0846DCBB4CE"),
+ "135114508DD63E279E709C26F7817C0482766CDE49132E3EDF2EEDD8996F4E35"+
+ "96D184100B384868249F1D8B8FDAA2C9")
+
+
+ def test_case_sha3_512_0(self):
+ self.check('sha3_512', b"",
+ "0EAB42DE4C3CEB9235FC91ACFFE746B29C29A8C366B7C60E4E67C466F36A4304"+
+ "C00FA9CAF9D87976BA469BCBE06713B435F091EF2769FB160CDAB33D3670680E")
+
+ def test_case_sha3_512_1(self):
+ self.check('sha3_512', bytes.fromhex("CC"),
+ "8630C13CBD066EA74BBE7FE468FEC1DEE10EDC1254FB4C1B7C5FD69B646E4416"+
+ "0B8CE01D05A0908CA790DFB080F4B513BC3B6225ECE7A810371441A5AC666EB9")
+
+ def test_case_sha3_512_2(self):
+ self.check('sha3_512', bytes.fromhex("41FB"),
+ "551DA6236F8B96FCE9F97F1190E901324F0B45E06DBBB5CDB8355D6ED1DC34B3"+
+ "F0EAE7DCB68622FF232FA3CECE0D4616CDEB3931F93803662A28DF1CD535B731")
+
+ def test_case_sha3_512_3(self):
+ self.check('sha3_512', bytes.fromhex(
+ "433C5303131624C0021D868A30825475E8D0BD3052A022180398F4CA4423B9"+
+ "8214B6BEAAC21C8807A2C33F8C93BD42B092CC1B06CEDF3224D5ED1EC29784"+
+ "444F22E08A55AA58542B524B02CD3D5D5F6907AFE71C5D7462224A3F9D9E53"+
+ "E7E0846DCBB4CE"),
+ "527D28E341E6B14F4684ADB4B824C496C6482E51149565D3D17226828884306B"+
+ "51D6148A72622C2B75F5D3510B799D8BDC03EAEDE453676A6EC8FE03A1AD0EAB")
+
+
def test_gil(self):
# Check things work fine with an input larger than the size required
# for multithreaded operation (which is hardwired to 2048).
@@ -406,8 +546,8 @@ class HashLibTestCase(unittest.TestCase):
events = []
for threadnum in range(num_threads):
chunk_size = len(data) // (10**threadnum)
- assert chunk_size > 0
- assert chunk_size % len(smallest_data) == 0
+ self.assertGreater(chunk_size, 0)
+ self.assertEqual(chunk_size % len(smallest_data), 0)
event = threading.Event()
events.append(event)
threading.Thread(target=hash_in_chunks,
@@ -418,8 +558,95 @@ class HashLibTestCase(unittest.TestCase):
self.assertEqual(expected_hash, hasher.hexdigest())
-def test_main():
- support.run_unittest(HashLibTestCase)
+
+class KDFTests(unittest.TestCase):
+
+ pbkdf2_test_vectors = [
+ (b'password', b'salt', 1, None),
+ (b'password', b'salt', 2, None),
+ (b'password', b'salt', 4096, None),
+ # too slow, it takes over a minute on a fast CPU.
+ #(b'password', b'salt', 16777216, None),
+ (b'passwordPASSWORDpassword', b'saltSALTsaltSALTsaltSALTsaltSALTsalt',
+ 4096, -1),
+ (b'pass\0word', b'sa\0lt', 4096, 16),
+ ]
+
+ pbkdf2_results = {
+ "sha1": [
+ # offical test vectors from RFC 6070
+ (bytes.fromhex('0c60c80f961f0e71f3a9b524af6012062fe037a6'), None),
+ (bytes.fromhex('ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957'), None),
+ (bytes.fromhex('4b007901b765489abead49d926f721d065a429c1'), None),
+ #(bytes.fromhex('eefe3d61cd4da4e4e9945b3d6ba2158c2634e984'), None),
+ (bytes.fromhex('3d2eec4fe41c849b80c8d83662c0e44a8b291a964c'
+ 'f2f07038'), 25),
+ (bytes.fromhex('56fa6aa75548099dcc37d7f03425e0c3'), None),],
+ "sha256": [
+ (bytes.fromhex('120fb6cffcf8b32c43e7225256c4f837'
+ 'a86548c92ccc35480805987cb70be17b'), None),
+ (bytes.fromhex('ae4d0c95af6b46d32d0adff928f06dd0'
+ '2a303f8ef3c251dfd6e2d85a95474c43'), None),
+ (bytes.fromhex('c5e478d59288c841aa530db6845c4c8d'
+ '962893a001ce4e11a4963873aa98134a'), None),
+ #(bytes.fromhex('cf81c66fe8cfc04d1f31ecb65dab4089'
+ # 'f7f179e89b3b0bcb17ad10e3ac6eba46'), None),
+ (bytes.fromhex('348c89dbcbd32b2f32d814b8116e84cf2b17'
+ '347ebc1800181c4e2a1fb8dd53e1c635518c7dac47e9'), 40),
+ (bytes.fromhex('89b69d0516f829893c696226650a8687'), None),],
+ "sha512": [
+ (bytes.fromhex('867f70cf1ade02cff3752599a3a53dc4af34c7a669815ae5'
+ 'd513554e1c8cf252c02d470a285a0501bad999bfe943c08f'
+ '050235d7d68b1da55e63f73b60a57fce'), None),
+ (bytes.fromhex('e1d9c16aa681708a45f5c7c4e215ceb66e011a2e9f004071'
+ '3f18aefdb866d53cf76cab2868a39b9f7840edce4fef5a82'
+ 'be67335c77a6068e04112754f27ccf4e'), None),
+ (bytes.fromhex('d197b1b33db0143e018b12f3d1d1479e6cdebdcc97c5c0f8'
+ '7f6902e072f457b5143f30602641b3d55cd335988cb36b84'
+ '376060ecd532e039b742a239434af2d5'), None),
+ (bytes.fromhex('8c0511f4c6e597c6ac6315d8f0362e225f3c501495ba23b8'
+ '68c005174dc4ee71115b59f9e60cd9532fa33e0f75aefe30'
+ '225c583a186cd82bd4daea9724a3d3b8'), 64),
+ (bytes.fromhex('9d9e9c4cd21fe4be24d5b8244c759665'), None),],
+ }
+
+ def _test_pbkdf2_hmac(self, pbkdf2):
+ for digest_name, results in self.pbkdf2_results.items():
+ for i, vector in enumerate(self.pbkdf2_test_vectors):
+ password, salt, rounds, dklen = vector
+ expected, overwrite_dklen = results[i]
+ if overwrite_dklen:
+ dklen = overwrite_dklen
+ out = pbkdf2(digest_name, password, salt, rounds, dklen)
+ self.assertEqual(out, expected,
+ (digest_name, password, salt, rounds, dklen))
+ out = pbkdf2(digest_name, memoryview(password),
+ memoryview(salt), rounds, dklen)
+ out = pbkdf2(digest_name, bytearray(password),
+ bytearray(salt), rounds, dklen)
+ self.assertEqual(out, expected)
+ if dklen is None:
+ out = pbkdf2(digest_name, password, salt, rounds)
+ self.assertEqual(out, expected,
+ (digest_name, password, salt, rounds))
+
+ self.assertRaises(TypeError, pbkdf2, b'sha1', b'pass', b'salt', 1)
+ self.assertRaises(TypeError, pbkdf2, 'sha1', 'pass', 'salt', 1)
+ self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', 0)
+ self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', -1)
+ self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', 1, 0)
+ self.assertRaises(ValueError, pbkdf2, 'sha1', b'pass', b'salt', 1, -1)
+ with self.assertRaisesRegex(ValueError, 'unsupported hash type'):
+ pbkdf2('unknown', b'pass', b'salt', 1)
+
+ def test_pbkdf2_hmac_py(self):
+ self._test_pbkdf2_hmac(py_hashlib.pbkdf2_hmac)
+
+ @unittest.skipUnless(hasattr(c_hashlib, 'pbkdf2_hmac'),
+ ' test requires OpenSSL > 1.0')
+ def test_pbkdf2_hmac_c(self):
+ self._test_pbkdf2_hmac(c_hashlib.pbkdf2_hmac)
+
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
index 4ca7cec44c..efd63ad7be 100644
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -253,6 +253,20 @@ class ConstructorTestCase(unittest.TestCase):
except:
self.fail("Constructor call with text argument raised exception.")
+ def test_with_bytearray(self):
+ try:
+ h = hmac.HMAC(bytearray(b"key"), bytearray(b"hash this!"))
+ self.assertEqual(h.hexdigest(), '34325b639da4cfd95735b381e28cb864')
+ except:
+ self.fail("Constructor call with bytearray arguments raised exception.")
+
+ def test_with_memoryview_msg(self):
+ try:
+ h = hmac.HMAC(b"key", memoryview(b"hash this!"))
+ self.assertEqual(h.hexdigest(), '34325b639da4cfd95735b381e28cb864')
+ except:
+ self.fail("Constructor call with memoryview msg raised exception.")
+
def test_withmodule(self):
# Constructor call with text and digest module.
try:
diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py
index c977a9dd4d..8863316a0f 100644
--- a/Lib/test/test_htmlparser.py
+++ b/Lib/test/test_htmlparser.py
@@ -96,7 +96,9 @@ class TestCaseBase(unittest.TestCase):
parser = self.get_collector()
parser.feed(source)
parser.close()
- self.assertRaises(html.parser.HTMLParseError, parse)
+ with self.assertRaises(html.parser.HTMLParseError):
+ with self.assertWarns(DeprecationWarning):
+ parse()
class HTMLParserStrictTestCase(TestCaseBase):
@@ -365,7 +367,16 @@ text
class HTMLParserTolerantTestCase(HTMLParserStrictTestCase):
def get_collector(self):
- return EventCollector(strict=False)
+ return EventCollector()
+
+ def test_deprecation_warnings(self):
+ with self.assertWarns(DeprecationWarning):
+ EventCollector(strict=True)
+ with self.assertWarns(DeprecationWarning):
+ EventCollector(strict=False)
+ with self.assertRaises(html.parser.HTMLParseError):
+ with self.assertWarns(DeprecationWarning):
+ EventCollector().error('test')
def test_tolerant_parsing(self):
self._run_check('<html <html>te>>xt&a<<bc</a></html>\n'
@@ -685,7 +696,7 @@ class AttributesStrictTestCase(TestCaseBase):
class AttributesTolerantTestCase(AttributesStrictTestCase):
def get_collector(self):
- return EventCollector(strict=False)
+ return EventCollector()
def test_attr_funky_names2(self):
self._run_check(
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
index f3c27c2df6..31c0b6a6ca 100644
--- a/Lib/test/test_httplib.py
+++ b/Lib/test/test_httplib.py
@@ -27,8 +27,10 @@ class FakeSocket:
self.text = text
self.fileclass = fileclass
self.data = b''
+ self.sendall_calls = 0
def sendall(self, data):
+ self.sendall_calls += 1
self.data += data
def makefile(self, mode, bufsize=None):
@@ -45,7 +47,7 @@ class EPipeSocket(FakeSocket):
def sendall(self, data):
if self.pipe_trigger in data:
- raise socket.error(errno.EPIPE, "gotcha")
+ raise OSError(errno.EPIPE, "gotcha")
self.data += data
def close(self):
@@ -597,7 +599,7 @@ class BasicTest(TestCase):
b"Content-Length")
conn = client.HTTPConnection("example.com")
conn.sock = sock
- self.assertRaises(socket.error,
+ self.assertRaises(OSError,
lambda: conn.request("PUT", "/url", "body"))
resp = conn.getresponse()
self.assertEqual(401, resp.status)
@@ -643,6 +645,28 @@ class BasicTest(TestCase):
resp.close()
self.assertTrue(resp.closed)
+ def test_delayed_ack_opt(self):
+ # Test that Nagle/delayed_ack optimistaion works correctly.
+
+ # For small payloads, it should coalesce the body with
+ # headers, resulting in a single sendall() call
+ conn = client.HTTPConnection('example.com')
+ sock = FakeSocket(None)
+ conn.sock = sock
+ body = b'x' * (conn.mss - 1)
+ conn.request('POST', '/', body)
+ self.assertEqual(sock.sendall_calls, 1)
+
+ # For large payloads, it should send the headers and
+ # then the body, resulting in more than one sendall()
+ # call
+ conn = client.HTTPConnection('example.com')
+ sock = FakeSocket(None)
+ conn.sock = sock
+ body = b'x' * conn.mss
+ conn.request('POST', '/', body)
+ self.assertGreater(sock.sendall_calls, 1)
+
class OfflineTest(TestCase):
def test_responses(self):
self.assertEqual(client.responses[client.NOT_FOUND], "Not Found")
@@ -733,7 +757,7 @@ class HTTPSTest(TestCase):
def make_server(self, certfile):
from test.ssl_servers import make_https_server
- return make_https_server(self, certfile)
+ return make_https_server(self, certfile=certfile)
def test_attributes(self):
# simple test to check it's storing the timeout
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index 9dd27788f7..b2bce1c27a 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -92,6 +92,13 @@ class BaseHTTPServerTestCase(BaseTestCase):
def do_KEYERROR(self):
self.send_error(999)
+ def do_NOTFOUND(self):
+ self.send_error(404)
+
+ def do_EXPLAINERROR(self):
+ self.send_error(999, "Short Message",
+ "This is a long \n explaination")
+
def do_CUSTOM(self):
self.send_response(999)
self.send_header('Content-Type', 'text/html')
@@ -203,6 +210,12 @@ class BaseHTTPServerTestCase(BaseTestCase):
res = self.con.getresponse()
self.assertEqual(res.status, 999)
+ def test_return_explain_error(self):
+ self.con.request('EXPLAINERROR', '/')
+ res = self.con.getresponse()
+ self.assertEqual(res.status, 999)
+ self.assertTrue(int(res.getheader('Content-Length')))
+
def test_latin1_header(self):
self.con.request('LATINONEHEADER', '/', headers={
'X-Special-Incoming': 'Ärger mit Unicode'
@@ -211,6 +224,14 @@ class BaseHTTPServerTestCase(BaseTestCase):
self.assertEqual(res.getheader('X-Special'), 'Dängerous Mind')
self.assertEqual(res.read(), 'Ärger mit Unicode'.encode('utf-8'))
+ def test_error_content_length(self):
+ # Issue #16088: standard error responses should have a content-length
+ self.con.request('NOTFOUND', '/')
+ res = self.con.getresponse()
+ self.assertEqual(res.status, 404)
+ data = res.read()
+ self.assertEqual(int(res.getheader('Content-Length')), len(data))
+
class SimpleHTTPServerTestCase(BaseTestCase):
class request_handler(NoLogRequestHandler, SimpleHTTPRequestHandler):
diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py
index daa8afeec5..81bfd1fbc2 100644
--- a/Lib/test/test_imaplib.py
+++ b/Lib/test/test_imaplib.py
@@ -125,7 +125,7 @@ class SimpleIMAPHandler(socketserver.StreamRequestHandler):
# Naked sockets return empty strings..
return
line += part
- except IOError:
+ except OSError:
# ..but SSLSockets raise exceptions.
return
if line.endswith(b'\r\n'):
diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py
index b56efe3bea..cf27a712d5 100644
--- a/Lib/test/test_imp.py
+++ b/Lib/test/test_imp.py
@@ -1,14 +1,29 @@
-import imp
+try:
+ import _thread
+except ImportError:
+ _thread = None
import importlib
import os
import os.path
import shutil
import sys
from test import support
-from test.test_importlib import util
import unittest
import warnings
+with warnings.catch_warnings():
+ warnings.simplefilter('ignore', PendingDeprecationWarning)
+ import imp
+
+def requires_load_dynamic(meth):
+ """Decorator to skip a test if not running under CPython or lacking
+ imp.load_dynamic()."""
+ meth = support.cpython_only(meth)
+ return unittest.skipIf(not hasattr(imp, 'load_dynamic'),
+ 'imp.load_dynamic() required')(meth)
+
+
+@unittest.skipIf(_thread is None, '_thread module is required')
class LockTests(unittest.TestCase):
"""Very basic test of import lock functions."""
@@ -208,9 +223,7 @@ class ImportTests(unittest.TestCase):
self.assertIs(orig_path, new_os.path)
self.assertIsNot(orig_getenv, new_os.getenv)
- @support.cpython_only
- @unittest.skipIf(not hasattr(imp, 'load_dynamic'),
- 'imp.load_dynamic() required')
+ @requires_load_dynamic
def test_issue15828_load_extensions(self):
# Issue 15828 picked up that the adapter between the old imp API
# and importlib couldn't handle C extensions
@@ -222,6 +235,22 @@ class ImportTests(unittest.TestCase):
mod = imp.load_module(example, *x)
self.assertEqual(mod.__name__, example)
+ @requires_load_dynamic
+ def test_issue16421_multiple_modules_in_one_dll(self):
+ # Issue 16421: loading several modules from the same compiled file fails
+ m = '_testimportmultiple'
+ fileobj, pathname, description = imp.find_module(m)
+ fileobj.close()
+ mod0 = imp.load_dynamic(m, pathname)
+ mod1 = imp.load_dynamic('_testimportmultiple_foo', pathname)
+ mod2 = imp.load_dynamic('_testimportmultiple_bar', pathname)
+ self.assertEqual(mod0.__name__, m)
+ self.assertEqual(mod1.__name__, '_testimportmultiple_foo')
+ self.assertEqual(mod2.__name__, '_testimportmultiple_bar')
+ with self.assertRaises(ImportError):
+ imp.load_dynamic('nonexistent', pathname)
+
+ @requires_load_dynamic
def test_load_dynamic_ImportError_path(self):
# Issue #1559549 added `name` and `path` attributes to ImportError
# in order to provide better detail. Issue #10854 implemented those
@@ -233,14 +262,12 @@ class ImportTests(unittest.TestCase):
self.assertIn(path, err.exception.path)
self.assertEqual(name, err.exception.name)
- @support.cpython_only
- @unittest.skipIf(not hasattr(imp, 'load_dynamic'),
- 'imp.load_dynamic() required')
+ @requires_load_dynamic
def test_load_module_extension_file_is_None(self):
# When loading an extension module and the file is None, open one
# on the behalf of imp.load_dynamic().
# Issue #15902
- name = '_heapq'
+ name = '_testimportmultiple'
found = imp.find_module(name)
if found[0] is not None:
found[0].close()
@@ -248,6 +275,15 @@ class ImportTests(unittest.TestCase):
return
imp.load_module(name, None, *found[1:])
+ @unittest.skipIf(sys.dont_write_bytecode,
+ "test meaningful only when writing bytecode")
+ def test_bug7732(self):
+ with support.temp_cwd():
+ source = support.TESTFN + '.py'
+ os.mkdir(source)
+ self.assertRaisesRegex(ImportError, '^No module',
+ imp.find_module, support.TESTFN, ["."])
+
def test_multiple_calls_to_get_data(self):
# Issue #18755: make sure multiple calls to get_data() can succeed.
loader = imp._LoadSourceCompatibility('imp', imp.__file__,
@@ -293,22 +329,6 @@ class ReloadTests(unittest.TestCase):
with self.assertRaisesRegex(ImportError, 'html'):
imp.reload(parser)
- def test_module_replaced(self):
- # see #18698
- def code():
- module = type(sys)('top_level')
- module.spam = 3
- sys.modules['top_level'] = module
- mock = util.mock_modules('top_level',
- module_code={'top_level': code})
- with mock:
- with util.import_state(meta_path=[mock]):
- module = importlib.import_module('top_level')
- reloaded = imp.reload(module)
- actual = sys.modules['top_level']
- self.assertEqual(actual.spam, 3)
- self.assertEqual(reloaded.spam, 3)
-
class PEP3147Tests(unittest.TestCase):
"""Tests of PEP 3147."""
@@ -461,20 +481,5 @@ class NullImporterTests(unittest.TestCase):
os.rmdir(name)
-def test_main():
- tests = [
- ImportTests,
- PEP3147Tests,
- ReloadTests,
- NullImporterTests,
- ]
- try:
- import _thread
- except ImportError:
- pass
- else:
- tests.append(LockTests)
- support.run_unittest(*tests)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py
index df08e6a86f..ae8e1601fe 100644
--- a/Lib/test/test_import.py
+++ b/Lib/test/test_import.py
@@ -1,9 +1,8 @@
# We import importlib *ASAP* in order to test #15386
import importlib
+import importlib.util
from importlib._bootstrap import _get_sourcefile
import builtins
-import imp
-from test.test_importlib.import_ import util as importlib_util
import marshal
import os
import platform
@@ -22,7 +21,7 @@ import test.support
from test.support import (
EnvironmentVarGuard, TESTFN, check_warnings, forget, is_jython,
make_legacy_pyc, rmtree, run_unittest, swap_attr, swap_item, temp_umask,
- unlink, unload, create_empty_file, cpython_only)
+ unlink, unload, create_empty_file, cpython_only, TESTFN_UNENCODABLE)
from test import script_helper
@@ -70,8 +69,6 @@ class ImportTests(unittest.TestCase):
def tearDown(self):
unload(TESTFN)
- setUp = tearDown
-
def test_case_sensitivity(self):
# Brief digression to test that import is case-sensitive: if we got
# this far, we know for sure that "random" exists.
@@ -129,16 +126,6 @@ class ImportTests(unittest.TestCase):
finally:
del sys.path[0]
- @skip_if_dont_write_bytecode
- def test_bug7732(self):
- source = TESTFN + '.py'
- os.mkdir(source)
- try:
- self.assertRaisesRegex(ImportError, '^No module',
- imp.find_module, TESTFN, ["."])
- finally:
- os.rmdir(source)
-
def test_module_with_large_stack(self, module='longlist'):
# Regression test for http://bugs.python.org/issue561858.
filename = module + '.py'
@@ -161,16 +148,24 @@ class ImportTests(unittest.TestCase):
sys.path.append('')
importlib.invalidate_caches()
+ namespace = {}
try:
make_legacy_pyc(filename)
# This used to crash.
- exec('import ' + module)
+ exec('import ' + module, None, namespace)
finally:
# Cleanup.
del sys.path[-1]
unlink(filename + 'c')
unlink(filename + 'o')
+ # Remove references to the module (unload the module)
+ namespace.clear()
+ try:
+ del sys.modules[module]
+ except KeyError:
+ pass
+
def test_failing_import_sticks(self):
source = TESTFN + ".py"
with open(source, "w") as f:
@@ -225,7 +220,7 @@ class ImportTests(unittest.TestCase):
with open(source, "w") as f:
f.write("a = 10\nb=20//0\n")
- self.assertRaises(ZeroDivisionError, imp.reload, mod)
+ self.assertRaises(ZeroDivisionError, importlib.reload, mod)
# But we still expect the module to be in sys.modules.
mod = sys.modules.get(TESTFN)
self.assertIsNot(mod, None, "expected module to be in sys.modules")
@@ -280,7 +275,7 @@ class ImportTests(unittest.TestCase):
import sys
class C:
def __del__(self):
- import imp
+ import importlib
sys.argv.insert(0, C())
"""))
script_helper.assert_python_ok(testfn)
@@ -291,7 +286,7 @@ class ImportTests(unittest.TestCase):
sys.path.insert(0, os.curdir)
try:
source = TESTFN + ".py"
- compiled = imp.cache_from_source(source)
+ compiled = importlib.util.cache_from_source(source)
with open(source, 'w') as f:
pass
try:
@@ -322,6 +317,14 @@ class ImportTests(unittest.TestCase):
stdout, stderr = popen.communicate()
self.assertIn(b"ImportError", stdout)
+ def test_from_import_message_for_nonexistent_module(self):
+ with self.assertRaisesRegex(ImportError, "^No module named 'bogus'"):
+ from bogus import foo
+
+ def test_from_import_message_for_existing_module(self):
+ with self.assertRaisesRegex(ImportError, "^cannot import name 'bogus'"):
+ from re import bogus
+
@skip_if_dont_write_bytecode
class FilePermissionTests(unittest.TestCase):
@@ -332,7 +335,7 @@ class FilePermissionTests(unittest.TestCase):
def test_creation_mode(self):
mask = 0o022
with temp_umask(mask), _ready_to_import() as (name, path):
- cached_path = imp.cache_from_source(path)
+ cached_path = importlib.util.cache_from_source(path)
module = __import__(name)
if not os.path.exists(cached_path):
self.fail("__import__ did not result in creation of "
@@ -350,7 +353,7 @@ class FilePermissionTests(unittest.TestCase):
# permissions of .pyc should match those of .py, regardless of mask
mode = 0o600
with temp_umask(0o022), _ready_to_import() as (name, path):
- cached_path = imp.cache_from_source(path)
+ cached_path = importlib.util.cache_from_source(path)
os.chmod(path, mode)
__import__(name)
if not os.path.exists(cached_path):
@@ -365,7 +368,7 @@ class FilePermissionTests(unittest.TestCase):
def test_cached_readonly(self):
mode = 0o400
with temp_umask(0o022), _ready_to_import() as (name, path):
- cached_path = imp.cache_from_source(path)
+ cached_path = importlib.util.cache_from_source(path)
os.chmod(path, mode)
__import__(name)
if not os.path.exists(cached_path):
@@ -405,7 +408,7 @@ class FilePermissionTests(unittest.TestCase):
bytecode_only = path + "c"
else:
bytecode_only = path + "o"
- os.rename(imp.cache_from_source(path), bytecode_only)
+ os.rename(importlib.util.cache_from_source(path), bytecode_only)
m = __import__(name)
self.assertEqual(m.x, 'rewritten')
@@ -427,7 +430,7 @@ func_filename = func.__code__.co_filename
"""
dir_name = os.path.abspath(TESTFN)
file_name = os.path.join(dir_name, module_name) + os.extsep + "py"
- compiled_name = imp.cache_from_source(file_name)
+ compiled_name = importlib.util.cache_from_source(file_name)
def setUp(self):
self.sys_path = sys.path[:]
@@ -488,7 +491,7 @@ func_filename = func.__code__.co_filename
header = f.read(12)
code = marshal.load(f)
constants = list(code.co_consts)
- foreign_code = test_main.__code__
+ foreign_code = importlib.import_module.__code__
pos = constants.index(1)
constants[pos] = foreign_code
code = type(code)(code.co_argcount, code.co_kwonlyargcount,
@@ -630,7 +633,7 @@ class OverridingImportBuiltinTests(unittest.TestCase):
class PycacheTests(unittest.TestCase):
# Test the various PEP 3147 related behaviors.
- tag = imp.get_tag()
+ tag = sys.implementation.cache_tag
def _clean(self):
forget(TESTFN)
@@ -678,10 +681,11 @@ class PycacheTests(unittest.TestCase):
# With PEP 3147 cache layout, removing the source but leaving the pyc
# file does not satisfy the import.
__import__(TESTFN)
- pyc_file = imp.cache_from_source(self.source)
+ pyc_file = importlib.util.cache_from_source(self.source)
self.assertTrue(os.path.exists(pyc_file))
os.remove(self.source)
forget(TESTFN)
+ importlib.invalidate_caches()
self.assertRaises(ImportError, __import__, TESTFN)
@skip_if_dont_write_bytecode
@@ -703,7 +707,7 @@ class PycacheTests(unittest.TestCase):
def test___cached__(self):
# Modules now also have an __cached__ that points to the pyc file.
m = __import__(TESTFN)
- pyc_file = imp.cache_from_source(TESTFN + '.py')
+ pyc_file = importlib.util.cache_from_source(TESTFN + '.py')
self.assertEqual(m.__cached__, os.path.join(os.curdir, pyc_file))
@skip_if_dont_write_bytecode
@@ -738,10 +742,10 @@ class PycacheTests(unittest.TestCase):
pass
importlib.invalidate_caches()
m = __import__('pep3147.foo')
- init_pyc = imp.cache_from_source(
+ init_pyc = importlib.util.cache_from_source(
os.path.join('pep3147', '__init__.py'))
self.assertEqual(m.__cached__, os.path.join(os.curdir, init_pyc))
- foo_pyc = imp.cache_from_source(os.path.join('pep3147', 'foo.py'))
+ foo_pyc = importlib.util.cache_from_source(os.path.join('pep3147', 'foo.py'))
self.assertEqual(sys.modules['pep3147.foo'].__cached__,
os.path.join(os.curdir, foo_pyc))
@@ -765,10 +769,10 @@ class PycacheTests(unittest.TestCase):
unload('pep3147')
importlib.invalidate_caches()
m = __import__('pep3147.foo')
- init_pyc = imp.cache_from_source(
+ init_pyc = importlib.util.cache_from_source(
os.path.join('pep3147', '__init__.py'))
self.assertEqual(m.__cached__, os.path.join(os.curdir, init_pyc))
- foo_pyc = imp.cache_from_source(os.path.join('pep3147', 'foo.py'))
+ foo_pyc = importlib.util.cache_from_source(os.path.join('pep3147', 'foo.py'))
self.assertEqual(sys.modules['pep3147.foo'].__cached__,
os.path.join(os.curdir, foo_pyc))
@@ -852,7 +856,6 @@ class ImportlibBootstrapTests(unittest.TestCase):
from importlib import machinery
mod = sys.modules['_frozen_importlib']
self.assertIs(machinery.FileFinder, mod.FileFinder)
- self.assertIs(imp.new_module, mod.new_module)
@cpython_only
@@ -1048,17 +1051,16 @@ class ImportTracebackTests(unittest.TestCase):
finally:
importlib.SourceLoader.load_module = old_load_module
-
-def test_main(verbose=None):
- run_unittest(ImportTests, PycacheTests, FilePermissionTests,
- PycRewritingTests, PathsTests, RelativeImportTests,
- OverridingImportBuiltinTests,
- ImportlibBootstrapTests, GetSourcefileTests,
- TestSymbolicallyLinkedPackage,
- ImportTracebackTests)
+ @unittest.skipUnless(TESTFN_UNENCODABLE, 'need TESTFN_UNENCODABLE')
+ def test_unencodable_filename(self):
+ # Issue #11619: The Python parser and the import machinery must not
+ # encode filenames, especially on Windows
+ pyname = script_helper.make_script('', TESTFN_UNENCODABLE, 'pass')
+ name = pyname[:-3]
+ script_helper.assert_python_ok("-c", "mod = __import__(%a)" % name,
+ __isolated=False)
if __name__ == '__main__':
# Test needs to be a package, so we can do relative imports.
- from test.test_import import test_main
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importhooks.py b/Lib/test/test_importhooks.py
deleted file mode 100644
index 2a22d1a186..0000000000
--- a/Lib/test/test_importhooks.py
+++ /dev/null
@@ -1,250 +0,0 @@
-import sys
-import imp
-import os
-import unittest
-from test import support
-
-
-test_src = """\
-def get_name():
- return __name__
-def get_file():
- return __file__
-"""
-
-absimp = "import sub\n"
-relimp = "from . import sub\n"
-deeprelimp = "from .... import sub\n"
-futimp = "from __future__ import absolute_import\n"
-
-reload_src = test_src+"""\
-reloaded = True
-"""
-
-test_co = compile(test_src, "<???>", "exec")
-reload_co = compile(reload_src, "<???>", "exec")
-
-test2_oldabs_co = compile(absimp + test_src, "<???>", "exec")
-test2_newabs_co = compile(futimp + absimp + test_src, "<???>", "exec")
-test2_newrel_co = compile(relimp + test_src, "<???>", "exec")
-test2_deeprel_co = compile(deeprelimp + test_src, "<???>", "exec")
-test2_futrel_co = compile(futimp + relimp + test_src, "<???>", "exec")
-
-test_path = "!!!_test_!!!"
-
-
-class TestImporter:
-
- modules = {
- "hooktestmodule": (False, test_co),
- "hooktestpackage": (True, test_co),
- "hooktestpackage.sub": (True, test_co),
- "hooktestpackage.sub.subber": (True, test_co),
- "hooktestpackage.oldabs": (False, test2_oldabs_co),
- "hooktestpackage.newabs": (False, test2_newabs_co),
- "hooktestpackage.newrel": (False, test2_newrel_co),
- "hooktestpackage.sub.subber.subest": (True, test2_deeprel_co),
- "hooktestpackage.futrel": (False, test2_futrel_co),
- "sub": (False, test_co),
- "reloadmodule": (False, test_co),
- }
-
- def __init__(self, path=test_path):
- if path != test_path:
- # if our class is on sys.path_hooks, we must raise
- # ImportError for any path item that we can't handle.
- raise ImportError
- self.path = path
-
- def _get__path__(self):
- raise NotImplementedError
-
- def find_module(self, fullname, path=None):
- if fullname in self.modules:
- return self
- else:
- return None
-
- def load_module(self, fullname):
- ispkg, code = self.modules[fullname]
- mod = sys.modules.setdefault(fullname,imp.new_module(fullname))
- mod.__file__ = "<%s>" % self.__class__.__name__
- mod.__loader__ = self
- if ispkg:
- mod.__path__ = self._get__path__()
- exec(code, mod.__dict__)
- return mod
-
-
-class MetaImporter(TestImporter):
- def _get__path__(self):
- return []
-
-class PathImporter(TestImporter):
- def _get__path__(self):
- return [self.path]
-
-
-class ImportBlocker:
- """Place an ImportBlocker instance on sys.meta_path and you
- can be sure the modules you specified can't be imported, even
- if it's a builtin."""
- def __init__(self, *namestoblock):
- self.namestoblock = dict.fromkeys(namestoblock)
- def find_module(self, fullname, path=None):
- if fullname in self.namestoblock:
- return self
- return None
- def load_module(self, fullname):
- raise ImportError("I dare you")
-
-
-class ImpWrapper:
-
- def __init__(self, path=None):
- if path is not None and not os.path.isdir(path):
- raise ImportError
- self.path = path
-
- def find_module(self, fullname, path=None):
- subname = fullname.split(".")[-1]
- if subname != fullname and self.path is None:
- return None
- if self.path is None:
- path = None
- else:
- path = [self.path]
- try:
- file, filename, stuff = imp.find_module(subname, path)
- except ImportError:
- return None
- return ImpLoader(file, filename, stuff)
-
-
-class ImpLoader:
-
- def __init__(self, file, filename, stuff):
- self.file = file
- self.filename = filename
- self.stuff = stuff
-
- def load_module(self, fullname):
- mod = imp.load_module(fullname, self.file, self.filename, self.stuff)
- if self.file:
- self.file.close()
- mod.__loader__ = self # for introspection
- return mod
-
-
-class ImportHooksBaseTestCase(unittest.TestCase):
-
- def setUp(self):
- self.path = sys.path[:]
- self.meta_path = sys.meta_path[:]
- self.path_hooks = sys.path_hooks[:]
- sys.path_importer_cache.clear()
- self.modules_before = support.modules_setup()
-
- def tearDown(self):
- sys.path[:] = self.path
- sys.meta_path[:] = self.meta_path
- sys.path_hooks[:] = self.path_hooks
- sys.path_importer_cache.clear()
- support.modules_cleanup(*self.modules_before)
-
-
-class ImportHooksTestCase(ImportHooksBaseTestCase):
-
- def doTestImports(self, importer=None):
- import hooktestmodule
- import hooktestpackage
- import hooktestpackage.sub
- import hooktestpackage.sub.subber
- self.assertEqual(hooktestmodule.get_name(),
- "hooktestmodule")
- self.assertEqual(hooktestpackage.get_name(),
- "hooktestpackage")
- self.assertEqual(hooktestpackage.sub.get_name(),
- "hooktestpackage.sub")
- self.assertEqual(hooktestpackage.sub.subber.get_name(),
- "hooktestpackage.sub.subber")
- if importer:
- self.assertEqual(hooktestmodule.__loader__, importer)
- self.assertEqual(hooktestpackage.__loader__, importer)
- self.assertEqual(hooktestpackage.sub.__loader__, importer)
- self.assertEqual(hooktestpackage.sub.subber.__loader__, importer)
-
- TestImporter.modules['reloadmodule'] = (False, test_co)
- import reloadmodule
- self.assertFalse(hasattr(reloadmodule,'reloaded'))
-
- import hooktestpackage.newrel
- self.assertEqual(hooktestpackage.newrel.get_name(),
- "hooktestpackage.newrel")
- self.assertEqual(hooktestpackage.newrel.sub,
- hooktestpackage.sub)
-
- import hooktestpackage.sub.subber.subest as subest
- self.assertEqual(subest.get_name(),
- "hooktestpackage.sub.subber.subest")
- self.assertEqual(subest.sub,
- hooktestpackage.sub)
-
- import hooktestpackage.futrel
- self.assertEqual(hooktestpackage.futrel.get_name(),
- "hooktestpackage.futrel")
- self.assertEqual(hooktestpackage.futrel.sub,
- hooktestpackage.sub)
-
- import sub
- self.assertEqual(sub.get_name(), "sub")
-
- import hooktestpackage.oldabs
- self.assertEqual(hooktestpackage.oldabs.get_name(),
- "hooktestpackage.oldabs")
- self.assertEqual(hooktestpackage.oldabs.sub, sub)
-
- import hooktestpackage.newabs
- self.assertEqual(hooktestpackage.newabs.get_name(),
- "hooktestpackage.newabs")
- self.assertEqual(hooktestpackage.newabs.sub, sub)
-
- def testMetaPath(self):
- i = MetaImporter()
- sys.meta_path.append(i)
- self.doTestImports(i)
-
- def testPathHook(self):
- sys.path_hooks.insert(0, PathImporter)
- sys.path.append(test_path)
- self.doTestImports()
-
- def testBlocker(self):
- mname = "exceptions" # an arbitrary harmless builtin module
- support.unload(mname)
- sys.meta_path.append(ImportBlocker(mname))
- self.assertRaises(ImportError, __import__, mname)
-
- def testImpWrapper(self):
- i = ImpWrapper()
- sys.meta_path.append(i)
- sys.path_hooks.insert(0, ImpWrapper)
- mnames = (
- "colorsys", "urllib.parse", "distutils.core", "sys",
- )
- for mname in mnames:
- parent = mname.split(".")[0]
- for n in list(sys.modules):
- if n.startswith(parent):
- del sys.modules[n]
- for mname in mnames:
- m = __import__(mname, globals(), locals(), ["__dummy__"])
- # to make sure we actually handled the import
- self.assertTrue(hasattr(m, "__loader__"))
-
-
-def test_main():
- support.run_unittest(ImportHooksTestCase)
-
-if __name__ == "__main__":
- test_main()
diff --git a/Lib/test/test_importlib/__main__.py b/Lib/test/test_importlib/__main__.py
index c39712871f..14bd5bc017 100644
--- a/Lib/test/test_importlib/__main__.py
+++ b/Lib/test/test_importlib/__main__.py
@@ -4,17 +4,6 @@ Specifying the ``--builtin`` flag will run tests, where applicable, with
builtins.__import__ instead of importlib.__import__.
"""
-from . import test_main
-
-
if __name__ == '__main__':
- import argparse
-
- parser = argparse.ArgumentParser(description='Execute the importlib test '
- 'suite')
- parser.add_argument('-b', '--builtin', action='store_true', default=False,
- help='use builtins.__import__() instead of importlib')
- args = parser.parse_args()
- if args.builtin:
- util.using___import__ = True
+ from . import test_main
test_main()
diff --git a/Lib/test/test_importlib/abc.py b/Lib/test/test_importlib/abc.py
index 2c17ac329b..b9ff3448f7 100644
--- a/Lib/test/test_importlib/abc.py
+++ b/Lib/test/test_importlib/abc.py
@@ -2,7 +2,7 @@ import abc
import unittest
-class FinderTests(unittest.TestCase, metaclass=abc.ABCMeta):
+class FinderTests(metaclass=abc.ABCMeta):
"""Basic tests for a finder to pass."""
@@ -39,7 +39,7 @@ class FinderTests(unittest.TestCase, metaclass=abc.ABCMeta):
pass
-class LoaderTests(unittest.TestCase, metaclass=abc.ABCMeta):
+class LoaderTests(metaclass=abc.ABCMeta):
@abc.abstractmethod
def test_module(self):
diff --git a/Lib/test/test_importlib/builtin/test_finder.py b/Lib/test/test_importlib/builtin/test_finder.py
index 146538d891..aa1ea423d0 100644
--- a/Lib/test/test_importlib/builtin/test_finder.py
+++ b/Lib/test/test_importlib/builtin/test_finder.py
@@ -1,11 +1,13 @@
-from importlib import machinery
from .. import abc
from .. import util
from . import util as builtin_util
+frozen_machinery, source_machinery = util.import_importlib('importlib.machinery')
+
import sys
import unittest
+
class FinderTests(abc.FinderTests):
"""Test find_module() for built-in modules."""
@@ -13,7 +15,7 @@ class FinderTests(abc.FinderTests):
def test_module(self):
# Common case.
with util.uncache(builtin_util.NAME):
- found = machinery.BuiltinImporter.find_module(builtin_util.NAME)
+ found = self.machinery.BuiltinImporter.find_module(builtin_util.NAME)
self.assertTrue(found)
def test_package(self):
@@ -34,22 +36,19 @@ class FinderTests(abc.FinderTests):
def test_failure(self):
assert 'importlib' not in sys.builtin_module_names
- loader = machinery.BuiltinImporter.find_module('importlib')
+ loader = self.machinery.BuiltinImporter.find_module('importlib')
self.assertIsNone(loader)
def test_ignore_path(self):
# The value for 'path' should always trigger a failed import.
with util.uncache(builtin_util.NAME):
- loader = machinery.BuiltinImporter.find_module(builtin_util.NAME,
+ loader = self.machinery.BuiltinImporter.find_module(builtin_util.NAME,
['pkg'])
self.assertIsNone(loader)
-
-
-def test_main():
- from test.support import run_unittest
- run_unittest(FinderTests)
+Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests,
+ machinery=[frozen_machinery, source_machinery])
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/builtin/test_loader.py b/Lib/test/test_importlib/builtin/test_loader.py
index 8e186e7156..4ee8dc488a 100644
--- a/Lib/test/test_importlib/builtin/test_loader.py
+++ b/Lib/test/test_importlib/builtin/test_loader.py
@@ -1,9 +1,9 @@
-import importlib
-from importlib import machinery
from .. import abc
from .. import util
from . import util as builtin_util
+frozen_machinery, source_machinery = util.import_importlib('importlib.machinery')
+
import sys
import types
import unittest
@@ -13,8 +13,9 @@ class LoaderTests(abc.LoaderTests):
"""Test load_module() for built-in modules."""
- verification = {'__name__': 'errno', '__package__': '',
- '__loader__': machinery.BuiltinImporter}
+ def setUp(self):
+ self.verification = {'__name__': 'errno', '__package__': '',
+ '__loader__': self.machinery.BuiltinImporter}
def verify(self, module):
"""Verify that the module matches against what it should have."""
@@ -23,8 +24,8 @@ class LoaderTests(abc.LoaderTests):
self.assertEqual(getattr(module, attr), value)
self.assertIn(module.__name__, sys.modules)
- load_module = staticmethod(lambda name:
- machinery.BuiltinImporter.load_module(name))
+ def load_module(self, name):
+ return self.machinery.BuiltinImporter.load_module(name)
def test_module(self):
# Common case.
@@ -61,45 +62,47 @@ class LoaderTests(abc.LoaderTests):
def test_already_imported(self):
# Using the name of a module already imported but not a built-in should
# still fail.
- assert hasattr(importlib, '__file__') # Not a built-in.
+ assert hasattr(unittest, '__file__') # Not a built-in.
with self.assertRaises(ImportError) as cm:
- self.load_module('importlib')
- self.assertEqual(cm.exception.name, 'importlib')
+ self.load_module('unittest')
+ self.assertEqual(cm.exception.name, 'unittest')
+
+
+Frozen_LoaderTests, Source_LoaderTests = util.test_both(LoaderTests,
+ machinery=[frozen_machinery, source_machinery])
-class InspectLoaderTests(unittest.TestCase):
+class InspectLoaderTests:
"""Tests for InspectLoader methods for BuiltinImporter."""
def test_get_code(self):
# There is no code object.
- result = machinery.BuiltinImporter.get_code(builtin_util.NAME)
+ result = self.machinery.BuiltinImporter.get_code(builtin_util.NAME)
self.assertIsNone(result)
def test_get_source(self):
# There is no source.
- result = machinery.BuiltinImporter.get_source(builtin_util.NAME)
+ result = self.machinery.BuiltinImporter.get_source(builtin_util.NAME)
self.assertIsNone(result)
def test_is_package(self):
# Cannot be a package.
- result = machinery.BuiltinImporter.is_package(builtin_util.NAME)
+ result = self.machinery.BuiltinImporter.is_package(builtin_util.NAME)
self.assertTrue(not result)
def test_not_builtin(self):
# Modules not built-in should raise ImportError.
for meth_name in ('get_code', 'get_source', 'is_package'):
- method = getattr(machinery.BuiltinImporter, meth_name)
+ method = getattr(self.machinery.BuiltinImporter, meth_name)
with self.assertRaises(ImportError) as cm:
method(builtin_util.BAD_NAME)
self.assertRaises(builtin_util.BAD_NAME)
-
-
-def test_main():
- from test.support import run_unittest
- run_unittest(LoaderTests, InspectLoaderTests)
+Frozen_InspectLoaderTests, Source_InspectLoaderTests = util.test_both(
+ InspectLoaderTests,
+ machinery=[frozen_machinery, source_machinery])
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/extension/test_case_sensitivity.py b/Lib/test/test_importlib/extension/test_case_sensitivity.py
index 76c53e4fd6..bb743212e9 100644
--- a/Lib/test/test_importlib/extension/test_case_sensitivity.py
+++ b/Lib/test/test_importlib/extension/test_case_sensitivity.py
@@ -1,22 +1,25 @@
-import imp
+from importlib import _bootstrap
import sys
from test import support
import unittest
-from importlib import _bootstrap
+
from .. import util
from . import util as ext_util
+frozen_machinery, source_machinery = util.import_importlib('importlib.machinery')
+
+@unittest.skipIf(ext_util.FILENAME is None, '_testcapi not available')
@util.case_insensitive_tests
-class ExtensionModuleCaseSensitivityTest(unittest.TestCase):
+class ExtensionModuleCaseSensitivityTest:
def find_module(self):
good_name = ext_util.NAME
bad_name = good_name.upper()
assert good_name != bad_name
- finder = _bootstrap.FileFinder(ext_util.PATH,
- (_bootstrap.ExtensionFileLoader,
- _bootstrap.EXTENSION_SUFFIXES))
+ finder = self.machinery.FileFinder(ext_util.PATH,
+ (self.machinery.ExtensionFileLoader,
+ self.machinery.EXTENSION_SUFFIXES))
return finder.find_module(bad_name)
def test_case_sensitive(self):
@@ -37,14 +40,10 @@ class ExtensionModuleCaseSensitivityTest(unittest.TestCase):
loader = self.find_module()
self.assertTrue(hasattr(loader, 'load_module'))
-
-
-
-def test_main():
- if ext_util.FILENAME is None:
- return
- support.run_unittest(ExtensionModuleCaseSensitivityTest)
+Frozen_ExtensionCaseSensitivity, Source_ExtensionCaseSensitivity = util.test_both(
+ ExtensionModuleCaseSensitivityTest,
+ machinery=[frozen_machinery, source_machinery])
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/extension/test_finder.py b/Lib/test/test_importlib/extension/test_finder.py
index a63cfdb9b4..10e78cc522 100644
--- a/Lib/test/test_importlib/extension/test_finder.py
+++ b/Lib/test/test_importlib/extension/test_finder.py
@@ -1,17 +1,20 @@
-from importlib import machinery
from .. import abc
+from .. import util as test_util
from . import util
+machinery = test_util.import_importlib('importlib.machinery')
+
import unittest
+
class FinderTests(abc.FinderTests):
"""Test the finder for extension modules."""
def find_module(self, fullname):
- importer = machinery.FileFinder(util.PATH,
- (machinery.ExtensionFileLoader,
- machinery.EXTENSION_SUFFIXES))
+ importer = self.machinery.FileFinder(util.PATH,
+ (self.machinery.ExtensionFileLoader,
+ self.machinery.EXTENSION_SUFFIXES))
return importer.find_module(fullname)
def test_module(self):
@@ -36,11 +39,9 @@ class FinderTests(abc.FinderTests):
def test_failure(self):
self.assertIsNone(self.find_module('asdfjkl;'))
-
-def test_main():
- from test.support import run_unittest
- run_unittest(FinderTests)
+Frozen_FinderTests, Source_FinderTests = test_util.test_both(
+ FinderTests, machinery=machinery)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/extension/test_loader.py b/Lib/test/test_importlib/extension/test_loader.py
index ca5af201c4..1e8afba702 100644
--- a/Lib/test/test_importlib/extension/test_loader.py
+++ b/Lib/test/test_importlib/extension/test_loader.py
@@ -1,8 +1,9 @@
-from importlib import machinery
from . import util as ext_util
from .. import abc
from .. import util
+machinery = util.import_importlib('importlib.machinery')
+
import os.path
import sys
import unittest
@@ -13,8 +14,8 @@ class LoaderTests(abc.LoaderTests):
"""Test load_module() for extension modules."""
def setUp(self):
- self.loader = machinery.ExtensionFileLoader(ext_util.NAME,
- ext_util.FILEPATH)
+ self.loader = self.machinery.ExtensionFileLoader(ext_util.NAME,
+ ext_util.FILEPATH)
def load_module(self, fullname):
return self.loader.load_module(fullname)
@@ -36,7 +37,7 @@ class LoaderTests(abc.LoaderTests):
self.assertEqual(getattr(module, attr), value)
self.assertIn(ext_util.NAME, sys.modules)
self.assertIsInstance(module.__loader__,
- machinery.ExtensionFileLoader)
+ self.machinery.ExtensionFileLoader)
def test_package(self):
# No extension module as __init__ available for testing.
@@ -64,16 +65,15 @@ class LoaderTests(abc.LoaderTests):
def test_is_package(self):
self.assertFalse(self.loader.is_package(ext_util.NAME))
- for suffix in machinery.EXTENSION_SUFFIXES:
+ for suffix in self.machinery.EXTENSION_SUFFIXES:
path = os.path.join('some', 'path', 'pkg', '__init__' + suffix)
- loader = machinery.ExtensionFileLoader('pkg', path)
+ loader = self.machinery.ExtensionFileLoader('pkg', path)
self.assertTrue(loader.is_package('pkg'))
+Frozen_LoaderTests, Source_LoaderTests = util.test_both(
+ LoaderTests, machinery=machinery)
-def test_main():
- from test.support import run_unittest
- run_unittest(LoaderTests)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/extension/test_path_hook.py b/Lib/test/test_importlib/extension/test_path_hook.py
index 1d969a1b28..49d6734711 100644
--- a/Lib/test/test_importlib/extension/test_path_hook.py
+++ b/Lib/test/test_importlib/extension/test_path_hook.py
@@ -1,32 +1,32 @@
-from importlib import machinery
+from .. import util as test_util
from . import util
+machinery = test_util.import_importlib('importlib.machinery')
+
import collections
-import imp
import sys
import unittest
-class PathHookTests(unittest.TestCase):
+class PathHookTests:
"""Test the path hook for extension modules."""
# XXX Should it only succeed for pre-existing directories?
# XXX Should it only work for directories containing an extension module?
def hook(self, entry):
- return machinery.FileFinder.path_hook((machinery.ExtensionFileLoader,
- machinery.EXTENSION_SUFFIXES))(entry)
+ return self.machinery.FileFinder.path_hook(
+ (self.machinery.ExtensionFileLoader,
+ self.machinery.EXTENSION_SUFFIXES))(entry)
def test_success(self):
# Path hook should handle a directory where a known extension module
# exists.
self.assertTrue(hasattr(self.hook(util.PATH), 'find_module'))
-
-def test_main():
- from test.support import run_unittest
- run_unittest(PathHookTests)
+Frozen_PathHooksTests, Source_PathHooksTests = test_util.test_both(
+ PathHookTests, machinery=machinery)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/extension/util.py b/Lib/test/test_importlib/extension/util.py
index a266dd98c8..8d089f0d7c 100644
--- a/Lib/test/test_importlib/extension/util.py
+++ b/Lib/test/test_importlib/extension/util.py
@@ -1,4 +1,3 @@
-import imp
from importlib import machinery
import os
import sys
diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py
index fa0c2a037e..9e629bf9d0 100644
--- a/Lib/test/test_importlib/frozen/test_finder.py
+++ b/Lib/test/test_importlib/frozen/test_finder.py
@@ -1,5 +1,7 @@
-from importlib import machinery
from .. import abc
+from .. import util
+
+machinery = util.import_importlib('importlib.machinery')
import unittest
@@ -9,7 +11,7 @@ class FinderTests(abc.FinderTests):
"""Test finding frozen modules."""
def find(self, name, path=None):
- finder = machinery.FrozenImporter
+ finder = self.machinery.FrozenImporter
return finder.find_module(name, path)
def test_module(self):
@@ -37,11 +39,9 @@ class FinderTests(abc.FinderTests):
loader = self.find('<not real>')
self.assertIsNone(loader)
-
-def test_main():
- from test.support import run_unittest
- run_unittest(FinderTests)
+Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests,
+ machinery=machinery)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py
index 4b8ec15554..be8dc3c07a 100644
--- a/Lib/test/test_importlib/frozen/test_loader.py
+++ b/Lib/test/test_importlib/frozen/test_loader.py
@@ -1,18 +1,21 @@
-from importlib import machinery
-import imp
-import unittest
from .. import abc
from .. import util
+
+machinery = util.import_importlib('importlib.machinery')
+
+import unittest
from test.support import captured_stdout
+import types
+
class LoaderTests(abc.LoaderTests):
def test_module(self):
with util.uncache('__hello__'), captured_stdout() as stdout:
- module = machinery.FrozenImporter.load_module('__hello__')
+ module = self.machinery.FrozenImporter.load_module('__hello__')
check = {'__name__': '__hello__',
'__package__': '',
- '__loader__': machinery.FrozenImporter,
+ '__loader__': self.machinery.FrozenImporter,
}
for attr, value in check.items():
self.assertEqual(getattr(module, attr), value)
@@ -21,11 +24,11 @@ class LoaderTests(abc.LoaderTests):
def test_package(self):
with util.uncache('__phello__'), captured_stdout() as stdout:
- module = machinery.FrozenImporter.load_module('__phello__')
+ module = self.machinery.FrozenImporter.load_module('__phello__')
check = {'__name__': '__phello__',
'__package__': '__phello__',
- '__path__': ['__phello__'],
- '__loader__': machinery.FrozenImporter,
+ '__path__': [],
+ '__loader__': self.machinery.FrozenImporter,
}
for attr, value in check.items():
attr_value = getattr(module, attr)
@@ -38,10 +41,10 @@ class LoaderTests(abc.LoaderTests):
def test_lacking_parent(self):
with util.uncache('__phello__', '__phello__.spam'), \
captured_stdout() as stdout:
- module = machinery.FrozenImporter.load_module('__phello__.spam')
+ module = self.machinery.FrozenImporter.load_module('__phello__.spam')
check = {'__name__': '__phello__.spam',
'__package__': '__phello__',
- '__loader__': machinery.FrozenImporter,
+ '__loader__': self.machinery.FrozenImporter,
}
for attr, value in check.items():
attr_value = getattr(module, attr)
@@ -53,15 +56,15 @@ class LoaderTests(abc.LoaderTests):
def test_module_reuse(self):
with util.uncache('__hello__'), captured_stdout() as stdout:
- module1 = machinery.FrozenImporter.load_module('__hello__')
- module2 = machinery.FrozenImporter.load_module('__hello__')
+ module1 = self.machinery.FrozenImporter.load_module('__hello__')
+ module2 = self.machinery.FrozenImporter.load_module('__hello__')
self.assertIs(module1, module2)
self.assertEqual(stdout.getvalue(),
'Hello world!\nHello world!\n')
def test_module_repr(self):
with util.uncache('__hello__'), captured_stdout():
- module = machinery.FrozenImporter.load_module('__hello__')
+ module = self.machinery.FrozenImporter.load_module('__hello__')
self.assertEqual(repr(module),
"<module '__hello__' (frozen)>")
@@ -70,13 +73,16 @@ class LoaderTests(abc.LoaderTests):
pass
def test_unloadable(self):
- assert machinery.FrozenImporter.find_module('_not_real') is None
+ assert self.machinery.FrozenImporter.find_module('_not_real') is None
with self.assertRaises(ImportError) as cm:
- machinery.FrozenImporter.load_module('_not_real')
+ self.machinery.FrozenImporter.load_module('_not_real')
self.assertEqual(cm.exception.name, '_not_real')
+Frozen_LoaderTests, Source_LoaderTests = util.test_both(LoaderTests,
+ machinery=machinery)
+
-class InspectLoaderTests(unittest.TestCase):
+class InspectLoaderTests:
"""Tests for the InspectLoader methods for FrozenImporter."""
@@ -84,15 +90,15 @@ class InspectLoaderTests(unittest.TestCase):
# Make sure that the code object is good.
name = '__hello__'
with captured_stdout() as stdout:
- code = machinery.FrozenImporter.get_code(name)
- mod = imp.new_module(name)
+ code = self.machinery.FrozenImporter.get_code(name)
+ mod = types.ModuleType(name)
exec(code, mod.__dict__)
self.assertTrue(hasattr(mod, 'initialized'))
self.assertEqual(stdout.getvalue(), 'Hello world!\n')
def test_get_source(self):
# Should always return None.
- result = machinery.FrozenImporter.get_source('__hello__')
+ result = self.machinery.FrozenImporter.get_source('__hello__')
self.assertIsNone(result)
def test_is_package(self):
@@ -100,22 +106,20 @@ class InspectLoaderTests(unittest.TestCase):
test_for = (('__hello__', False), ('__phello__', True),
('__phello__.spam', False))
for name, is_package in test_for:
- result = machinery.FrozenImporter.is_package(name)
+ result = self.machinery.FrozenImporter.is_package(name)
self.assertEqual(bool(result), is_package)
def test_failure(self):
# Raise ImportError for modules that are not frozen.
for meth_name in ('get_code', 'get_source', 'is_package'):
- method = getattr(machinery.FrozenImporter, meth_name)
+ method = getattr(self.machinery.FrozenImporter, meth_name)
with self.assertRaises(ImportError) as cm:
method('importlib')
self.assertEqual(cm.exception.name, 'importlib')
-
-def test_main():
- from test.support import run_unittest
- run_unittest(LoaderTests, InspectLoaderTests)
+Frozen_ILTests, Source_ILTests = util.test_both(InspectLoaderTests,
+ machinery=machinery)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test___loader__.py b/Lib/test/test_importlib/import_/test___loader__.py
new file mode 100644
index 0000000000..9c18d19709
--- /dev/null
+++ b/Lib/test/test_importlib/import_/test___loader__.py
@@ -0,0 +1,48 @@
+import sys
+import types
+import unittest
+
+from .. import util
+from . import util as import_util
+
+
+class LoaderMock:
+
+ def find_module(self, fullname, path=None):
+ return self
+
+ def load_module(self, fullname):
+ sys.modules[fullname] = self.module
+ return self.module
+
+
+class LoaderAttributeTests:
+
+ def test___loader___missing(self):
+ module = types.ModuleType('blah')
+ try:
+ del module.__loader__
+ except AttributeError:
+ pass
+ loader = LoaderMock()
+ loader.module = module
+ with util.uncache('blah'), util.import_state(meta_path=[loader]):
+ module = self.__import__('blah')
+ self.assertEqual(loader, module.__loader__)
+
+ def test___loader___is_None(self):
+ module = types.ModuleType('blah')
+ module.__loader__ = None
+ loader = LoaderMock()
+ loader.module = module
+ with util.uncache('blah'), util.import_state(meta_path=[loader]):
+ returned_module = self.__import__('blah')
+ self.assertEqual(loader, module.__loader__)
+
+
+Frozen_Tests, Source_Tests = util.test_both(LoaderAttributeTests,
+ __import__=import_util.__import__)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test___package__.py b/Lib/test/test_importlib/import_/test___package__.py
index 783cde1729..5fa432f2e3 100644
--- a/Lib/test/test_importlib/import_/test___package__.py
+++ b/Lib/test/test_importlib/import_/test___package__.py
@@ -9,7 +9,7 @@ from .. import util
from . import util as import_util
-class Using__package__(unittest.TestCase):
+class Using__package__:
"""Use of __package__ supercedes the use of __name__/__path__ to calculate
what package a module belongs to. The basic algorithm is [__package__]::
@@ -38,8 +38,8 @@ class Using__package__(unittest.TestCase):
# [__package__]
with util.mock_modules('pkg.__init__', 'pkg.fake') as importer:
with util.import_state(meta_path=[importer]):
- import_util.import_('pkg.fake')
- module = import_util.import_('',
+ self.__import__('pkg.fake')
+ module = self.__import__('',
globals={'__package__': 'pkg.fake'},
fromlist=['attr'], level=2)
self.assertEqual(module.__name__, 'pkg')
@@ -51,8 +51,8 @@ class Using__package__(unittest.TestCase):
globals_['__package__'] = None
with util.mock_modules('pkg.__init__', 'pkg.fake') as importer:
with util.import_state(meta_path=[importer]):
- import_util.import_('pkg.fake')
- module = import_util.import_('', globals= globals_,
+ self.__import__('pkg.fake')
+ module = self.__import__('', globals= globals_,
fromlist=['attr'], level=2)
self.assertEqual(module.__name__, 'pkg')
@@ -63,15 +63,17 @@ class Using__package__(unittest.TestCase):
def test_bad__package__(self):
globals = {'__package__': '<not real>'}
with self.assertRaises(SystemError):
- import_util.import_('', globals, {}, ['relimport'], 1)
+ self.__import__('', globals, {}, ['relimport'], 1)
def test_bunk__package__(self):
globals = {'__package__': 42}
with self.assertRaises(TypeError):
- import_util.import_('', globals, {}, ['relimport'], 1)
+ self.__import__('', globals, {}, ['relimport'], 1)
+
+Frozen_UsingPackage, Source_UsingPackage = util.test_both(
+ Using__package__, __import__=import_util.__import__)
-@import_util.importlib_only
class Setting__package__(unittest.TestCase):
"""Because __package__ is a new feature, it is not always set by a loader.
@@ -84,12 +86,14 @@ class Setting__package__(unittest.TestCase):
"""
+ __import__ = import_util.__import__[1]
+
# [top-level]
def test_top_level(self):
with util.mock_modules('top_level') as mock:
with util.import_state(meta_path=[mock]):
del mock['top_level'].__package__
- module = import_util.import_('top_level')
+ module = self.__import__('top_level')
self.assertEqual(module.__package__, '')
# [package]
@@ -97,7 +101,7 @@ class Setting__package__(unittest.TestCase):
with util.mock_modules('pkg.__init__') as mock:
with util.import_state(meta_path=[mock]):
del mock['pkg'].__package__
- module = import_util.import_('pkg')
+ module = self.__import__('pkg')
self.assertEqual(module.__package__, 'pkg')
# [submodule]
@@ -105,15 +109,10 @@ class Setting__package__(unittest.TestCase):
with util.mock_modules('pkg.__init__', 'pkg.mod') as mock:
with util.import_state(meta_path=[mock]):
del mock['pkg.mod'].__package__
- pkg = import_util.import_('pkg.mod')
+ pkg = self.__import__('pkg.mod')
module = getattr(pkg, 'mod')
self.assertEqual(module.__package__, 'pkg')
-def test_main():
- from test.support import run_unittest
- run_unittest(Using__package__, Setting__package__)
-
-
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_api.py b/Lib/test/test_importlib/import_/test_api.py
index 3d4cd94889..dc8b8a8f38 100644
--- a/Lib/test/test_importlib/import_/test_api.py
+++ b/Lib/test/test_importlib/import_/test_api.py
@@ -1,7 +1,7 @@
-from .. import util as importlib_test_util
-from . import util
-import imp
+from .. import util
+from . import util as import_util
import sys
+import types
import unittest
@@ -17,7 +17,7 @@ class BadLoaderFinder:
raise ImportError('I cannot be loaded!')
-class APITest(unittest.TestCase):
+class APITest:
"""Test API-specific details for __import__ (e.g. raising the right
exception when passing in an int for the module name)."""
@@ -25,43 +25,40 @@ class APITest(unittest.TestCase):
def test_name_requires_rparition(self):
# Raise TypeError if a non-string is passed in for the module name.
with self.assertRaises(TypeError):
- util.import_(42)
+ self.__import__(42)
def test_negative_level(self):
# Raise ValueError when a negative level is specified.
# PEP 328 did away with sys.module None entries and the ambiguity of
# absolute/relative imports.
with self.assertRaises(ValueError):
- util.import_('os', globals(), level=-1)
+ self.__import__('os', globals(), level=-1)
def test_nonexistent_fromlist_entry(self):
# If something in fromlist doesn't exist, that's okay.
# issue15715
- mod = imp.new_module('fine')
+ mod = types.ModuleType('fine')
mod.__path__ = ['XXX']
- with importlib_test_util.import_state(meta_path=[BadLoaderFinder]):
- with importlib_test_util.uncache('fine'):
+ with util.import_state(meta_path=[BadLoaderFinder]):
+ with util.uncache('fine'):
sys.modules['fine'] = mod
- util.import_('fine', fromlist=['not here'])
+ self.__import__('fine', fromlist=['not here'])
def test_fromlist_load_error_propagates(self):
# If something in fromlist triggers an exception not related to not
# existing, let that exception propagate.
# issue15316
- mod = imp.new_module('fine')
+ mod = types.ModuleType('fine')
mod.__path__ = ['XXX']
- with importlib_test_util.import_state(meta_path=[BadLoaderFinder]):
- with importlib_test_util.uncache('fine'):
+ with util.import_state(meta_path=[BadLoaderFinder]):
+ with util.uncache('fine'):
sys.modules['fine'] = mod
with self.assertRaises(ImportError):
- util.import_('fine', fromlist=['bogus'])
+ self.__import__('fine', fromlist=['bogus'])
-
-
-def test_main():
- from test.support import run_unittest
- run_unittest(APITest)
+Frozen_APITests, Source_APITests = util.test_both(
+ APITest, __import__=import_util.__import__)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_caching.py b/Lib/test/test_importlib/import_/test_caching.py
index 207378a6fb..4d082b8eeb 100644
--- a/Lib/test/test_importlib/import_/test_caching.py
+++ b/Lib/test/test_importlib/import_/test_caching.py
@@ -6,7 +6,7 @@ from types import MethodType
import unittest
-class UseCache(unittest.TestCase):
+class UseCache:
"""When it comes to sys.modules, import prefers it over anything else.
@@ -21,12 +21,13 @@ class UseCache(unittest.TestCase):
ImportError is raised [None in cache].
"""
+
def test_using_cache(self):
# [use cache]
module_to_use = "some module found!"
with util.uncache('some_module'):
sys.modules['some_module'] = module_to_use
- module = import_util.import_('some_module')
+ module = self.__import__('some_module')
self.assertEqual(id(module_to_use), id(module))
def test_None_in_cache(self):
@@ -35,7 +36,7 @@ class UseCache(unittest.TestCase):
with util.uncache(name):
sys.modules[name] = None
with self.assertRaises(ImportError) as cm:
- import_util.import_(name)
+ self.__import__(name)
self.assertEqual(cm.exception.name, name)
def create_mock(self, *names, return_=None):
@@ -47,42 +48,43 @@ class UseCache(unittest.TestCase):
mock.load_module = MethodType(load_module, mock)
return mock
+Frozen_UseCache, Source_UseCache = util.test_both(
+ UseCache, __import__=import_util.__import__)
+
+
+class ImportlibUseCache(UseCache, unittest.TestCase):
+
+ __import__ = import_util.__import__[1]
+
# __import__ inconsistent between loaders and built-in import when it comes
# to when to use the module in sys.modules and when not to.
- @import_util.importlib_only
def test_using_cache_after_loader(self):
# [from cache on return]
with self.create_mock('module') as mock:
with util.import_state(meta_path=[mock]):
- module = import_util.import_('module')
+ module = self.__import__('module')
self.assertEqual(id(module), id(sys.modules['module']))
# See test_using_cache_after_loader() for reasoning.
- @import_util.importlib_only
def test_using_cache_for_assigning_to_attribute(self):
# [from cache to attribute]
with self.create_mock('pkg.__init__', 'pkg.module') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('pkg.module')
+ module = self.__import__('pkg.module')
self.assertTrue(hasattr(module, 'module'))
self.assertEqual(id(module.module),
id(sys.modules['pkg.module']))
# See test_using_cache_after_loader() for reasoning.
- @import_util.importlib_only
def test_using_cache_for_fromlist(self):
# [from cache for fromlist]
with self.create_mock('pkg.__init__', 'pkg.module') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('pkg', fromlist=['module'])
+ module = self.__import__('pkg', fromlist=['module'])
self.assertTrue(hasattr(module, 'module'))
self.assertEqual(id(module.module),
id(sys.modules['pkg.module']))
-def test_main():
- from test.support import run_unittest
- run_unittest(UseCache)
-
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_fromlist.py b/Lib/test/test_importlib/import_/test_fromlist.py
index c16c33710f..fc29ce667d 100644
--- a/Lib/test/test_importlib/import_/test_fromlist.py
+++ b/Lib/test/test_importlib/import_/test_fromlist.py
@@ -1,10 +1,10 @@
"""Test that the semantics relating to the 'fromlist' argument are correct."""
from .. import util
from . import util as import_util
-import imp
import unittest
-class ReturnValue(unittest.TestCase):
+
+class ReturnValue:
"""The use of fromlist influences what import returns.
@@ -19,18 +19,21 @@ class ReturnValue(unittest.TestCase):
# [import return]
with util.mock_modules('pkg.__init__', 'pkg.module') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('pkg.module')
+ module = self.__import__('pkg.module')
self.assertEqual(module.__name__, 'pkg')
def test_return_from_from_import(self):
# [from return]
with util.mock_modules('pkg.__init__', 'pkg.module')as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('pkg.module', fromlist=['attr'])
+ module = self.__import__('pkg.module', fromlist=['attr'])
self.assertEqual(module.__name__, 'pkg.module')
+Frozen_ReturnValue, Source_ReturnValue = util.test_both(
+ ReturnValue, __import__=import_util.__import__)
+
-class HandlingFromlist(unittest.TestCase):
+class HandlingFromlist:
"""Using fromlist triggers different actions based on what is being asked
of it.
@@ -49,14 +52,14 @@ class HandlingFromlist(unittest.TestCase):
# [object case]
with util.mock_modules('module') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('module', fromlist=['attr'])
+ module = self.__import__('module', fromlist=['attr'])
self.assertEqual(module.__name__, 'module')
def test_nonexistent_object(self):
# [bad object]
with util.mock_modules('module') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('module', fromlist=['non_existent'])
+ module = self.__import__('module', fromlist=['non_existent'])
self.assertEqual(module.__name__, 'module')
self.assertTrue(not hasattr(module, 'non_existent'))
@@ -64,7 +67,7 @@ class HandlingFromlist(unittest.TestCase):
# [module]
with util.mock_modules('pkg.__init__', 'pkg.module') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('pkg', fromlist=['module'])
+ module = self.__import__('pkg', fromlist=['module'])
self.assertEqual(module.__name__, 'pkg')
self.assertTrue(hasattr(module, 'module'))
self.assertEqual(module.module.__name__, 'pkg.module')
@@ -79,13 +82,13 @@ class HandlingFromlist(unittest.TestCase):
module_code={'pkg.mod': module_code}) as importer:
with util.import_state(meta_path=[importer]):
with self.assertRaises(ImportError) as exc:
- import_util.import_('pkg', fromlist=['mod'])
+ self.__import__('pkg', fromlist=['mod'])
self.assertEqual('i_do_not_exist', exc.exception.name)
def test_empty_string(self):
with util.mock_modules('pkg.__init__', 'pkg.mod') as importer:
with util.import_state(meta_path=[importer]):
- module = import_util.import_('pkg.mod', fromlist=[''])
+ module = self.__import__('pkg.mod', fromlist=[''])
self.assertEqual(module.__name__, 'pkg.mod')
def basic_star_test(self, fromlist=['*']):
@@ -93,7 +96,7 @@ class HandlingFromlist(unittest.TestCase):
with util.mock_modules('pkg.__init__', 'pkg.module') as mock:
with util.import_state(meta_path=[mock]):
mock['pkg'].__all__ = ['module']
- module = import_util.import_('pkg', fromlist=fromlist)
+ module = self.__import__('pkg', fromlist=fromlist)
self.assertEqual(module.__name__, 'pkg')
self.assertTrue(hasattr(module, 'module'))
self.assertEqual(module.module.__name__, 'pkg.module')
@@ -111,17 +114,16 @@ class HandlingFromlist(unittest.TestCase):
with context as mock:
with util.import_state(meta_path=[mock]):
mock['pkg'].__all__ = ['module1']
- module = import_util.import_('pkg', fromlist=['module2', '*'])
+ module = self.__import__('pkg', fromlist=['module2', '*'])
self.assertEqual(module.__name__, 'pkg')
self.assertTrue(hasattr(module, 'module1'))
self.assertTrue(hasattr(module, 'module2'))
self.assertEqual(module.module1.__name__, 'pkg.module1')
self.assertEqual(module.module2.__name__, 'pkg.module2')
+Frozen_FromList, Source_FromList = util.test_both(
+ HandlingFromlist, __import__=import_util.__import__)
-def test_main():
- from test.support import run_unittest
- run_unittest(ReturnValue, HandlingFromlist)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_meta_path.py b/Lib/test/test_importlib/import_/test_meta_path.py
index 4d85f80e67..74afd1b704 100644
--- a/Lib/test/test_importlib/import_/test_meta_path.py
+++ b/Lib/test/test_importlib/import_/test_meta_path.py
@@ -7,7 +7,7 @@ import unittest
import warnings
-class CallingOrder(unittest.TestCase):
+class CallingOrder:
"""Calls to the importers on sys.meta_path happen in order that they are
specified in the sequence, starting with the first importer
@@ -24,7 +24,7 @@ class CallingOrder(unittest.TestCase):
first.modules[mod] = 42
second.modules[mod] = -13
with util.import_state(meta_path=[first, second]):
- self.assertEqual(import_util.import_(mod), 42)
+ self.assertEqual(self.__import__(mod), 42)
def test_continuing(self):
# [continuing]
@@ -34,7 +34,7 @@ class CallingOrder(unittest.TestCase):
first.find_module = lambda self, fullname, path=None: None
second.modules[mod_name] = 42
with util.import_state(meta_path=[first, second]):
- self.assertEqual(import_util.import_(mod_name), 42)
+ self.assertEqual(self.__import__(mod_name), 42)
def test_empty(self):
# Raise an ImportWarning if sys.meta_path is empty.
@@ -51,8 +51,11 @@ class CallingOrder(unittest.TestCase):
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, ImportWarning))
+Frozen_CallingOrder, Source_CallingOrder = util.test_both(
+ CallingOrder, __import__=import_util.__import__)
-class CallSignature(unittest.TestCase):
+
+class CallSignature:
"""If there is no __path__ entry on the parent module, then 'path' is None
[no path]. Otherwise, the value for __path__ is passed in for the 'path'
@@ -74,7 +77,7 @@ class CallSignature(unittest.TestCase):
log, wrapped_call = self.log(importer.find_module)
importer.find_module = MethodType(wrapped_call, importer)
with util.import_state(meta_path=[importer]):
- import_util.import_(mod_name)
+ self.__import__(mod_name)
assert len(log) == 1
args = log[0][0]
kwargs = log[0][1]
@@ -95,7 +98,7 @@ class CallSignature(unittest.TestCase):
log, wrapped_call = self.log(importer.find_module)
importer.find_module = MethodType(wrapped_call, importer)
with util.import_state(meta_path=[importer]):
- import_util.import_(mod_name)
+ self.__import__(mod_name)
assert len(log) == 2
args = log[1][0]
kwargs = log[1][1]
@@ -104,12 +107,9 @@ class CallSignature(unittest.TestCase):
self.assertEqual(args[0], mod_name)
self.assertIs(args[1], path)
-
-
-def test_main():
- from test.support import run_unittest
- run_unittest(CallingOrder, CallSignature)
+Frozen_CallSignature, Source_CallSignature = util.test_both(
+ CallSignature, __import__=import_util.__import__)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_packages.py b/Lib/test/test_importlib/import_/test_packages.py
index bfa18dc217..2df421d78e 100644
--- a/Lib/test/test_importlib/import_/test_packages.py
+++ b/Lib/test/test_importlib/import_/test_packages.py
@@ -6,21 +6,21 @@ import importlib
from test import support
-class ParentModuleTests(unittest.TestCase):
+class ParentModuleTests:
"""Importing a submodule should import the parent modules."""
def test_import_parent(self):
with util.mock_modules('pkg.__init__', 'pkg.module') as mock:
with util.import_state(meta_path=[mock]):
- module = import_util.import_('pkg.module')
+ module = self.__import__('pkg.module')
self.assertIn('pkg', sys.modules)
def test_bad_parent(self):
with util.mock_modules('pkg.module') as mock:
with util.import_state(meta_path=[mock]):
with self.assertRaises(ImportError) as cm:
- import_util.import_('pkg.module')
+ self.__import__('pkg.module')
self.assertEqual(cm.exception.name, 'pkg')
def test_raising_parent_after_importing_child(self):
@@ -32,11 +32,11 @@ class ParentModuleTests(unittest.TestCase):
with mock:
with util.import_state(meta_path=[mock]):
with self.assertRaises(ZeroDivisionError):
- import_util.import_('pkg')
+ self.__import__('pkg')
self.assertNotIn('pkg', sys.modules)
self.assertIn('pkg.module', sys.modules)
with self.assertRaises(ZeroDivisionError):
- import_util.import_('pkg.module')
+ self.__import__('pkg.module')
self.assertNotIn('pkg', sys.modules)
self.assertIn('pkg.module', sys.modules)
@@ -51,10 +51,10 @@ class ParentModuleTests(unittest.TestCase):
with self.assertRaises((ZeroDivisionError, ImportError)):
# This raises ImportError on the "from . import module"
# line, not sure why.
- import_util.import_('pkg')
+ self.__import__('pkg')
self.assertNotIn('pkg', sys.modules)
with self.assertRaises((ZeroDivisionError, ImportError)):
- import_util.import_('pkg.module')
+ self.__import__('pkg.module')
self.assertNotIn('pkg', sys.modules)
# XXX False
#self.assertIn('pkg.module', sys.modules)
@@ -71,10 +71,10 @@ class ParentModuleTests(unittest.TestCase):
with self.assertRaises((ZeroDivisionError, ImportError)):
# This raises ImportError on the "from ..subpkg import module"
# line, not sure why.
- import_util.import_('pkg.subpkg')
+ self.__import__('pkg.subpkg')
self.assertNotIn('pkg.subpkg', sys.modules)
with self.assertRaises((ZeroDivisionError, ImportError)):
- import_util.import_('pkg.subpkg.module')
+ self.__import__('pkg.subpkg.module')
self.assertNotIn('pkg.subpkg', sys.modules)
# XXX False
#self.assertIn('pkg.subpkg.module', sys.modules)
@@ -83,7 +83,7 @@ class ParentModuleTests(unittest.TestCase):
# Try to import a submodule from a non-package should raise ImportError.
assert not hasattr(sys, '__path__')
with self.assertRaises(ImportError) as cm:
- import_util.import_('sys.no_submodules_here')
+ self.__import__('sys.no_submodules_here')
self.assertEqual(cm.exception.name, 'sys.no_submodules_here')
def test_module_not_package_but_side_effects(self):
@@ -98,15 +98,13 @@ class ParentModuleTests(unittest.TestCase):
with mock_modules as mock:
with util.import_state(meta_path=[mock]):
try:
- submodule = import_util.import_(subname)
+ submodule = self.__import__(subname)
finally:
support.unload(subname)
-
-def test_main():
- from test.support import run_unittest
- run_unittest(ParentModuleTests)
+Frozen_ParentTests, Source_ParentTests = util.test_both(
+ ParentModuleTests, __import__=import_util.__import__)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_path.py b/Lib/test/test_importlib/import_/test_path.py
index d82b7f6b0c..14fdaa339e 100644
--- a/Lib/test/test_importlib/import_/test_path.py
+++ b/Lib/test/test_importlib/import_/test_path.py
@@ -1,8 +1,9 @@
-from importlib import _bootstrap
-from importlib import machinery
-from importlib import import_module
from .. import util
from . import util as import_util
+
+importlib = util.import_importlib('importlib')
+machinery = util.import_importlib('importlib.machinery')
+
import os
import sys
from types import ModuleType
@@ -11,7 +12,7 @@ import warnings
import zipimport
-class FinderTests(unittest.TestCase):
+class FinderTests:
"""Tests for PathFinder."""
@@ -19,7 +20,7 @@ class FinderTests(unittest.TestCase):
# Test None returned upon not finding a suitable finder.
module = '<test module>'
with util.import_state():
- self.assertIsNone(machinery.PathFinder.find_module(module))
+ self.assertIsNone(self.machinery.PathFinder.find_module(module))
def test_sys_path(self):
# Test that sys.path is used when 'path' is None.
@@ -29,7 +30,7 @@ class FinderTests(unittest.TestCase):
importer = util.mock_modules(module)
with util.import_state(path_importer_cache={path: importer},
path=[path]):
- loader = machinery.PathFinder.find_module(module)
+ loader = self.machinery.PathFinder.find_module(module)
self.assertIs(loader, importer)
def test_path(self):
@@ -39,7 +40,7 @@ class FinderTests(unittest.TestCase):
path = '<test path>'
importer = util.mock_modules(module)
with util.import_state(path_importer_cache={path: importer}):
- loader = machinery.PathFinder.find_module(module, [path])
+ loader = self.machinery.PathFinder.find_module(module, [path])
self.assertIs(loader, importer)
def test_empty_list(self):
@@ -49,7 +50,7 @@ class FinderTests(unittest.TestCase):
importer = util.mock_modules(module)
with util.import_state(path_importer_cache={path: importer},
path=[path]):
- self.assertIsNone(machinery.PathFinder.find_module('module', []))
+ self.assertIsNone(self.machinery.PathFinder.find_module('module', []))
def test_path_hooks(self):
# Test that sys.path_hooks is used.
@@ -59,7 +60,7 @@ class FinderTests(unittest.TestCase):
importer = util.mock_modules(module)
hook = import_util.mock_path_hook(path, importer=importer)
with util.import_state(path_hooks=[hook]):
- loader = machinery.PathFinder.find_module(module, [path])
+ loader = self.machinery.PathFinder.find_module(module, [path])
self.assertIs(loader, importer)
self.assertIn(path, sys.path_importer_cache)
self.assertIs(sys.path_importer_cache[path], importer)
@@ -72,7 +73,7 @@ class FinderTests(unittest.TestCase):
path=[path_entry]):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
- self.assertIsNone(machinery.PathFinder.find_module('os'))
+ self.assertIsNone(self.machinery.PathFinder.find_module('os'))
self.assertIsNone(sys.path_importer_cache[path_entry])
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, ImportWarning))
@@ -82,11 +83,11 @@ class FinderTests(unittest.TestCase):
path = ''
module = '<test module>'
importer = util.mock_modules(module)
- hook = import_util.mock_path_hook(os.curdir, importer=importer)
+ hook = import_util.mock_path_hook(os.getcwd(), importer=importer)
with util.import_state(path=[path], path_hooks=[hook]):
- loader = machinery.PathFinder.find_module(module)
+ loader = self.machinery.PathFinder.find_module(module)
self.assertIs(loader, importer)
- self.assertIn(os.curdir, sys.path_importer_cache)
+ self.assertIn(os.getcwd(), sys.path_importer_cache)
def test_None_on_sys_path(self):
# Putting None in sys.path[0] caused an import regression from Python
@@ -96,8 +97,8 @@ class FinderTests(unittest.TestCase):
new_path_importer_cache = sys.path_importer_cache.copy()
new_path_importer_cache.pop(None, None)
new_path_hooks = [zipimport.zipimporter,
- _bootstrap.FileFinder.path_hook(
- *_bootstrap._get_supported_file_loaders())]
+ self.machinery.FileFinder.path_hook(
+ *self.importlib._bootstrap._get_supported_file_loaders())]
missing = object()
email = sys.modules.pop('email', missing)
try:
@@ -105,16 +106,15 @@ class FinderTests(unittest.TestCase):
path=new_path,
path_importer_cache=new_path_importer_cache,
path_hooks=new_path_hooks):
- module = import_module('email')
+ module = self.importlib.import_module('email')
self.assertIsInstance(module, ModuleType)
finally:
if email is not missing:
sys.modules['email'] = email
+Frozen_FinderTests, Source_FinderTests = util.test_both(
+ FinderTests, importlib=importlib, machinery=machinery)
-def test_main():
- from test.support import run_unittest
- run_unittest(FinderTests)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/test_relative_imports.py b/Lib/test/test_importlib/import_/test_relative_imports.py
index 4569c26424..fc6d25a5e5 100644
--- a/Lib/test/test_importlib/import_/test_relative_imports.py
+++ b/Lib/test/test_importlib/import_/test_relative_imports.py
@@ -4,7 +4,7 @@ from . import util as import_util
import sys
import unittest
-class RelativeImports(unittest.TestCase):
+class RelativeImports:
"""PEP 328 introduced relative imports. This allows for imports to occur
from within a package without having to specify the actual package name.
@@ -76,8 +76,8 @@ class RelativeImports(unittest.TestCase):
create = 'pkg.__init__', 'pkg.mod2'
globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.mod1'}
def callback(global_):
- import_util.import_('pkg') # For __import__().
- module = import_util.import_('', global_, fromlist=['mod2'], level=1)
+ self.__import__('pkg') # For __import__().
+ module = self.__import__('', global_, fromlist=['mod2'], level=1)
self.assertEqual(module.__name__, 'pkg')
self.assertTrue(hasattr(module, 'mod2'))
self.assertEqual(module.mod2.attr, 'pkg.mod2')
@@ -88,8 +88,8 @@ class RelativeImports(unittest.TestCase):
create = 'pkg.__init__', 'pkg.mod2'
globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.mod1'}
def callback(global_):
- import_util.import_('pkg') # For __import__().
- module = import_util.import_('mod2', global_, fromlist=['attr'],
+ self.__import__('pkg') # For __import__().
+ module = self.__import__('mod2', global_, fromlist=['attr'],
level=1)
self.assertEqual(module.__name__, 'pkg.mod2')
self.assertEqual(module.attr, 'pkg.mod2')
@@ -101,8 +101,8 @@ class RelativeImports(unittest.TestCase):
globals_ = ({'__package__': 'pkg'},
{'__name__': 'pkg', '__path__': ['blah']})
def callback(global_):
- import_util.import_('pkg') # For __import__().
- module = import_util.import_('', global_, fromlist=['module'],
+ self.__import__('pkg') # For __import__().
+ module = self.__import__('', global_, fromlist=['module'],
level=1)
self.assertEqual(module.__name__, 'pkg')
self.assertTrue(hasattr(module, 'module'))
@@ -114,8 +114,8 @@ class RelativeImports(unittest.TestCase):
create = 'pkg.__init__', 'pkg.module'
globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.module'}
def callback(global_):
- import_util.import_('pkg') # For __import__().
- module = import_util.import_('', global_, fromlist=['attr'], level=1)
+ self.__import__('pkg') # For __import__().
+ module = self.__import__('', global_, fromlist=['attr'], level=1)
self.assertEqual(module.__name__, 'pkg')
self.relative_import_test(create, globals_, callback)
@@ -126,7 +126,7 @@ class RelativeImports(unittest.TestCase):
globals_ = ({'__package__': 'pkg.subpkg1'},
{'__name__': 'pkg.subpkg1', '__path__': ['blah']})
def callback(global_):
- module = import_util.import_('', global_, fromlist=['subpkg2'],
+ module = self.__import__('', global_, fromlist=['subpkg2'],
level=2)
self.assertEqual(module.__name__, 'pkg')
self.assertTrue(hasattr(module, 'subpkg2'))
@@ -142,8 +142,8 @@ class RelativeImports(unittest.TestCase):
{'__name__': 'pkg.pkg1.pkg2.pkg3.pkg4.pkg5',
'__path__': ['blah']})
def callback(global_):
- import_util.import_(globals_[0]['__package__'])
- module = import_util.import_('', global_, fromlist=['attr'], level=6)
+ self.__import__(globals_[0]['__package__'])
+ module = self.__import__('', global_, fromlist=['attr'], level=6)
self.assertEqual(module.__name__, 'pkg')
self.relative_import_test(create, globals_, callback)
@@ -153,9 +153,9 @@ class RelativeImports(unittest.TestCase):
globals_ = ({'__package__': 'pkg'},
{'__name__': 'pkg', '__path__': ['blah']})
def callback(global_):
- import_util.import_('pkg')
+ self.__import__('pkg')
with self.assertRaises(ValueError):
- import_util.import_('', global_, fromlist=['top_level'],
+ self.__import__('', global_, fromlist=['top_level'],
level=2)
self.relative_import_test(create, globals_, callback)
@@ -164,16 +164,16 @@ class RelativeImports(unittest.TestCase):
create = ['top_level', 'pkg.__init__', 'pkg.module']
globals_ = {'__package__': 'pkg'}, {'__name__': 'pkg.module'}
def callback(global_):
- import_util.import_('pkg')
+ self.__import__('pkg')
with self.assertRaises(ValueError):
- import_util.import_('', global_, fromlist=['top_level'],
+ self.__import__('', global_, fromlist=['top_level'],
level=2)
self.relative_import_test(create, globals_, callback)
def test_empty_name_w_level_0(self):
# [empty name]
with self.assertRaises(ValueError):
- import_util.import_('')
+ self.__import__('')
def test_import_from_different_package(self):
# Test importing from a different package than the caller.
@@ -186,8 +186,8 @@ class RelativeImports(unittest.TestCase):
'__runpy_pkg__.uncle.cousin.nephew']
globals_ = {'__package__': '__runpy_pkg__.__runpy_pkg__'}
def callback(global_):
- import_util.import_('__runpy_pkg__.__runpy_pkg__')
- module = import_util.import_('uncle.cousin', globals_, {},
+ self.__import__('__runpy_pkg__.__runpy_pkg__')
+ module = self.__import__('uncle.cousin', globals_, {},
fromlist=['nephew'],
level=2)
self.assertEqual(module.__name__, '__runpy_pkg__.uncle.cousin')
@@ -198,20 +198,19 @@ class RelativeImports(unittest.TestCase):
create = ['crash.__init__', 'crash.mod']
globals_ = [{'__package__': 'crash', '__name__': 'crash'}]
def callback(global_):
- import_util.import_('crash')
- mod = import_util.import_('mod', global_, {}, [], 1)
+ self.__import__('crash')
+ mod = self.__import__('mod', global_, {}, [], 1)
self.assertEqual(mod.__name__, 'crash.mod')
self.relative_import_test(create, globals_, callback)
def test_relative_import_no_globals(self):
# No globals for a relative import is an error.
with self.assertRaises(KeyError):
- import_util.import_('sys', level=1)
+ self.__import__('sys', level=1)
+Frozen_RelativeImports, Source_RelativeImports = util.test_both(
+ RelativeImports, __import__=import_util.__import__)
-def test_main():
- from test.support import run_unittest
- run_unittest(RelativeImports)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/import_/util.py b/Lib/test/test_importlib/import_/util.py
index 86ac065e64..dcb490f967 100644
--- a/Lib/test/test_importlib/import_/util.py
+++ b/Lib/test/test_importlib/import_/util.py
@@ -1,22 +1,14 @@
+from .. import util
+
+frozen_importlib, source_importlib = util.import_importlib('importlib')
+
+import builtins
import functools
import importlib
import unittest
-using___import__ = False
-
-
-def import_(*args, **kwargs):
- """Delegate to allow for injecting different implementations of import."""
- if using___import__:
- return __import__(*args, **kwargs)
- else:
- return importlib.__import__(*args, **kwargs)
-
-
-def importlib_only(fxn):
- """Decorator to skip a test if using __builtins__.__import__."""
- return unittest.skipIf(using___import__, "importlib-specific test")(fxn)
+__import__ = staticmethod(builtins.__import__), staticmethod(source_importlib.__import__)
def mock_path_hook(*entries, importer):
diff --git a/Lib/test/test_importlib/source/test_abc_loader.py b/Lib/test/test_importlib/source/test_abc_loader.py
deleted file mode 100644
index 0d912b6469..0000000000
--- a/Lib/test/test_importlib/source/test_abc_loader.py
+++ /dev/null
@@ -1,906 +0,0 @@
-import importlib
-from importlib import abc
-
-from .. import abc as testing_abc
-from .. import util
-from . import util as source_util
-
-import imp
-import inspect
-import io
-import marshal
-import os
-import sys
-import types
-import unittest
-import warnings
-
-
-class SourceOnlyLoaderMock(abc.SourceLoader):
-
- # Globals that should be defined for all modules.
- source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, "
- b"repr(__loader__)])")
-
- def __init__(self, path):
- self.path = path
-
- def get_data(self, path):
- assert self.path == path
- return self.source
-
- def get_filename(self, fullname):
- return self.path
-
- def module_repr(self, module):
- return '<module>'
-
-
-class SourceLoaderMock(SourceOnlyLoaderMock):
-
- source_mtime = 1
-
- def __init__(self, path, magic=imp.get_magic()):
- super().__init__(path)
- self.bytecode_path = imp.cache_from_source(self.path)
- self.source_size = len(self.source)
- data = bytearray(magic)
- data.extend(importlib._w_long(self.source_mtime))
- data.extend(importlib._w_long(self.source_size))
- code_object = compile(self.source, self.path, 'exec',
- dont_inherit=True)
- data.extend(marshal.dumps(code_object))
- self.bytecode = bytes(data)
- self.written = {}
-
- def get_data(self, path):
- if path == self.path:
- return super().get_data(path)
- elif path == self.bytecode_path:
- return self.bytecode
- else:
- raise IOError
-
- def path_stats(self, path):
- assert path == self.path
- return {'mtime': self.source_mtime, 'size': self.source_size}
-
- def set_data(self, path, data):
- self.written[path] = bytes(data)
- return path == self.bytecode_path
-
-
-class PyLoaderMock(abc.PyLoader):
-
- # Globals that should be defined for all modules.
- source = (b"_ = '::'.join([__name__, __file__, __package__, "
- b"repr(__loader__)])")
-
- def __init__(self, data):
- """Take a dict of 'module_name: path' pairings.
-
- Paths should have no file extension, allowing packages to be denoted by
- ending in '__init__'.
-
- """
- self.module_paths = data
- self.path_to_module = {val:key for key,val in data.items()}
-
- def get_data(self, path):
- if path not in self.path_to_module:
- raise IOError
- return self.source
-
- def is_package(self, name):
- filename = os.path.basename(self.get_filename(name))
- return os.path.splitext(filename)[0] == '__init__'
-
- def source_path(self, name):
- try:
- return self.module_paths[name]
- except KeyError:
- raise ImportError
-
- def get_filename(self, name):
- """Silence deprecation warning."""
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- path = super().get_filename(name)
- assert len(w) == 1
- assert issubclass(w[0].category, DeprecationWarning)
- return path
-
- def module_repr(self):
- return '<module>'
-
-
-class PyLoaderCompatMock(PyLoaderMock):
-
- """Mock that matches what is suggested to have a loader that is compatible
- from Python 3.1 onwards."""
-
- def get_filename(self, fullname):
- try:
- return self.module_paths[fullname]
- except KeyError:
- raise ImportError
-
- def source_path(self, fullname):
- try:
- return self.get_filename(fullname)
- except ImportError:
- return None
-
-
-class PyPycLoaderMock(abc.PyPycLoader, PyLoaderMock):
-
- default_mtime = 1
-
- def __init__(self, source, bc={}):
- """Initialize mock.
-
- 'bc' is a dict keyed on a module's name. The value is dict with
- possible keys of 'path', 'mtime', 'magic', and 'bc'. Except for 'path',
- each of those keys control if any part of created bytecode is to
- deviate from default values.
-
- """
- super().__init__(source)
- self.module_bytecode = {}
- self.path_to_bytecode = {}
- self.bytecode_to_path = {}
- for name, data in bc.items():
- self.path_to_bytecode[data['path']] = name
- self.bytecode_to_path[name] = data['path']
- magic = data.get('magic', imp.get_magic())
- mtime = importlib._w_long(data.get('mtime', self.default_mtime))
- source_size = importlib._w_long(len(self.source) & 0xFFFFFFFF)
- if 'bc' in data:
- bc = data['bc']
- else:
- bc = self.compile_bc(name)
- self.module_bytecode[name] = magic + mtime + source_size + bc
-
- def compile_bc(self, name):
- source_path = self.module_paths.get(name, '<test>') or '<test>'
- code = compile(self.source, source_path, 'exec')
- return marshal.dumps(code)
-
- def source_mtime(self, name):
- if name in self.module_paths:
- return self.default_mtime
- elif name in self.module_bytecode:
- return None
- else:
- raise ImportError
-
- def bytecode_path(self, name):
- try:
- return self.bytecode_to_path[name]
- except KeyError:
- if name in self.module_paths:
- return None
- else:
- raise ImportError
-
- def write_bytecode(self, name, bytecode):
- self.module_bytecode[name] = bytecode
- return True
-
- def get_data(self, path):
- if path in self.path_to_module:
- return super().get_data(path)
- elif path in self.path_to_bytecode:
- name = self.path_to_bytecode[path]
- return self.module_bytecode[name]
- else:
- raise IOError
-
- def is_package(self, name):
- try:
- return super().is_package(name)
- except TypeError:
- return '__init__' in self.bytecode_to_path[name]
-
- def get_code(self, name):
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- code_object = super().get_code(name)
- assert len(w) == 1
- assert issubclass(w[0].category, DeprecationWarning)
- return code_object
-
-class PyLoaderTests(testing_abc.LoaderTests):
-
- """Tests for importlib.abc.PyLoader."""
-
- mocker = PyLoaderMock
-
- def eq_attrs(self, ob, **kwargs):
- for attr, val in kwargs.items():
- found = getattr(ob, attr)
- self.assertEqual(found, val,
- "{} attribute: {} != {}".format(attr, found, val))
-
- def test_module(self):
- name = '<module>'
- path = os.path.join('', 'path', 'to', 'module')
- mock = self.mocker({name: path})
- with util.uncache(name):
- module = mock.load_module(name)
- self.assertIn(name, sys.modules)
- self.eq_attrs(module, __name__=name, __file__=path, __package__='',
- __loader__=mock)
- self.assertTrue(not hasattr(module, '__path__'))
- return mock, name
-
- def test_package(self):
- name = '<pkg>'
- path = os.path.join('path', 'to', name, '__init__')
- mock = self.mocker({name: path})
- with util.uncache(name):
- module = mock.load_module(name)
- self.assertIn(name, sys.modules)
- self.eq_attrs(module, __name__=name, __file__=path,
- __path__=[os.path.dirname(path)], __package__=name,
- __loader__=mock)
- return mock, name
-
- def test_lacking_parent(self):
- name = 'pkg.mod'
- path = os.path.join('path', 'to', 'pkg', 'mod')
- mock = self.mocker({name: path})
- with util.uncache(name):
- module = mock.load_module(name)
- self.assertIn(name, sys.modules)
- self.eq_attrs(module, __name__=name, __file__=path, __package__='pkg',
- __loader__=mock)
- self.assertFalse(hasattr(module, '__path__'))
- return mock, name
-
- def test_module_reuse(self):
- name = 'mod'
- path = os.path.join('path', 'to', 'mod')
- module = imp.new_module(name)
- mock = self.mocker({name: path})
- with util.uncache(name):
- sys.modules[name] = module
- loaded_module = mock.load_module(name)
- self.assertIs(loaded_module, module)
- self.assertIs(sys.modules[name], module)
- return mock, name
-
- def test_state_after_failure(self):
- name = "mod"
- module = imp.new_module(name)
- module.blah = None
- mock = self.mocker({name: os.path.join('path', 'to', 'mod')})
- mock.source = b"1/0"
- with util.uncache(name):
- sys.modules[name] = module
- with self.assertRaises(ZeroDivisionError):
- mock.load_module(name)
- self.assertIs(sys.modules[name], module)
- self.assertTrue(hasattr(module, 'blah'))
- return mock
-
- def test_unloadable(self):
- name = "mod"
- mock = self.mocker({name: os.path.join('path', 'to', 'mod')})
- mock.source = b"1/0"
- with util.uncache(name):
- with self.assertRaises(ZeroDivisionError):
- mock.load_module(name)
- self.assertNotIn(name, sys.modules)
- return mock
-
-
-class PyLoaderCompatTests(PyLoaderTests):
-
- """Test that the suggested code to make a loader that is compatible from
- Python 3.1 forward works."""
-
- mocker = PyLoaderCompatMock
-
-
-class PyLoaderInterfaceTests(unittest.TestCase):
-
- """Tests for importlib.abc.PyLoader to make sure that when source_path()
- doesn't return a path everything works as expected."""
-
- def test_no_source_path(self):
- # No source path should lead to ImportError.
- name = 'mod'
- mock = PyLoaderMock({})
- with util.uncache(name), self.assertRaises(ImportError):
- mock.load_module(name)
-
- def test_source_path_is_None(self):
- name = 'mod'
- mock = PyLoaderMock({name: None})
- with util.uncache(name), self.assertRaises(ImportError):
- mock.load_module(name)
-
- def test_get_filename_with_source_path(self):
- # get_filename() should return what source_path() returns.
- name = 'mod'
- path = os.path.join('path', 'to', 'source')
- mock = PyLoaderMock({name: path})
- with util.uncache(name):
- self.assertEqual(mock.get_filename(name), path)
-
- def test_get_filename_no_source_path(self):
- # get_filename() should raise ImportError if source_path returns None.
- name = 'mod'
- mock = PyLoaderMock({name: None})
- with util.uncache(name), self.assertRaises(ImportError):
- mock.get_filename(name)
-
-
-class PyPycLoaderTests(PyLoaderTests):
-
- """Tests for importlib.abc.PyPycLoader."""
-
- mocker = PyPycLoaderMock
-
- @source_util.writes_bytecode_files
- def verify_bytecode(self, mock, name):
- assert name in mock.module_paths
- self.assertIn(name, mock.module_bytecode)
- magic = mock.module_bytecode[name][:4]
- self.assertEqual(magic, imp.get_magic())
- mtime = importlib._r_long(mock.module_bytecode[name][4:8])
- self.assertEqual(mtime, 1)
- source_size = mock.module_bytecode[name][8:12]
- self.assertEqual(len(mock.source) & 0xFFFFFFFF,
- importlib._r_long(source_size))
- bc = mock.module_bytecode[name][12:]
- self.assertEqual(bc, mock.compile_bc(name))
-
- def test_module(self):
- mock, name = super().test_module()
- self.verify_bytecode(mock, name)
-
- def test_package(self):
- mock, name = super().test_package()
- self.verify_bytecode(mock, name)
-
- def test_lacking_parent(self):
- mock, name = super().test_lacking_parent()
- self.verify_bytecode(mock, name)
-
- def test_module_reuse(self):
- mock, name = super().test_module_reuse()
- self.verify_bytecode(mock, name)
-
- def test_state_after_failure(self):
- super().test_state_after_failure()
-
- def test_unloadable(self):
- super().test_unloadable()
-
-
-class PyPycLoaderInterfaceTests(unittest.TestCase):
-
- """Test for the interface of importlib.abc.PyPycLoader."""
-
- def get_filename_check(self, src_path, bc_path, expect):
- name = 'mod'
- mock = PyPycLoaderMock({name: src_path}, {name: {'path': bc_path}})
- with util.uncache(name):
- assert mock.source_path(name) == src_path
- assert mock.bytecode_path(name) == bc_path
- self.assertEqual(mock.get_filename(name), expect)
-
- def test_filename_with_source_bc(self):
- # When source and bytecode paths present, return the source path.
- self.get_filename_check('source_path', 'bc_path', 'source_path')
-
- def test_filename_with_source_no_bc(self):
- # With source but no bc, return source path.
- self.get_filename_check('source_path', None, 'source_path')
-
- def test_filename_with_no_source_bc(self):
- # With not source but bc, return the bc path.
- self.get_filename_check(None, 'bc_path', 'bc_path')
-
- def test_filename_with_no_source_or_bc(self):
- # With no source or bc, raise ImportError.
- name = 'mod'
- mock = PyPycLoaderMock({name: None}, {name: {'path': None}})
- with util.uncache(name), self.assertRaises(ImportError):
- mock.get_filename(name)
-
-
-class SkipWritingBytecodeTests(unittest.TestCase):
-
- """Test that bytecode is properly handled based on
- sys.dont_write_bytecode."""
-
- @source_util.writes_bytecode_files
- def run_test(self, dont_write_bytecode):
- name = 'mod'
- mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')})
- sys.dont_write_bytecode = dont_write_bytecode
- with util.uncache(name):
- mock.load_module(name)
- self.assertIsNot(name in mock.module_bytecode, dont_write_bytecode)
-
- def test_no_bytecode_written(self):
- self.run_test(True)
-
- def test_bytecode_written(self):
- self.run_test(False)
-
-
-class RegeneratedBytecodeTests(unittest.TestCase):
-
- """Test that bytecode is regenerated as expected."""
-
- @source_util.writes_bytecode_files
- def test_different_magic(self):
- # A different magic number should lead to new bytecode.
- name = 'mod'
- bad_magic = b'\x00\x00\x00\x00'
- assert bad_magic != imp.get_magic()
- mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')},
- {name: {'path': os.path.join('path', 'to',
- 'mod.bytecode'),
- 'magic': bad_magic}})
- with util.uncache(name):
- mock.load_module(name)
- self.assertIn(name, mock.module_bytecode)
- magic = mock.module_bytecode[name][:4]
- self.assertEqual(magic, imp.get_magic())
-
- @source_util.writes_bytecode_files
- def test_old_mtime(self):
- # Bytecode with an older mtime should be regenerated.
- name = 'mod'
- old_mtime = PyPycLoaderMock.default_mtime - 1
- mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')},
- {name: {'path': 'path/to/mod.bytecode', 'mtime': old_mtime}})
- with util.uncache(name):
- mock.load_module(name)
- self.assertIn(name, mock.module_bytecode)
- mtime = importlib._r_long(mock.module_bytecode[name][4:8])
- self.assertEqual(mtime, PyPycLoaderMock.default_mtime)
-
-
-class BadBytecodeFailureTests(unittest.TestCase):
-
- """Test import failures when there is no source and parts of the bytecode
- is bad."""
-
- def test_bad_magic(self):
- # A bad magic number should lead to an ImportError.
- name = 'mod'
- bad_magic = b'\x00\x00\x00\x00'
- bc = {name:
- {'path': os.path.join('path', 'to', 'mod'),
- 'magic': bad_magic}}
- mock = PyPycLoaderMock({name: None}, bc)
- with util.uncache(name), self.assertRaises(ImportError) as cm:
- mock.load_module(name)
- self.assertEqual(cm.exception.name, name)
-
- def test_no_bytecode(self):
- # Missing code object bytecode should lead to an EOFError.
- name = 'mod'
- bc = {name: {'path': os.path.join('path', 'to', 'mod'), 'bc': b''}}
- mock = PyPycLoaderMock({name: None}, bc)
- with util.uncache(name), self.assertRaises(EOFError):
- mock.load_module(name)
-
- def test_bad_bytecode(self):
- # Malformed code object bytecode should lead to a ValueError.
- name = 'mod'
- bc = {name: {'path': os.path.join('path', 'to', 'mod'), 'bc': b'1234'}}
- mock = PyPycLoaderMock({name: None}, bc)
- with util.uncache(name), self.assertRaises(ValueError):
- mock.load_module(name)
-
-
-def raise_ImportError(*args, **kwargs):
- raise ImportError
-
-class MissingPathsTests(unittest.TestCase):
-
- """Test what happens when a source or bytecode path does not exist (either
- from *_path returning None or raising ImportError)."""
-
- def test_source_path_None(self):
- # Bytecode should be used when source_path returns None, along with
- # __file__ being set to the bytecode path.
- name = 'mod'
- bytecode_path = 'path/to/mod'
- mock = PyPycLoaderMock({name: None}, {name: {'path': bytecode_path}})
- with util.uncache(name):
- module = mock.load_module(name)
- self.assertEqual(module.__file__, bytecode_path)
-
- # Testing for bytecode_path returning None handled by all tests where no
- # bytecode initially exists.
-
- def test_all_paths_None(self):
- # If all *_path methods return None, raise ImportError.
- name = 'mod'
- mock = PyPycLoaderMock({name: None})
- with util.uncache(name), self.assertRaises(ImportError) as cm:
- mock.load_module(name)
- self.assertEqual(cm.exception.name, name)
-
- def test_source_path_ImportError(self):
- # An ImportError from source_path should trigger an ImportError.
- name = 'mod'
- mock = PyPycLoaderMock({}, {name: {'path': os.path.join('path', 'to',
- 'mod')}})
- with util.uncache(name), self.assertRaises(ImportError):
- mock.load_module(name)
-
- def test_bytecode_path_ImportError(self):
- # An ImportError from bytecode_path should trigger an ImportError.
- name = 'mod'
- mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')})
- bad_meth = types.MethodType(raise_ImportError, mock)
- mock.bytecode_path = bad_meth
- with util.uncache(name), self.assertRaises(ImportError) as cm:
- mock.load_module(name)
-
-
-class SourceLoaderTestHarness(unittest.TestCase):
-
- def setUp(self, *, is_package=True, **kwargs):
- self.package = 'pkg'
- if is_package:
- self.path = os.path.join(self.package, '__init__.py')
- self.name = self.package
- else:
- module_name = 'mod'
- self.path = os.path.join(self.package, '.'.join(['mod', 'py']))
- self.name = '.'.join([self.package, module_name])
- self.cached = imp.cache_from_source(self.path)
- self.loader = self.loader_mock(self.path, **kwargs)
-
- def verify_module(self, module):
- self.assertEqual(module.__name__, self.name)
- self.assertEqual(module.__file__, self.path)
- self.assertEqual(module.__cached__, self.cached)
- self.assertEqual(module.__package__, self.package)
- self.assertEqual(module.__loader__, self.loader)
- values = module._.split('::')
- self.assertEqual(values[0], self.name)
- self.assertEqual(values[1], self.path)
- self.assertEqual(values[2], self.cached)
- self.assertEqual(values[3], self.package)
- self.assertEqual(values[4], repr(self.loader))
-
- def verify_code(self, code_object):
- module = imp.new_module(self.name)
- module.__file__ = self.path
- module.__cached__ = self.cached
- module.__package__ = self.package
- module.__loader__ = self.loader
- module.__path__ = []
- exec(code_object, module.__dict__)
- self.verify_module(module)
-
-
-class SourceOnlyLoaderTests(SourceLoaderTestHarness):
-
- """Test importlib.abc.SourceLoader for source-only loading.
-
- Reload testing is subsumed by the tests for
- importlib.util.module_for_loader.
-
- """
-
- loader_mock = SourceOnlyLoaderMock
-
- def test_get_source(self):
- # Verify the source code is returned as a string.
- # If an IOError is raised by get_data then raise ImportError.
- expected_source = self.loader.source.decode('utf-8')
- self.assertEqual(self.loader.get_source(self.name), expected_source)
- def raise_IOError(path):
- raise IOError
- self.loader.get_data = raise_IOError
- with self.assertRaises(ImportError) as cm:
- self.loader.get_source(self.name)
- self.assertEqual(cm.exception.name, self.name)
-
- def test_is_package(self):
- # Properly detect when loading a package.
- self.setUp(is_package=False)
- self.assertFalse(self.loader.is_package(self.name))
- self.setUp(is_package=True)
- self.assertTrue(self.loader.is_package(self.name))
- self.assertFalse(self.loader.is_package(self.name + '.__init__'))
-
- def test_get_code(self):
- # Verify the code object is created.
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object)
-
- def test_load_module(self):
- # Loading a module should set __name__, __loader__, __package__,
- # __path__ (for packages), __file__, and __cached__.
- # The module should also be put into sys.modules.
- with util.uncache(self.name):
- module = self.loader.load_module(self.name)
- self.verify_module(module)
- self.assertEqual(module.__path__, [os.path.dirname(self.path)])
- self.assertIn(self.name, sys.modules)
-
- def test_package_settings(self):
- # __package__ needs to be set, while __path__ is set on if the module
- # is a package.
- # Testing the values for a package are covered by test_load_module.
- self.setUp(is_package=False)
- with util.uncache(self.name):
- module = self.loader.load_module(self.name)
- self.verify_module(module)
- self.assertTrue(not hasattr(module, '__path__'))
-
- def test_get_source_encoding(self):
- # Source is considered encoded in UTF-8 by default unless otherwise
- # specified by an encoding line.
- source = "_ = 'ü'"
- self.loader.source = source.encode('utf-8')
- returned_source = self.loader.get_source(self.name)
- self.assertEqual(returned_source, source)
- source = "# coding: latin-1\n_ = ü"
- self.loader.source = source.encode('latin-1')
- returned_source = self.loader.get_source(self.name)
- self.assertEqual(returned_source, source)
-
-
-@unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true")
-class SourceLoaderBytecodeTests(SourceLoaderTestHarness):
-
- """Test importlib.abc.SourceLoader's use of bytecode.
-
- Source-only testing handled by SourceOnlyLoaderTests.
-
- """
-
- loader_mock = SourceLoaderMock
-
- def verify_code(self, code_object, *, bytecode_written=False):
- super().verify_code(code_object)
- if bytecode_written:
- self.assertIn(self.cached, self.loader.written)
- data = bytearray(imp.get_magic())
- data.extend(importlib._w_long(self.loader.source_mtime))
- data.extend(importlib._w_long(self.loader.source_size))
- data.extend(marshal.dumps(code_object))
- self.assertEqual(self.loader.written[self.cached], bytes(data))
-
- def test_code_with_everything(self):
- # When everything should work.
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object)
-
- def test_no_bytecode(self):
- # If no bytecode exists then move on to the source.
- self.loader.bytecode_path = "<does not exist>"
- # Sanity check
- with self.assertRaises(IOError):
- bytecode_path = imp.cache_from_source(self.path)
- self.loader.get_data(bytecode_path)
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object, bytecode_written=True)
-
- def test_code_bad_timestamp(self):
- # Bytecode is only used when the timestamp matches the source EXACTLY.
- for source_mtime in (0, 2):
- assert source_mtime != self.loader.source_mtime
- original = self.loader.source_mtime
- self.loader.source_mtime = source_mtime
- # If bytecode is used then EOFError would be raised by marshal.
- self.loader.bytecode = self.loader.bytecode[8:]
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object, bytecode_written=True)
- self.loader.source_mtime = original
-
- def test_code_bad_magic(self):
- # Skip over bytecode with a bad magic number.
- self.setUp(magic=b'0000')
- # If bytecode is used then EOFError would be raised by marshal.
- self.loader.bytecode = self.loader.bytecode[8:]
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object, bytecode_written=True)
-
- def test_dont_write_bytecode(self):
- # Bytecode is not written if sys.dont_write_bytecode is true.
- # Can assume it is false already thanks to the skipIf class decorator.
- try:
- sys.dont_write_bytecode = True
- self.loader.bytecode_path = "<does not exist>"
- code_object = self.loader.get_code(self.name)
- self.assertNotIn(self.cached, self.loader.written)
- finally:
- sys.dont_write_bytecode = False
-
- def test_no_set_data(self):
- # If set_data is not defined, one can still read bytecode.
- self.setUp(magic=b'0000')
- original_set_data = self.loader.__class__.set_data
- try:
- del self.loader.__class__.set_data
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object)
- finally:
- self.loader.__class__.set_data = original_set_data
-
- def test_set_data_raises_exceptions(self):
- # Raising NotImplementedError or IOError is okay for set_data.
- def raise_exception(exc):
- def closure(*args, **kwargs):
- raise exc
- return closure
-
- self.setUp(magic=b'0000')
- self.loader.set_data = raise_exception(NotImplementedError)
- code_object = self.loader.get_code(self.name)
- self.verify_code(code_object)
-
-
-class SourceLoaderGetSourceTests(unittest.TestCase):
-
- """Tests for importlib.abc.SourceLoader.get_source()."""
-
- def test_default_encoding(self):
- # Should have no problems with UTF-8 text.
- name = 'mod'
- mock = SourceOnlyLoaderMock('mod.file')
- source = 'x = "ü"'
- mock.source = source.encode('utf-8')
- returned_source = mock.get_source(name)
- self.assertEqual(returned_source, source)
-
- def test_decoded_source(self):
- # Decoding should work.
- name = 'mod'
- mock = SourceOnlyLoaderMock("mod.file")
- source = "# coding: Latin-1\nx='ü'"
- assert source.encode('latin-1') != source.encode('utf-8')
- mock.source = source.encode('latin-1')
- returned_source = mock.get_source(name)
- self.assertEqual(returned_source, source)
-
- def test_universal_newlines(self):
- # PEP 302 says universal newlines should be used.
- name = 'mod'
- mock = SourceOnlyLoaderMock('mod.file')
- source = "x = 42\r\ny = -13\r\n"
- mock.source = source.encode('utf-8')
- expect = io.IncrementalNewlineDecoder(None, True).decode(source)
- self.assertEqual(mock.get_source(name), expect)
-
-
-class AbstractMethodImplTests(unittest.TestCase):
-
- """Test the concrete abstractmethod implementations."""
-
- class MetaPathFinder(abc.MetaPathFinder):
- def find_module(self, fullname, path):
- super().find_module(fullname, path)
-
- class PathEntryFinder(abc.PathEntryFinder):
- def find_module(self, _):
- super().find_module(_)
-
- def find_loader(self, _):
- super().find_loader(_)
-
- class Finder(abc.Finder):
- def find_module(self, fullname, path):
- super().find_module(fullname, path)
-
- class Loader(abc.Loader):
- def load_module(self, fullname):
- super().load_module(fullname)
- def module_repr(self, module):
- super().module_repr(module)
-
- class ResourceLoader(Loader, abc.ResourceLoader):
- def get_data(self, _):
- super().get_data(_)
-
- class InspectLoader(Loader, abc.InspectLoader):
- def is_package(self, _):
- super().is_package(_)
-
- def get_code(self, _):
- super().get_code(_)
-
- def get_source(self, _):
- super().get_source(_)
-
- class ExecutionLoader(InspectLoader, abc.ExecutionLoader):
- def get_filename(self, _):
- super().get_filename(_)
-
- class SourceLoader(ResourceLoader, ExecutionLoader, abc.SourceLoader):
- pass
-
- class PyLoader(ResourceLoader, InspectLoader, abc.PyLoader):
- def source_path(self, _):
- super().source_path(_)
-
- class PyPycLoader(PyLoader, abc.PyPycLoader):
- def bytecode_path(self, _):
- super().bytecode_path(_)
-
- def source_mtime(self, _):
- super().source_mtime(_)
-
- def write_bytecode(self, _, _2):
- super().write_bytecode(_, _2)
-
- def raises_NotImplementedError(self, ins, *args):
- for method_name in args:
- method = getattr(ins, method_name)
- arg_count = len(inspect.getfullargspec(method)[0]) - 1
- args = [''] * arg_count
- try:
- method(*args)
- except NotImplementedError:
- pass
- else:
- msg = "{}.{} did not raise NotImplementedError"
- self.fail(msg.format(ins.__class__.__name__, method_name))
-
- def test_Loader(self):
- self.raises_NotImplementedError(self.Loader(), 'load_module')
-
- # XXX misplaced; should be somewhere else
- def test_Finder(self):
- self.raises_NotImplementedError(self.Finder(), 'find_module')
-
- def test_ResourceLoader(self):
- self.raises_NotImplementedError(self.ResourceLoader(), 'load_module',
- 'get_data')
-
- def test_InspectLoader(self):
- self.raises_NotImplementedError(self.InspectLoader(), 'load_module',
- 'is_package', 'get_code', 'get_source')
-
- def test_ExecutionLoader(self):
- self.raises_NotImplementedError(self.ExecutionLoader(), 'load_module',
- 'is_package', 'get_code', 'get_source',
- 'get_filename')
-
- def test_SourceLoader(self):
- ins = self.SourceLoader()
- # Required abstractmethods.
- self.raises_NotImplementedError(ins, 'get_filename', 'get_data')
- # Optional abstractmethods.
- self.raises_NotImplementedError(ins,'path_stats', 'set_data')
-
- def test_PyLoader(self):
- self.raises_NotImplementedError(self.PyLoader(), 'source_path',
- 'get_data', 'is_package')
-
- def test_PyPycLoader(self):
- self.raises_NotImplementedError(self.PyPycLoader(), 'source_path',
- 'source_mtime', 'bytecode_path',
- 'write_bytecode')
-
-
-def test_main():
- from test.support import run_unittest
- run_unittest(PyLoaderTests, PyLoaderCompatTests,
- PyLoaderInterfaceTests,
- PyPycLoaderTests, PyPycLoaderInterfaceTests,
- SkipWritingBytecodeTests, RegeneratedBytecodeTests,
- BadBytecodeFailureTests, MissingPathsTests,
- SourceOnlyLoaderTests,
- SourceLoaderBytecodeTests,
- SourceLoaderGetSourceTests,
- AbstractMethodImplTests)
-
-
-if __name__ == '__main__':
- test_main()
diff --git a/Lib/test/test_importlib/source/test_case_sensitivity.py b/Lib/test/test_importlib/source/test_case_sensitivity.py
index 241173fb44..b3e9d25bb2 100644
--- a/Lib/test/test_importlib/source/test_case_sensitivity.py
+++ b/Lib/test/test_importlib/source/test_case_sensitivity.py
@@ -1,9 +1,10 @@
"""Test case-sensitivity (PEP 235)."""
-from importlib import _bootstrap
-from importlib import machinery
from .. import util
from . import util as source_util
-import imp
+
+importlib = util.import_importlib('importlib')
+machinery = util.import_importlib('importlib.machinery')
+
import os
import sys
from test import support as test_support
@@ -11,7 +12,7 @@ import unittest
@util.case_insensitive_tests
-class CaseSensitivityTest(unittest.TestCase):
+class CaseSensitivityTest:
"""PEP 235 dictates that on case-preserving, case-insensitive file systems
that imports are case-sensitive unless the PYTHONCASEOK environment
@@ -21,11 +22,11 @@ class CaseSensitivityTest(unittest.TestCase):
assert name != name.lower()
def find(self, path):
- finder = machinery.FileFinder(path,
- (machinery.SourceFileLoader,
- machinery.SOURCE_SUFFIXES),
- (machinery.SourcelessFileLoader,
- machinery.BYTECODE_SUFFIXES))
+ finder = self.machinery.FileFinder(path,
+ (self.machinery.SourceFileLoader,
+ self.machinery.SOURCE_SUFFIXES),
+ (self.machinery.SourcelessFileLoader,
+ self.machinery.BYTECODE_SUFFIXES))
return finder.find_module(self.name)
def sensitivity_test(self):
@@ -41,7 +42,7 @@ class CaseSensitivityTest(unittest.TestCase):
def test_sensitive(self):
with test_support.EnvironmentVarGuard() as env:
env.unset('PYTHONCASEOK')
- if b'PYTHONCASEOK' in _bootstrap._os.environ:
+ if b'PYTHONCASEOK' in self.importlib._bootstrap._os.environ:
self.skipTest('os.environ changes not reflected in '
'_os.environ')
sensitive, insensitive = self.sensitivity_test()
@@ -52,7 +53,7 @@ class CaseSensitivityTest(unittest.TestCase):
def test_insensitive(self):
with test_support.EnvironmentVarGuard() as env:
env.set('PYTHONCASEOK', '1')
- if b'PYTHONCASEOK' not in _bootstrap._os.environ:
+ if b'PYTHONCASEOK' not in self.importlib._bootstrap._os.environ:
self.skipTest('os.environ changes not reflected in '
'_os.environ')
sensitive, insensitive = self.sensitivity_test()
@@ -61,10 +62,9 @@ class CaseSensitivityTest(unittest.TestCase):
self.assertTrue(hasattr(insensitive, 'load_module'))
self.assertIn(self.name, insensitive.get_filename(self.name))
-
-def test_main():
- test_support.run_unittest(CaseSensitivityTest)
+Frozen_CaseSensitivityTest, Source_CaseSensitivityTest = util.test_both(
+ CaseSensitivityTest, importlib=importlib, machinery=machinery)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py
index d59969a8ce..f1e2713a97 100644
--- a/Lib/test/test_importlib/source/test_file_loader.py
+++ b/Lib/test/test_importlib/source/test_file_loader.py
@@ -1,24 +1,26 @@
-from importlib import machinery
-import importlib
-import importlib.abc
from .. import abc
from .. import util
from . import util as source_util
+importlib = util.import_importlib('importlib')
+importlib_abc = util.import_importlib('importlib.abc')
+machinery = util.import_importlib('importlib.machinery')
+importlib_util = util.import_importlib('importlib.util')
+
import errno
-import imp
import marshal
import os
import py_compile
import shutil
import stat
import sys
+import types
import unittest
-from test.support import make_legacy_pyc
+from test.support import make_legacy_pyc, unload
-class SimpleTest(unittest.TestCase):
+class SimpleTest(abc.LoaderTests):
"""Should have no issue importing a source module [basic]. And if there is
a syntax error, it should raise a SyntaxError [syntax error].
@@ -26,27 +28,17 @@ class SimpleTest(unittest.TestCase):
"""
def test_load_module_API(self):
- # If fullname is not specified that assume self.name is desired.
- class TesterMixin(importlib.abc.Loader):
- def load_module(self, fullname): return fullname
- def module_repr(self, module): return '<module>'
-
- class Tester(importlib.abc.FileLoader, TesterMixin):
- def get_code(self, _): pass
- def get_source(self, _): pass
- def is_package(self, _): pass
+ class Tester(self.abc.FileLoader):
+ def get_source(self, _): return 'attr = 42'
+ def is_package(self, _): return False
- name = 'mod_name'
- loader = Tester(name, 'some_path')
- self.assertEqual(name, loader.load_module())
- self.assertEqual(name, loader.load_module(None))
- self.assertEqual(name, loader.load_module(name))
- with self.assertRaises(ImportError):
- loader.load_module(loader.name + 'XXX')
+ loader = Tester('blah', 'blah.py')
+ self.addCleanup(unload, 'blah')
+ module = loader.load_module() # Should not raise an exception.
def test_get_filename_API(self):
# If fullname is not set then assume self.path is desired.
- class Tester(importlib.abc.FileLoader):
+ class Tester(self.abc.FileLoader):
def get_code(self, _): pass
def get_source(self, _): pass
def is_package(self, _): pass
@@ -64,7 +56,7 @@ class SimpleTest(unittest.TestCase):
# [basic]
def test_module(self):
with source_util.create_modules('_temp') as mapping:
- loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ loader = self.machinery.SourceFileLoader('_temp', mapping['_temp'])
module = loader.load_module('_temp')
self.assertIn('_temp', sys.modules)
check = {'__name__': '_temp', '__file__': mapping['_temp'],
@@ -74,7 +66,7 @@ class SimpleTest(unittest.TestCase):
def test_package(self):
with source_util.create_modules('_pkg.__init__') as mapping:
- loader = machinery.SourceFileLoader('_pkg',
+ loader = self.machinery.SourceFileLoader('_pkg',
mapping['_pkg.__init__'])
module = loader.load_module('_pkg')
self.assertIn('_pkg', sys.modules)
@@ -87,7 +79,7 @@ class SimpleTest(unittest.TestCase):
def test_lacking_parent(self):
with source_util.create_modules('_pkg.__init__', '_pkg.mod')as mapping:
- loader = machinery.SourceFileLoader('_pkg.mod',
+ loader = self.machinery.SourceFileLoader('_pkg.mod',
mapping['_pkg.mod'])
module = loader.load_module('_pkg.mod')
self.assertIn('_pkg.mod', sys.modules)
@@ -102,7 +94,7 @@ class SimpleTest(unittest.TestCase):
def test_module_reuse(self):
with source_util.create_modules('_temp') as mapping:
- loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ loader = self.machinery.SourceFileLoader('_temp', mapping['_temp'])
module = loader.load_module('_temp')
module_id = id(module)
module_dict_id = id(module.__dict__)
@@ -122,12 +114,12 @@ class SimpleTest(unittest.TestCase):
value = '<test>'
name = '_temp'
with source_util.create_modules(name) as mapping:
- orig_module = imp.new_module(name)
+ orig_module = types.ModuleType(name)
for attr in attributes:
setattr(orig_module, attr, value)
with open(mapping[name], 'w') as file:
file.write('+++ bad syntax +++')
- loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ loader = self.machinery.SourceFileLoader('_temp', mapping['_temp'])
with self.assertRaises(SyntaxError):
loader.load_module(name)
for attr in attributes:
@@ -138,7 +130,7 @@ class SimpleTest(unittest.TestCase):
with source_util.create_modules('_temp') as mapping:
with open(mapping['_temp'], 'w') as file:
file.write('=')
- loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ loader = self.machinery.SourceFileLoader('_temp', mapping['_temp'])
with self.assertRaises(SyntaxError):
loader.load_module('_temp')
self.assertNotIn('_temp', sys.modules)
@@ -151,14 +143,14 @@ class SimpleTest(unittest.TestCase):
file.write("# test file for importlib")
try:
with util.uncache('_temp'):
- loader = machinery.SourceFileLoader('_temp', file_path)
+ loader = self.machinery.SourceFileLoader('_temp', file_path)
mod = loader.load_module('_temp')
self.assertEqual(file_path, mod.__file__)
- self.assertEqual(imp.cache_from_source(file_path),
+ self.assertEqual(self.util.cache_from_source(file_path),
mod.__cached__)
finally:
os.unlink(file_path)
- pycache = os.path.dirname(imp.cache_from_source(file_path))
+ pycache = os.path.dirname(self.util.cache_from_source(file_path))
if os.path.exists(pycache):
shutil.rmtree(pycache)
@@ -167,7 +159,7 @@ class SimpleTest(unittest.TestCase):
# truncated rather than raise an OverflowError.
with source_util.create_modules('_temp') as mapping:
source = mapping['_temp']
- compiled = imp.cache_from_source(source)
+ compiled = self.util.cache_from_source(source)
with open(source, 'w') as f:
f.write("x = 5")
try:
@@ -178,7 +170,7 @@ class SimpleTest(unittest.TestCase):
if e.errno != getattr(errno, 'EOVERFLOW', None):
raise
self.skipTest("cannot set modification time to large integer ({})".format(e))
- loader = machinery.SourceFileLoader('_temp', mapping['_temp'])
+ loader = self.machinery.SourceFileLoader('_temp', mapping['_temp'])
mod = loader.load_module('_temp')
# Sanity checks.
self.assertEqual(mod.__cached__, compiled)
@@ -186,8 +178,17 @@ class SimpleTest(unittest.TestCase):
# The pyc file was created.
os.stat(compiled)
+ def test_unloadable(self):
+ loader = self.machinery.SourceFileLoader('good name', {})
+ with self.assertRaises(ImportError):
+ loader.load_module('bad name')
+
+Frozen_SimpleTest, Source_SimpleTest = util.test_both(
+ SimpleTest, importlib=importlib, machinery=machinery, abc=importlib_abc,
+ util=importlib_util)
+
-class BadBytecodeTest(unittest.TestCase):
+class BadBytecodeTest:
def import_(self, file, module_name):
loader = self.loader(module_name, file)
@@ -204,7 +205,7 @@ class BadBytecodeTest(unittest.TestCase):
pass
py_compile.compile(mapping[name])
if not del_source:
- bytecode_path = imp.cache_from_source(mapping[name])
+ bytecode_path = self.util.cache_from_source(mapping[name])
else:
os.unlink(mapping[name])
bytecode_path = make_legacy_pyc(mapping[name])
@@ -293,7 +294,9 @@ class BadBytecodeTest(unittest.TestCase):
class SourceLoaderBadBytecodeTest(BadBytecodeTest):
- loader = machinery.SourceFileLoader
+ @classmethod
+ def setUpClass(cls):
+ cls.loader = cls.machinery.SourceFileLoader
@source_util.writes_bytecode_files
def test_empty_file(self):
@@ -332,7 +335,8 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
def test(name, mapping, bytecode_path):
self.import_(mapping[name], name)
with open(bytecode_path, 'rb') as bytecode_file:
- self.assertEqual(bytecode_file.read(4), imp.get_magic())
+ self.assertEqual(bytecode_file.read(4),
+ self.util.MAGIC_NUMBER)
self._test_bad_magic(test)
@@ -382,13 +386,13 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
zeros = b'\x00\x00\x00\x00'
with source_util.create_modules('_temp') as mapping:
py_compile.compile(mapping['_temp'])
- bytecode_path = imp.cache_from_source(mapping['_temp'])
+ bytecode_path = self.util.cache_from_source(mapping['_temp'])
with open(bytecode_path, 'r+b') as bytecode_file:
bytecode_file.seek(4)
bytecode_file.write(zeros)
self.import_(mapping['_temp'], '_temp')
source_mtime = os.path.getmtime(mapping['_temp'])
- source_timestamp = importlib._w_long(source_mtime)
+ source_timestamp = self.importlib._w_long(source_mtime)
with open(bytecode_path, 'rb') as bytecode_file:
bytecode_file.seek(4)
self.assertEqual(bytecode_file.read(4), source_timestamp)
@@ -400,7 +404,7 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
with source_util.create_modules('_temp') as mapping:
# Create bytecode that will need to be re-created.
py_compile.compile(mapping['_temp'])
- bytecode_path = imp.cache_from_source(mapping['_temp'])
+ bytecode_path = self.util.cache_from_source(mapping['_temp'])
with open(bytecode_path, 'r+b') as bytecode_file:
bytecode_file.seek(0)
bytecode_file.write(b'\x00\x00\x00\x00')
@@ -408,16 +412,22 @@ class SourceLoaderBadBytecodeTest(BadBytecodeTest):
os.chmod(bytecode_path,
stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
try:
- # Should not raise IOError!
+ # Should not raise OSError!
self.import_(mapping['_temp'], '_temp')
finally:
# Make writable for eventual clean-up.
os.chmod(bytecode_path, stat.S_IWUSR)
+Frozen_SourceBadBytecode, Source_SourceBadBytecode = util.test_both(
+ SourceLoaderBadBytecodeTest, importlib=importlib, machinery=machinery,
+ abc=importlib_abc, util=importlib_util)
+
class SourcelessLoaderBadBytecodeTest(BadBytecodeTest):
- loader = machinery.SourcelessFileLoader
+ @classmethod
+ def setUpClass(cls):
+ cls.loader = cls.machinery.SourcelessFileLoader
def test_empty_file(self):
def test(name, mapping, bytecode_path):
@@ -472,14 +482,10 @@ class SourcelessLoaderBadBytecodeTest(BadBytecodeTest):
def test_non_code_marshal(self):
self._test_non_code_marshal(del_source=True)
-
-def test_main():
- from test.support import run_unittest
- run_unittest(SimpleTest,
- SourceLoaderBadBytecodeTest,
- SourcelessLoaderBadBytecodeTest
- )
+Frozen_SourcelessBadBytecode, Source_SourcelessBadBytecode = util.test_both(
+ SourcelessLoaderBadBytecodeTest, importlib=importlib,
+ machinery=machinery, abc=importlib_abc, util=importlib_util)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py
index 8e4986835d..8bf8cb7880 100644
--- a/Lib/test/test_importlib/source/test_finder.py
+++ b/Lib/test/test_importlib/source/test_finder.py
@@ -1,9 +1,10 @@
from .. import abc
+from .. import util
from . import util as source_util
-from importlib import machinery
+machinery = util.import_importlib('importlib.machinery')
+
import errno
-import imp
import os
import py_compile
import stat
@@ -39,11 +40,11 @@ class FinderTests(abc.FinderTests):
"""
def get_finder(self, root):
- loader_details = [(machinery.SourceFileLoader,
- machinery.SOURCE_SUFFIXES),
- (machinery.SourcelessFileLoader,
- machinery.BYTECODE_SUFFIXES)]
- return machinery.FileFinder(root, *loader_details)
+ loader_details = [(self.machinery.SourceFileLoader,
+ self.machinery.SOURCE_SUFFIXES),
+ (self.machinery.SourcelessFileLoader,
+ self.machinery.BYTECODE_SUFFIXES)]
+ return self.machinery.FileFinder(root, *loader_details)
def import_(self, root, module):
return self.get_finder(root).find_module(module)
@@ -124,8 +125,8 @@ class FinderTests(abc.FinderTests):
def test_empty_string_for_dir(self):
# The empty string from sys.path means to search in the cwd.
- finder = machinery.FileFinder('', (machinery.SourceFileLoader,
- machinery.SOURCE_SUFFIXES))
+ finder = self.machinery.FileFinder('', (self.machinery.SourceFileLoader,
+ self.machinery.SOURCE_SUFFIXES))
with open('mod.py', 'w') as file:
file.write("# test file for importlib")
try:
@@ -136,8 +137,8 @@ class FinderTests(abc.FinderTests):
def test_invalidate_caches(self):
# invalidate_caches() should reset the mtime.
- finder = machinery.FileFinder('', (machinery.SourceFileLoader,
- machinery.SOURCE_SUFFIXES))
+ finder = self.machinery.FileFinder('', (self.machinery.SourceFileLoader,
+ self.machinery.SOURCE_SUFFIXES))
finder._path_mtime = 42
finder.invalidate_caches()
self.assertEqual(finder._path_mtime, -1)
@@ -181,11 +182,9 @@ class FinderTests(abc.FinderTests):
finder = self.get_finder(file_obj.name)
self.assertEqual((None, []), finder.find_loader('doesnotexist'))
+Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests, machinery=machinery)
-def test_main():
- from test.support import run_unittest
- run_unittest(FinderTests)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/source/test_path_hook.py b/Lib/test/test_importlib/source/test_path_hook.py
index 6a78792f07..92da77265b 100644
--- a/Lib/test/test_importlib/source/test_path_hook.py
+++ b/Lib/test/test_importlib/source/test_path_hook.py
@@ -1,17 +1,18 @@
+from .. import util
from . import util as source_util
-from importlib import machinery
-import imp
+machinery = util.import_importlib('importlib.machinery')
+
import unittest
-class PathHookTest(unittest.TestCase):
+class PathHookTest:
"""Test the path hook for source."""
def path_hook(self):
- return machinery.FileFinder.path_hook((machinery.SourceFileLoader,
- machinery.SOURCE_SUFFIXES))
+ return self.machinery.FileFinder.path_hook((self.machinery.SourceFileLoader,
+ self.machinery.SOURCE_SUFFIXES))
def test_success(self):
with source_util.create_modules('dummy') as mapping:
@@ -22,11 +23,8 @@ class PathHookTest(unittest.TestCase):
# The empty string represents the cwd.
self.assertTrue(hasattr(self.path_hook()(''), 'find_module'))
-
-def test_main():
- from test.support import run_unittest
- run_unittest(PathHookTest)
+Frozen_PathHookTest, Source_PathHooktest = util.test_both(PathHookTest, machinery=machinery)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/source/test_source_encoding.py b/Lib/test/test_importlib/source/test_source_encoding.py
index ba02b44274..654f4c2b2f 100644
--- a/Lib/test/test_importlib/source/test_source_encoding.py
+++ b/Lib/test/test_importlib/source/test_source_encoding.py
@@ -1,6 +1,8 @@
+from .. import util
from . import util as source_util
-from importlib import _bootstrap
+machinery = util.import_importlib('importlib.machinery')
+
import codecs
import re
import sys
@@ -13,7 +15,7 @@ import unittest
CODING_RE = re.compile(r'^[ \t\f]*#.*coding[:=][ \t]*([-\w.]+)', re.ASCII)
-class EncodingTest(unittest.TestCase):
+class EncodingTest:
"""PEP 3120 makes UTF-8 the default encoding for source code
[default encoding].
@@ -35,7 +37,7 @@ class EncodingTest(unittest.TestCase):
with source_util.create_modules(self.module_name) as mapping:
with open(mapping[self.module_name], 'wb') as file:
file.write(source)
- loader = _bootstrap.SourceFileLoader(self.module_name,
+ loader = self.machinery.SourceFileLoader(self.module_name,
mapping[self.module_name])
return loader.load_module(self.module_name)
@@ -84,8 +86,10 @@ class EncodingTest(unittest.TestCase):
with self.assertRaises(SyntaxError):
self.run_test(source)
+Frozen_EncodingTest, Source_EncodingTest = util.test_both(EncodingTest, machinery=machinery)
+
-class LineEndingTest(unittest.TestCase):
+class LineEndingTest:
r"""Source written with the three types of line endings (\n, \r\n, \r)
need to be readable [cr][crlf][lf]."""
@@ -97,7 +101,7 @@ class LineEndingTest(unittest.TestCase):
with source_util.create_modules(module_name) as mapping:
with open(mapping[module_name], 'wb') as file:
file.write(source)
- loader = _bootstrap.SourceFileLoader(module_name,
+ loader = self.machinery.SourceFileLoader(module_name,
mapping[module_name])
return loader.load_module(module_name)
@@ -113,11 +117,9 @@ class LineEndingTest(unittest.TestCase):
def test_lf(self):
self.run_test(b'\n')
+Frozen_LineEndings, Source_LineEndings = util.test_both(LineEndingTest, machinery=machinery)
-def test_main():
- from test.support import run_unittest
- run_unittest(EncodingTest, LineEndingTest)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/source/util.py b/Lib/test/test_importlib/source/util.py
index ae65663a67..63cd25adc4 100644
--- a/Lib/test/test_importlib/source/util.py
+++ b/Lib/test/test_importlib/source/util.py
@@ -2,7 +2,6 @@ from .. import util
import contextlib
import errno
import functools
-import imp
import os
import os.path
import sys
diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py
index c620c3771b..979a481b6b 100644
--- a/Lib/test/test_importlib/test_abc.py
+++ b/Lib/test/test_importlib/test_abc.py
@@ -1,9 +1,21 @@
-from importlib import abc
-from importlib import machinery
+import contextlib
import inspect
+import io
+import marshal
+import os
+import sys
+from test import support
+import types
import unittest
+from unittest import mock
+from . import util
+frozen_init, source_init = util.import_importlib('importlib')
+frozen_abc, source_abc = util.import_importlib('importlib.abc')
+frozen_util, source_util = util.import_importlib('importlib.util')
+
+##### Inheritance ##############################################################
class InheritanceTests:
"""Test that the specified class is a subclass/superclass of the expected
@@ -14,8 +26,19 @@ class InheritanceTests:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.superclasses = [getattr(self.abc, class_name)
+ for class_name in self.superclass_names]
+ if hasattr(self, 'subclass_names'):
+ # Because test.support.import_fresh_module() creates a new
+ # importlib._bootstrap per module, inheritance checks fail when
+ # checking across module boundaries (i.e. the _bootstrap in abc is
+ # not the same as the one in machinery). That means stealing one of
+ # the modules from the other to make sure the same instance is used.
+ self.subclasses = [getattr(self.abc.machinery, class_name)
+ for class_name in self.subclass_names]
assert self.subclasses or self.superclasses, self.__class__
- self.__test = getattr(abc, self.__class__.__name__)
+ testing = self.__class__.__name__.partition('_')[2]
+ self.__test = getattr(self.abc, testing)
def test_subclasses(self):
# Test that the expected subclasses inherit.
@@ -29,75 +52,958 @@ class InheritanceTests:
self.assertTrue(issubclass(self.__test, superclass),
"{0} is not a superclass of {1}".format(superclass, self.__test))
+def create_inheritance_tests(base_class):
+ def set_frozen(ns):
+ ns['abc'] = frozen_abc
+ def set_source(ns):
+ ns['abc'] = source_abc
-class MetaPathFinder(InheritanceTests, unittest.TestCase):
+ classes = []
+ for prefix, ns_set in [('Frozen', set_frozen), ('Source', set_source)]:
+ classes.append(types.new_class('_'.join([prefix, base_class.__name__]),
+ (base_class, unittest.TestCase),
+ exec_body=ns_set))
+ return classes
- superclasses = [abc.Finder]
- subclasses = [machinery.BuiltinImporter, machinery.FrozenImporter,
- machinery.PathFinder, machinery.WindowsRegistryFinder]
+class MetaPathFinder(InheritanceTests):
+ superclass_names = ['Finder']
+ subclass_names = ['BuiltinImporter', 'FrozenImporter', 'PathFinder',
+ 'WindowsRegistryFinder']
-class PathEntryFinder(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(MetaPathFinder)
+Frozen_MetaPathFinderInheritanceTests, Source_MetaPathFinderInheritanceTests = tests
- superclasses = [abc.Finder]
- subclasses = [machinery.FileFinder]
+class PathEntryFinder(InheritanceTests):
+ superclass_names = ['Finder']
+ subclass_names = ['FileFinder']
-class Loader(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(PathEntryFinder)
+Frozen_PathEntryFinderInheritanceTests, Source_PathEntryFinderInheritanceTests = tests
- subclasses = [abc.PyLoader]
+class ResourceLoader(InheritanceTests):
+ superclass_names = ['Loader']
-class ResourceLoader(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(ResourceLoader)
+Frozen_ResourceLoaderInheritanceTests, Source_ResourceLoaderInheritanceTests = tests
- superclasses = [abc.Loader]
+class InspectLoader(InheritanceTests):
+ superclass_names = ['Loader']
+ subclass_names = ['BuiltinImporter', 'FrozenImporter', 'ExtensionFileLoader']
-class InspectLoader(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(InspectLoader)
+Frozen_InspectLoaderInheritanceTests, Source_InspectLoaderInheritanceTests = tests
- superclasses = [abc.Loader]
- subclasses = [abc.PyLoader, machinery.BuiltinImporter,
- machinery.FrozenImporter, machinery.ExtensionFileLoader]
+class ExecutionLoader(InheritanceTests):
+ superclass_names = ['InspectLoader']
+ subclass_names = ['ExtensionFileLoader']
-class ExecutionLoader(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(ExecutionLoader)
+Frozen_ExecutionLoaderInheritanceTests, Source_ExecutionLoaderInheritanceTests = tests
- superclasses = [abc.InspectLoader]
- subclasses = [abc.PyLoader]
+class FileLoader(InheritanceTests):
+ superclass_names = ['ResourceLoader', 'ExecutionLoader']
+ subclass_names = ['SourceFileLoader', 'SourcelessFileLoader']
-class FileLoader(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(FileLoader)
+Frozen_FileLoaderInheritanceTests, Source_FileLoaderInheritanceTests = tests
- superclasses = [abc.ResourceLoader, abc.ExecutionLoader]
- subclasses = [machinery.SourceFileLoader, machinery.SourcelessFileLoader]
+class SourceLoader(InheritanceTests):
+ superclass_names = ['ResourceLoader', 'ExecutionLoader']
+ subclass_names = ['SourceFileLoader']
-class SourceLoader(InheritanceTests, unittest.TestCase):
+tests = create_inheritance_tests(SourceLoader)
+Frozen_SourceLoaderInheritanceTests, Source_SourceLoaderInheritanceTests = tests
- superclasses = [abc.ResourceLoader, abc.ExecutionLoader]
- subclasses = [machinery.SourceFileLoader]
+##### Default return values ####################################################
+def make_abc_subclasses(base_class):
+ classes = []
+ for kind, abc in [('Frozen', frozen_abc), ('Source', source_abc)]:
+ name = '_'.join([kind, base_class.__name__])
+ base_classes = base_class, getattr(abc, base_class.__name__)
+ classes.append(types.new_class(name, base_classes))
+ return classes
+def make_return_value_tests(base_class, test_class):
+ frozen_class, source_class = make_abc_subclasses(base_class)
+ tests = []
+ for prefix, class_in_test in [('Frozen', frozen_class), ('Source', source_class)]:
+ def set_ns(ns):
+ ns['ins'] = class_in_test()
+ tests.append(types.new_class('_'.join([prefix, test_class.__name__]),
+ (test_class, unittest.TestCase),
+ exec_body=set_ns))
+ return tests
-class PyLoader(InheritanceTests, unittest.TestCase):
- superclasses = [abc.Loader, abc.ResourceLoader, abc.ExecutionLoader]
+class MetaPathFinder:
+ def find_module(self, fullname, path):
+ return super().find_module(fullname, path)
-class PyPycLoader(InheritanceTests, unittest.TestCase):
+Frozen_MPF, Source_MPF = make_abc_subclasses(MetaPathFinder)
- superclasses = [abc.PyLoader]
+class MetaPathFinderDefaultsTests:
+
+ def test_find_module(self):
+ # Default should return None.
+ self.assertIsNone(self.ins.find_module('something', None))
+
+ def test_invalidate_caches(self):
+ # Calling the method is a no-op.
+ self.ins.invalidate_caches()
+
+
+tests = make_return_value_tests(MetaPathFinder, MetaPathFinderDefaultsTests)
+Frozen_MPFDefaultTests, Source_MPFDefaultTests = tests
+
+
+class PathEntryFinder:
+
+ def find_loader(self, fullname):
+ return super().find_loader(fullname)
+
+Frozen_PEF, Source_PEF = make_abc_subclasses(PathEntryFinder)
+
+
+class PathEntryFinderDefaultsTests:
+
+ def test_find_loader(self):
+ self.assertEqual((None, []), self.ins.find_loader('something'))
+
+ def find_module(self):
+ self.assertEqual(None, self.ins.find_module('something'))
+
+ def test_invalidate_caches(self):
+ # Should be a no-op.
+ self.ins.invalidate_caches()
+
+
+tests = make_return_value_tests(PathEntryFinder, PathEntryFinderDefaultsTests)
+Frozen_PEFDefaultTests, Source_PEFDefaultTests = tests
+
+
+class Loader:
+
+ def load_module(self, fullname):
+ return super().load_module(fullname)
+
+
+Frozen_L, Source_L = make_abc_subclasses(Loader)
+
+
+class LoaderDefaultsTests:
+
+ def test_load_module(self):
+ with self.assertRaises(ImportError):
+ self.ins.load_module('something')
+
+ def test_module_repr(self):
+ mod = types.ModuleType('blah')
+ with self.assertRaises(NotImplementedError):
+ self.ins.module_repr(mod)
+ original_repr = repr(mod)
+ mod.__loader__ = self.ins
+ # Should still return a proper repr.
+ self.assertTrue(repr(mod))
+
+
+tests = make_return_value_tests(Loader, LoaderDefaultsTests)
+Frozen_LDefaultTests, SourceLDefaultTests = tests
+
+
+class ResourceLoader(Loader):
+
+ def get_data(self, path):
+ return super().get_data(path)
+
+
+Frozen_RL, Source_RL = make_abc_subclasses(ResourceLoader)
+
+
+class ResourceLoaderDefaultsTests:
+
+ def test_get_data(self):
+ with self.assertRaises(IOError):
+ self.ins.get_data('/some/path')
+
+
+tests = make_return_value_tests(ResourceLoader, ResourceLoaderDefaultsTests)
+Frozen_RLDefaultTests, Source_RLDefaultTests = tests
+
+
+class InspectLoader(Loader):
+
+ def is_package(self, fullname):
+ return super().is_package(fullname)
+
+ def get_source(self, fullname):
+ return super().get_source(fullname)
+
+
+Frozen_IL, Source_IL = make_abc_subclasses(InspectLoader)
+
+
+class InspectLoaderDefaultsTests:
+
+ def test_is_package(self):
+ with self.assertRaises(ImportError):
+ self.ins.is_package('blah')
+
+ def test_get_source(self):
+ with self.assertRaises(ImportError):
+ self.ins.get_source('blah')
+
+
+tests = make_return_value_tests(InspectLoader, InspectLoaderDefaultsTests)
+Frozen_ILDefaultTests, Source_ILDefaultTests = tests
+
+
+class ExecutionLoader(InspectLoader):
+
+ def get_filename(self, fullname):
+ return super().get_filename(fullname)
+
+Frozen_EL, Source_EL = make_abc_subclasses(ExecutionLoader)
+
+
+class ExecutionLoaderDefaultsTests:
+
+ def test_get_filename(self):
+ with self.assertRaises(ImportError):
+ self.ins.get_filename('blah')
+
+
+tests = make_return_value_tests(ExecutionLoader, InspectLoaderDefaultsTests)
+Frozen_ELDefaultTests, Source_ELDefaultsTests = tests
+
+##### Loader concrete methods ##################################################
+class LoaderConcreteMethodTests:
+
+ def test_init_module_attrs(self):
+ loader = self.LoaderSubclass()
+ module = types.ModuleType('blah')
+ loader.init_module_attrs(module)
+ self.assertEqual(module.__loader__, loader)
+
+
+class Frozen_LoaderConcreateMethodTests(LoaderConcreteMethodTests, unittest.TestCase):
+ LoaderSubclass = Frozen_L
+
+class Source_LoaderConcreateMethodTests(LoaderConcreteMethodTests, unittest.TestCase):
+ LoaderSubclass = Source_L
+
+
+##### InspectLoader concrete methods ###########################################
+class InspectLoaderSourceToCodeTests:
+
+ def source_to_module(self, data, path=None):
+ """Help with source_to_code() tests."""
+ module = types.ModuleType('blah')
+ loader = self.InspectLoaderSubclass()
+ if path is None:
+ code = loader.source_to_code(data)
+ else:
+ code = loader.source_to_code(data, path)
+ exec(code, module.__dict__)
+ return module
+
+ def test_source_to_code_source(self):
+ # Since compile() can handle strings, so should source_to_code().
+ source = 'attr = 42'
+ module = self.source_to_module(source)
+ self.assertTrue(hasattr(module, 'attr'))
+ self.assertEqual(module.attr, 42)
+
+ def test_source_to_code_bytes(self):
+ # Since compile() can handle bytes, so should source_to_code().
+ source = b'attr = 42'
+ module = self.source_to_module(source)
+ self.assertTrue(hasattr(module, 'attr'))
+ self.assertEqual(module.attr, 42)
+
+ def test_source_to_code_path(self):
+ # Specifying a path should set it for the code object.
+ path = 'path/to/somewhere'
+ loader = self.InspectLoaderSubclass()
+ code = loader.source_to_code('', path)
+ self.assertEqual(code.co_filename, path)
+
+ def test_source_to_code_no_path(self):
+ # Not setting a path should still work and be set to <string> since that
+ # is a pre-existing practice as a default to compile().
+ loader = self.InspectLoaderSubclass()
+ code = loader.source_to_code('')
+ self.assertEqual(code.co_filename, '<string>')
+
+
+class Frozen_ILSourceToCodeTests(InspectLoaderSourceToCodeTests, unittest.TestCase):
+ InspectLoaderSubclass = Frozen_IL
+
+class Source_ILSourceToCodeTests(InspectLoaderSourceToCodeTests, unittest.TestCase):
+ InspectLoaderSubclass = Source_IL
+
+
+class InspectLoaderGetCodeTests:
+
+ def test_get_code(self):
+ # Test success.
+ module = types.ModuleType('blah')
+ with mock.patch.object(self.InspectLoaderSubclass, 'get_source') as mocked:
+ mocked.return_value = 'attr = 42'
+ loader = self.InspectLoaderSubclass()
+ code = loader.get_code('blah')
+ exec(code, module.__dict__)
+ self.assertEqual(module.attr, 42)
+
+ def test_get_code_source_is_None(self):
+ # If get_source() is None then this should be None.
+ with mock.patch.object(self.InspectLoaderSubclass, 'get_source') as mocked:
+ mocked.return_value = None
+ loader = self.InspectLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertIsNone(code)
+
+ def test_get_code_source_not_found(self):
+ # If there is no source then there is no code object.
+ loader = self.InspectLoaderSubclass()
+ with self.assertRaises(ImportError):
+ loader.get_code('blah')
+
+
+class Frozen_ILGetCodeTests(InspectLoaderGetCodeTests, unittest.TestCase):
+ InspectLoaderSubclass = Frozen_IL
+
+class Source_ILGetCodeTests(InspectLoaderGetCodeTests, unittest.TestCase):
+ InspectLoaderSubclass = Source_IL
+
+
+class InspectLoaderInitModuleTests:
+
+ def mock_is_package(self, return_value):
+ return mock.patch.object(self.InspectLoaderSubclass, 'is_package',
+ return_value=return_value)
+
+ def init_module_attrs(self, name):
+ loader = self.InspectLoaderSubclass()
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertEqual(module.__loader__, loader)
+ return module
+
+ def test_package(self):
+ # If a package, then __package__ == __name__, __path__ == []
+ with self.mock_is_package(True):
+ name = 'blah'
+ module = self.init_module_attrs(name)
+ self.assertEqual(module.__package__, name)
+ self.assertEqual(module.__path__, [])
+
+ def test_toplevel(self):
+ # If a module is top-level, __package__ == ''
+ with self.mock_is_package(False):
+ name = 'blah'
+ module = self.init_module_attrs(name)
+ self.assertEqual(module.__package__, '')
+
+ def test_submodule(self):
+ # If a module is contained within a package then set __package__ to the
+ # package name.
+ with self.mock_is_package(False):
+ name = 'pkg.mod'
+ module = self.init_module_attrs(name)
+ self.assertEqual(module.__package__, 'pkg')
+
+ def test_is_package_ImportError(self):
+ # If is_package() raises ImportError, __package__ should be None and
+ # __path__ should not be set.
+ with self.mock_is_package(False) as mocked_method:
+ mocked_method.side_effect = ImportError
+ name = 'mod'
+ module = self.init_module_attrs(name)
+ self.assertIsNone(module.__package__)
+ self.assertFalse(hasattr(module, '__path__'))
+
+
+class Frozen_ILInitModuleTests(InspectLoaderInitModuleTests, unittest.TestCase):
+ InspectLoaderSubclass = Frozen_IL
+
+class Source_ILInitModuleTests(InspectLoaderInitModuleTests, unittest.TestCase):
+ InspectLoaderSubclass = Source_IL
+
+
+class InspectLoaderLoadModuleTests:
+
+ """Test InspectLoader.load_module()."""
+
+ module_name = 'blah'
+
+ def setUp(self):
+ support.unload(self.module_name)
+ self.addCleanup(support.unload, self.module_name)
+
+ def mock_get_code(self):
+ return mock.patch.object(self.InspectLoaderSubclass, 'get_code')
+
+ def test_get_code_ImportError(self):
+ # If get_code() raises ImportError, it should propagate.
+ with self.mock_get_code() as mocked_get_code:
+ mocked_get_code.side_effect = ImportError
+ with self.assertRaises(ImportError):
+ loader = self.InspectLoaderSubclass()
+ loader.load_module(self.module_name)
+
+ def test_get_code_None(self):
+ # If get_code() returns None, raise ImportError.
+ with self.mock_get_code() as mocked_get_code:
+ mocked_get_code.return_value = None
+ with self.assertRaises(ImportError):
+ loader = self.InspectLoaderSubclass()
+ loader.load_module(self.module_name)
+
+ def test_module_returned(self):
+ # The loaded module should be returned.
+ code = compile('attr = 42', '<string>', 'exec')
+ with self.mock_get_code() as mocked_get_code:
+ mocked_get_code.return_value = code
+ loader = self.InspectLoaderSubclass()
+ module = loader.load_module(self.module_name)
+ self.assertEqual(module, sys.modules[self.module_name])
+
+
+class Frozen_ILLoadModuleTests(InspectLoaderLoadModuleTests, unittest.TestCase):
+ InspectLoaderSubclass = Frozen_IL
+
+class Source_ILLoadModuleTests(InspectLoaderLoadModuleTests, unittest.TestCase):
+ InspectLoaderSubclass = Source_IL
+
+
+##### ExecutionLoader concrete methods #########################################
+class ExecutionLoaderGetCodeTests:
+
+ def mock_methods(self, *, get_source=False, get_filename=False):
+ source_mock_context, filename_mock_context = None, None
+ if get_source:
+ source_mock_context = mock.patch.object(self.ExecutionLoaderSubclass,
+ 'get_source')
+ if get_filename:
+ filename_mock_context = mock.patch.object(self.ExecutionLoaderSubclass,
+ 'get_filename')
+ return source_mock_context, filename_mock_context
+
+ def test_get_code(self):
+ path = 'blah.py'
+ source_mock_context, filename_mock_context = self.mock_methods(
+ get_source=True, get_filename=True)
+ with source_mock_context as source_mock, filename_mock_context as name_mock:
+ source_mock.return_value = 'attr = 42'
+ name_mock.return_value = path
+ loader = self.ExecutionLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertEqual(code.co_filename, path)
+ module = types.ModuleType('blah')
+ exec(code, module.__dict__)
+ self.assertEqual(module.attr, 42)
+
+ def test_get_code_source_is_None(self):
+ # If get_source() is None then this should be None.
+ source_mock_context, _ = self.mock_methods(get_source=True)
+ with source_mock_context as mocked:
+ mocked.return_value = None
+ loader = self.ExecutionLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertIsNone(code)
+
+ def test_get_code_source_not_found(self):
+ # If there is no source then there is no code object.
+ loader = self.ExecutionLoaderSubclass()
+ with self.assertRaises(ImportError):
+ loader.get_code('blah')
+
+ def test_get_code_no_path(self):
+ # If get_filename() raises ImportError then simply skip setting the path
+ # on the code object.
+ source_mock_context, filename_mock_context = self.mock_methods(
+ get_source=True, get_filename=True)
+ with source_mock_context as source_mock, filename_mock_context as name_mock:
+ source_mock.return_value = 'attr = 42'
+ name_mock.side_effect = ImportError
+ loader = self.ExecutionLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertEqual(code.co_filename, '<string>')
+ module = types.ModuleType('blah')
+ exec(code, module.__dict__)
+ self.assertEqual(module.attr, 42)
+
+
+class Frozen_ELGetCodeTests(ExecutionLoaderGetCodeTests, unittest.TestCase):
+ ExecutionLoaderSubclass = Frozen_EL
+
+class Source_ELGetCodeTests(ExecutionLoaderGetCodeTests, unittest.TestCase):
+ ExecutionLoaderSubclass = Source_EL
+
+
+class ExecutionLoaderInitModuleTests:
+
+ def mock_is_package(self, return_value):
+ return mock.patch.object(self.ExecutionLoaderSubclass, 'is_package',
+ return_value=return_value)
+
+ @contextlib.contextmanager
+ def mock_methods(self, is_package, filename):
+ is_package_manager = self.mock_is_package(is_package)
+ get_filename_manager = mock.patch.object(self.ExecutionLoaderSubclass,
+ 'get_filename', return_value=filename)
+ with is_package_manager as mock_is_package:
+ with get_filename_manager as mock_get_filename:
+ yield {'is_package': mock_is_package,
+ 'get_filename': mock_get_filename}
+
+ def test_toplevel(self):
+ # Verify __loader__, __file__, and __package__; no __path__.
+ name = 'blah'
+ path = os.path.join('some', 'path', '{}.py'.format(name))
+ with self.mock_methods(False, path):
+ loader = self.ExecutionLoaderSubclass()
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertIs(module.__loader__, loader)
+ self.assertEqual(module.__file__, path)
+ self.assertEqual(module.__package__, '')
+ self.assertFalse(hasattr(module, '__path__'))
+
+ def test_package(self):
+ # Verify __loader__, __file__, __package__, and __path__.
+ name = 'pkg'
+ path = os.path.join('some', 'pkg', '__init__.py')
+ with self.mock_methods(True, path):
+ loader = self.ExecutionLoaderSubclass()
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertIs(module.__loader__, loader)
+ self.assertEqual(module.__file__, path)
+ self.assertEqual(module.__package__, 'pkg')
+ self.assertEqual(module.__path__, [os.path.dirname(path)])
+
+ def test_submodule(self):
+ # Verify __package__ and not __path__; test_toplevel() takes care of
+ # other attributes.
+ name = 'pkg.submodule'
+ path = os.path.join('some', 'pkg', 'submodule.py')
+ with self.mock_methods(False, path):
+ loader = self.ExecutionLoaderSubclass()
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertEqual(module.__package__, 'pkg')
+ self.assertEqual(module.__file__, path)
+ self.assertFalse(hasattr(module, '__path__'))
+
+ def test_get_filename_ImportError(self):
+ # If get_filename() raises ImportError, don't set __file__.
+ name = 'blah'
+ path = 'blah.py'
+ with self.mock_methods(False, path) as mocked_methods:
+ mocked_methods['get_filename'].side_effect = ImportError
+ loader = self.ExecutionLoaderSubclass()
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertFalse(hasattr(module, '__file__'))
+
+
+class Frozen_ELInitModuleTests(ExecutionLoaderInitModuleTests, unittest.TestCase):
+ ExecutionLoaderSubclass = Frozen_EL
+
+class Source_ELInitModuleTests(ExecutionLoaderInitModuleTests, unittest.TestCase):
+ ExecutionLoaderSubclass = Source_EL
+
+
+##### SourceLoader concrete methods ############################################
+class SourceLoader:
+
+ # Globals that should be defined for all modules.
+ source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, "
+ b"repr(__loader__)])")
+
+ def __init__(self, path):
+ self.path = path
+
+ def get_data(self, path):
+ if path != self.path:
+ raise IOError
+ return self.source
+
+ def get_filename(self, fullname):
+ return self.path
+
+ def module_repr(self, module):
+ return '<module>'
+
+
+Frozen_SourceOnlyL, Source_SourceOnlyL = make_abc_subclasses(SourceLoader)
+
+
+class SourceLoader(SourceLoader):
+
+ source_mtime = 1
+
+ def __init__(self, path, magic=None):
+ super().__init__(path)
+ self.bytecode_path = self.util.cache_from_source(self.path)
+ self.source_size = len(self.source)
+ if magic is None:
+ magic = self.util.MAGIC_NUMBER
+ data = bytearray(magic)
+ data.extend(self.init._w_long(self.source_mtime))
+ data.extend(self.init._w_long(self.source_size))
+ code_object = compile(self.source, self.path, 'exec',
+ dont_inherit=True)
+ data.extend(marshal.dumps(code_object))
+ self.bytecode = bytes(data)
+ self.written = {}
+
+ def get_data(self, path):
+ if path == self.path:
+ return super().get_data(path)
+ elif path == self.bytecode_path:
+ return self.bytecode
+ else:
+ raise OSError
+
+ def path_stats(self, path):
+ if path != self.path:
+ raise IOError
+ return {'mtime': self.source_mtime, 'size': self.source_size}
+
+ def set_data(self, path, data):
+ self.written[path] = bytes(data)
+ return path == self.bytecode_path
+
+
+Frozen_SL, Source_SL = make_abc_subclasses(SourceLoader)
+Frozen_SL.util = frozen_util
+Source_SL.util = source_util
+Frozen_SL.init = frozen_init
+Source_SL.init = source_init
+
+
+class SourceLoaderTestHarness:
+
+ def setUp(self, *, is_package=True, **kwargs):
+ self.package = 'pkg'
+ if is_package:
+ self.path = os.path.join(self.package, '__init__.py')
+ self.name = self.package
+ else:
+ module_name = 'mod'
+ self.path = os.path.join(self.package, '.'.join(['mod', 'py']))
+ self.name = '.'.join([self.package, module_name])
+ self.cached = self.util.cache_from_source(self.path)
+ self.loader = self.loader_mock(self.path, **kwargs)
+
+ def verify_module(self, module):
+ self.assertEqual(module.__name__, self.name)
+ self.assertEqual(module.__file__, self.path)
+ self.assertEqual(module.__cached__, self.cached)
+ self.assertEqual(module.__package__, self.package)
+ self.assertEqual(module.__loader__, self.loader)
+ values = module._.split('::')
+ self.assertEqual(values[0], self.name)
+ self.assertEqual(values[1], self.path)
+ self.assertEqual(values[2], self.cached)
+ self.assertEqual(values[3], self.package)
+ self.assertEqual(values[4], repr(self.loader))
+
+ def verify_code(self, code_object):
+ module = types.ModuleType(self.name)
+ module.__file__ = self.path
+ module.__cached__ = self.cached
+ module.__package__ = self.package
+ module.__loader__ = self.loader
+ module.__path__ = []
+ exec(code_object, module.__dict__)
+ self.verify_module(module)
+
+
+class SourceOnlyLoaderTests(SourceLoaderTestHarness):
+
+ """Test importlib.abc.SourceLoader for source-only loading.
+
+ Reload testing is subsumed by the tests for
+ importlib.util.module_for_loader.
+
+ """
+
+ def test_get_source(self):
+ # Verify the source code is returned as a string.
+ # If an OSError is raised by get_data then raise ImportError.
+ expected_source = self.loader.source.decode('utf-8')
+ self.assertEqual(self.loader.get_source(self.name), expected_source)
+ def raise_OSError(path):
+ raise OSError
+ self.loader.get_data = raise_OSError
+ with self.assertRaises(ImportError) as cm:
+ self.loader.get_source(self.name)
+ self.assertEqual(cm.exception.name, self.name)
+
+ def test_is_package(self):
+ # Properly detect when loading a package.
+ self.setUp(is_package=False)
+ self.assertFalse(self.loader.is_package(self.name))
+ self.setUp(is_package=True)
+ self.assertTrue(self.loader.is_package(self.name))
+ self.assertFalse(self.loader.is_package(self.name + '.__init__'))
+
+ def test_get_code(self):
+ # Verify the code object is created.
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+
+ def test_source_to_code(self):
+ # Verify the compiled code object.
+ code = self.loader.source_to_code(self.loader.source, self.path)
+ self.verify_code(code)
+
+ def test_load_module(self):
+ # Loading a module should set __name__, __loader__, __package__,
+ # __path__ (for packages), __file__, and __cached__.
+ # The module should also be put into sys.modules.
+ with util.uncache(self.name):
+ module = self.loader.load_module(self.name)
+ self.verify_module(module)
+ self.assertEqual(module.__path__, [os.path.dirname(self.path)])
+ self.assertIn(self.name, sys.modules)
+
+ def test_package_settings(self):
+ # __package__ needs to be set, while __path__ is set on if the module
+ # is a package.
+ # Testing the values for a package are covered by test_load_module.
+ self.setUp(is_package=False)
+ with util.uncache(self.name):
+ module = self.loader.load_module(self.name)
+ self.verify_module(module)
+ self.assertTrue(not hasattr(module, '__path__'))
+
+ def test_get_source_encoding(self):
+ # Source is considered encoded in UTF-8 by default unless otherwise
+ # specified by an encoding line.
+ source = "_ = 'ü'"
+ self.loader.source = source.encode('utf-8')
+ returned_source = self.loader.get_source(self.name)
+ self.assertEqual(returned_source, source)
+ source = "# coding: latin-1\n_ = ü"
+ self.loader.source = source.encode('latin-1')
+ returned_source = self.loader.get_source(self.name)
+ self.assertEqual(returned_source, source)
+
+
+class Frozen_SourceOnlyLTests(SourceOnlyLoaderTests, unittest.TestCase):
+ loader_mock = Frozen_SourceOnlyL
+ util = frozen_util
+
+class Source_SourceOnlyLTests(SourceOnlyLoaderTests, unittest.TestCase):
+ loader_mock = Source_SourceOnlyL
+ util = source_util
+
+
+@unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true")
+class SourceLoaderBytecodeTests(SourceLoaderTestHarness):
+
+ """Test importlib.abc.SourceLoader's use of bytecode.
+
+ Source-only testing handled by SourceOnlyLoaderTests.
+
+ """
+
+ def verify_code(self, code_object, *, bytecode_written=False):
+ super().verify_code(code_object)
+ if bytecode_written:
+ self.assertIn(self.cached, self.loader.written)
+ data = bytearray(self.util.MAGIC_NUMBER)
+ data.extend(self.init._w_long(self.loader.source_mtime))
+ data.extend(self.init._w_long(self.loader.source_size))
+ data.extend(marshal.dumps(code_object))
+ self.assertEqual(self.loader.written[self.cached], bytes(data))
+
+ def test_code_with_everything(self):
+ # When everything should work.
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+
+ def test_no_bytecode(self):
+ # If no bytecode exists then move on to the source.
+ self.loader.bytecode_path = "<does not exist>"
+ # Sanity check
+ with self.assertRaises(OSError):
+ bytecode_path = self.util.cache_from_source(self.path)
+ self.loader.get_data(bytecode_path)
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object, bytecode_written=True)
+
+ def test_code_bad_timestamp(self):
+ # Bytecode is only used when the timestamp matches the source EXACTLY.
+ for source_mtime in (0, 2):
+ assert source_mtime != self.loader.source_mtime
+ original = self.loader.source_mtime
+ self.loader.source_mtime = source_mtime
+ # If bytecode is used then EOFError would be raised by marshal.
+ self.loader.bytecode = self.loader.bytecode[8:]
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object, bytecode_written=True)
+ self.loader.source_mtime = original
+
+ def test_code_bad_magic(self):
+ # Skip over bytecode with a bad magic number.
+ self.setUp(magic=b'0000')
+ # If bytecode is used then EOFError would be raised by marshal.
+ self.loader.bytecode = self.loader.bytecode[8:]
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object, bytecode_written=True)
+
+ def test_dont_write_bytecode(self):
+ # Bytecode is not written if sys.dont_write_bytecode is true.
+ # Can assume it is false already thanks to the skipIf class decorator.
+ try:
+ sys.dont_write_bytecode = True
+ self.loader.bytecode_path = "<does not exist>"
+ code_object = self.loader.get_code(self.name)
+ self.assertNotIn(self.cached, self.loader.written)
+ finally:
+ sys.dont_write_bytecode = False
+
+ def test_no_set_data(self):
+ # If set_data is not defined, one can still read bytecode.
+ self.setUp(magic=b'0000')
+ original_set_data = self.loader.__class__.mro()[1].set_data
+ try:
+ del self.loader.__class__.mro()[1].set_data
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+ finally:
+ self.loader.__class__.mro()[1].set_data = original_set_data
+
+ def test_set_data_raises_exceptions(self):
+ # Raising NotImplementedError or OSError is okay for set_data.
+ def raise_exception(exc):
+ def closure(*args, **kwargs):
+ raise exc
+ return closure
+
+ self.setUp(magic=b'0000')
+ self.loader.set_data = raise_exception(NotImplementedError)
+ code_object = self.loader.get_code(self.name)
+ self.verify_code(code_object)
+
+
+class Frozen_SLBytecodeTests(SourceLoaderBytecodeTests, unittest.TestCase):
+ loader_mock = Frozen_SL
+ init = frozen_init
+ util = frozen_util
+
+class SourceSLBytecodeTests(SourceLoaderBytecodeTests, unittest.TestCase):
+ loader_mock = Source_SL
+ init = source_init
+ util = source_util
+
+
+class SourceLoaderGetSourceTests:
+
+ """Tests for importlib.abc.SourceLoader.get_source()."""
+
+ def test_default_encoding(self):
+ # Should have no problems with UTF-8 text.
+ name = 'mod'
+ mock = self.SourceOnlyLoaderMock('mod.file')
+ source = 'x = "ü"'
+ mock.source = source.encode('utf-8')
+ returned_source = mock.get_source(name)
+ self.assertEqual(returned_source, source)
+
+ def test_decoded_source(self):
+ # Decoding should work.
+ name = 'mod'
+ mock = self.SourceOnlyLoaderMock("mod.file")
+ source = "# coding: Latin-1\nx='ü'"
+ assert source.encode('latin-1') != source.encode('utf-8')
+ mock.source = source.encode('latin-1')
+ returned_source = mock.get_source(name)
+ self.assertEqual(returned_source, source)
+
+ def test_universal_newlines(self):
+ # PEP 302 says universal newlines should be used.
+ name = 'mod'
+ mock = self.SourceOnlyLoaderMock('mod.file')
+ source = "x = 42\r\ny = -13\r\n"
+ mock.source = source.encode('utf-8')
+ expect = io.IncrementalNewlineDecoder(None, True).decode(source)
+ self.assertEqual(mock.get_source(name), expect)
+
+
+class Frozen_SourceOnlyLGetSourceTests(SourceLoaderGetSourceTests, unittest.TestCase):
+ SourceOnlyLoaderMock = Frozen_SourceOnlyL
+
+class Source_SourceOnlyLGetSourceTests(SourceLoaderGetSourceTests, unittest.TestCase):
+ SourceOnlyLoaderMock = Source_SourceOnlyL
+
+
+class SourceLoaderInitModuleAttrTests:
+
+ """Tests for importlib.abc.SourceLoader.init_module_attrs()."""
+
+ def test_init_module_attrs(self):
+ # If __file__ set, __cached__ == importlib.util.cached_from_source(__file__).
+ name = 'blah'
+ path = 'blah.py'
+ loader = self.SourceOnlyLoaderMock(path)
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertEqual(module.__loader__, loader)
+ self.assertEqual(module.__package__, '')
+ self.assertEqual(module.__file__, path)
+ self.assertEqual(module.__cached__, self.util.cache_from_source(path))
+
+ def test_no_get_filename(self):
+ # No __file__, no __cached__.
+ with mock.patch.object(self.SourceOnlyLoaderMock, 'get_filename') as mocked:
+ mocked.side_effect = ImportError
+ name = 'blah'
+ loader = self.SourceOnlyLoaderMock('blah.py')
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertFalse(hasattr(module, '__file__'))
+ self.assertFalse(hasattr(module, '__cached__'))
+
+
+class Frozen_SLInitModuleAttrTests(SourceLoaderInitModuleAttrTests, unittest.TestCase):
+ SourceOnlyLoaderMock = Frozen_SourceOnlyL
+ util = frozen_util
+
+ # Difficult to test under source thanks to cross-module mocking needs.
+ @mock.patch('importlib._bootstrap.cache_from_source')
+ def test_cache_from_source_NotImplementedError(self, mock_cache_from_source):
+ # If importlib.util.cache_from_source() raises NotImplementedError don't set
+ # __cached__.
+ mock_cache_from_source.side_effect = NotImplementedError
+ name = 'blah'
+ path = 'blah.py'
+ loader = self.SourceOnlyLoaderMock(path)
+ module = types.ModuleType(name)
+ loader.init_module_attrs(module)
+ self.assertEqual(module.__file__, path)
+ self.assertFalse(hasattr(module, '__cached__'))
+
+
+class Source_SLInitModuleAttrTests(SourceLoaderInitModuleAttrTests, unittest.TestCase):
+ SourceOnlyLoaderMock = Source_SourceOnlyL
+ util = source_util
-def test_main():
- from test.support import run_unittest
- classes = []
- for class_ in globals().values():
- if (inspect.isclass(class_) and
- issubclass(class_, unittest.TestCase) and
- issubclass(class_, InheritanceTests)):
- classes.append(class_)
- run_unittest(*classes)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py
index b1a58944f3..51266345bc 100644
--- a/Lib/test/test_importlib/test_api.py
+++ b/Lib/test/test_importlib/test_api.py
@@ -1,14 +1,17 @@
from . import util
-import imp
-import importlib
-from importlib import machinery
+
+frozen_init, source_init = util.import_importlib('importlib')
+frozen_util, source_util = util.import_importlib('importlib.util')
+frozen_machinery, source_machinery = util.import_importlib('importlib.machinery')
+
+import os.path
import sys
from test import support
import types
import unittest
-class ImportModuleTests(unittest.TestCase):
+class ImportModuleTests:
"""Test importlib.import_module."""
@@ -16,7 +19,7 @@ class ImportModuleTests(unittest.TestCase):
# Test importing a top-level module.
with util.mock_modules('top_level') as mock:
with util.import_state(meta_path=[mock]):
- module = importlib.import_module('top_level')
+ module = self.init.import_module('top_level')
self.assertEqual(module.__name__, 'top_level')
def test_absolute_package_import(self):
@@ -26,7 +29,7 @@ class ImportModuleTests(unittest.TestCase):
name = '{0}.mod'.format(pkg_name)
with util.mock_modules(pkg_long_name, name) as mock:
with util.import_state(meta_path=[mock]):
- module = importlib.import_module(name)
+ module = self.init.import_module(name)
self.assertEqual(module.__name__, name)
def test_shallow_relative_package_import(self):
@@ -38,17 +41,17 @@ class ImportModuleTests(unittest.TestCase):
relative_name = '.{0}'.format(module_name)
with util.mock_modules(pkg_long_name, absolute_name) as mock:
with util.import_state(meta_path=[mock]):
- importlib.import_module(pkg_name)
- module = importlib.import_module(relative_name, pkg_name)
+ self.init.import_module(pkg_name)
+ module = self.init.import_module(relative_name, pkg_name)
self.assertEqual(module.__name__, absolute_name)
def test_deep_relative_package_import(self):
modules = ['a.__init__', 'a.b.__init__', 'a.c']
with util.mock_modules(*modules) as mock:
with util.import_state(meta_path=[mock]):
- importlib.import_module('a')
- importlib.import_module('a.b')
- module = importlib.import_module('..c', 'a.b')
+ self.init.import_module('a')
+ self.init.import_module('a.b')
+ module = self.init.import_module('..c', 'a.b')
self.assertEqual(module.__name__, 'a.c')
def test_absolute_import_with_package(self):
@@ -59,15 +62,15 @@ class ImportModuleTests(unittest.TestCase):
name = '{0}.mod'.format(pkg_name)
with util.mock_modules(pkg_long_name, name) as mock:
with util.import_state(meta_path=[mock]):
- importlib.import_module(pkg_name)
- module = importlib.import_module(name, pkg_name)
+ self.init.import_module(pkg_name)
+ module = self.init.import_module(name, pkg_name)
self.assertEqual(module.__name__, name)
def test_relative_import_wo_package(self):
# Relative imports cannot happen without the 'package' argument being
# set.
with self.assertRaises(TypeError):
- importlib.import_module('.support')
+ self.init.import_module('.support')
def test_loaded_once(self):
@@ -76,7 +79,7 @@ class ImportModuleTests(unittest.TestCase):
# module currently being imported.
b_load_count = 0
def load_a():
- importlib.import_module('a.b')
+ self.init.import_module('a.b')
def load_b():
nonlocal b_load_count
b_load_count += 1
@@ -84,11 +87,17 @@ class ImportModuleTests(unittest.TestCase):
modules = ['a.__init__', 'a.b']
with util.mock_modules(*modules, module_code=code) as mock:
with util.import_state(meta_path=[mock]):
- importlib.import_module('a.b')
+ self.init.import_module('a.b')
self.assertEqual(b_load_count, 1)
+class Frozen_ImportModuleTests(ImportModuleTests, unittest.TestCase):
+ init = frozen_init
+
+class Source_ImportModuleTests(ImportModuleTests, unittest.TestCase):
+ init = source_init
-class FindLoaderTests(unittest.TestCase):
+
+class FindLoaderTests:
class FakeMetaFinder:
@staticmethod
@@ -98,29 +107,43 @@ class FindLoaderTests(unittest.TestCase):
# If a module with __loader__ is in sys.modules, then return it.
name = 'some_mod'
with util.uncache(name):
- module = imp.new_module(name)
+ module = types.ModuleType(name)
loader = 'a loader!'
module.__loader__ = loader
sys.modules[name] = module
- found = importlib.find_loader(name)
+ found = self.init.find_loader(name)
self.assertEqual(loader, found)
def test_sys_modules_loader_is_None(self):
# If sys.modules[name].__loader__ is None, raise ValueError.
name = 'some_mod'
with util.uncache(name):
- module = imp.new_module(name)
+ module = types.ModuleType(name)
module.__loader__ = None
sys.modules[name] = module
with self.assertRaises(ValueError):
- importlib.find_loader(name)
+ self.init.find_loader(name)
+
+ def test_sys_modules_loader_is_not_set(self):
+ # Should raise ValueError
+ # Issue #17099
+ name = 'some_mod'
+ with util.uncache(name):
+ module = types.ModuleType(name)
+ try:
+ del module.__loader__
+ except AttributeError:
+ pass
+ sys.modules[name] = module
+ with self.assertRaises(ValueError):
+ self.init.find_loader(name)
def test_success(self):
# Return the loader found on sys.meta_path.
name = 'some_mod'
with util.uncache(name):
with util.import_state(meta_path=[self.FakeMetaFinder]):
- self.assertEqual((name, None), importlib.find_loader(name))
+ self.assertEqual((name, None), self.init.find_loader(name))
def test_success_path(self):
# Searching on a path should work.
@@ -129,14 +152,172 @@ class FindLoaderTests(unittest.TestCase):
with util.uncache(name):
with util.import_state(meta_path=[self.FakeMetaFinder]):
self.assertEqual((name, path),
- importlib.find_loader(name, path))
+ self.init.find_loader(name, path))
def test_nothing(self):
# None is returned upon failure to find a loader.
- self.assertIsNone(importlib.find_loader('nevergoingtofindthismodule'))
+ self.assertIsNone(self.init.find_loader('nevergoingtofindthismodule'))
+
+class Frozen_FindLoaderTests(FindLoaderTests, unittest.TestCase):
+ init = frozen_init
+class Source_FindLoaderTests(FindLoaderTests, unittest.TestCase):
+ init = source_init
-class InvalidateCacheTests(unittest.TestCase):
+
+class ReloadTests:
+
+ """Test module reloading for builtin and extension modules."""
+
+ def test_reload_modules(self):
+ for mod in ('tokenize', 'time', 'marshal'):
+ with self.subTest(module=mod):
+ with support.CleanImport(mod):
+ module = self.init.import_module(mod)
+ self.init.reload(module)
+
+ def test_module_replaced(self):
+ def code():
+ import sys
+ module = type(sys)('top_level')
+ module.spam = 3
+ sys.modules['top_level'] = module
+ mock = util.mock_modules('top_level',
+ module_code={'top_level': code})
+ with mock:
+ with util.import_state(meta_path=[mock]):
+ module = self.init.import_module('top_level')
+ reloaded = self.init.reload(module)
+ actual = sys.modules['top_level']
+ self.assertEqual(actual.spam, 3)
+ self.assertEqual(reloaded.spam, 3)
+
+ def test_reload_missing_loader(self):
+ with support.CleanImport('types'):
+ import types
+ loader = types.__loader__
+ del types.__loader__
+ reloaded = self.init.reload(types)
+
+ self.assertIs(reloaded, types)
+ self.assertIs(sys.modules['types'], types)
+ self.assertEqual(reloaded.__loader__.path, loader.path)
+
+ def test_reload_loader_replaced(self):
+ with support.CleanImport('types'):
+ import types
+ types.__loader__ = None
+ self.init.invalidate_caches()
+ reloaded = self.init.reload(types)
+
+ self.assertIsNot(reloaded.__loader__, None)
+ self.assertIs(reloaded, types)
+ self.assertIs(sys.modules['types'], types)
+
+ def test_reload_location_changed(self):
+ name = 'spam'
+ with support.temp_cwd(None) as cwd:
+ with util.uncache('spam'):
+ with support.DirsOnSysPath(cwd):
+ self.init.invalidate_caches()
+ path = os.path.join(cwd, name + '.py')
+ cached = self.util.cache_from_source(path)
+ expected = {'__name__': name,
+ '__package__': '',
+ '__file__': path,
+ '__cached__': cached,
+ '__doc__': None,
+ '__builtins__': __builtins__,
+ }
+ support.create_empty_file(path)
+ module = self.init.import_module(name)
+ ns = vars(module)
+ del ns['__initializing__']
+ loader = ns.pop('__loader__')
+ self.assertEqual(loader.path, path)
+ self.assertEqual(ns, expected)
+
+ self.init.invalidate_caches()
+ init_path = os.path.join(cwd, name, '__init__.py')
+ cached = self.util.cache_from_source(init_path)
+ expected = {'__name__': name,
+ '__package__': name,
+ '__file__': init_path,
+ '__cached__': cached,
+ '__path__': [os.path.dirname(init_path)],
+ '__doc__': None,
+ '__builtins__': __builtins__,
+ }
+ os.mkdir(name)
+ os.rename(path, init_path)
+ reloaded = self.init.reload(module)
+ ns = vars(reloaded)
+ del ns['__initializing__']
+ loader = ns.pop('__loader__')
+ self.assertIs(reloaded, module)
+ self.assertEqual(loader.path, init_path)
+ self.assertEqual(ns, expected)
+
+ def test_reload_namespace_changed(self):
+ self.maxDiff = None
+ name = 'spam'
+ with support.temp_cwd(None) as cwd:
+ with util.uncache('spam'):
+ with support.DirsOnSysPath(cwd):
+ self.init.invalidate_caches()
+ bad_path = os.path.join(cwd, name, '__init.py')
+ cached = self.util.cache_from_source(bad_path)
+ expected = {'__name__': name,
+ '__package__': name,
+ '__doc__': None,
+ }
+ os.mkdir(name)
+ with open(bad_path, 'w') as init_file:
+ init_file.write('eggs = None')
+ module = self.init.import_module(name)
+ ns = vars(module)
+ del ns['__initializing__']
+ loader = ns.pop('__loader__')
+ path = ns.pop('__path__')
+ self.assertEqual(set(path),
+ set([os.path.dirname(bad_path)]))
+ with self.assertRaises(AttributeError):
+ # a NamespaceLoader
+ loader.path
+ self.assertEqual(ns, expected)
+
+ self.init.invalidate_caches()
+ init_path = os.path.join(cwd, name, '__init__.py')
+ cached = self.util.cache_from_source(init_path)
+ expected = {'__name__': name,
+ '__package__': name,
+ '__file__': init_path,
+ '__cached__': cached,
+ '__path__': [os.path.dirname(init_path)],
+ '__doc__': None,
+ '__builtins__': __builtins__,
+ 'eggs': None,
+ }
+ os.rename(bad_path, init_path)
+ reloaded = self.init.reload(module)
+ ns = vars(reloaded)
+ del ns['__initializing__']
+ loader = ns.pop('__loader__')
+ self.assertIs(reloaded, module)
+ self.assertEqual(loader.path, init_path)
+ self.assertEqual(ns, expected)
+
+
+class Frozen_ReloadTests(ReloadTests, unittest.TestCase):
+ init = frozen_init
+ util = frozen_util
+
+class Source_ReloadTests(ReloadTests, unittest.TestCase):
+ init = source_init
+ util = source_util
+
+
+class InvalidateCacheTests:
def test_method_called(self):
# If defined the method should be called.
@@ -155,48 +336,54 @@ class InvalidateCacheTests(unittest.TestCase):
self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key))
sys.path_importer_cache[key] = path_ins
self.addCleanup(lambda: sys.meta_path.remove(meta_ins))
- importlib.invalidate_caches()
+ self.init.invalidate_caches()
self.assertTrue(meta_ins.called)
self.assertTrue(path_ins.called)
def test_method_lacking(self):
# There should be no issues if the method is not defined.
key = 'gobbledeegook'
- sys.path_importer_cache[key] = imp.NullImporter('abc')
+ sys.path_importer_cache[key] = None
self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key))
- importlib.invalidate_caches() # Shouldn't trigger an exception.
+ self.init.invalidate_caches() # Shouldn't trigger an exception.
+
+class Frozen_InvalidateCacheTests(InvalidateCacheTests, unittest.TestCase):
+ init = frozen_init
+
+class Source_InvalidateCacheTests(InvalidateCacheTests, unittest.TestCase):
+ init = source_init
class FrozenImportlibTests(unittest.TestCase):
def test_no_frozen_importlib(self):
# Should be able to import w/o _frozen_importlib being defined.
- module = support.import_fresh_module('importlib', blocked=['_frozen_importlib'])
- self.assertFalse(isinstance(module.__loader__,
- machinery.FrozenImporter))
+ # Can't do an isinstance() check since separate copies of importlib
+ # may have been used for import, so just check the name is not for the
+ # frozen loader.
+ self.assertNotEqual(source_init.__loader__.__class__.__name__,
+ 'FrozenImporter')
-class StartupTests(unittest.TestCase):
+class StartupTests:
def test_everyone_has___loader__(self):
# Issue #17098: all modules should have __loader__ defined.
for name, module in sys.modules.items():
if isinstance(module, types.ModuleType):
- if name in sys.builtin_module_names:
- self.assertEqual(importlib.machinery.BuiltinImporter,
- module.__loader__)
- elif imp.is_frozen(name):
- self.assertEqual(importlib.machinery.FrozenImporter,
- module.__loader__)
-
-def test_main():
- from test.support import run_unittest
- run_unittest(ImportModuleTests,
- FindLoaderTests,
- InvalidateCacheTests,
- FrozenImportlibTests,
- StartupTests)
+ self.assertTrue(hasattr(module, '__loader__'),
+ '{!r} lacks a __loader__ attribute'.format(name))
+ if self.machinery.BuiltinImporter.find_module(name):
+ self.assertIsNot(module.__loader__, None)
+ elif self.machinery.FrozenImporter.find_module(name):
+ self.assertIsNot(module.__loader__, None)
+
+class Frozen_StartupTests(StartupTests, unittest.TestCase):
+ machinery = frozen_machinery
+
+class Source_StartupTests(StartupTests, unittest.TestCase):
+ machinery = source_machinery
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py
index c373b11256..dc97ba1567 100644
--- a/Lib/test/test_importlib/test_locks.py
+++ b/Lib/test/test_importlib/test_locks.py
@@ -1,4 +1,8 @@
-from importlib import _bootstrap
+from . import util
+frozen_init, source_init = util.import_importlib('importlib')
+frozen_bootstrap = frozen_init._bootstrap
+source_bootstrap = source_init._bootstrap
+
import sys
import time
import unittest
@@ -13,14 +17,9 @@ except ImportError:
else:
from test import lock_tests
-
-LockType = _bootstrap._ModuleLock
-DeadlockError = _bootstrap._DeadlockError
-
-
if threading is not None:
- class ModuleLockAsRLockTests(lock_tests.RLockTests):
- locktype = staticmethod(lambda: LockType("some_lock"))
+ class ModuleLockAsRLockTests:
+ locktype = classmethod(lambda cls: cls.LockType("some_lock"))
# _is_owned() unsupported
test__is_owned = None
@@ -34,13 +33,21 @@ if threading is not None:
# _release_save() unsupported
test_release_save_unacquired = None
+ class Frozen_ModuleLockAsRLockTests(ModuleLockAsRLockTests, lock_tests.RLockTests):
+ LockType = frozen_bootstrap._ModuleLock
+
+ class Source_ModuleLockAsRLockTests(ModuleLockAsRLockTests, lock_tests.RLockTests):
+ LockType = source_bootstrap._ModuleLock
+
else:
- class ModuleLockAsRLockTests(unittest.TestCase):
+ class Frozen_ModuleLockAsRLockTests(unittest.TestCase):
pass
+ class Source_ModuleLockAsRLockTests(unittest.TestCase):
+ pass
-@unittest.skipUnless(threading, "threads needed for this test")
-class DeadlockAvoidanceTests(unittest.TestCase):
+
+class DeadlockAvoidanceTests:
def setUp(self):
try:
@@ -55,7 +62,7 @@ class DeadlockAvoidanceTests(unittest.TestCase):
def run_deadlock_avoidance_test(self, create_deadlock):
NLOCKS = 10
- locks = [LockType(str(i)) for i in range(NLOCKS)]
+ locks = [self.LockType(str(i)) for i in range(NLOCKS)]
pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)]
if create_deadlock:
NTHREADS = NLOCKS
@@ -67,7 +74,7 @@ class DeadlockAvoidanceTests(unittest.TestCase):
"""Try to acquire the lock. Return True on success, False on deadlock."""
try:
lock.acquire()
- except DeadlockError:
+ except self.DeadlockError:
return False
else:
return True
@@ -99,30 +106,50 @@ class DeadlockAvoidanceTests(unittest.TestCase):
self.assertEqual(results.count((True, False)), 0)
self.assertEqual(results.count((True, True)), len(results))
+@unittest.skipUnless(threading, "threads needed for this test")
+class Frozen_DeadlockAvoidanceTests(DeadlockAvoidanceTests, unittest.TestCase):
+ LockType = frozen_bootstrap._ModuleLock
+ DeadlockError = frozen_bootstrap._DeadlockError
+
+@unittest.skipUnless(threading, "threads needed for this test")
+class Source_DeadlockAvoidanceTests(DeadlockAvoidanceTests, unittest.TestCase):
+ LockType = source_bootstrap._ModuleLock
+ DeadlockError = source_bootstrap._DeadlockError
-class LifetimeTests(unittest.TestCase):
+
+class LifetimeTests:
def test_lock_lifetime(self):
name = "xyzzy"
- self.assertNotIn(name, _bootstrap._module_locks)
- lock = _bootstrap._get_module_lock(name)
- self.assertIn(name, _bootstrap._module_locks)
+ self.assertNotIn(name, self.bootstrap._module_locks)
+ lock = self.bootstrap._get_module_lock(name)
+ self.assertIn(name, self.bootstrap._module_locks)
wr = weakref.ref(lock)
del lock
support.gc_collect()
- self.assertNotIn(name, _bootstrap._module_locks)
+ self.assertNotIn(name, self.bootstrap._module_locks)
self.assertIsNone(wr())
def test_all_locks(self):
support.gc_collect()
- self.assertEqual(0, len(_bootstrap._module_locks), _bootstrap._module_locks)
+ self.assertEqual(0, len(self.bootstrap._module_locks),
+ self.bootstrap._module_locks)
+
+class Frozen_LifetimeTests(LifetimeTests, unittest.TestCase):
+ bootstrap = frozen_bootstrap
+
+class Source_LifetimeTests(LifetimeTests, unittest.TestCase):
+ bootstrap = source_bootstrap
@support.reap_threads
def test_main():
- support.run_unittest(ModuleLockAsRLockTests,
- DeadlockAvoidanceTests,
- LifetimeTests)
+ support.run_unittest(Frozen_ModuleLockAsRLockTests,
+ Source_ModuleLockAsRLockTests,
+ Frozen_DeadlockAvoidanceTests,
+ Source_DeadlockAvoidanceTests,
+ Frozen_LifetimeTests,
+ Source_LifetimeTests)
if __name__ == '__main__':
diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py
index efc8977fb4..2ac57df3ac 100644
--- a/Lib/test/test_importlib/test_util.py
+++ b/Lib/test/test_importlib/test_util.py
@@ -1,23 +1,129 @@
from importlib import util
from . import util as test_util
-import imp
+frozen_util, source_util = test_util.import_importlib('importlib.util')
+
+import os
import sys
+from test import support
import types
import unittest
+import warnings
+
+
+class DecodeSourceBytesTests:
+
+ source = "string ='ü'"
+
+ def test_ut8_default(self):
+ source_bytes = self.source.encode('utf-8')
+ self.assertEqual(self.util.decode_source(source_bytes), self.source)
+
+ def test_specified_encoding(self):
+ source = '# coding=latin-1\n' + self.source
+ source_bytes = source.encode('latin-1')
+ assert source_bytes != source.encode('utf-8')
+ self.assertEqual(self.util.decode_source(source_bytes), source)
+
+ def test_universal_newlines(self):
+ source = '\r\n'.join([self.source, self.source])
+ source_bytes = source.encode('utf-8')
+ self.assertEqual(self.util.decode_source(source_bytes),
+ '\n'.join([self.source, self.source]))
+
+Frozen_DecodeSourceBytesTests, Source_DecodeSourceBytesTests = test_util.test_both(
+ DecodeSourceBytesTests, util=[frozen_util, source_util])
+
+
+class ModuleToLoadTests:
+ module_name = 'ModuleManagerTest_module'
-class ModuleForLoaderTests(unittest.TestCase):
+ def setUp(self):
+ support.unload(self.module_name)
+ self.addCleanup(support.unload, self.module_name)
+
+ def test_new_module(self):
+ # Test a new module is created, inserted into sys.modules, has
+ # __initializing__ set to True after entering the context manager,
+ # and __initializing__ set to False after exiting.
+ with self.util.module_to_load(self.module_name) as module:
+ self.assertIn(self.module_name, sys.modules)
+ self.assertIs(sys.modules[self.module_name], module)
+ self.assertTrue(module.__initializing__)
+ self.assertFalse(module.__initializing__)
+
+ def test_new_module_failed(self):
+ # Test the module is removed from sys.modules.
+ try:
+ with self.util.module_to_load(self.module_name) as module:
+ self.assertIn(self.module_name, sys.modules)
+ raise exception
+ except Exception:
+ self.assertNotIn(self.module_name, sys.modules)
+ else:
+ self.fail('importlib.util.module_to_load swallowed an exception')
+
+ def test_reload(self):
+ # Test that the same module is in sys.modules.
+ created_module = types.ModuleType(self.module_name)
+ sys.modules[self.module_name] = created_module
+ with self.util.module_to_load(self.module_name) as module:
+ self.assertIs(module, created_module)
+
+ def test_reload_failed(self):
+ # Test that the module was left in sys.modules.
+ created_module = types.ModuleType(self.module_name)
+ sys.modules[self.module_name] = created_module
+ try:
+ with self.util.module_to_load(self.module_name) as module:
+ raise Exception
+ except Exception:
+ self.assertIn(self.module_name, sys.modules)
+ else:
+ self.fail('importlib.util.module_to_load swallowed an exception')
+
+ def test_reset_name(self):
+ # If reset_name is true then module.__name__ = name, else leave it be.
+ odd_name = 'not your typical name'
+ created_module = types.ModuleType(self.module_name)
+ created_module.__name__ = odd_name
+ sys.modules[self.module_name] = created_module
+ with self.util.module_to_load(self.module_name) as module:
+ self.assertEqual(module.__name__, self.module_name)
+ created_module.__name__ = odd_name
+ with self.util.module_to_load(self.module_name, reset_name=False) as module:
+ self.assertEqual(module.__name__, odd_name)
+
+Frozen_ModuleToLoadTests, Source_ModuleToLoadTests = test_util.test_both(
+ ModuleToLoadTests,
+ util=[frozen_util, source_util])
+
+
+class ModuleForLoaderTests:
"""Tests for importlib.util.module_for_loader."""
+ @classmethod
+ def module_for_loader(cls, func):
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', PendingDeprecationWarning)
+ return cls.util.module_for_loader(func)
+
+ def test_warning(self):
+ # Should raise a PendingDeprecationWarning when used.
+ with warnings.catch_warnings():
+ warnings.simplefilter('error', PendingDeprecationWarning)
+ with self.assertRaises(PendingDeprecationWarning):
+ func = self.util.module_for_loader(lambda x: x)
+
def return_module(self, name):
- fxn = util.module_for_loader(lambda self, module: module)
+ fxn = self.module_for_loader(lambda self, module: module)
return fxn(self, name)
def raise_exception(self, name):
def to_wrap(self, module):
raise ImportError
- fxn = util.module_for_loader(to_wrap)
+ fxn = self.module_for_loader(to_wrap)
try:
fxn(self, name)
except ImportError:
@@ -35,12 +141,23 @@ class ModuleForLoaderTests(unittest.TestCase):
def test_reload(self):
# Test that a module is reused if already in sys.modules.
+ class FakeLoader:
+ def is_package(self, name):
+ return True
+ @self.module_for_loader
+ def load_module(self, module):
+ return module
name = 'a.b.c'
- module = imp.new_module('a.b.c')
+ module = types.ModuleType('a.b.c')
+ module.__loader__ = 42
+ module.__package__ = 42
with test_util.uncache(name):
sys.modules[name] = module
- returned_module = self.return_module(name)
+ loader = FakeLoader()
+ returned_module = loader.load_module(name)
self.assertIs(returned_module, sys.modules[name])
+ self.assertEqual(module.__loader__, loader)
+ self.assertEqual(module.__package__, name)
def test_new_module_failure(self):
# Test that a module is removed from sys.modules if added but an
@@ -53,7 +170,7 @@ class ModuleForLoaderTests(unittest.TestCase):
def test_reload_failure(self):
# Test that a failure on reload leaves the module in-place.
name = 'a.b.c'
- module = imp.new_module(name)
+ module = types.ModuleType(name)
with test_util.uncache(name):
sys.modules[name] = module
self.raise_exception(name)
@@ -61,7 +178,7 @@ class ModuleForLoaderTests(unittest.TestCase):
def test_decorator_attrs(self):
def fxn(self, module): pass
- wrapped = util.module_for_loader(fxn)
+ wrapped = self.module_for_loader(fxn)
self.assertEqual(wrapped.__name__, fxn.__name__)
self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
@@ -87,7 +204,7 @@ class ModuleForLoaderTests(unittest.TestCase):
self._pkg = is_package
def is_package(self, name):
return self._pkg
- @util.module_for_loader
+ @self.module_for_loader
def load_module(self, module):
return module
@@ -107,8 +224,11 @@ class ModuleForLoaderTests(unittest.TestCase):
self.assertIs(module.__loader__, loader)
self.assertEqual(module.__package__, name)
+Frozen_ModuleForLoaderTests, Source_ModuleForLoaderTests = test_util.test_both(
+ ModuleForLoaderTests, util=[frozen_util, source_util])
+
-class SetPackageTests(unittest.TestCase):
+class SetPackageTests:
"""Tests for importlib.util.set_package."""
@@ -116,7 +236,7 @@ class SetPackageTests(unittest.TestCase):
"""Verify the module has the expected value for __package__ after
passing through set_package."""
fxn = lambda: module
- wrapped = util.set_package(fxn)
+ wrapped = self.util.set_package(fxn)
wrapped()
self.assertTrue(hasattr(module, '__package__'))
self.assertEqual(expect, module.__package__)
@@ -124,26 +244,26 @@ class SetPackageTests(unittest.TestCase):
def test_top_level(self):
# __package__ should be set to the empty string if a top-level module.
# Implicitly tests when package is set to None.
- module = imp.new_module('module')
+ module = types.ModuleType('module')
module.__package__ = None
self.verify(module, '')
def test_package(self):
# Test setting __package__ for a package.
- module = imp.new_module('pkg')
+ module = types.ModuleType('pkg')
module.__path__ = ['<path>']
module.__package__ = None
self.verify(module, 'pkg')
def test_submodule(self):
# Test __package__ for a module in a package.
- module = imp.new_module('pkg.mod')
+ module = types.ModuleType('pkg.mod')
module.__package__ = None
self.verify(module, 'pkg')
def test_setting_if_missing(self):
# __package__ should be set if it is missing.
- module = imp.new_module('mod')
+ module = types.ModuleType('mod')
if hasattr(module, '__package__'):
delattr(module, '__package__')
self.verify(module, '')
@@ -151,58 +271,230 @@ class SetPackageTests(unittest.TestCase):
def test_leaving_alone(self):
# If __package__ is set and not None then leave it alone.
for value in (True, False):
- module = imp.new_module('mod')
+ module = types.ModuleType('mod')
module.__package__ = value
self.verify(module, value)
def test_decorator_attrs(self):
def fxn(module): pass
- wrapped = util.set_package(fxn)
+ wrapped = self.util.set_package(fxn)
self.assertEqual(wrapped.__name__, fxn.__name__)
self.assertEqual(wrapped.__qualname__, fxn.__qualname__)
+Frozen_SetPackageTests, Source_SetPackageTests = test_util.test_both(
+ SetPackageTests, util=[frozen_util, source_util])
-class ResolveNameTests(unittest.TestCase):
+
+class SetLoaderTests:
+
+ """Tests importlib.util.set_loader()."""
+
+ class DummyLoader:
+ @util.set_loader
+ def load_module(self, module):
+ return self.module
+
+ def test_no_attribute(self):
+ loader = self.DummyLoader()
+ loader.module = types.ModuleType('blah')
+ try:
+ del loader.module.__loader__
+ except AttributeError:
+ pass
+ self.assertEqual(loader, loader.load_module('blah').__loader__)
+
+ def test_attribute_is_None(self):
+ loader = self.DummyLoader()
+ loader.module = types.ModuleType('blah')
+ loader.module.__loader__ = None
+ self.assertEqual(loader, loader.load_module('blah').__loader__)
+
+ def test_not_reset(self):
+ loader = self.DummyLoader()
+ loader.module = types.ModuleType('blah')
+ loader.module.__loader__ = 42
+ self.assertEqual(42, loader.load_module('blah').__loader__)
+
+class Frozen_SetLoaderTests(SetLoaderTests, unittest.TestCase):
+ class DummyLoader:
+ @frozen_util.set_loader
+ def load_module(self, module):
+ return self.module
+
+class Source_SetLoaderTests(SetLoaderTests, unittest.TestCase):
+ class DummyLoader:
+ @source_util.set_loader
+ def load_module(self, module):
+ return self.module
+
+
+class ResolveNameTests:
"""Tests importlib.util.resolve_name()."""
def test_absolute(self):
# bacon
- self.assertEqual('bacon', util.resolve_name('bacon', None))
+ self.assertEqual('bacon', self.util.resolve_name('bacon', None))
def test_aboslute_within_package(self):
# bacon in spam
- self.assertEqual('bacon', util.resolve_name('bacon', 'spam'))
+ self.assertEqual('bacon', self.util.resolve_name('bacon', 'spam'))
def test_no_package(self):
# .bacon in ''
with self.assertRaises(ValueError):
- util.resolve_name('.bacon', '')
+ self.util.resolve_name('.bacon', '')
def test_in_package(self):
# .bacon in spam
self.assertEqual('spam.eggs.bacon',
- util.resolve_name('.bacon', 'spam.eggs'))
+ self.util.resolve_name('.bacon', 'spam.eggs'))
def test_other_package(self):
# ..bacon in spam.bacon
self.assertEqual('spam.bacon',
- util.resolve_name('..bacon', 'spam.eggs'))
+ self.util.resolve_name('..bacon', 'spam.eggs'))
def test_escape(self):
# ..bacon in spam
with self.assertRaises(ValueError):
- util.resolve_name('..bacon', 'spam')
-
-
-def test_main():
- from test import support
- support.run_unittest(
- ModuleForLoaderTests,
- SetPackageTests,
- ResolveNameTests
- )
+ self.util.resolve_name('..bacon', 'spam')
+
+Frozen_ResolveNameTests, Source_ResolveNameTests = test_util.test_both(
+ ResolveNameTests,
+ util=[frozen_util, source_util])
+
+
+class MagicNumberTests:
+
+ def test_length(self):
+ # Should be 4 bytes.
+ self.assertEqual(len(self.util.MAGIC_NUMBER), 4)
+
+ def test_incorporates_rn(self):
+ # The magic number uses \r\n to come out wrong when splitting on lines.
+ self.assertTrue(self.util.MAGIC_NUMBER.endswith(b'\r\n'))
+
+Frozen_MagicNumberTests, Source_MagicNumberTests = test_util.test_both(
+ MagicNumberTests, util=[frozen_util, source_util])
+
+
+class PEP3147Tests:
+
+ """Tests of PEP 3147-related functions: cache_from_source and source_from_cache."""
+
+ tag = sys.implementation.cache_tag
+
+ @unittest.skipUnless(sys.implementation.cache_tag is not None,
+ 'requires sys.implementation.cache_tag not be None')
+ def test_cache_from_source(self):
+ # Given the path to a .py file, return the path to its PEP 3147
+ # defined .pyc file (i.e. under __pycache__).
+ path = os.path.join('foo', 'bar', 'baz', 'qux.py')
+ expect = os.path.join('foo', 'bar', 'baz', '__pycache__',
+ 'qux.{}.pyc'.format(self.tag))
+ self.assertEqual(self.util.cache_from_source(path, True), expect)
+
+ def test_cache_from_source_no_cache_tag(self):
+ # No cache tag means NotImplementedError.
+ with support.swap_attr(sys.implementation, 'cache_tag', None):
+ with self.assertRaises(NotImplementedError):
+ self.util.cache_from_source('whatever.py')
+
+ def test_cache_from_source_no_dot(self):
+ # Directory with a dot, filename without dot.
+ path = os.path.join('foo.bar', 'file')
+ expect = os.path.join('foo.bar', '__pycache__',
+ 'file{}.pyc'.format(self.tag))
+ self.assertEqual(self.util.cache_from_source(path, True), expect)
+
+ def test_cache_from_source_optimized(self):
+ # Given the path to a .py file, return the path to its PEP 3147
+ # defined .pyo file (i.e. under __pycache__).
+ path = os.path.join('foo', 'bar', 'baz', 'qux.py')
+ expect = os.path.join('foo', 'bar', 'baz', '__pycache__',
+ 'qux.{}.pyo'.format(self.tag))
+ self.assertEqual(self.util.cache_from_source(path, False), expect)
+
+ def test_cache_from_source_cwd(self):
+ path = 'foo.py'
+ expect = os.path.join('__pycache__', 'foo.{}.pyc'.format(self.tag))
+ self.assertEqual(self.util.cache_from_source(path, True), expect)
+
+ def test_cache_from_source_override(self):
+ # When debug_override is not None, it can be any true-ish or false-ish
+ # value.
+ path = os.path.join('foo', 'bar', 'baz.py')
+ partial_expect = os.path.join('foo', 'bar', '__pycache__',
+ 'baz.{}.py'.format(self.tag))
+ self.assertEqual(self.util.cache_from_source(path, []), partial_expect + 'o')
+ self.assertEqual(self.util.cache_from_source(path, [17]),
+ partial_expect + 'c')
+ # However if the bool-ishness can't be determined, the exception
+ # propagates.
+ class Bearish:
+ def __bool__(self): raise RuntimeError
+ with self.assertRaises(RuntimeError):
+ self.util.cache_from_source('/foo/bar/baz.py', Bearish())
+
+ @unittest.skipUnless(os.sep == '\\' and os.altsep == '/',
+ 'test meaningful only where os.altsep is defined')
+ def test_sep_altsep_and_sep_cache_from_source(self):
+ # Windows path and PEP 3147 where sep is right of altsep.
+ self.assertEqual(
+ self.util.cache_from_source('\\foo\\bar\\baz/qux.py', True),
+ '\\foo\\bar\\baz\\__pycache__\\qux.{}.pyc'.format(self.tag))
+
+ @unittest.skipUnless(sys.implementation.cache_tag is not None,
+ 'requires sys.implementation.cache_tag to not be '
+ 'None')
+ def test_source_from_cache(self):
+ # Given the path to a PEP 3147 defined .pyc file, return the path to
+ # its source. This tests the good path.
+ path = os.path.join('foo', 'bar', 'baz', '__pycache__',
+ 'qux.{}.pyc'.format(self.tag))
+ expect = os.path.join('foo', 'bar', 'baz', 'qux.py')
+ self.assertEqual(self.util.source_from_cache(path), expect)
+
+ def test_source_from_cache_no_cache_tag(self):
+ # If sys.implementation.cache_tag is None, raise NotImplementedError.
+ path = os.path.join('blah', '__pycache__', 'whatever.pyc')
+ with support.swap_attr(sys.implementation, 'cache_tag', None):
+ with self.assertRaises(NotImplementedError):
+ self.util.source_from_cache(path)
+
+ def test_source_from_cache_bad_path(self):
+ # When the path to a pyc file is not in PEP 3147 format, a ValueError
+ # is raised.
+ self.assertRaises(
+ ValueError, self.util.source_from_cache, '/foo/bar/bazqux.pyc')
+
+ def test_source_from_cache_no_slash(self):
+ # No slashes at all in path -> ValueError
+ self.assertRaises(
+ ValueError, self.util.source_from_cache, 'foo.cpython-32.pyc')
+
+ def test_source_from_cache_too_few_dots(self):
+ # Too few dots in final path component -> ValueError
+ self.assertRaises(
+ ValueError, self.util.source_from_cache, '__pycache__/foo.pyc')
+
+ def test_source_from_cache_too_many_dots(self):
+ # Too many dots in final path component -> ValueError
+ self.assertRaises(
+ ValueError, self.util.source_from_cache,
+ '__pycache__/foo.cpython-32.foo.pyc')
+
+ def test_source_from_cache_no__pycache__(self):
+ # Another problem with the path -> ValueError
+ self.assertRaises(
+ ValueError, self.util.source_from_cache,
+ '/foo/bar/foo.cpython-32.foo.pyc')
+
+Frozen_PEP3147Tests, Source_PEP3147Tests = test_util.test_both(
+ PEP3147Tests,
+ util=[frozen_util, source_util])
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py
index ef32f7d690..cb08ce6e52 100644
--- a/Lib/test/test_importlib/util.py
+++ b/Lib/test/test_importlib/util.py
@@ -1,9 +1,29 @@
from contextlib import contextmanager
-import imp
import os.path
from test import support
import unittest
import sys
+import types
+
+
+def import_importlib(module_name):
+ """Import a module from importlib both w/ and w/o _frozen_importlib."""
+ fresh = ('importlib',) if '.' in module_name else ()
+ frozen = support.import_fresh_module(module_name)
+ source = support.import_fresh_module(module_name, fresh=fresh,
+ blocked=('_frozen_importlib',))
+ return frozen, source
+
+
+def test_both(test_class, **kwargs):
+ frozen_tests = types.new_class('Frozen_'+test_class.__name__,
+ (test_class, unittest.TestCase))
+ source_tests = types.new_class('Source_'+test_class.__name__,
+ (test_class, unittest.TestCase))
+ for attr, (frozen_value, source_value) in kwargs.items():
+ setattr(frozen_tests, attr, frozen_value)
+ setattr(source_tests, attr, source_value)
+ return frozen_tests, source_tests
CASE_INSENSITIVE_FS = True
@@ -98,7 +118,7 @@ class mock_modules:
package = name.rsplit('.', 1)[0]
else:
package = import_name
- module = imp.new_module(import_name)
+ module = types.ModuleType(import_name)
module.__loader__ = self
module.__file__ = '<mock __file__>'
module.__package__ = package
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 5cbec9bb17..9d3490449a 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -8,10 +8,16 @@ import datetime
import collections
import os
import shutil
+import functools
+import importlib
from os.path import normcase
+try:
+ from concurrent.futures import ThreadPoolExecutor
+except ImportError:
+ ThreadPoolExecutor = None
from test.support import run_unittest, TESTFN, DirsOnSysPath
-
+from test.script_helper import assert_python_ok, assert_python_failure
from test import inspect_fodder as mod
from test import inspect_fodder2 as mod2
@@ -120,7 +126,6 @@ class TestPredicates(IsTestBase):
def test_get_slot_members(self):
class C(object):
__slots__ = ("a", "b")
-
x = C()
x.a = 42
members = dict(inspect.getmembers(x))
@@ -418,14 +423,14 @@ class TestBuggyCases(GetSourceBase):
unicodedata.__file__[-4:] in (".pyc", ".pyo"),
"unicodedata is not an external binary module")
def test_findsource_binary(self):
- self.assertRaises(IOError, inspect.getsource, unicodedata)
- self.assertRaises(IOError, inspect.findsource, unicodedata)
+ self.assertRaises(OSError, inspect.getsource, unicodedata)
+ self.assertRaises(OSError, inspect.findsource, unicodedata)
def test_findsource_code_in_linecache(self):
lines = ["x=1"]
co = compile(lines[0], "_dynamically_created_file", "exec")
- self.assertRaises(IOError, inspect.findsource, co)
- self.assertRaises(IOError, inspect.getsource, co)
+ self.assertRaises(OSError, inspect.findsource, co)
+ self.assertRaises(OSError, inspect.getsource, co)
linecache.cache[co.co_filename] = (1, None, lines, co.co_filename)
try:
self.assertEqual(inspect.findsource(co), (lines,0))
@@ -463,13 +468,13 @@ class _BrokenDataDescriptor(object):
A broken data descriptor. See bug #1785.
"""
def __get__(*args):
- raise AssertionError("should not __get__ data descriptors")
+ raise AttributeError("broken data descriptor")
def __set__(*args):
raise RuntimeError
def __getattr__(*args):
- raise AssertionError("should not __getattr__ data descriptors")
+ raise AttributeError("broken data descriptor")
class _BrokenMethodDescriptor(object):
@@ -477,10 +482,10 @@ class _BrokenMethodDescriptor(object):
A broken method descriptor. See bug #1785.
"""
def __get__(*args):
- raise AssertionError("should not __get__ method descriptors")
+ raise AttributeError("broken method descriptor")
def __getattr__(*args):
- raise AssertionError("should not __getattr__ method descriptors")
+ raise AttributeError("broken method descriptor")
# Helper for testing classify_class_attrs.
@@ -650,6 +655,88 @@ class TestClassesAndFunctions(unittest.TestCase):
if isinstance(builtin, type):
inspect.classify_class_attrs(builtin)
+ def test_classify_DynamicClassAttribute(self):
+ class Meta(type):
+ def __getattr__(self, name):
+ if name == 'ham':
+ return 'spam'
+ return super().__getattr__(name)
+ class VA(metaclass=Meta):
+ @types.DynamicClassAttribute
+ def ham(self):
+ return 'eggs'
+ should_find_dca = inspect.Attribute('ham', 'data', VA, VA.__dict__['ham'])
+ self.assertIn(should_find_dca, inspect.classify_class_attrs(VA))
+ should_find_ga = inspect.Attribute('ham', 'data', Meta, 'spam')
+ self.assertIn(should_find_ga, inspect.classify_class_attrs(VA))
+
+ def test_classify_metaclass_class_attribute(self):
+ class Meta(type):
+ fish = 'slap'
+ def __dir__(self):
+ return ['__class__', '__modules__', '__name__', 'fish']
+ class Class(metaclass=Meta):
+ pass
+ should_find = inspect.Attribute('fish', 'data', Meta, 'slap')
+ self.assertIn(should_find, inspect.classify_class_attrs(Class))
+
+ def test_classify_VirtualAttribute(self):
+ class Meta(type):
+ def __dir__(cls):
+ return ['__class__', '__module__', '__name__', 'BOOM']
+ def __getattr__(self, name):
+ if name =='BOOM':
+ return 42
+ return super().__getattr(name)
+ class Class(metaclass=Meta):
+ pass
+ should_find = inspect.Attribute('BOOM', 'data', Meta, 42)
+ self.assertIn(should_find, inspect.classify_class_attrs(Class))
+
+ def test_classify_VirtualAttribute_multi_classes(self):
+ class Meta1(type):
+ def __dir__(cls):
+ return ['__class__', '__module__', '__name__', 'one']
+ def __getattr__(self, name):
+ if name =='one':
+ return 1
+ return super().__getattr__(name)
+ class Meta2(type):
+ def __dir__(cls):
+ return ['__class__', '__module__', '__name__', 'two']
+ def __getattr__(self, name):
+ if name =='two':
+ return 2
+ return super().__getattr__(name)
+ class Meta3(Meta1, Meta2):
+ def __dir__(cls):
+ return list(sorted(set(['__class__', '__module__', '__name__', 'three'] +
+ Meta1.__dir__(cls) + Meta2.__dir__(cls))))
+ def __getattr__(self, name):
+ if name =='three':
+ return 3
+ return super().__getattr__(name)
+ class Class1(metaclass=Meta1):
+ pass
+ class Class2(Class1, metaclass=Meta3):
+ pass
+
+ should_find1 = inspect.Attribute('one', 'data', Meta1, 1)
+ should_find2 = inspect.Attribute('two', 'data', Meta2, 2)
+ should_find3 = inspect.Attribute('three', 'data', Meta3, 3)
+ cca = inspect.classify_class_attrs(Class2)
+ for sf in (should_find1, should_find2, should_find3):
+ self.assertIn(sf, cca)
+
+ def test_classify_class_attrs_with_buggy_dir(self):
+ class M(type):
+ def __dir__(cls):
+ return ['__class__', '__name__', 'missing']
+ class C(metaclass=M):
+ pass
+ attrs = [a[0] for a in inspect.classify_class_attrs(C)]
+ self.assertNotIn('missing', attrs)
+
def test_getmembers_descriptors(self):
class A(object):
dd = _BrokenDataDescriptor()
@@ -693,6 +780,28 @@ class TestClassesAndFunctions(unittest.TestCase):
self.assertIn(('f', b.f), inspect.getmembers(b))
self.assertIn(('f', b.f), inspect.getmembers(b, inspect.ismethod))
+ def test_getmembers_VirtualAttribute(self):
+ class M(type):
+ def __getattr__(cls, name):
+ if name == 'eggs':
+ return 'scrambled'
+ return super().__getattr__(name)
+ class A(metaclass=M):
+ @types.DynamicClassAttribute
+ def eggs(self):
+ return 'spam'
+ self.assertIn(('eggs', 'scrambled'), inspect.getmembers(A))
+ self.assertIn(('eggs', 'spam'), inspect.getmembers(A()))
+
+ def test_getmembers_with_buggy_dir(self):
+ class M(type):
+ def __dir__(cls):
+ return ['__class__', '__name__', 'missing']
+ class C(metaclass=M):
+ pass
+ attrs = [a[0] for a in inspect.getmembers(C)]
+ self.assertNotIn('missing', attrs)
+
_global_ref = object()
class TestGetClosureVars(unittest.TestCase):
@@ -1080,6 +1189,15 @@ class TestGetattrStatic(unittest.TestCase):
self.assertEqual(inspect.getattr_static(Thing, 'x'), Thing.x)
+ def test_classVirtualAttribute(self):
+ class Thing(object):
+ @types.DynamicClassAttribute
+ def x(self):
+ return self._x
+ _x = object()
+
+ self.assertEqual(inspect.getattr_static(Thing, 'x'), Thing.__dict__['x'])
+
def test_inherited_classattribute(self):
class Thing(object):
x = object()
@@ -1736,6 +1854,17 @@ class TestSignatureObject(unittest.TestCase):
((('b', ..., ..., "positional_or_keyword"),),
...))
+ # Test we handle __signature__ partway down the wrapper stack
+ def wrapped_foo_call():
+ pass
+ wrapped_foo_call.__wrapped__ = Foo.__call__
+
+ self.assertEqual(self.signature(wrapped_foo_call),
+ ((('a', ..., ..., "positional_or_keyword"),
+ ('b', ..., ..., "positional_or_keyword")),
+ ...))
+
+
def test_signature_on_class(self):
class C:
def __init__(self, a):
@@ -1850,6 +1979,10 @@ class TestSignatureObject(unittest.TestCase):
self.assertEqual(self.signature(Wrapped),
((('a', ..., ..., "positional_or_keyword"),),
...))
+ # wrapper loop:
+ Wrapped.__wrapped__ = Wrapped
+ with self.assertRaisesRegex(ValueError, 'wrapper loop'):
+ self.signature(Wrapped)
def test_signature_on_lambdas(self):
self.assertEqual(self.signature((lambda a=10: a)),
@@ -2301,6 +2434,103 @@ class TestBoundArguments(unittest.TestCase):
self.assertNotEqual(ba, ba4)
+class TestUnwrap(unittest.TestCase):
+
+ def test_unwrap_one(self):
+ def func(a, b):
+ return a + b
+ wrapper = functools.lru_cache(maxsize=20)(func)
+ self.assertIs(inspect.unwrap(wrapper), func)
+
+ def test_unwrap_several(self):
+ def func(a, b):
+ return a + b
+ wrapper = func
+ for __ in range(10):
+ @functools.wraps(wrapper)
+ def wrapper():
+ pass
+ self.assertIsNot(wrapper.__wrapped__, func)
+ self.assertIs(inspect.unwrap(wrapper), func)
+
+ def test_stop(self):
+ def func1(a, b):
+ return a + b
+ @functools.wraps(func1)
+ def func2():
+ pass
+ @functools.wraps(func2)
+ def wrapper():
+ pass
+ func2.stop_here = 1
+ unwrapped = inspect.unwrap(wrapper,
+ stop=(lambda f: hasattr(f, "stop_here")))
+ self.assertIs(unwrapped, func2)
+
+ def test_cycle(self):
+ def func1(): pass
+ func1.__wrapped__ = func1
+ with self.assertRaisesRegex(ValueError, 'wrapper loop'):
+ inspect.unwrap(func1)
+
+ def func2(): pass
+ func2.__wrapped__ = func1
+ func1.__wrapped__ = func2
+ with self.assertRaisesRegex(ValueError, 'wrapper loop'):
+ inspect.unwrap(func1)
+ with self.assertRaisesRegex(ValueError, 'wrapper loop'):
+ inspect.unwrap(func2)
+
+ def test_unhashable(self):
+ def func(): pass
+ func.__wrapped__ = None
+ class C:
+ __hash__ = None
+ __wrapped__ = func
+ self.assertIsNone(inspect.unwrap(C()))
+
+class TestMain(unittest.TestCase):
+ def test_only_source(self):
+ module = importlib.import_module('unittest')
+ rc, out, err = assert_python_ok('-m', 'inspect',
+ 'unittest')
+ lines = out.decode().splitlines()
+ # ignore the final newline
+ self.assertEqual(lines[:-1], inspect.getsource(module).splitlines())
+ self.assertEqual(err, b'')
+
+ @unittest.skipIf(ThreadPoolExecutor is None,
+ 'threads required to test __qualname__ for source files')
+ def test_qualname_source(self):
+ rc, out, err = assert_python_ok('-m', 'inspect',
+ 'concurrent.futures:ThreadPoolExecutor')
+ lines = out.decode().splitlines()
+ # ignore the final newline
+ self.assertEqual(lines[:-1],
+ inspect.getsource(ThreadPoolExecutor).splitlines())
+ self.assertEqual(err, b'')
+
+ def test_builtins(self):
+ module = importlib.import_module('unittest')
+ _, out, err = assert_python_failure('-m', 'inspect',
+ 'sys')
+ lines = err.decode().splitlines()
+ self.assertEqual(lines, ["Can't get info for builtin modules."])
+
+ def test_details(self):
+ module = importlib.import_module('unittest')
+ rc, out, err = assert_python_ok('-m', 'inspect',
+ 'unittest', '--details')
+ output = out.decode()
+ # Just a quick sanity check on the output
+ self.assertIn(module.__name__, output)
+ self.assertIn(module.__file__, output)
+ self.assertIn(module.__cached__, output)
+ self.assertEqual(err, b'')
+
+
+
+
def test_main():
run_unittest(
TestDecorators, TestRetrievingSourceCode, TestOneliners, TestBuggyCases,
@@ -2308,7 +2538,7 @@ def test_main():
TestGetcallargsFunctions, TestGetcallargsMethods,
TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState,
TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject,
- TestBoundArguments, TestGetClosureVars
+ TestBoundArguments, TestGetClosureVars, TestUnwrap, TestMain
)
if __name__ == "__main__":
diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py
index c198bcc740..4e00622325 100644
--- a/Lib/test/test_int.py
+++ b/Lib/test/test_int.py
@@ -228,6 +228,47 @@ class IntTestCases(unittest.TestCase):
self.assertRaises(TypeError, int, base=10)
self.assertRaises(TypeError, int, base=0)
+ def test_int_base_limits(self):
+ """Testing the supported limits of the int() base parameter."""
+ self.assertEqual(int('0', 5), 0)
+ with self.assertRaises(ValueError):
+ int('0', 1)
+ with self.assertRaises(ValueError):
+ int('0', 37)
+ with self.assertRaises(ValueError):
+ int('0', -909) # An old magic value base from Python 2.
+ with self.assertRaises(ValueError):
+ int('0', base=0-(2**234))
+ with self.assertRaises(ValueError):
+ int('0', base=2**234)
+ # Bases 2 through 36 are supported.
+ for base in range(2,37):
+ self.assertEqual(int('0', base=base), 0)
+
+ def test_int_base_bad_types(self):
+ """Not integer types are not valid bases; issue16772."""
+ with self.assertRaises(TypeError):
+ int('0', 5.5)
+ with self.assertRaises(TypeError):
+ int('0', 5.0)
+
+ def test_int_base_indexable(self):
+ class MyIndexable(object):
+ def __init__(self, value):
+ self.value = value
+ def __index__(self):
+ return self.value
+
+ # Check out of range bases.
+ for base in 2**100, -2**100, 1, 37:
+ with self.assertRaises(ValueError):
+ int('43', base)
+
+ # Check in-range bases.
+ self.assertEqual(int('101', base=MyIndexable(2)), 5)
+ self.assertEqual(int('101', base=MyIndexable(10)), 101)
+ self.assertEqual(int('101', base=MyIndexable(36)), 1 + 36**2)
+
def test_non_numeric_input_types(self):
# Test possible non-numeric types for the argument x, including
# subclasses of the explicitly documented accepted types.
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index 9b8920211e..c1ea6f245f 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -165,7 +165,7 @@ class CloseFailureIO(MockRawIO):
def close(self):
if not self.closed:
self.closed = 1
- raise IOError
+ raise OSError
class CCloseFailureIO(CloseFailureIO, io.RawIOBase):
pass
@@ -601,9 +601,9 @@ class IOTest(unittest.TestCase):
def test_flush_error_on_close(self):
f = self.open(support.TESTFN, "wb", buffering=0)
def bad_flush():
- raise IOError()
+ raise OSError()
f.flush = bad_flush
- self.assertRaises(IOError, f.close) # exception not swallowed
+ self.assertRaises(OSError, f.close) # exception not swallowed
self.assertTrue(f.closed)
def test_multi_close(self):
@@ -762,7 +762,7 @@ class CommonBufferedTests:
if s:
# The destructor *may* have printed an unraisable error, check it
self.assertEqual(len(s.splitlines()), 1)
- self.assertTrue(s.startswith("Exception IOError: "), s)
+ self.assertTrue(s.startswith("Exception OSError: "), s)
self.assertTrue(s.endswith(" ignored"), s)
def test_repr(self):
@@ -778,22 +778,22 @@ class CommonBufferedTests:
def test_flush_error_on_close(self):
raw = self.MockRawIO()
def bad_flush():
- raise IOError()
+ raise OSError()
raw.flush = bad_flush
b = self.tp(raw)
- self.assertRaises(IOError, b.close) # exception not swallowed
+ self.assertRaises(OSError, b.close) # exception not swallowed
self.assertTrue(b.closed)
def test_close_error_on_close(self):
raw = self.MockRawIO()
def bad_flush():
- raise IOError('flush')
+ raise OSError('flush')
def bad_close():
- raise IOError('close')
+ raise OSError('close')
raw.close = bad_close
b = self.tp(raw)
b.flush = bad_flush
- with self.assertRaises(IOError) as err: # exception not swallowed
+ with self.assertRaises(OSError) as err: # exception not swallowed
b.close()
self.assertEqual(err.exception.args, ('close',))
self.assertEqual(err.exception.__context__.args, ('flush',))
@@ -833,6 +833,14 @@ class SizeofTest:
bufio = self.tp(rawio, buffer_size=bufsize2)
self.assertEqual(sys.getsizeof(bufio), size + bufsize2)
+ @support.cpython_only
+ def test_buffer_freeing(self) :
+ bufsize = 4096
+ rawio = self.MockRawIO()
+ bufio = self.tp(rawio, buffer_size=bufsize)
+ size = sys.getsizeof(bufio) - bufsize
+ bufio.close()
+ self.assertEqual(sys.getsizeof(bufio), size)
class BufferedReaderTest(unittest.TestCase, CommonBufferedTests):
read_mode = "rb"
@@ -1007,8 +1015,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests):
def test_misbehaved_io(self):
rawio = self.MisbehavedRawIO((b"abc", b"d", b"efg"))
bufio = self.tp(rawio)
- self.assertRaises(IOError, bufio.seek, 0)
- self.assertRaises(IOError, bufio.tell)
+ self.assertRaises(OSError, bufio.seek, 0)
+ self.assertRaises(OSError, bufio.tell)
def test_no_extraneous_read(self):
# Issue #9550; when the raw IO object has satisfied the read request,
@@ -1059,17 +1067,18 @@ class CBufferedReaderTest(BufferedReaderTest, SizeofTest):
bufio = self.tp(rawio)
# _pyio.BufferedReader seems to implement reading different, so that
# checking this is not so easy.
- self.assertRaises(IOError, bufio.read, 10)
+ self.assertRaises(OSError, bufio.read, 10)
def test_garbage_collection(self):
# C BufferedReader objects are collected.
# The Python version has __del__, so it ends into gc.garbage instead
- rawio = self.FileIO(support.TESTFN, "w+b")
- f = self.tp(rawio)
- f.f = f
- wr = weakref.ref(f)
- del f
- support.gc_collect()
+ with support.check_warnings(('', ResourceWarning)):
+ rawio = self.FileIO(support.TESTFN, "w+b")
+ f = self.tp(rawio)
+ f.f = f
+ wr = weakref.ref(f)
+ del f
+ support.gc_collect()
self.assertTrue(wr() is None, wr)
def test_args_error(self):
@@ -1312,9 +1321,9 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
def test_misbehaved_io(self):
rawio = self.MisbehavedRawIO()
bufio = self.tp(rawio, 5)
- self.assertRaises(IOError, bufio.seek, 0)
- self.assertRaises(IOError, bufio.tell)
- self.assertRaises(IOError, bufio.write, b"abcdef")
+ self.assertRaises(OSError, bufio.seek, 0)
+ self.assertRaises(OSError, bufio.tell)
+ self.assertRaises(OSError, bufio.write, b"abcdef")
def test_max_buffer_size_removal(self):
with self.assertRaises(TypeError):
@@ -1323,11 +1332,11 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
def test_write_error_on_close(self):
raw = self.MockRawIO()
def bad_write(b):
- raise IOError()
+ raise OSError()
raw.write = bad_write
b = self.tp(raw)
b.write(b'spam')
- self.assertRaises(IOError, b.close) # exception not swallowed
+ self.assertRaises(OSError, b.close) # exception not swallowed
self.assertTrue(b.closed)
@@ -1358,13 +1367,14 @@ class CBufferedWriterTest(BufferedWriterTest, SizeofTest):
# C BufferedWriter objects are collected, and collecting them flushes
# all data to disk.
# The Python version has __del__, so it ends into gc.garbage instead
- rawio = self.FileIO(support.TESTFN, "w+b")
- f = self.tp(rawio)
- f.write(b"123xxx")
- f.x = f
- wr = weakref.ref(f)
- del f
- support.gc_collect()
+ with support.check_warnings(('', ResourceWarning)):
+ rawio = self.FileIO(support.TESTFN, "w+b")
+ f = self.tp(rawio)
+ f.write(b"123xxx")
+ f.x = f
+ wr = weakref.ref(f)
+ del f
+ support.gc_collect()
self.assertTrue(wr() is None, wr)
with self.open(support.TESTFN, "rb") as f:
self.assertEqual(f.read(), b"123xxx")
@@ -1397,14 +1407,14 @@ class BufferedRWPairTest(unittest.TestCase):
def readable(self):
return False
- self.assertRaises(IOError, self.tp, NotReadable(), self.MockRawIO())
+ self.assertRaises(OSError, self.tp, NotReadable(), self.MockRawIO())
def test_constructor_with_not_writeable(self):
class NotWriteable(MockRawIO):
def writable(self):
return False
- self.assertRaises(IOError, self.tp, self.MockRawIO(), NotWriteable())
+ self.assertRaises(OSError, self.tp, self.MockRawIO(), NotWriteable())
def test_read(self):
pair = self.tp(self.BytesIO(b"abcdef"), self.MockRawIO())
@@ -2165,7 +2175,7 @@ class TextIOWrapperTest(unittest.TestCase):
if s:
# The destructor *may* have printed an unraisable error, check it
self.assertEqual(len(s.splitlines()), 1)
- self.assertTrue(s.startswith("Exception IOError: "), s)
+ self.assertTrue(s.startswith("Exception OSError: "), s)
self.assertTrue(s.endswith(" ignored"), s)
# Systematic tests of the text I/O API
@@ -2237,7 +2247,7 @@ class TextIOWrapperTest(unittest.TestCase):
f.seek(0)
for line in f:
self.assertEqual(line, "\xff\n")
- self.assertRaises(IOError, f.tell)
+ self.assertRaises(OSError, f.tell)
self.assertEqual(f.tell(), p2)
f.close()
@@ -2341,7 +2351,7 @@ class TextIOWrapperTest(unittest.TestCase):
def readable(self):
return False
txt = self.TextIOWrapper(UnReadable())
- self.assertRaises(IOError, txt.read)
+ self.assertRaises(OSError, txt.read)
def test_read_one_by_one(self):
txt = self.TextIOWrapper(self.BytesIO(b"AA\r\nBB"))
@@ -2516,9 +2526,9 @@ class TextIOWrapperTest(unittest.TestCase):
def test_flush_error_on_close(self):
txt = self.TextIOWrapper(self.BytesIO(self.testdata), encoding="ascii")
def bad_flush():
- raise IOError()
+ raise OSError()
txt.flush = bad_flush
- self.assertRaises(IOError, txt.close) # exception not swallowed
+ self.assertRaises(OSError, txt.close) # exception not swallowed
self.assertTrue(txt.closed)
def test_multi_close(self):
@@ -2599,14 +2609,15 @@ class CTextIOWrapperTest(TextIOWrapperTest):
# C TextIOWrapper objects are collected, and collecting them flushes
# all data to disk.
# The Python version has __del__, so it ends in gc.garbage instead.
- rawio = io.FileIO(support.TESTFN, "wb")
- b = self.BufferedWriter(rawio)
- t = self.TextIOWrapper(b, encoding="ascii")
- t.write("456def")
- t.x = t
- wr = weakref.ref(t)
- del t
- support.gc_collect()
+ with support.check_warnings(('', ResourceWarning)):
+ rawio = io.FileIO(support.TESTFN, "wb")
+ b = self.BufferedWriter(rawio)
+ t = self.TextIOWrapper(b, encoding="ascii")
+ t.write("456def")
+ t.x = t
+ wr = weakref.ref(t)
+ del t
+ support.gc_collect()
self.assertTrue(wr() is None, wr)
with self.open(support.TESTFN, "rb") as f:
self.assertEqual(f.read(), b"456def")
@@ -2896,7 +2907,7 @@ class MiscIOTest(unittest.TestCase):
for fd in fds:
try:
os.close(fd)
- except EnvironmentError as e:
+ except OSError as e:
if e.errno != errno.EBADF:
raise
self.addCleanup(cleanup_fds)
@@ -3062,15 +3073,18 @@ class SignalsTest(unittest.TestCase):
try:
wio = self.io.open(w, **fdopen_kwargs)
t.start()
- signal.alarm(1)
# Fill the pipe enough that the write will be blocking.
# It will be interrupted by the timer armed above. Since the
# other thread has read one byte, the low-level write will
# return with a successful (partial) result rather than an EINTR.
# The buffered IO layer must check for pending signal
# handlers, which in this case will invoke alarm_interrupt().
- self.assertRaises(ZeroDivisionError,
- wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1))
+ signal.alarm(1)
+ try:
+ self.assertRaises(ZeroDivisionError,
+ wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1))
+ finally:
+ signal.alarm(0)
t.join()
# We got one byte, get another one and check that it isn't a
# repeat of the first one.
@@ -3084,7 +3098,7 @@ class SignalsTest(unittest.TestCase):
# buffer, and block again.
try:
wio.close()
- except IOError as e:
+ except OSError as e:
if e.errno != errno.EBADF:
raise
@@ -3212,7 +3226,7 @@ class SignalsTest(unittest.TestCase):
# buffer, and could block (in case of failure).
try:
wio.close()
- except IOError as e:
+ except OSError as e:
if e.errno != errno.EBADF:
raise
diff --git a/Lib/test/test_ioctl.py b/Lib/test/test_ioctl.py
index 531c9afbb5..efe9f516cc 100644
--- a/Lib/test/test_ioctl.py
+++ b/Lib/test/test_ioctl.py
@@ -8,7 +8,7 @@ get_attribute(termios, 'TIOCGPGRP') #Can't run tests without this feature
try:
tty = open("/dev/tty", "rb")
-except IOError:
+except OSError:
raise unittest.SkipTest("Unable to open /dev/tty")
else:
# Skip if another process is in foreground
@@ -86,8 +86,6 @@ class IoctlTests(unittest.TestCase):
os.close(mfd)
os.close(sfd)
-def test_main():
- run_unittest(IoctlTests)
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py
index 99c54f1614..f3b1565744 100644
--- a/Lib/test/test_ipaddress.py
+++ b/Lib/test/test_ipaddress.py
@@ -1319,6 +1319,14 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(True, ipaddress.ip_network(
'127.42.0.0/16').is_loopback)
self.assertEqual(False, ipaddress.ip_network('128.0.0.0').is_loopback)
+ self.assertEqual(False,
+ ipaddress.ip_network('100.64.0.0/10').is_private)
+ self.assertEqual(False, ipaddress.ip_network('100.64.0.0/10').is_global)
+
+ self.assertEqual(True,
+ ipaddress.ip_network('192.0.2.128/25').is_private)
+ self.assertEqual(True,
+ ipaddress.ip_network('192.0.3.0/24').is_global)
# test addresses
self.assertEqual(True, ipaddress.ip_address('0.0.0.0').is_unspecified)
@@ -1384,6 +1392,10 @@ class IpaddrUnitTest(unittest.TestCase):
self.assertEqual(False, ipaddress.ip_network('::1').is_unspecified)
self.assertEqual(False, ipaddress.ip_network('::/127').is_unspecified)
+ self.assertEqual(True,
+ ipaddress.ip_network('2001::1/128').is_private)
+ self.assertEqual(True,
+ ipaddress.ip_network('200::1/128').is_global)
# test addresses
self.assertEqual(True, ipaddress.ip_address('ffff::').is_multicast)
self.assertEqual(True, ipaddress.ip_address(2**128 - 1).is_multicast)
diff --git a/Lib/test/test_iterlen.py b/Lib/test/test_iterlen.py
index 9101f6c884..152f5fc0cb 100644
--- a/Lib/test/test_iterlen.py
+++ b/Lib/test/test_iterlen.py
@@ -45,31 +45,21 @@ import unittest
from test import support
from itertools import repeat
from collections import deque
-from builtins import len as _len
+from operator import length_hint
n = 10
-def len(obj):
- try:
- return _len(obj)
- except TypeError:
- try:
- # note: this is an internal undocumented API,
- # don't rely on it in your own programs
- return obj.__length_hint__()
- except AttributeError:
- raise TypeError
class TestInvariantWithoutMutations:
def test_invariant(self):
it = self.it
for i in reversed(range(1, n+1)):
- self.assertEqual(len(it), i)
+ self.assertEqual(length_hint(it), i)
next(it)
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
self.assertRaises(StopIteration, next, it)
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
class TestTemporarilyImmutable(TestInvariantWithoutMutations):
@@ -78,12 +68,12 @@ class TestTemporarilyImmutable(TestInvariantWithoutMutations):
# length immutability during iteration
it = self.it
- self.assertEqual(len(it), n)
+ self.assertEqual(length_hint(it), n)
next(it)
- self.assertEqual(len(it), n-1)
+ self.assertEqual(length_hint(it), n-1)
self.mutate()
self.assertRaises(RuntimeError, next, it)
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
## ------- Concrete Type Tests -------
@@ -92,10 +82,6 @@ class TestRepeat(TestInvariantWithoutMutations, unittest.TestCase):
def setUp(self):
self.it = repeat(None, n)
- def test_no_len_for_infinite_repeat(self):
- # The repeat() object can also be infinite
- self.assertRaises(TypeError, len, repeat(None))
-
class TestXrange(TestInvariantWithoutMutations, unittest.TestCase):
def setUp(self):
@@ -167,14 +153,15 @@ class TestList(TestInvariantWithoutMutations, unittest.TestCase):
it = iter(d)
next(it)
next(it)
- self.assertEqual(len(it), n-2)
+ self.assertEqual(length_hint(it), n - 2)
d.append(n)
- self.assertEqual(len(it), n-1) # grow with append
+ self.assertEqual(length_hint(it), n - 1) # grow with append
d[1:] = []
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
self.assertEqual(list(it), [])
d.extend(range(20))
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
+
class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase):
@@ -186,32 +173,41 @@ class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase):
it = reversed(d)
next(it)
next(it)
- self.assertEqual(len(it), n-2)
+ self.assertEqual(length_hint(it), n - 2)
d.append(n)
- self.assertEqual(len(it), n-2) # ignore append
+ self.assertEqual(length_hint(it), n - 2) # ignore append
d[1:] = []
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
self.assertEqual(list(it), []) # confirm invariant
d.extend(range(20))
- self.assertEqual(len(it), 0)
+ self.assertEqual(length_hint(it), 0)
## -- Check to make sure exceptions are not suppressed by __length_hint__()
class BadLen(object):
- def __iter__(self): return iter(range(10))
+ def __iter__(self):
+ return iter(range(10))
+
def __len__(self):
raise RuntimeError('hello')
+
class BadLengthHint(object):
- def __iter__(self): return iter(range(10))
+ def __iter__(self):
+ return iter(range(10))
+
def __length_hint__(self):
raise RuntimeError('hello')
+
class NoneLengthHint(object):
- def __iter__(self): return iter(range(10))
+ def __iter__(self):
+ return iter(range(10))
+
def __length_hint__(self):
- return None
+ return NotImplemented
+
class TestLengthHintExceptions(unittest.TestCase):
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index 514a6b7a20..2672b00c56 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -1729,9 +1729,8 @@ class TestVariousIteratorArgs(unittest.TestCase):
class LengthTransparency(unittest.TestCase):
def test_repeat(self):
- from test.test_iterlen import len
- self.assertEqual(len(repeat(None, 50)), 50)
- self.assertRaises(TypeError, len, repeat(None))
+ self.assertEqual(operator.length_hint(repeat(None, 50)), 50)
+ self.assertEqual(operator.length_hint(repeat(None), 12), 12)
class RegressionTests(unittest.TestCase):
diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py
index d23e285909..35c02de88c 100644
--- a/Lib/test/test_json/test_decode.py
+++ b/Lib/test/test_json/test_decode.py
@@ -1,5 +1,5 @@
import decimal
-from io import StringIO
+from io import StringIO, BytesIO
from collections import OrderedDict
from test.test_json import PyTest, CTest
@@ -70,5 +70,26 @@ class TestDecode:
msg = 'escape'
self.assertRaisesRegex(ValueError, msg, self.loads, s)
+ def test_invalid_input_type(self):
+ msg = 'the JSON object must be str'
+ for value in [1, 3.14, b'bytes', b'\xff\x00', [], {}, None]:
+ self.assertRaisesRegex(TypeError, msg, self.loads, value)
+ with self.assertRaisesRegex(TypeError, msg):
+ self.json.load(BytesIO(b'[1,2,3]'))
+
+ def test_string_with_utf8_bom(self):
+ # see #18958
+ bom_json = "[1,2,3]".encode('utf-8-sig').decode('utf-8')
+ with self.assertRaises(ValueError) as cm:
+ self.loads(bom_json)
+ self.assertIn('BOM', str(cm.exception))
+ with self.assertRaises(ValueError) as cm:
+ self.json.load(StringIO(bom_json))
+ self.assertIn('BOM', str(cm.exception))
+ # make sure that the BOM is not detected in the middle of a string
+ bom_in_str = '"{}"'.format(''.encode('utf-8-sig').decode('utf-8'))
+ self.assertEqual(self.loads(bom_in_str), '\ufeff')
+ self.assertEqual(self.json.load(StringIO(bom_in_str)), '\ufeff')
+
class TestPyDecode(TestDecode, PyTest): pass
class TestCDecode(TestDecode, CTest): pass
diff --git a/Lib/test/test_json/test_enum.py b/Lib/test/test_json/test_enum.py
new file mode 100644
index 0000000000..10f414898b
--- /dev/null
+++ b/Lib/test/test_json/test_enum.py
@@ -0,0 +1,120 @@
+from enum import Enum, IntEnum
+from math import isnan
+from test.test_json import PyTest, CTest
+
+SMALL = 1
+BIG = 1<<32
+HUGE = 1<<64
+REALLY_HUGE = 1<<96
+
+class BigNum(IntEnum):
+ small = SMALL
+ big = BIG
+ huge = HUGE
+ really_huge = REALLY_HUGE
+
+E = 2.718281
+PI = 3.141593
+TAU = 2 * PI
+
+class FloatNum(float, Enum):
+ e = E
+ pi = PI
+ tau = TAU
+
+INF = float('inf')
+NEG_INF = float('-inf')
+NAN = float('nan')
+
+class WierdNum(float, Enum):
+ inf = INF
+ neg_inf = NEG_INF
+ nan = NAN
+
+class TestEnum:
+
+ def test_floats(self):
+ for enum in FloatNum:
+ self.assertEqual(self.dumps(enum), repr(enum.value))
+ self.assertEqual(float(self.dumps(enum)), enum)
+ self.assertEqual(self.loads(self.dumps(enum)), enum)
+
+ def test_weird_floats(self):
+ for enum, expected in zip(WierdNum, ('Infinity', '-Infinity', 'NaN')):
+ self.assertEqual(self.dumps(enum), expected)
+ if not isnan(enum):
+ self.assertEqual(float(self.dumps(enum)), enum)
+ self.assertEqual(self.loads(self.dumps(enum)), enum)
+ else:
+ self.assertTrue(isnan(float(self.dumps(enum))))
+ self.assertTrue(isnan(self.loads(self.dumps(enum))))
+
+ def test_ints(self):
+ for enum in BigNum:
+ self.assertEqual(self.dumps(enum), str(enum.value))
+ self.assertEqual(int(self.dumps(enum)), enum)
+ self.assertEqual(self.loads(self.dumps(enum)), enum)
+
+ def test_list(self):
+ self.assertEqual(self.dumps(list(BigNum)),
+ str([SMALL, BIG, HUGE, REALLY_HUGE]))
+ self.assertEqual(self.loads(self.dumps(list(BigNum))),
+ list(BigNum))
+ self.assertEqual(self.dumps(list(FloatNum)),
+ str([E, PI, TAU]))
+ self.assertEqual(self.loads(self.dumps(list(FloatNum))),
+ list(FloatNum))
+ self.assertEqual(self.dumps(list(WierdNum)),
+ '[Infinity, -Infinity, NaN]')
+ self.assertEqual(self.loads(self.dumps(list(WierdNum)))[:2],
+ list(WierdNum)[:2])
+ self.assertTrue(isnan(self.loads(self.dumps(list(WierdNum)))[2]))
+
+ def test_dict_keys(self):
+ s, b, h, r = BigNum
+ e, p, t = FloatNum
+ i, j, n = WierdNum
+ d = {
+ s:'tiny', b:'large', h:'larger', r:'largest',
+ e:"Euler's number", p:'pi', t:'tau',
+ i:'Infinity', j:'-Infinity', n:'NaN',
+ }
+ nd = self.loads(self.dumps(d))
+ self.assertEqual(nd[str(SMALL)], 'tiny')
+ self.assertEqual(nd[str(BIG)], 'large')
+ self.assertEqual(nd[str(HUGE)], 'larger')
+ self.assertEqual(nd[str(REALLY_HUGE)], 'largest')
+ self.assertEqual(nd[repr(E)], "Euler's number")
+ self.assertEqual(nd[repr(PI)], 'pi')
+ self.assertEqual(nd[repr(TAU)], 'tau')
+ self.assertEqual(nd['Infinity'], 'Infinity')
+ self.assertEqual(nd['-Infinity'], '-Infinity')
+ self.assertEqual(nd['NaN'], 'NaN')
+
+ def test_dict_values(self):
+ d = dict(
+ tiny=BigNum.small,
+ large=BigNum.big,
+ larger=BigNum.huge,
+ largest=BigNum.really_huge,
+ e=FloatNum.e,
+ pi=FloatNum.pi,
+ tau=FloatNum.tau,
+ i=WierdNum.inf,
+ j=WierdNum.neg_inf,
+ n=WierdNum.nan,
+ )
+ nd = self.loads(self.dumps(d))
+ self.assertEqual(nd['tiny'], SMALL)
+ self.assertEqual(nd['large'], BIG)
+ self.assertEqual(nd['larger'], HUGE)
+ self.assertEqual(nd['largest'], REALLY_HUGE)
+ self.assertEqual(nd['e'], E)
+ self.assertEqual(nd['pi'], PI)
+ self.assertEqual(nd['tau'], TAU)
+ self.assertEqual(nd['i'], INF)
+ self.assertEqual(nd['j'], NEG_INF)
+ self.assertTrue(isnan(nd['n']))
+
+class TestPyEnum(TestEnum, PyTest): pass
+class TestCEnum(TestEnum, CTest): pass
diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py
index 3dd877afdb..7caafdbddd 100644
--- a/Lib/test/test_json/test_fail.py
+++ b/Lib/test/test_json/test_fail.py
@@ -1,4 +1,5 @@
from test.test_json import PyTest, CTest
+import re
# 2007-10-05
JSONDOCS = [
@@ -100,6 +101,94 @@ class TestFail:
#This is for python encoder
self.assertRaises(TypeError, self.dumps, data, indent=True)
+ def test_truncated_input(self):
+ test_cases = [
+ ('', 'Expecting value', 0),
+ ('[', 'Expecting value', 1),
+ ('[42', "Expecting ',' delimiter", 3),
+ ('[42,', 'Expecting value', 4),
+ ('["', 'Unterminated string starting at', 1),
+ ('["spam', 'Unterminated string starting at', 1),
+ ('["spam"', "Expecting ',' delimiter", 7),
+ ('["spam",', 'Expecting value', 8),
+ ('{', 'Expecting property name enclosed in double quotes', 1),
+ ('{"', 'Unterminated string starting at', 1),
+ ('{"spam', 'Unterminated string starting at', 1),
+ ('{"spam"', "Expecting ':' delimiter", 7),
+ ('{"spam":', 'Expecting value', 8),
+ ('{"spam":42', "Expecting ',' delimiter", 10),
+ ('{"spam":42,', 'Expecting property name enclosed in double quotes', 11),
+ ]
+ test_cases += [
+ ('"', 'Unterminated string starting at', 0),
+ ('"spam', 'Unterminated string starting at', 0),
+ ]
+ for data, msg, idx in test_cases:
+ self.assertRaisesRegex(ValueError,
+ r'^{0}: line 1 column {1} \(char {2}\)'.format(
+ re.escape(msg), idx + 1, idx),
+ self.loads, data)
+
+ def test_unexpected_data(self):
+ test_cases = [
+ ('[,', 'Expecting value', 1),
+ ('{"spam":[}', 'Expecting value', 9),
+ ('[42:', "Expecting ',' delimiter", 3),
+ ('[42 "spam"', "Expecting ',' delimiter", 4),
+ ('[42,]', 'Expecting value', 4),
+ ('{"spam":[42}', "Expecting ',' delimiter", 11),
+ ('["]', 'Unterminated string starting at', 1),
+ ('["spam":', "Expecting ',' delimiter", 7),
+ ('["spam",]', 'Expecting value', 8),
+ ('{:', 'Expecting property name enclosed in double quotes', 1),
+ ('{,', 'Expecting property name enclosed in double quotes', 1),
+ ('{42', 'Expecting property name enclosed in double quotes', 1),
+ ('[{]', 'Expecting property name enclosed in double quotes', 2),
+ ('{"spam",', "Expecting ':' delimiter", 7),
+ ('{"spam"}', "Expecting ':' delimiter", 7),
+ ('[{"spam"]', "Expecting ':' delimiter", 8),
+ ('{"spam":}', 'Expecting value', 8),
+ ('[{"spam":]', 'Expecting value', 9),
+ ('{"spam":42 "ham"', "Expecting ',' delimiter", 11),
+ ('[{"spam":42]', "Expecting ',' delimiter", 11),
+ ('{"spam":42,}', 'Expecting property name enclosed in double quotes', 11),
+ ]
+ for data, msg, idx in test_cases:
+ self.assertRaisesRegex(ValueError,
+ r'^{0}: line 1 column {1} \(char {2}\)'.format(
+ re.escape(msg), idx + 1, idx),
+ self.loads, data)
+
+ def test_extra_data(self):
+ test_cases = [
+ ('[]]', 'Extra data', 2),
+ ('{}}', 'Extra data', 2),
+ ('[],[]', 'Extra data', 2),
+ ('{},{}', 'Extra data', 2),
+ ]
+ test_cases += [
+ ('42,"spam"', 'Extra data', 2),
+ ('"spam",42', 'Extra data', 6),
+ ]
+ for data, msg, idx in test_cases:
+ self.assertRaisesRegex(ValueError,
+ r'^{0}: line 1 column {1} - line 1 column {2}'
+ r' \(char {3} - {4}\)'.format(
+ re.escape(msg), idx + 1, len(data) + 1, idx, len(data)),
+ self.loads, data)
+
+ def test_linecol(self):
+ test_cases = [
+ ('!', 1, 1, 0),
+ (' !', 1, 2, 1),
+ ('\n!', 2, 1, 1),
+ ('\n \n\n !', 4, 6, 10),
+ ]
+ for data, line, col, idx in test_cases:
+ self.assertRaisesRegex(ValueError,
+ r'^Expecting value: line {0} column {1}'
+ r' \(char {2}\)$'.format(line, col, idx),
+ self.loads, data)
class TestPyFail(TestFail, PyTest): pass
class TestCFail(TestFail, CTest): pass
diff --git a/Lib/test/test_json/test_indent.py b/Lib/test/test_json/test_indent.py
index a4d4d20142..e07856f33c 100644
--- a/Lib/test/test_json/test_indent.py
+++ b/Lib/test/test_json/test_indent.py
@@ -32,6 +32,8 @@ class TestIndent:
d1 = self.dumps(h)
d2 = self.dumps(h, indent=2, sort_keys=True, separators=(',', ': '))
d3 = self.dumps(h, indent='\t', sort_keys=True, separators=(',', ': '))
+ d4 = self.dumps(h, indent=2, sort_keys=True)
+ d5 = self.dumps(h, indent='\t', sort_keys=True)
h1 = self.loads(d1)
h2 = self.loads(d2)
@@ -42,6 +44,8 @@ class TestIndent:
self.assertEqual(h3, h)
self.assertEqual(d2, expect.expandtabs(2))
self.assertEqual(d3, expect)
+ self.assertEqual(d4, d2)
+ self.assertEqual(d5, d3)
def test_indent0(self):
h = {3: 1}
diff --git a/Lib/test/test_keyword.py b/Lib/test/test_keyword.py
new file mode 100644
index 0000000000..af99f52c63
--- /dev/null
+++ b/Lib/test/test_keyword.py
@@ -0,0 +1,138 @@
+import keyword
+import unittest
+from test import support
+import filecmp
+import os
+import sys
+import subprocess
+import shutil
+import textwrap
+
+KEYWORD_FILE = support.findfile('keyword.py')
+GRAMMAR_FILE = os.path.join(os.path.split(__file__)[0],
+ '..', '..', 'Python', 'graminit.c')
+TEST_PY_FILE = 'keyword_test.py'
+GRAMMAR_TEST_FILE = 'graminit_test.c'
+PY_FILE_WITHOUT_KEYWORDS = 'minimal_keyword.py'
+NONEXISTENT_FILE = 'not_here.txt'
+
+
+class Test_iskeyword(unittest.TestCase):
+ def test_true_is_a_keyword(self):
+ self.assertTrue(keyword.iskeyword('True'))
+
+ def test_uppercase_true_is_not_a_keyword(self):
+ self.assertFalse(keyword.iskeyword('TRUE'))
+
+ def test_none_value_is_not_a_keyword(self):
+ self.assertFalse(keyword.iskeyword(None))
+
+ # This is probably an accident of the current implementation, but should be
+ # preserved for backward compatibility.
+ def test_changing_the_kwlist_does_not_affect_iskeyword(self):
+ oldlist = keyword.kwlist
+ self.addCleanup(setattr, keyword, 'kwlist', oldlist)
+ keyword.kwlist = ['its', 'all', 'eggs', 'beans', 'and', 'a', 'slice']
+ self.assertFalse(keyword.iskeyword('eggs'))
+
+
+class TestKeywordGeneration(unittest.TestCase):
+
+ def _copy_file_without_generated_keywords(self, source_file, dest_file):
+ with open(source_file, 'rb') as fp:
+ lines = fp.readlines()
+ nl = lines[0][len(lines[0].strip()):]
+ with open(dest_file, 'wb') as fp:
+ fp.writelines(lines[:lines.index(b"#--start keywords--" + nl) + 1])
+ fp.writelines(lines[lines.index(b"#--end keywords--" + nl):])
+
+ def _generate_keywords(self, grammar_file, target_keyword_py_file):
+ proc = subprocess.Popen([sys.executable,
+ KEYWORD_FILE,
+ grammar_file,
+ target_keyword_py_file], stderr=subprocess.PIPE)
+ stderr = proc.communicate()[1]
+ return proc.returncode, stderr
+
+ @unittest.skipIf(not os.path.exists(GRAMMAR_FILE),
+ 'test only works from source build directory')
+ def test_real_grammar_and_keyword_file(self):
+ self._copy_file_without_generated_keywords(KEYWORD_FILE, TEST_PY_FILE)
+ self.addCleanup(support.unlink, TEST_PY_FILE)
+ self.assertFalse(filecmp.cmp(KEYWORD_FILE, TEST_PY_FILE))
+ self.assertEqual((0, b''), self._generate_keywords(GRAMMAR_FILE,
+ TEST_PY_FILE))
+ self.assertTrue(filecmp.cmp(KEYWORD_FILE, TEST_PY_FILE))
+
+ def test_grammar(self):
+ self._copy_file_without_generated_keywords(KEYWORD_FILE, TEST_PY_FILE)
+ self.addCleanup(support.unlink, TEST_PY_FILE)
+ with open(GRAMMAR_TEST_FILE, 'w') as fp:
+ # Some of these are probably implementation accidents.
+ fp.writelines(textwrap.dedent("""\
+ {2, 1},
+ {11, "encoding_decl", 0, 2, states_79,
+ "\000\000\040\000\000\000\000\000\000\000\000\000"
+ "\000\000\000\000\000\000\000\000\000"},
+ {1, "jello"},
+ {326, 0},
+ {1, "turnip"},
+ \t{1, "This one is tab indented"
+ {278, 0},
+ {1, "crazy but legal"
+ "also legal" {1, "
+ {1, "continue"},
+ {1, "lemon"},
+ {1, "tomato"},
+ {1, "wigii"},
+ {1, 'no good'}
+ {283, 0},
+ {1, "too many spaces"}"""))
+ self.addCleanup(support.unlink, GRAMMAR_TEST_FILE)
+ self._generate_keywords(GRAMMAR_TEST_FILE, TEST_PY_FILE)
+ expected = [
+ " 'This one is tab indented',",
+ " 'also legal',",
+ " 'continue',",
+ " 'crazy but legal',",
+ " 'jello',",
+ " 'lemon',",
+ " 'tomato',",
+ " 'turnip',",
+ " 'wigii',",
+ ]
+ with open(TEST_PY_FILE) as fp:
+ lines = fp.read().splitlines()
+ start = lines.index("#--start keywords--") + 1
+ end = lines.index("#--end keywords--")
+ actual = lines[start:end]
+ self.assertEqual(actual, expected)
+
+ def test_empty_grammar_results_in_no_keywords(self):
+ self._copy_file_without_generated_keywords(KEYWORD_FILE,
+ PY_FILE_WITHOUT_KEYWORDS)
+ self.addCleanup(support.unlink, PY_FILE_WITHOUT_KEYWORDS)
+ shutil.copyfile(KEYWORD_FILE, TEST_PY_FILE)
+ self.addCleanup(support.unlink, TEST_PY_FILE)
+ self.assertEqual((0, b''), self._generate_keywords(os.devnull,
+ TEST_PY_FILE))
+ self.assertTrue(filecmp.cmp(TEST_PY_FILE, PY_FILE_WITHOUT_KEYWORDS))
+
+ def test_keywords_py_without_markers_produces_error(self):
+ rc, stderr = self._generate_keywords(os.devnull, os.devnull)
+ self.assertNotEqual(rc, 0)
+ self.assertRegex(stderr, b'does not contain format markers')
+
+ def test_missing_grammar_file_produces_error(self):
+ rc, stderr = self._generate_keywords(NONEXISTENT_FILE, KEYWORD_FILE)
+ self.assertNotEqual(rc, 0)
+ self.assertRegex(stderr, b'(?ms)' + NONEXISTENT_FILE.encode())
+
+ def test_missing_keywords_py_file_produces_error(self):
+ rc, stderr = self._generate_keywords(os.devnull, NONEXISTENT_FILE)
+ self.assertNotEqual(rc, 0)
+ self.assertRegex(stderr, b'(?ms)' + NONEXISTENT_FILE.encode())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_keywordonlyarg.py b/Lib/test/test_keywordonlyarg.py
index 0503a7fc6a..e4a44c1d8f 100644
--- a/Lib/test/test_keywordonlyarg.py
+++ b/Lib/test/test_keywordonlyarg.py
@@ -176,6 +176,18 @@ class KeywordOnlyArgTestCase(unittest.TestCase):
return __a
self.assertEqual(X().f(), 42)
+ def test_default_evaluation_order(self):
+ # See issue 16967
+ a = 42
+ with self.assertRaises(NameError) as err:
+ def f(v=a, x=b, *, y=c, z=d):
+ pass
+ self.assertEqual(str(err.exception), "name 'b' is not defined")
+ with self.assertRaises(NameError) as err:
+ f = lambda v=a, x=b, *, y=c, z=d: None
+ self.assertEqual(str(err.exception), "name 'b' is not defined")
+
+
def test_main():
run_unittest(KeywordOnlyArgTestCase)
diff --git a/Lib/test/test_kqueue.py b/Lib/test/test_kqueue.py
index e5e6058334..bafdeba6ff 100644
--- a/Lib/test/test_kqueue.py
+++ b/Lib/test/test_kqueue.py
@@ -94,7 +94,7 @@ class TestKQueue(unittest.TestCase):
client.setblocking(False)
try:
client.connect(('127.0.0.1', serverSocket.getsockname()[1]))
- except socket.error as e:
+ except OSError as e:
self.assertEqual(e.args[0], errno.EINPROGRESS)
else:
#raise AssertionError("Connect should have raised EINPROGRESS")
@@ -185,6 +185,33 @@ class TestKQueue(unittest.TestCase):
b.close()
kq.close()
+ def test_close(self):
+ open_file = open(__file__, "rb")
+ self.addCleanup(open_file.close)
+ fd = open_file.fileno()
+ kqueue = select.kqueue()
+
+ # test fileno() method and closed attribute
+ self.assertIsInstance(kqueue.fileno(), int)
+ self.assertFalse(kqueue.closed)
+
+ # test close()
+ kqueue.close()
+ self.assertTrue(kqueue.closed)
+ self.assertRaises(ValueError, kqueue.fileno)
+
+ # close() can be called more than once
+ kqueue.close()
+
+ # operations must fail with ValueError("I/O operation on closed ...")
+ self.assertRaises(ValueError, kqueue.control, None, 4)
+
+ def test_fd_non_inheritable(self):
+ kqueue = select.kqueue()
+ self.addCleanup(kqueue.close)
+ self.assertEqual(os.get_inheritable(kqueue.fileno()), False)
+
+
def test_main():
support.run_unittest(TestKQueue)
diff --git a/Lib/test/test_largefile.py b/Lib/test/test_largefile.py
index 63ee697250..5b276e76ff 100644
--- a/Lib/test/test_largefile.py
+++ b/Lib/test/test_largefile.py
@@ -159,7 +159,7 @@ def setUpModule():
# Seeking is not enough of a test: you must write and flush, too!
f.write(b'x')
f.flush()
- except (IOError, OverflowError):
+ except (OSError, OverflowError):
raise unittest.SkipTest("filesystem does not have "
"largefile support")
finally:
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index ae4ca18444..ffd0c8bf2b 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -26,6 +26,7 @@ import logging.handlers
import logging.config
import codecs
+import configparser
import datetime
import pickle
import io
@@ -58,7 +59,9 @@ try:
import smtpd
from urllib.parse import urlparse, parse_qs
from socketserver import (ThreadingUDPServer, DatagramRequestHandler,
- ThreadingTCPServer, StreamRequestHandler)
+ ThreadingTCPServer, StreamRequestHandler,
+ ThreadingUnixStreamServer,
+ ThreadingUnixDatagramServer)
except ImportError:
threading = None
try:
@@ -75,13 +78,12 @@ try:
except ImportError:
pass
-
class BaseTest(unittest.TestCase):
"""Base class for logging tests."""
log_format = "%(name)s -> %(levelname)s: %(message)s"
- expected_log_pat = r"^([\w.]+) -> ([\w]+): ([\d]+)$"
+ expected_log_pat = r"^([\w.]+) -> (\w+): (\d+)$"
message_num = 0
def setUp(self):
@@ -93,7 +95,8 @@ class BaseTest(unittest.TestCase):
self.saved_handlers = logging._handlers.copy()
self.saved_handler_list = logging._handlerList[:]
self.saved_loggers = saved_loggers = logger_dict.copy()
- self.saved_level_names = logging._levelNames.copy()
+ self.saved_name_to_level = logging._nameToLevel.copy()
+ self.saved_level_to_name = logging._levelToName.copy()
self.logger_states = logger_states = {}
for name in saved_loggers:
logger_states[name] = getattr(saved_loggers[name],
@@ -135,8 +138,10 @@ class BaseTest(unittest.TestCase):
self.root_logger.setLevel(self.original_logging_level)
logging._acquireLock()
try:
- logging._levelNames.clear()
- logging._levelNames.update(self.saved_level_names)
+ logging._levelToName.clear()
+ logging._levelToName.update(self.saved_level_to_name)
+ logging._nameToLevel.clear()
+ logging._nameToLevel.update(self.saved_name_to_level)
logging._handlers.clear()
logging._handlers.update(self.saved_handlers)
logging._handlerList[:] = self.saved_handler_list
@@ -150,12 +155,12 @@ class BaseTest(unittest.TestCase):
finally:
logging._releaseLock()
- def assert_log_lines(self, expected_values, stream=None):
+ def assert_log_lines(self, expected_values, stream=None, pat=None):
"""Match the collected log lines against the regular expression
self.expected_log_pat, and compare the extracted group values to
the expected_values list of tuples."""
stream = stream or self.stream
- pat = re.compile(self.expected_log_pat)
+ pat = re.compile(pat or self.expected_log_pat)
actual_lines = stream.getvalue().splitlines()
self.assertEqual(len(actual_lines), len(expected_values))
for actual, expected in zip(actual_lines, expected_values):
@@ -430,7 +435,7 @@ class CustomLevelsAndFiltersTest(BaseTest):
"""Test various filtering possibilities with custom logging levels."""
# Skip the logger name group.
- expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
+ expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$"
def setUp(self):
BaseTest.setUp(self)
@@ -560,7 +565,7 @@ class HandlerTest(BaseTest):
self.assertEqual(h.facility, h.LOG_USER)
self.assertTrue(h.unixsocket)
h.close()
- except socket.error: # syslogd might not be available
+ except OSError: # syslogd might not be available
pass
for method in ('GET', 'POST', 'PUT'):
if method == 'PUT':
@@ -647,41 +652,6 @@ class StreamHandlerTest(BaseTest):
# -- if it proves to be of wider utility than just test_logging
if threading:
- class TestSMTPChannel(smtpd.SMTPChannel):
- """
- This derived class has had to be created because smtpd does not
- support use of custom channel maps, although they are allowed by
- asyncore's design. Issue #11959 has been raised to address this,
- and if resolved satisfactorily, some of this code can be removed.
- """
- def __init__(self, server, conn, addr, sockmap):
- asynchat.async_chat.__init__(self, conn, sockmap)
- self.smtp_server = server
- self.conn = conn
- self.addr = addr
- self.data_size_limit = None
- self.received_lines = []
- self.smtp_state = self.COMMAND
- self.seen_greeting = ''
- self.mailfrom = None
- self.rcpttos = []
- self.received_data = ''
- self.fqdn = socket.getfqdn()
- self.num_bytes = 0
- try:
- self.peer = conn.getpeername()
- except socket.error as err:
- # a race condition may occur if the other end is closing
- # before we can get the peername
- self.close()
- if err.args[0] != errno.ENOTCONN:
- raise
- return
- self.push('220 %s %s' % (self.fqdn, smtpd.__version__))
- self.set_terminator(b'\r\n')
- self.extended_smtp = False
-
-
class TestSMTPServer(smtpd.SMTPServer):
"""
This class implements a test SMTP server.
@@ -702,37 +672,14 @@ if threading:
:func:`asyncore.loop`. This avoids changing the
:mod:`asyncore` module's global state.
"""
- channel_class = TestSMTPChannel
def __init__(self, addr, handler, poll_interval, sockmap):
- self._localaddr = addr
- self._remoteaddr = None
- self.data_size_limit = None
- self.sockmap = sockmap
- asyncore.dispatcher.__init__(self, map=sockmap)
- try:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setblocking(0)
- self.set_socket(sock, map=sockmap)
- # try to re-use a server port if possible
- self.set_reuse_addr()
- self.bind(addr)
- self.port = sock.getsockname()[1]
- self.listen(5)
- except:
- self.close()
- raise
+ smtpd.SMTPServer.__init__(self, addr, None, map=sockmap)
+ self.port = self.socket.getsockname()[1]
self._handler = handler
self._thread = None
self.poll_interval = poll_interval
- def handle_accepted(self, conn, addr):
- """
- Redefined only because the base class does not pass in a
- map, forcing use of a global in :mod:`asyncore`.
- """
- channel = self.channel_class(self, conn, addr, self.sockmap)
-
def process_message(self, peer, mailfrom, rcpttos, data):
"""
Delegates to the handler passed in to the server's constructor.
@@ -763,8 +710,8 @@ if threading:
:func:`asyncore.loop`.
"""
try:
- asyncore.loop(poll_interval, map=self.sockmap)
- except select.error:
+ asyncore.loop(poll_interval, map=self._map)
+ except OSError:
# On FreeBSD 8, closing the server repeatably
# raises this error. We swallow it if the
# server has been closed.
@@ -871,7 +818,7 @@ if threading:
sock, addr = self.socket.accept()
if self.sslctx:
sock = self.sslctx.wrap_socket(sock, server_side=True)
- except socket.error as e:
+ except OSError as e:
# socket errors are silenced by the caller, print them here
sys.stderr.write("Got an error:\n%s\n" % e)
raise
@@ -908,6 +855,9 @@ if threading:
super(TestTCPServer, self).server_bind()
self.port = self.socket.getsockname()[1]
+ class TestUnixStreamServer(TestTCPServer):
+ address_family = socket.AF_UNIX
+
class TestUDPServer(ControlMixin, ThreadingUDPServer):
"""
A UDP server which is controllable using :class:`ControlMixin`.
@@ -937,7 +887,7 @@ if threading:
if data:
try:
super(DelegatingUDPRequestHandler, self).finish()
- except socket.error:
+ except OSError:
if not self.server._closed:
raise
@@ -955,6 +905,9 @@ if threading:
super(TestUDPServer, self).server_close()
self._closed = True
+ class TestUnixDatagramServer(TestUDPServer):
+ address_family = socket.AF_UNIX
+
# - end of server_helper section
@unittest.skipUnless(threading, 'Threading required for this test.')
@@ -991,7 +944,7 @@ class MemoryHandlerTest(BaseTest):
"""Tests for the MemoryHandler."""
# Do not bother with a logger name group.
- expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
+ expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$"
def setUp(self):
BaseTest.setUp(self)
@@ -1044,7 +997,7 @@ class ConfigFileTest(BaseTest):
"""Reading logging config from a .ini-style config file."""
- expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$"
+ expected_log_pat = r"^(\w+) \+\+ (\w+)$"
# config0 is a standard configuration.
config0 = """
@@ -1292,6 +1245,24 @@ class ConfigFileTest(BaseTest):
# Original logger output is empty.
self.assert_log_lines([])
+ def test_config0_using_cp_ok(self):
+ # A simple config file which overrides the default settings.
+ with captured_stdout() as output:
+ file = io.StringIO(textwrap.dedent(self.config0))
+ cp = configparser.ConfigParser()
+ cp.read_file(file)
+ logging.config.fileConfig(cp)
+ logger = logging.getLogger()
+ # Won't output anything
+ logger.info(self.next_message())
+ # Outputs a message
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('ERROR', '2'),
+ ], stream=output)
+ # Original logger output is empty.
+ self.assert_log_lines([])
+
def test_config1_ok(self, config=config1):
# A config file defining a sub-parser as well.
with captured_stdout() as output:
@@ -1381,7 +1352,7 @@ class ConfigFileTest(BaseTest):
def test_logger_disabling(self):
self.apply_config(self.disable_test)
- logger = logging.getLogger('foo')
+ logger = logging.getLogger('some_pristine_logger')
self.assertFalse(logger.disabled)
self.apply_config(self.disable_test)
self.assertTrue(logger.disabled)
@@ -1394,17 +1365,23 @@ class SocketHandlerTest(BaseTest):
"""Test for SocketHandler objects."""
+ if threading:
+ server_class = TestTCPServer
+ address = ('localhost', 0)
+
def setUp(self):
"""Set up a TCP server to receive log messages, and a SocketHandler
pointing to that server's address and port."""
BaseTest.setUp(self)
- addr = ('localhost', 0)
- self.server = server = TestTCPServer(addr, self.handle_socket,
- 0.01)
+ self.server = server = self.server_class(self.address,
+ self.handle_socket, 0.01)
server.start()
server.ready.wait()
- self.sock_hdlr = logging.handlers.SocketHandler('localhost',
- server.port)
+ hcls = logging.handlers.SocketHandler
+ if isinstance(server.server_address, tuple):
+ self.sock_hdlr = hcls('localhost', server.port)
+ else:
+ self.sock_hdlr = hcls(server.server_address, None)
self.log_output = ''
self.root_logger.removeHandler(self.root_logger.handlers[0])
self.root_logger.addHandler(self.sock_hdlr)
@@ -1444,35 +1421,69 @@ class SocketHandlerTest(BaseTest):
self.assertEqual(self.log_output, "spam\neggs\n")
def test_noserver(self):
+ # Avoid timing-related failures due to SocketHandler's own hard-wired
+ # one-second timeout on socket.create_connection() (issue #16264).
+ self.sock_hdlr.retryStart = 2.5
# Kill the server
self.server.stop(2.0)
- #The logging call should try to connect, which should fail
+ # The logging call should try to connect, which should fail
try:
raise RuntimeError('Deliberate mistake')
except RuntimeError:
self.root_logger.exception('Never sent')
self.root_logger.error('Never sent, either')
now = time.time()
- self.assertTrue(self.sock_hdlr.retryTime > now)
+ self.assertGreater(self.sock_hdlr.retryTime, now)
time.sleep(self.sock_hdlr.retryTime - now + 0.001)
self.root_logger.error('Nor this')
+def _get_temp_domain_socket():
+ fd, fn = tempfile.mkstemp(prefix='test_logging_', suffix='.sock')
+ os.close(fd)
+ # just need a name - file can't be present, or we'll get an
+ # 'address already in use' error.
+ os.remove(fn)
+ return fn
+
+@unittest.skipUnless(threading, 'Threading required for this test.')
+class UnixSocketHandlerTest(SocketHandlerTest):
+
+ """Test for SocketHandler with unix sockets."""
+
+ if threading:
+ server_class = TestUnixStreamServer
+
+ def setUp(self):
+ # override the definition in the base class
+ self.address = _get_temp_domain_socket()
+ SocketHandlerTest.setUp(self)
+
+ def tearDown(self):
+ SocketHandlerTest.tearDown(self)
+ os.remove(self.address)
@unittest.skipUnless(threading, 'Threading required for this test.')
class DatagramHandlerTest(BaseTest):
"""Test for DatagramHandler."""
+ if threading:
+ server_class = TestUDPServer
+ address = ('localhost', 0)
+
def setUp(self):
"""Set up a UDP server to receive log messages, and a DatagramHandler
pointing to that server's address and port."""
BaseTest.setUp(self)
- addr = ('localhost', 0)
- self.server = server = TestUDPServer(addr, self.handle_datagram, 0.01)
+ self.server = server = self.server_class(self.address,
+ self.handle_datagram, 0.01)
server.start()
server.ready.wait()
- self.sock_hdlr = logging.handlers.DatagramHandler('localhost',
- server.port)
+ hcls = logging.handlers.DatagramHandler
+ if isinstance(server.server_address, tuple):
+ self.sock_hdlr = hcls('localhost', server.port)
+ else:
+ self.sock_hdlr = hcls(server.server_address, None)
self.log_output = ''
self.root_logger.removeHandler(self.root_logger.handlers[0])
self.root_logger.addHandler(self.sock_hdlr)
@@ -1507,21 +1518,44 @@ class DatagramHandlerTest(BaseTest):
@unittest.skipUnless(threading, 'Threading required for this test.')
+class UnixDatagramHandlerTest(DatagramHandlerTest):
+
+ """Test for DatagramHandler using Unix sockets."""
+
+ if threading:
+ server_class = TestUnixDatagramServer
+
+ def setUp(self):
+ # override the definition in the base class
+ self.address = _get_temp_domain_socket()
+ DatagramHandlerTest.setUp(self)
+
+ def tearDown(self):
+ DatagramHandlerTest.tearDown(self)
+ os.remove(self.address)
+
+@unittest.skipUnless(threading, 'Threading required for this test.')
class SysLogHandlerTest(BaseTest):
"""Test for SysLogHandler using UDP."""
+ if threading:
+ server_class = TestUDPServer
+ address = ('localhost', 0)
+
def setUp(self):
"""Set up a UDP server to receive log messages, and a SysLogHandler
pointing to that server's address and port."""
BaseTest.setUp(self)
- addr = ('localhost', 0)
- self.server = server = TestUDPServer(addr, self.handle_datagram,
- 0.01)
+ self.server = server = self.server_class(self.address,
+ self.handle_datagram, 0.01)
server.start()
server.ready.wait()
- self.sl_hdlr = logging.handlers.SysLogHandler(('localhost',
- server.port))
+ hcls = logging.handlers.SysLogHandler
+ if isinstance(server.server_address, tuple):
+ self.sl_hdlr = hcls(('localhost', server.port))
+ else:
+ self.sl_hdlr = hcls(server.server_address)
self.log_output = ''
self.root_logger.removeHandler(self.root_logger.handlers[0])
self.root_logger.addHandler(self.sl_hdlr)
@@ -1559,6 +1593,23 @@ class SysLogHandlerTest(BaseTest):
@unittest.skipUnless(threading, 'Threading required for this test.')
+class UnixSysLogHandlerTest(SysLogHandlerTest):
+
+ """Test for SysLogHandler with Unix sockets."""
+
+ if threading:
+ server_class = TestUnixDatagramServer
+
+ def setUp(self):
+ # override the definition in the base class
+ self.address = _get_temp_domain_socket()
+ SysLogHandlerTest.setUp(self)
+
+ def tearDown(self):
+ SysLogHandlerTest.tearDown(self)
+ os.remove(self.address)
+
+@unittest.skipUnless(threading, 'Threading required for this test.')
class HTTPHandlerTest(BaseTest):
"""Test for HTTPHandler."""
@@ -1779,7 +1830,7 @@ class WarningsTest(BaseTest):
logger.removeHandler(h)
s = stream.getvalue()
h.close()
- self.assertTrue(s.find("UserWarning: I'm warning you...\n") > 0)
+ self.assertGreater(s.find("UserWarning: I'm warning you...\n"), 0)
#See if an explicit file uses the original implementation
a_file = io.StringIO()
@@ -1817,7 +1868,7 @@ class ConfigDictTest(BaseTest):
"""Reading logging config from a dictionary."""
- expected_log_pat = r"^([\w]+) \+\+ ([\w]+)$"
+ expected_log_pat = r"^(\w+) \+\+ (\w+)$"
# config0 is a standard configuration.
config0 = {
@@ -2389,6 +2440,32 @@ class ConfigDictTest(BaseTest):
},
}
+ # As config0, but with properties
+ config14 = {
+ 'version': 1,
+ 'formatters': {
+ 'form1' : {
+ 'format' : '%(levelname)s ++ %(message)s',
+ },
+ },
+ 'handlers' : {
+ 'hand1' : {
+ 'class' : 'logging.StreamHandler',
+ 'formatter' : 'form1',
+ 'level' : 'NOTSET',
+ 'stream' : 'ext://sys.stdout',
+ '.': {
+ 'foo': 'bar',
+ 'terminator': '!\n',
+ }
+ },
+ },
+ 'root' : {
+ 'level' : 'WARNING',
+ 'handlers' : ['hand1'],
+ },
+ }
+
out_of_order = {
"version": 1,
"formatters": {
@@ -2656,11 +2733,20 @@ class ConfigDictTest(BaseTest):
def test_config13_failure(self):
self.assertRaises(Exception, self.apply_config, self.config13)
+ def test_config14_ok(self):
+ with captured_stdout() as output:
+ self.apply_config(self.config14)
+ h = logging._handlers['hand1']
+ self.assertEqual(h.foo, 'bar')
+ self.assertEqual(h.terminator, '!\n')
+ logging.warning('Exclamation')
+ self.assertTrue(output.getvalue().endswith('Exclamation!\n'))
+
@unittest.skipUnless(threading, 'listen() needs threading to work')
- def setup_via_listener(self, text):
+ def setup_via_listener(self, text, verify=None):
text = text.encode("utf-8")
# Ask for a randomly assigned port (by using port 0)
- t = logging.config.listen(0)
+ t = logging.config.listen(0, verify)
t.start()
t.ready.wait()
# Now get the port allocated
@@ -2720,6 +2806,69 @@ class ConfigDictTest(BaseTest):
# Original logger output is empty.
self.assert_log_lines([])
+ @unittest.skipUnless(threading, 'Threading required for this test.')
+ def test_listen_verify(self):
+
+ def verify_fail(stuff):
+ return None
+
+ def verify_reverse(stuff):
+ return stuff[::-1]
+
+ logger = logging.getLogger("compiler.parser")
+ to_send = textwrap.dedent(ConfigFileTest.config1)
+ # First, specify a verification function that will fail.
+ # We expect to see no output, since our configuration
+ # never took effect.
+ with captured_stdout() as output:
+ self.setup_via_listener(to_send, verify_fail)
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([], stream=output)
+ # Original logger output has the stuff we logged.
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], pat=r"^[\w.]+ -> (\w+): (\d+)$")
+
+ # Now, perform no verification. Our configuration
+ # should take effect.
+
+ with captured_stdout() as output:
+ self.setup_via_listener(to_send) # no verify callable specified
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '3'),
+ ('ERROR', '4'),
+ ], stream=output)
+ # Original logger output still has the stuff we logged before.
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], pat=r"^[\w.]+ -> (\w+): (\d+)$")
+
+ # Now, perform verification which transforms the bytes.
+
+ with captured_stdout() as output:
+ self.setup_via_listener(to_send[::-1], verify_reverse)
+ logger = logging.getLogger("compiler.parser")
+ # Both will output a message
+ logger.info(self.next_message())
+ logger.error(self.next_message())
+ self.assert_log_lines([
+ ('INFO', '5'),
+ ('ERROR', '6'),
+ ], stream=output)
+ # Original logger output still has the stuff we logged before.
+ self.assert_log_lines([
+ ('INFO', '1'),
+ ('ERROR', '2'),
+ ], pat=r"^[\w.]+ -> (\w+): (\d+)$")
+
def test_out_of_order(self):
self.apply_config(self.out_of_order)
handler = logging.getLogger('mymodule').handlers[0]
@@ -2779,14 +2928,14 @@ class ChildLoggerTest(BaseTest):
l2 = logging.getLogger('def.ghi')
c1 = r.getChild('xyz')
c2 = r.getChild('uvw.xyz')
- self.assertTrue(c1 is logging.getLogger('xyz'))
- self.assertTrue(c2 is logging.getLogger('uvw.xyz'))
+ self.assertIs(c1, logging.getLogger('xyz'))
+ self.assertIs(c2, logging.getLogger('uvw.xyz'))
c1 = l1.getChild('def')
c2 = c1.getChild('ghi')
c3 = l1.getChild('def.ghi')
- self.assertTrue(c1 is logging.getLogger('abc.def'))
- self.assertTrue(c2 is logging.getLogger('abc.def.ghi'))
- self.assertTrue(c2 is c3)
+ self.assertIs(c1, logging.getLogger('abc.def'))
+ self.assertIs(c2, logging.getLogger('abc.def.ghi'))
+ self.assertIs(c2, c3)
class DerivedLogRecord(logging.LogRecord):
@@ -2829,7 +2978,7 @@ class LogRecordFactoryTest(BaseTest):
class QueueHandlerTest(BaseTest):
# Do not bother with a logger name group.
- expected_log_pat = r"^[\w.]+ -> ([\w]+): ([\d]+)$"
+ expected_log_pat = r"^[\w.]+ -> (\w+): (\d+)$"
def setUp(self):
BaseTest.setUp(self)
@@ -3123,13 +3272,13 @@ class ShutdownTest(BaseTest):
self.assertEqual('0 - release', self.called[-1])
def test_with_ioerror_in_acquire(self):
- self._test_with_failure_in_method('acquire', IOError)
+ self._test_with_failure_in_method('acquire', OSError)
def test_with_ioerror_in_flush(self):
- self._test_with_failure_in_method('flush', IOError)
+ self._test_with_failure_in_method('flush', OSError)
def test_with_ioerror_in_close(self):
- self._test_with_failure_in_method('close', IOError)
+ self._test_with_failure_in_method('close', OSError)
def test_with_valueerror_in_acquire(self):
self._test_with_failure_in_method('acquire', ValueError)
@@ -3336,6 +3485,12 @@ class BasicConfigTest(unittest.TestCase):
self.assertEqual(logging.root.level, self.original_logging_level)
def test_filename(self):
+
+ def cleanup(h1, h2, fn):
+ h1.close()
+ h2.close()
+ os.remove(fn)
+
logging.basicConfig(filename='test.log')
self.assertEqual(len(logging.root.handlers), 1)
@@ -3343,17 +3498,23 @@ class BasicConfigTest(unittest.TestCase):
self.assertIsInstance(handler, logging.FileHandler)
expected = logging.FileHandler('test.log', 'a')
- self.addCleanup(expected.close)
self.assertEqual(handler.stream.mode, expected.stream.mode)
self.assertEqual(handler.stream.name, expected.stream.name)
+ self.addCleanup(cleanup, handler, expected, 'test.log')
def test_filemode(self):
+
+ def cleanup(h1, h2, fn):
+ h1.close()
+ h2.close()
+ os.remove(fn)
+
logging.basicConfig(filename='test.log', filemode='wb')
handler = logging.root.handlers[0]
expected = logging.FileHandler('test.log', 'wb')
- self.addCleanup(expected.close)
self.assertEqual(handler.stream.mode, expected.stream.mode)
+ self.addCleanup(cleanup, handler, expected, 'test.log')
def test_stream(self):
stream = io.StringIO()
@@ -3809,6 +3970,63 @@ class TimedRotatingFileHandlerTest(BaseFileTest):
assertRaises(ValueError, logging.handlers.TimedRotatingFileHandler,
self.fn, 'W7', delay=True)
+ def test_compute_rollover_daily_attime(self):
+ currentTime = 0
+ atTime = datetime.time(12, 0, 0)
+ rh = logging.handlers.TimedRotatingFileHandler(
+ self.fn, when='MIDNIGHT', interval=1, backupCount=0, utc=True,
+ atTime=atTime)
+ try:
+ actual = rh.computeRollover(currentTime)
+ self.assertEqual(actual, currentTime + 12 * 60 * 60)
+
+ actual = rh.computeRollover(currentTime + 13 * 60 * 60)
+ self.assertEqual(actual, currentTime + 36 * 60 * 60)
+ finally:
+ rh.close()
+
+ #@unittest.skipIf(True, 'Temporarily skipped while failures investigated.')
+ def test_compute_rollover_weekly_attime(self):
+ currentTime = int(time.time())
+ today = currentTime - currentTime % 86400
+
+ atTime = datetime.time(12, 0, 0)
+
+ wday = time.gmtime(today).tm_wday
+ for day in range(7):
+ rh = logging.handlers.TimedRotatingFileHandler(
+ self.fn, when='W%d' % day, interval=1, backupCount=0, utc=True,
+ atTime=atTime)
+ try:
+ if wday > day:
+ # The rollover day has already passed this week, so we
+ # go over into next week
+ expected = (7 - wday + day)
+ else:
+ expected = (day - wday)
+ # At this point expected is in days from now, convert to seconds
+ expected *= 24 * 60 * 60
+ # Add in the rollover time
+ expected += 12 * 60 * 60
+ # Add in adjustment for today
+ expected += today
+ actual = rh.computeRollover(today)
+ if actual != expected:
+ print('failed in timezone: %d' % time.timezone)
+ print('local vars: %s' % locals())
+ self.assertEqual(actual, expected)
+ if day == wday:
+ # goes into following week
+ expected += 7 * 24 * 60 * 60
+ actual = rh.computeRollover(today + 13 * 60 * 60)
+ if actual != expected:
+ print('failed in timezone: %d' % time.timezone)
+ print('local vars: %s' % locals())
+ self.assertEqual(actual, expected)
+ finally:
+ rh.close()
+
+
def secs(**kw):
return datetime.timedelta(**kw) // datetime.timedelta(seconds=1)
@@ -3867,7 +4085,7 @@ class NTEventLogHandlerTest(BaseTest):
h.handle(r)
h.close()
# Now see if the event is recorded
- self.assertTrue(num_recs < win32evtlog.GetNumberOfEventLogRecords(elh))
+ self.assertLess(num_recs, win32evtlog.GetNumberOfEventLogRecords(elh))
flags = win32evtlog.EVENTLOG_BACKWARDS_READ | \
win32evtlog.EVENTLOG_SEQUENTIAL_READ
found = False
@@ -3900,7 +4118,8 @@ def test_main():
SMTPHandlerTest, FileHandlerTest, RotatingFileHandlerTest,
LastResortTest, LogRecordTest, ExceptionTest,
SysLogHandlerTest, HTTPHandlerTest, NTEventLogHandlerTest,
- TimedRotatingFileHandlerTest
+ TimedRotatingFileHandlerTest, UnixSocketHandlerTest,
+ UnixDatagramHandlerTest, UnixSysLogHandlerTest
)
if __name__ == "__main__":
diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py
index baf1d6a3b2..6c30fed7c9 100644
--- a/Lib/test/test_long.py
+++ b/Lib/test/test_long.py
@@ -1079,7 +1079,7 @@ class LongTest(unittest.TestCase):
self.assertRaises(OverflowError, (256).to_bytes, 1, 'big', signed=True)
self.assertRaises(OverflowError, (256).to_bytes, 1, 'little', signed=False)
self.assertRaises(OverflowError, (256).to_bytes, 1, 'little', signed=True)
- self.assertRaises(OverflowError, (-1).to_bytes, 2, 'big', signed=False),
+ self.assertRaises(OverflowError, (-1).to_bytes, 2, 'big', signed=False)
self.assertRaises(OverflowError, (-1).to_bytes, 2, 'little', signed=False)
self.assertEqual((0).to_bytes(0, 'big'), b'')
self.assertEqual((1).to_bytes(5, 'big'), b'\x00\x00\x00\x00\x01')
diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py
index ad9045604e..26d19da5e9 100644
--- a/Lib/test/test_lzma.py
+++ b/Lib/test/test_lzma.py
@@ -371,6 +371,8 @@ class FileTestCase(unittest.TestCase):
pass
with LZMAFile(BytesIO(), "w") as f:
pass
+ with LZMAFile(BytesIO(), "x") as f:
+ pass
with LZMAFile(BytesIO(), "a") as f:
pass
@@ -398,13 +400,29 @@ class FileTestCase(unittest.TestCase):
with LZMAFile(TESTFN, "ab"):
pass
+ def test_init_with_x_mode(self):
+ self.addCleanup(unlink, TESTFN)
+ for mode in ("x", "xb"):
+ unlink(TESTFN)
+ with LZMAFile(TESTFN, mode):
+ pass
+ with self.assertRaises(FileExistsError):
+ with LZMAFile(TESTFN, mode):
+ pass
+
def test_init_bad_mode(self):
with self.assertRaises(ValueError):
LZMAFile(BytesIO(COMPRESSED_XZ), (3, "x"))
with self.assertRaises(ValueError):
LZMAFile(BytesIO(COMPRESSED_XZ), "")
with self.assertRaises(ValueError):
- LZMAFile(BytesIO(COMPRESSED_XZ), "x")
+ LZMAFile(BytesIO(COMPRESSED_XZ), "xt")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "x+")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "rx")
+ with self.assertRaises(ValueError):
+ LZMAFile(BytesIO(COMPRESSED_XZ), "wx")
with self.assertRaises(ValueError):
LZMAFile(BytesIO(COMPRESSED_XZ), "rt")
with self.assertRaises(ValueError):
@@ -678,6 +696,20 @@ class FileTestCase(unittest.TestCase):
with LZMAFile(BytesIO(COMPRESSED_XZ[:128])) as f:
self.assertRaises(EOFError, f.read)
+ def test_read_truncated(self):
+ # Drop stream footer: CRC (4 bytes), index size (4 bytes),
+ # flags (2 bytes) and magic number (2 bytes).
+ truncated = COMPRESSED_XZ[:-12]
+ with LZMAFile(BytesIO(truncated)) as f:
+ self.assertRaises(EOFError, f.read)
+ with LZMAFile(BytesIO(truncated)) as f:
+ self.assertEqual(f.read(len(INPUT)), INPUT)
+ self.assertRaises(EOFError, f.read, 1)
+ # Incomplete 12-byte header.
+ for i in range(12):
+ with LZMAFile(BytesIO(truncated[:i])) as f:
+ self.assertRaises(EOFError, f.read, 1)
+
def test_read_bad_args(self):
f = LZMAFile(BytesIO(COMPRESSED_XZ))
f.close()
@@ -1017,8 +1049,6 @@ class OpenTestCase(unittest.TestCase):
with self.assertRaises(ValueError):
lzma.open(TESTFN, "")
with self.assertRaises(ValueError):
- lzma.open(TESTFN, "x")
- with self.assertRaises(ValueError):
lzma.open(TESTFN, "rbt")
with self.assertRaises(ValueError):
lzma.open(TESTFN, "rb", encoding="utf-8")
@@ -1067,6 +1097,16 @@ class OpenTestCase(unittest.TestCase):
with lzma.open(bio, "rt", newline="\r") as f:
self.assertEqual(f.readlines(), [text])
+ def test_x_mode(self):
+ self.addCleanup(unlink, TESTFN)
+ for mode in ("x", "xb", "xt"):
+ unlink(TESTFN)
+ with lzma.open(TESTFN, mode):
+ pass
+ with self.assertRaises(FileExistsError):
+ with lzma.open(TESTFN, mode):
+ pass
+
class MiscellaneousTestCase(unittest.TestCase):
diff --git a/Lib/test/test_mailbox.py b/Lib/test/test_mailbox.py
index f2e4c63240..78e2a30dbc 100644
--- a/Lib/test/test_mailbox.py
+++ b/Lib/test/test_mailbox.py
@@ -596,7 +596,7 @@ class TestMaildir(TestMailbox, unittest.TestCase):
def setUp(self):
TestMailbox.setUp(self)
- if os.name in ('nt', 'os2') or sys.platform == 'cygwin':
+ if (os.name == 'nt') or (sys.platform == 'cygwin'):
self._box.colon = '!'
def assertMailboxEmpty(self):
diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py
index 7e37f39c6a..255eb87bb3 100644
--- a/Lib/test/test_marshal.py
+++ b/Lib/test/test_marshal.py
@@ -24,37 +24,13 @@ class HelperMixin:
class IntTestCase(unittest.TestCase, HelperMixin):
def test_ints(self):
- # Test the full range of Python ints.
- n = sys.maxsize
+ # Test a range of Python ints larger than the machine word size.
+ n = sys.maxsize ** 2
while n:
for expected in (-n, n):
self.helper(expected)
n = n >> 1
- def test_int64(self):
- # Simulate int marshaling on a 64-bit box. This is most interesting if
- # we're running the test on a 32-bit box, of course.
-
- def to_little_endian_string(value, nbytes):
- b = bytearray()
- for i in range(nbytes):
- b.append(value & 0xff)
- value >>= 8
- return b
-
- maxint64 = (1 << 63) - 1
- minint64 = -maxint64-1
-
- for base in maxint64, minint64, -maxint64, -(minint64 >> 1):
- while base:
- s = b'I' + to_little_endian_string(base, 8)
- got = marshal.loads(s)
- self.assertEqual(base, got)
- if base == -1: # a fixed-point for shifting right 1
- base = 0
- else:
- base >>= 1
-
def test_bool(self):
for b in (True, False):
self.helper(b)
@@ -201,10 +177,14 @@ class BugsTestCase(unittest.TestCase):
except Exception:
pass
- def test_loads_recursion(self):
+ def test_loads_2x_code(self):
s = b'c' + (b'X' * 4*4) + b'{' * 2**20
self.assertRaises(ValueError, marshal.loads, s)
+ def test_loads_recursion(self):
+ s = b'c' + (b'X' * 4*5) + b'{' * 2**20
+ self.assertRaises(ValueError, marshal.loads, s)
+
def test_recursion_limit(self):
# Create a deeply nested structure.
head = last = []
@@ -282,15 +262,20 @@ class BugsTestCase(unittest.TestCase):
def test_bad_reader(self):
class BadReader(io.BytesIO):
- def read(self, n=-1):
- b = super().read(n)
+ def readinto(self, buf):
+ n = super().readinto(buf)
if n is not None and n > 4:
- b += b' ' * 10**6
- return b
+ n += 10**6
+ return n
for value in (1.0, 1j, b'0123456789', '0123456789'):
self.assertRaises(ValueError, marshal.load,
BadReader(marshal.dumps(value)))
+ def _test_eof(self):
+ data = marshal.dumps(("hello", "dolly", None))
+ for i in range(len(data)):
+ self.assertRaises(EOFError, marshal.loads, data[0: i])
+
LARGE_SIZE = 2**31
pointer_size = 8 if sys.maxsize > 0xFFFFFFFF else 4
@@ -335,6 +320,122 @@ class LargeValuesTestCase(unittest.TestCase):
def test_bytearray(self, size):
self.check_unmarshallable(bytearray(size))
+def CollectObjectIDs(ids, obj):
+ """Collect object ids seen in a structure"""
+ if id(obj) in ids:
+ return
+ ids.add(id(obj))
+ if isinstance(obj, (list, tuple, set, frozenset)):
+ for e in obj:
+ CollectObjectIDs(ids, e)
+ elif isinstance(obj, dict):
+ for k, v in obj.items():
+ CollectObjectIDs(ids, k)
+ CollectObjectIDs(ids, v)
+ return len(ids)
+
+class InstancingTestCase(unittest.TestCase, HelperMixin):
+ intobj = 123321
+ floatobj = 1.2345
+ strobj = "abcde"*3
+ dictobj = {"hello":floatobj, "goodbye":floatobj, floatobj:"hello"}
+
+ def helper3(self, rsample, recursive=False, simple=False):
+ #we have two instances
+ sample = (rsample, rsample)
+
+ n0 = CollectObjectIDs(set(), sample)
+
+ s3 = marshal.dumps(sample, 3)
+ n3 = CollectObjectIDs(set(), marshal.loads(s3))
+
+ #same number of instances generated
+ self.assertEqual(n3, n0)
+
+ if not recursive:
+ #can compare with version 2
+ s2 = marshal.dumps(sample, 2)
+ n2 = CollectObjectIDs(set(), marshal.loads(s2))
+ #old format generated more instances
+ self.assertGreater(n2, n0)
+
+ #if complex objects are in there, old format is larger
+ if not simple:
+ self.assertGreater(len(s2), len(s3))
+ else:
+ self.assertGreaterEqual(len(s2), len(s3))
+
+ def testInt(self):
+ self.helper(self.intobj)
+ self.helper3(self.intobj, simple=True)
+
+ def testFloat(self):
+ self.helper(self.floatobj)
+ self.helper3(self.floatobj)
+
+ def testStr(self):
+ self.helper(self.strobj)
+ self.helper3(self.strobj)
+
+ def testDict(self):
+ self.helper(self.dictobj)
+ self.helper3(self.dictobj)
+
+ def testModule(self):
+ with open(__file__, "rb") as f:
+ code = f.read()
+ if __file__.endswith(".py"):
+ code = compile(code, __file__, "exec")
+ self.helper(code)
+ self.helper3(code)
+
+ def testRecursion(self):
+ d = dict(self.dictobj)
+ d["self"] = d
+ self.helper3(d, recursive=True)
+ l = [self.dictobj]
+ l.append(l)
+ self.helper3(l, recursive=True)
+
+class CompatibilityTestCase(unittest.TestCase):
+ def _test(self, version):
+ with open(__file__, "rb") as f:
+ code = f.read()
+ if __file__.endswith(".py"):
+ code = compile(code, __file__, "exec")
+ data = marshal.dumps(code, version)
+ marshal.loads(data)
+
+ def test0To3(self):
+ self._test(0)
+
+ def test1To3(self):
+ self._test(1)
+
+ def test2To3(self):
+ self._test(2)
+
+ def test3To3(self):
+ self._test(3)
+
+class InterningTestCase(unittest.TestCase, HelperMixin):
+ strobj = "this is an interned string"
+ strobj = sys.intern(strobj)
+
+ def testIntern(self):
+ s = marshal.loads(marshal.dumps(self.strobj))
+ self.assertEqual(s, self.strobj)
+ self.assertEqual(id(s), id(self.strobj))
+ s2 = sys.intern(s)
+ self.assertEqual(id(s2), id(s))
+
+ def testNoIntern(self):
+ s = marshal.loads(marshal.dumps(self.strobj, 2))
+ self.assertEqual(s, self.strobj)
+ self.assertNotEqual(id(s), id(self.strobj))
+ s2 = sys.intern(s)
+ self.assertNotEqual(id(s2), id(s))
+
def test_main():
support.run_unittest(IntTestCase,
diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py
index d611a3138b..4de4f65329 100644
--- a/Lib/test/test_memoryio.py
+++ b/Lib/test/test_memoryio.py
@@ -520,12 +520,12 @@ class TextIOTestMixin:
def test_relative_seek(self):
memio = self.ioclass()
- self.assertRaises(IOError, memio.seek, -1, 1)
- self.assertRaises(IOError, memio.seek, 3, 1)
- self.assertRaises(IOError, memio.seek, -3, 1)
- self.assertRaises(IOError, memio.seek, -1, 2)
- self.assertRaises(IOError, memio.seek, 1, 1)
- self.assertRaises(IOError, memio.seek, 1, 2)
+ self.assertRaises(OSError, memio.seek, -1, 1)
+ self.assertRaises(OSError, memio.seek, 3, 1)
+ self.assertRaises(OSError, memio.seek, -3, 1)
+ self.assertRaises(OSError, memio.seek, -1, 2)
+ self.assertRaises(OSError, memio.seek, 1, 1)
+ self.assertRaises(OSError, memio.seek, 1, 2)
def test_textio_properties(self):
memio = self.ioclass()
diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py
index ee6b15ac14..ffd4f5851a 100644
--- a/Lib/test/test_memoryview.py
+++ b/Lib/test/test_memoryview.py
@@ -352,6 +352,15 @@ class AbstractMemoryTests:
self.assertIs(wr(), None)
self.assertIs(L[0], b)
+ def test_reversed(self):
+ for tp in self._types:
+ b = tp(self._source)
+ m = self._view(b)
+ aslist = list(reversed(m.tolist()))
+ self.assertEqual(list(reversed(m)), aslist)
+ self.assertEqual(list(reversed(m)), list(m[::-1]))
+
+
# Variations on source objects for the buffer: bytes-like objects, then arrays
# with itemsize > 1.
# NOTE: support for multi-dimensional objects is unimplemented.
diff --git a/Lib/test/test_mmap.py b/Lib/test/test_mmap.py
index 899df8d818..6ca5e1b730 100644
--- a/Lib/test/test_mmap.py
+++ b/Lib/test/test_mmap.py
@@ -1,11 +1,12 @@
from test.support import (TESTFN, run_unittest, import_module, unlink,
- requires, _2G, _4G)
+ requires, _2G, _4G, gc_collect)
import unittest
import os
import re
import itertools
import socket
import sys
+import weakref
# Skip test if we can't import mmap.
mmap = import_module('mmap')
@@ -245,7 +246,7 @@ class MmapTests(unittest.TestCase):
def test_bad_file_desc(self):
# Try opening a bad file descriptor...
- self.assertRaises(mmap.error, mmap.mmap, -2, 4096)
+ self.assertRaises(OSError, mmap.mmap, -2, 4096)
def test_tougher_find(self):
# Do a tougher .find() test. SF bug 515943 pointed out that, in 2.2,
@@ -655,7 +656,7 @@ class MmapTests(unittest.TestCase):
m = mmap.mmap(f.fileno(), 0)
f.close()
try:
- m.resize(0) # will raise WindowsError
+ m.resize(0) # will raise OSError
except:
pass
try:
@@ -671,7 +672,7 @@ class MmapTests(unittest.TestCase):
# parameters to _get_osfhandle.
s = socket.socket()
try:
- with self.assertRaises(mmap.error):
+ with self.assertRaises(OSError):
m = mmap.mmap(s.fileno(), 10)
finally:
s.close()
@@ -682,14 +683,23 @@ class MmapTests(unittest.TestCase):
self.assertTrue(m.closed)
def test_context_manager_exception(self):
- # Test that the IOError gets passed through
+ # Test that the OSError gets passed through
with self.assertRaises(Exception) as exc:
with mmap.mmap(-1, 10) as m:
- raise IOError
- self.assertIsInstance(exc.exception, IOError,
+ raise OSError
+ self.assertIsInstance(exc.exception, OSError,
"wrong exception raised in context manager")
self.assertTrue(m.closed, "context manager failed")
+ def test_weakref(self):
+ # Check mmap objects are weakrefable
+ mm = mmap.mmap(-1, 16)
+ wr = weakref.ref(mm)
+ self.assertIs(wr(), mm)
+ del mm
+ gc_collect()
+ self.assertIs(wr(), None)
+
class LargeMmapTests(unittest.TestCase):
def setUp(self):
@@ -707,7 +717,7 @@ class LargeMmapTests(unittest.TestCase):
f.seek(num_zeroes)
f.write(tail)
f.flush()
- except (IOError, OverflowError):
+ except (OSError, OverflowError):
f.close()
raise unittest.SkipTest("filesystem does not have largefile support")
return f
diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py
index e5a2525d18..5a9b503485 100644
--- a/Lib/test/test_module.py
+++ b/Lib/test/test_module.py
@@ -1,6 +1,8 @@
# Test the module type
import unittest
+import weakref
from test.support import run_unittest, gc_collect
+from test.script_helper import assert_python_ok
import sys
ModuleType = type(sys)
@@ -33,7 +35,10 @@ class ModuleTests(unittest.TestCase):
foo = ModuleType("foo")
self.assertEqual(foo.__name__, "foo")
self.assertEqual(foo.__doc__, None)
- self.assertEqual(foo.__dict__, {"__name__": "foo", "__doc__": None})
+ self.assertIs(foo.__loader__, None)
+ self.assertIs(foo.__package__, None)
+ self.assertEqual(foo.__dict__, {"__name__": "foo", "__doc__": None,
+ "__loader__": None, "__package__": None})
def test_ascii_docstring(self):
# ASCII docstring
@@ -41,7 +46,8 @@ class ModuleTests(unittest.TestCase):
self.assertEqual(foo.__name__, "foo")
self.assertEqual(foo.__doc__, "foodoc")
self.assertEqual(foo.__dict__,
- {"__name__": "foo", "__doc__": "foodoc"})
+ {"__name__": "foo", "__doc__": "foodoc",
+ "__loader__": None, "__package__": None})
def test_unicode_docstring(self):
# Unicode docstring
@@ -49,7 +55,8 @@ class ModuleTests(unittest.TestCase):
self.assertEqual(foo.__name__, "foo")
self.assertEqual(foo.__doc__, "foodoc\u1234")
self.assertEqual(foo.__dict__,
- {"__name__": "foo", "__doc__": "foodoc\u1234"})
+ {"__name__": "foo", "__doc__": "foodoc\u1234",
+ "__loader__": None, "__package__": None})
def test_reinit(self):
# Reinitialization should not replace the __dict__
@@ -61,10 +68,10 @@ class ModuleTests(unittest.TestCase):
self.assertEqual(foo.__doc__, "foodoc")
self.assertEqual(foo.bar, 42)
self.assertEqual(foo.__dict__,
- {"__name__": "foo", "__doc__": "foodoc", "bar": 42})
+ {"__name__": "foo", "__doc__": "foodoc", "bar": 42,
+ "__loader__": None, "__package__": None})
self.assertTrue(foo.__dict__ is d)
- @unittest.expectedFailure
def test_dont_clear_dict(self):
# See issue 7140.
def f():
@@ -89,6 +96,14 @@ a = A(destroyed)"""
gc_collect()
self.assertEqual(destroyed, [1])
+ def test_weakref(self):
+ m = ModuleType("foo")
+ wr = weakref.ref(m)
+ self.assertIs(wr(), m)
+ del m
+ gc_collect()
+ self.assertIs(wr(), None)
+
def test_module_repr_minimal(self):
# reprs when modules have no __file__, __name__, or __loader__
m = ModuleType('foo')
@@ -110,13 +125,19 @@ a = A(destroyed)"""
m.__file__ = '/tmp/foo.py'
self.assertEqual(repr(m), "<module '?' from '/tmp/foo.py'>")
+ def test_module_repr_with_loader_as_None(self):
+ m = ModuleType('foo')
+ assert m.__loader__ is None
+ self.assertEqual(repr(m), "<module 'foo'>")
+
def test_module_repr_with_bare_loader_but_no_name(self):
m = ModuleType('foo')
del m.__name__
# Yes, a class not an instance.
m.__loader__ = BareLoader
+ loader_repr = repr(BareLoader)
self.assertEqual(
- repr(m), "<module '?' (<class 'test.test_module.BareLoader'>)>")
+ repr(m), "<module '?' ({})>".format(loader_repr))
def test_module_repr_with_full_loader_but_no_name(self):
# m.__loader__.module_repr() will fail because the module has no
@@ -126,15 +147,17 @@ a = A(destroyed)"""
del m.__name__
# Yes, a class not an instance.
m.__loader__ = FullLoader
+ loader_repr = repr(FullLoader)
self.assertEqual(
- repr(m), "<module '?' (<class 'test.test_module.FullLoader'>)>")
+ repr(m), "<module '?' ({})>".format(loader_repr))
def test_module_repr_with_bare_loader(self):
m = ModuleType('foo')
# Yes, a class not an instance.
m.__loader__ = BareLoader
+ module_repr = repr(BareLoader)
self.assertEqual(
- repr(m), "<module 'foo' (<class 'test.test_module.BareLoader'>)>")
+ repr(m), "<module 'foo' ({})>".format(module_repr))
def test_module_repr_with_full_loader(self):
m = ModuleType('foo')
@@ -167,6 +190,19 @@ a = A(destroyed)"""
self.assertEqual(r[:25], "<module 'unittest' from '")
self.assertEqual(r[-13:], "__init__.py'>")
+ def test_module_finalization_at_shutdown(self):
+ # Module globals and builtins should still be available during shutdown
+ rc, out, err = assert_python_ok("-c", "from test import final_a")
+ self.assertFalse(err)
+ lines = out.splitlines()
+ self.assertEqual(set(lines), {
+ b"x = a",
+ b"x = b",
+ b"final_a.x = a",
+ b"final_b.x = b",
+ b"len = len",
+ b"shutil.rmtree = rmtree"})
+
# frozen and namespace module reprs are tested in importlib.
diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py
index feb7bd595a..91148a6fc5 100644
--- a/Lib/test/test_multibytecodec.py
+++ b/Lib/test/test_multibytecodec.py
@@ -176,57 +176,28 @@ class Test_StreamReader(unittest.TestCase):
support.unlink(TESTFN)
class Test_StreamWriter(unittest.TestCase):
- if len('\U00012345') == 2: # UCS2
- def test_gb18030(self):
- s= io.BytesIO()
- c = codecs.getwriter('gb18030')(s)
- c.write('123')
- self.assertEqual(s.getvalue(), b'123')
- c.write('\U00012345')
- self.assertEqual(s.getvalue(), b'123\x907\x959')
- c.write('\U00012345'[0])
- self.assertEqual(s.getvalue(), b'123\x907\x959')
- c.write('\U00012345'[1] + '\U00012345' + '\uac00\u00ac')
- self.assertEqual(s.getvalue(),
- b'123\x907\x959\x907\x959\x907\x959\x827\xcf5\x810\x851')
- c.write('\U00012345'[0])
- self.assertEqual(s.getvalue(),
- b'123\x907\x959\x907\x959\x907\x959\x827\xcf5\x810\x851')
- self.assertRaises(UnicodeError, c.reset)
- self.assertEqual(s.getvalue(),
- b'123\x907\x959\x907\x959\x907\x959\x827\xcf5\x810\x851')
-
- def test_utf_8(self):
- s= io.BytesIO()
- c = codecs.getwriter('utf-8')(s)
- c.write('123')
- self.assertEqual(s.getvalue(), b'123')
- c.write('\U00012345')
- self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85')
-
- # Python utf-8 codec can't buffer surrogate pairs yet.
- if 0:
- c.write('\U00012345'[0])
- self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85')
- c.write('\U00012345'[1] + '\U00012345' + '\uac00\u00ac')
- self.assertEqual(s.getvalue(),
- b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85'
- b'\xea\xb0\x80\xc2\xac')
- c.write('\U00012345'[0])
- self.assertEqual(s.getvalue(),
- b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85'
- b'\xea\xb0\x80\xc2\xac')
- c.reset()
- self.assertEqual(s.getvalue(),
- b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85'
- b'\xea\xb0\x80\xc2\xac\xed\xa0\x88')
- c.write('\U00012345'[1])
- self.assertEqual(s.getvalue(),
- b'123\xf0\x92\x8d\x85\xf0\x92\x8d\x85\xf0\x92\x8d\x85'
- b'\xea\xb0\x80\xc2\xac\xed\xa0\x88\xed\xbd\x85')
-
- else: # UCS4
- pass
+ def test_gb18030(self):
+ s= io.BytesIO()
+ c = codecs.getwriter('gb18030')(s)
+ c.write('123')
+ self.assertEqual(s.getvalue(), b'123')
+ c.write('\U00012345')
+ self.assertEqual(s.getvalue(), b'123\x907\x959')
+ c.write('\uac00\u00ac')
+ self.assertEqual(s.getvalue(),
+ b'123\x907\x959\x827\xcf5\x810\x851')
+
+ def test_utf_8(self):
+ s= io.BytesIO()
+ c = codecs.getwriter('utf-8')(s)
+ c.write('123')
+ self.assertEqual(s.getvalue(), b'123')
+ c.write('\U00012345')
+ self.assertEqual(s.getvalue(), b'123\xf0\x92\x8d\x85')
+ c.write('\uac00\u00ac')
+ self.assertEqual(s.getvalue(),
+ b'123\xf0\x92\x8d\x85'
+ b'\xea\xb0\x80\xc2\xac')
def test_streamwriter_strwrite(self):
s = io.BytesIO()
diff --git a/Lib/test/test_multiprocessing_fork.py b/Lib/test/test_multiprocessing_fork.py
new file mode 100644
index 0000000000..2bf4e75644
--- /dev/null
+++ b/Lib/test/test_multiprocessing_fork.py
@@ -0,0 +1,7 @@
+import unittest
+import test._test_multiprocessing
+
+test._test_multiprocessing.install_tests_in_module_dict(globals(), 'fork')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_multiprocessing_forkserver.py b/Lib/test/test_multiprocessing_forkserver.py
new file mode 100644
index 0000000000..193a04a5fc
--- /dev/null
+++ b/Lib/test/test_multiprocessing_forkserver.py
@@ -0,0 +1,7 @@
+import unittest
+import test._test_multiprocessing
+
+test._test_multiprocessing.install_tests_in_module_dict(globals(), 'forkserver')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_multiprocessing_spawn.py b/Lib/test/test_multiprocessing_spawn.py
new file mode 100644
index 0000000000..334ae9e8f7
--- /dev/null
+++ b/Lib/test/test_multiprocessing_spawn.py
@@ -0,0 +1,7 @@
+import unittest
+import test._test_multiprocessing
+
+test._test_multiprocessing.install_tests_in_module_dict(globals(), 'spawn')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_namespace_pkgs.py b/Lib/test/test_namespace_pkgs.py
index 7067b12e8f..4570bee568 100644
--- a/Lib/test/test_namespace_pkgs.py
+++ b/Lib/test/test_namespace_pkgs.py
@@ -1,7 +1,11 @@
-import sys
import contextlib
-import unittest
+from importlib._bootstrap import NamespaceLoader
+import importlib.abc
+import importlib.machinery
import os
+import sys
+import types
+import unittest
from test.test_importlib import util
from test.support import run_unittest
@@ -286,9 +290,24 @@ class ModuleAndNamespacePackageInSameDir(NamespacePackageTest):
self.assertEqual(a_test.attr, 'in module')
-def test_main():
- run_unittest(*NamespacePackageTest.__subclasses__())
+class ABCTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loader = NamespaceLoader('foo', ['pkg'],
+ importlib.machinery.PathFinder)
+
+ def test_is_package(self):
+ self.assertTrue(self.loader.is_package('foo'))
+
+ def test_get_code(self):
+ self.assertTrue(isinstance(self.loader.get_code('foo'), types.CodeType))
+
+ def test_get_source(self):
+ self.assertEqual(self.loader.get_source('foo'), '')
+
+ def test_abc_isinstance(self):
+ self.assertTrue(isinstance(self.loader, importlib.abc.InspectLoader))
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py
index 1d52713f07..71a4ec022b 100644
--- a/Lib/test/test_nntplib.py
+++ b/Lib/test/test_nntplib.py
@@ -266,7 +266,7 @@ class NetworkedNNTPTestsMixin:
return False
try:
server.help()
- except (socket.error, EOFError):
+ except (OSError, EOFError):
return False
return True
diff --git a/Lib/test/test_normalization.py b/Lib/test/test_normalization.py
index 28ede34cd7..ab2eeb77da 100644
--- a/Lib/test/test_normalization.py
+++ b/Lib/test/test_normalization.py
@@ -43,7 +43,7 @@ class NormalizationTest(unittest.TestCase):
try:
testdata = open_urlresource(TESTDATAURL, encoding="utf-8",
check=check_version)
- except (IOError, HTTPException):
+ except (OSError, HTTPException):
self.skipTest("Could not retrieve " + TESTDATAURL)
self.addCleanup(testdata.close)
for line in testdata:
diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py
index f8098767ce..285ef62a46 100644
--- a/Lib/test/test_ntpath.py
+++ b/Lib/test/test_ntpath.py
@@ -256,6 +256,40 @@ class TestNtpath(unittest.TestCase):
# dialogs (#4804)
ntpath.sameopenfile(-1, -1)
+ def test_ismount(self):
+ self.assertTrue(ntpath.ismount("c:\\"))
+ self.assertTrue(ntpath.ismount("C:\\"))
+ self.assertTrue(ntpath.ismount("c:/"))
+ self.assertTrue(ntpath.ismount("C:/"))
+ self.assertTrue(ntpath.ismount("\\\\.\\c:\\"))
+ self.assertTrue(ntpath.ismount("\\\\.\\C:\\"))
+
+ self.assertTrue(ntpath.ismount(b"c:\\"))
+ self.assertTrue(ntpath.ismount(b"C:\\"))
+ self.assertTrue(ntpath.ismount(b"c:/"))
+ self.assertTrue(ntpath.ismount(b"C:/"))
+ self.assertTrue(ntpath.ismount(b"\\\\.\\c:\\"))
+ self.assertTrue(ntpath.ismount(b"\\\\.\\C:\\"))
+
+ with support.temp_dir() as d:
+ self.assertFalse(ntpath.ismount(d))
+
+ if sys.platform == "win32":
+ #
+ # Make sure the current folder isn't the root folder
+ # (or any other volume root). The drive-relative
+ # locations below cannot then refer to mount points
+ #
+ drive, path = ntpath.splitdrive(sys.executable)
+ with support.change_cwd(os.path.dirname(sys.executable)):
+ self.assertFalse(ntpath.ismount(drive.lower()))
+ self.assertFalse(ntpath.ismount(drive.upper()))
+
+ self.assertTrue(ntpath.ismount("\\\\localhost\\c$"))
+ self.assertTrue(ntpath.ismount("\\\\localhost\\c$\\"))
+
+ self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$"))
+ self.assertTrue(ntpath.ismount(b"\\\\localhost\\c$\\"))
class NtCommonTest(test_genericpath.CommonTest, unittest.TestCase):
pathmodule = ntpath
diff --git a/Lib/test/test_openpty.py b/Lib/test/test_openpty.py
index 63843705b1..47851072c1 100644
--- a/Lib/test/test_openpty.py
+++ b/Lib/test/test_openpty.py
@@ -4,7 +4,7 @@ import os, unittest
from test.support import run_unittest
if not hasattr(os, "openpty"):
- raise unittest.SkipTest("No openpty() available.")
+ raise unittest.SkipTest("os.openpty() not available.")
class OpenptyTest(unittest.TestCase):
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
index fa608b9a52..ab58a98365 100644
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -1,8 +1,10 @@
-import operator
import unittest
from test import support
+py_operator = support.import_fresh_module('operator', blocked=['_operator'])
+c_operator = support.import_fresh_module('operator', fresh=['_operator'])
+
class Seq1:
def __init__(self, lst):
self.lst = lst
@@ -32,8 +34,9 @@ class Seq2(object):
return other * self.lst
-class OperatorTestCase(unittest.TestCase):
+class OperatorTestCase:
def test_lt(self):
+ operator = self.module
self.assertRaises(TypeError, operator.lt)
self.assertRaises(TypeError, operator.lt, 1j, 2j)
self.assertFalse(operator.lt(1, 0))
@@ -44,6 +47,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertTrue(operator.lt(1, 2.0))
def test_le(self):
+ operator = self.module
self.assertRaises(TypeError, operator.le)
self.assertRaises(TypeError, operator.le, 1j, 2j)
self.assertFalse(operator.le(1, 0))
@@ -54,6 +58,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertTrue(operator.le(1, 2.0))
def test_eq(self):
+ operator = self.module
class C(object):
def __eq__(self, other):
raise SyntaxError
@@ -67,6 +72,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertFalse(operator.eq(1, 2.0))
def test_ne(self):
+ operator = self.module
class C(object):
def __ne__(self, other):
raise SyntaxError
@@ -80,6 +86,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertTrue(operator.ne(1, 2.0))
def test_ge(self):
+ operator = self.module
self.assertRaises(TypeError, operator.ge)
self.assertRaises(TypeError, operator.ge, 1j, 2j)
self.assertTrue(operator.ge(1, 0))
@@ -90,6 +97,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertFalse(operator.ge(1, 2.0))
def test_gt(self):
+ operator = self.module
self.assertRaises(TypeError, operator.gt)
self.assertRaises(TypeError, operator.gt, 1j, 2j)
self.assertTrue(operator.gt(1, 0))
@@ -100,22 +108,26 @@ class OperatorTestCase(unittest.TestCase):
self.assertFalse(operator.gt(1, 2.0))
def test_abs(self):
+ operator = self.module
self.assertRaises(TypeError, operator.abs)
self.assertRaises(TypeError, operator.abs, None)
self.assertEqual(operator.abs(-1), 1)
self.assertEqual(operator.abs(1), 1)
def test_add(self):
+ operator = self.module
self.assertRaises(TypeError, operator.add)
self.assertRaises(TypeError, operator.add, None, None)
self.assertTrue(operator.add(3, 4) == 7)
def test_bitwise_and(self):
+ operator = self.module
self.assertRaises(TypeError, operator.and_)
self.assertRaises(TypeError, operator.and_, None, None)
self.assertTrue(operator.and_(0xf, 0xa) == 0xa)
def test_concat(self):
+ operator = self.module
self.assertRaises(TypeError, operator.concat)
self.assertRaises(TypeError, operator.concat, None, None)
self.assertTrue(operator.concat('py', 'thon') == 'python')
@@ -125,12 +137,14 @@ class OperatorTestCase(unittest.TestCase):
self.assertRaises(TypeError, operator.concat, 13, 29)
def test_countOf(self):
+ operator = self.module
self.assertRaises(TypeError, operator.countOf)
self.assertRaises(TypeError, operator.countOf, None, None)
self.assertTrue(operator.countOf([1, 2, 1, 3, 1, 4], 3) == 1)
self.assertTrue(operator.countOf([1, 2, 1, 3, 1, 4], 5) == 0)
def test_delitem(self):
+ operator = self.module
a = [4, 3, 2, 1]
self.assertRaises(TypeError, operator.delitem, a)
self.assertRaises(TypeError, operator.delitem, a, None)
@@ -138,33 +152,39 @@ class OperatorTestCase(unittest.TestCase):
self.assertTrue(a == [4, 2, 1])
def test_floordiv(self):
+ operator = self.module
self.assertRaises(TypeError, operator.floordiv, 5)
self.assertRaises(TypeError, operator.floordiv, None, None)
self.assertTrue(operator.floordiv(5, 2) == 2)
def test_truediv(self):
+ operator = self.module
self.assertRaises(TypeError, operator.truediv, 5)
self.assertRaises(TypeError, operator.truediv, None, None)
self.assertTrue(operator.truediv(5, 2) == 2.5)
def test_getitem(self):
+ operator = self.module
a = range(10)
self.assertRaises(TypeError, operator.getitem)
self.assertRaises(TypeError, operator.getitem, a, None)
self.assertTrue(operator.getitem(a, 2) == 2)
def test_indexOf(self):
+ operator = self.module
self.assertRaises(TypeError, operator.indexOf)
self.assertRaises(TypeError, operator.indexOf, None, None)
self.assertTrue(operator.indexOf([4, 3, 2, 1], 3) == 1)
self.assertRaises(ValueError, operator.indexOf, [4, 3, 2, 1], 0)
def test_invert(self):
+ operator = self.module
self.assertRaises(TypeError, operator.invert)
self.assertRaises(TypeError, operator.invert, None)
self.assertEqual(operator.inv(4), -5)
def test_lshift(self):
+ operator = self.module
self.assertRaises(TypeError, operator.lshift)
self.assertRaises(TypeError, operator.lshift, None, 42)
self.assertTrue(operator.lshift(5, 1) == 10)
@@ -172,16 +192,19 @@ class OperatorTestCase(unittest.TestCase):
self.assertRaises(ValueError, operator.lshift, 2, -1)
def test_mod(self):
+ operator = self.module
self.assertRaises(TypeError, operator.mod)
self.assertRaises(TypeError, operator.mod, None, 42)
self.assertTrue(operator.mod(5, 2) == 1)
def test_mul(self):
+ operator = self.module
self.assertRaises(TypeError, operator.mul)
self.assertRaises(TypeError, operator.mul, None, None)
self.assertTrue(operator.mul(5, 2) == 10)
def test_neg(self):
+ operator = self.module
self.assertRaises(TypeError, operator.neg)
self.assertRaises(TypeError, operator.neg, None)
self.assertEqual(operator.neg(5), -5)
@@ -190,11 +213,13 @@ class OperatorTestCase(unittest.TestCase):
self.assertEqual(operator.neg(-0), 0)
def test_bitwise_or(self):
+ operator = self.module
self.assertRaises(TypeError, operator.or_)
self.assertRaises(TypeError, operator.or_, None, None)
self.assertTrue(operator.or_(0xa, 0x5) == 0xf)
def test_pos(self):
+ operator = self.module
self.assertRaises(TypeError, operator.pos)
self.assertRaises(TypeError, operator.pos, None)
self.assertEqual(operator.pos(5), 5)
@@ -203,14 +228,15 @@ class OperatorTestCase(unittest.TestCase):
self.assertEqual(operator.pos(-0), 0)
def test_pow(self):
+ operator = self.module
self.assertRaises(TypeError, operator.pow)
self.assertRaises(TypeError, operator.pow, None, None)
self.assertEqual(operator.pow(3,5), 3**5)
- self.assertEqual(operator.__pow__(3,5), 3**5)
self.assertRaises(TypeError, operator.pow, 1)
self.assertRaises(TypeError, operator.pow, 1, 2, 3)
def test_rshift(self):
+ operator = self.module
self.assertRaises(TypeError, operator.rshift)
self.assertRaises(TypeError, operator.rshift, None, 42)
self.assertTrue(operator.rshift(5, 1) == 2)
@@ -218,12 +244,14 @@ class OperatorTestCase(unittest.TestCase):
self.assertRaises(ValueError, operator.rshift, 2, -1)
def test_contains(self):
+ operator = self.module
self.assertRaises(TypeError, operator.contains)
self.assertRaises(TypeError, operator.contains, None, None)
self.assertTrue(operator.contains(range(4), 2))
self.assertFalse(operator.contains(range(4), 5))
def test_setitem(self):
+ operator = self.module
a = list(range(3))
self.assertRaises(TypeError, operator.setitem, a)
self.assertRaises(TypeError, operator.setitem, a, None, None)
@@ -232,11 +260,13 @@ class OperatorTestCase(unittest.TestCase):
self.assertRaises(IndexError, operator.setitem, a, 4, 2)
def test_sub(self):
+ operator = self.module
self.assertRaises(TypeError, operator.sub)
self.assertRaises(TypeError, operator.sub, None, None)
self.assertTrue(operator.sub(5, 2) == 3)
def test_truth(self):
+ operator = self.module
class C(object):
def __bool__(self):
raise SyntaxError
@@ -248,11 +278,13 @@ class OperatorTestCase(unittest.TestCase):
self.assertFalse(operator.truth([]))
def test_bitwise_xor(self):
+ operator = self.module
self.assertRaises(TypeError, operator.xor)
self.assertRaises(TypeError, operator.xor, None, None)
self.assertTrue(operator.xor(0xb, 0xc) == 0x7)
def test_is(self):
+ operator = self.module
a = b = 'xyzpdq'
c = a[:3] + b[3:]
self.assertRaises(TypeError, operator.is_)
@@ -260,6 +292,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertFalse(operator.is_(a,c))
def test_is_not(self):
+ operator = self.module
a = b = 'xyzpdq'
c = a[:3] + b[3:]
self.assertRaises(TypeError, operator.is_not)
@@ -267,6 +300,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertTrue(operator.is_not(a,c))
def test_attrgetter(self):
+ operator = self.module
class A:
pass
a = A()
@@ -316,6 +350,7 @@ class OperatorTestCase(unittest.TestCase):
self.assertEqual(f(a), ('arthur', 'thomas', 'johnson'))
def test_itemgetter(self):
+ operator = self.module
a = 'ABCDE'
f = operator.itemgetter(2)
self.assertEqual(f(a), 'C')
@@ -350,12 +385,15 @@ class OperatorTestCase(unittest.TestCase):
self.assertRaises(TypeError, operator.itemgetter(2, 'x', 5), data)
def test_methodcaller(self):
+ operator = self.module
self.assertRaises(TypeError, operator.methodcaller)
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]
def bar(self, f=42):
return f
+ def baz(*args, **kwds):
+ return kwds['name'], kwds['self']
a = A()
f = operator.methodcaller('foo')
self.assertRaises(IndexError, f, a)
@@ -366,8 +404,11 @@ class OperatorTestCase(unittest.TestCase):
self.assertRaises(TypeError, f, a, a)
f = operator.methodcaller('bar', f=5)
self.assertEqual(f(a), 5)
+ f = operator.methodcaller('baz', name='spam', self='eggs')
+ self.assertEqual(f(a), ('spam', 'eggs'))
def test_inplace(self):
+ operator = self.module
class C(object):
def __iadd__ (self, other): return "iadd"
def __iand__ (self, other): return "iand"
@@ -396,37 +437,48 @@ class OperatorTestCase(unittest.TestCase):
self.assertEqual(operator.itruediv (c, 5), "itruediv")
self.assertEqual(operator.ixor (c, 5), "ixor")
self.assertEqual(operator.iconcat (c, c), "iadd")
- self.assertEqual(operator.__iadd__ (c, 5), "iadd")
- self.assertEqual(operator.__iand__ (c, 5), "iand")
- self.assertEqual(operator.__ifloordiv__(c, 5), "ifloordiv")
- self.assertEqual(operator.__ilshift__ (c, 5), "ilshift")
- self.assertEqual(operator.__imod__ (c, 5), "imod")
- self.assertEqual(operator.__imul__ (c, 5), "imul")
- self.assertEqual(operator.__ior__ (c, 5), "ior")
- self.assertEqual(operator.__ipow__ (c, 5), "ipow")
- self.assertEqual(operator.__irshift__ (c, 5), "irshift")
- self.assertEqual(operator.__isub__ (c, 5), "isub")
- self.assertEqual(operator.__itruediv__ (c, 5), "itruediv")
- self.assertEqual(operator.__ixor__ (c, 5), "ixor")
- self.assertEqual(operator.__iconcat__ (c, c), "iadd")
-
-def test_main(verbose=None):
- import sys
- test_classes = (
- OperatorTestCase,
- )
-
- support.run_unittest(*test_classes)
-
- # verify reference counting
- if verbose and hasattr(sys, "gettotalrefcount"):
- import gc
- counts = [None] * 5
- for i in range(len(counts)):
- support.run_unittest(*test_classes)
- gc.collect()
- counts[i] = sys.gettotalrefcount()
- print(counts)
+
+ def test_length_hint(self):
+ operator = self.module
+ class X(object):
+ def __init__(self, value):
+ self.value = value
+
+ def __length_hint__(self):
+ if type(self.value) is type:
+ raise self.value
+ else:
+ return self.value
+
+ self.assertEqual(operator.length_hint([], 2), 0)
+ self.assertEqual(operator.length_hint(iter([1, 2, 3])), 3)
+
+ self.assertEqual(operator.length_hint(X(2)), 2)
+ self.assertEqual(operator.length_hint(X(NotImplemented), 4), 4)
+ self.assertEqual(operator.length_hint(X(TypeError), 12), 12)
+ with self.assertRaises(TypeError):
+ operator.length_hint(X("abc"))
+ with self.assertRaises(ValueError):
+ operator.length_hint(X(-2))
+ with self.assertRaises(LookupError):
+ operator.length_hint(X(LookupError))
+
+ def test_dunder_is_original(self):
+ operator = self.module
+
+ names = [name for name in dir(operator) if not name.startswith('_')]
+ for name in names:
+ orig = getattr(operator, name)
+ dunder = getattr(operator, '__' + name.strip('_') + '__', None)
+ if dunder:
+ self.assertIs(dunder, orig)
+
+class PyOperatorTestCase(OperatorTestCase, unittest.TestCase):
+ module = py_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class COperatorTestCase(OperatorTestCase, unittest.TestCase):
+ module = c_operator
if __name__ == "__main__":
- test_main(verbose=True)
+ unittest.main()
diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py
index 78de278f76..94730119e9 100644
--- a/Lib/test/test_optparse.py
+++ b/Lib/test/test_optparse.py
@@ -730,7 +730,7 @@ class TestStandard(BaseTest):
def test_short_and_long_option_split(self):
self.assertParseOK(["-a", "xyz", "--foo", "bar"],
{'a': 'xyz', 'boo': None, 'foo': ["bar"]},
- []),
+ [])
def test_short_option_split_long_option_append(self):
self.assertParseOK(["--foo=bar", "-b", "123", "--foo", "baz"],
@@ -740,15 +740,15 @@ class TestStandard(BaseTest):
def test_short_option_split_one_positional_arg(self):
self.assertParseOK(["-a", "foo", "bar"],
{'a': "foo", 'boo': None, 'foo': None},
- ["bar"]),
+ ["bar"])
def test_short_option_consumes_separator(self):
self.assertParseOK(["-a", "--", "foo", "bar"],
{'a': "--", 'boo': None, 'foo': None},
- ["foo", "bar"]),
+ ["foo", "bar"])
self.assertParseOK(["-a", "--", "--foo", "bar"],
{'a': "--", 'boo': None, 'foo': ["bar"]},
- []),
+ [])
def test_short_option_joined_and_separator(self):
self.assertParseOK(["-ab", "--", "--foo", "bar"],
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index 601c6b2e97..1ffa7dafeb 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -24,6 +24,9 @@ import itertools
import stat
import locale
import codecs
+import decimal
+import fractions
+import pickle
try:
import threading
except ImportError:
@@ -32,6 +35,10 @@ try:
import resource
except ImportError:
resource = None
+try:
+ import fcntl
+except ImportError:
+ fcntl = None
from test.script_helper import assert_python_ok
@@ -256,6 +263,13 @@ class StatAttributeTests(unittest.TestCase):
warnings.simplefilter("ignore", DeprecationWarning)
self.check_stat_attributes(fname)
+ def test_stat_result_pickle(self):
+ result = os.stat(self.fname)
+ p = pickle.dumps(result)
+ self.assertIn(b'\x03cos\nstat_result\n', p)
+ unpickled = pickle.loads(p)
+ self.assertEqual(result, unpickled)
+
@unittest.skipUnless(hasattr(os, 'statvfs'), 'test needs os.statvfs()')
def test_statvfs_attributes(self):
try:
@@ -263,7 +277,7 @@ class StatAttributeTests(unittest.TestCase):
except OSError as e:
# On AtheOS, glibc always returns ENOSYS
if e.errno == errno.ENOSYS:
- return
+ self.skipTest('os.statvfs() failed with ENOSYS')
# Make sure direct access works
self.assertEqual(result.f_bfree, result[3])
@@ -300,6 +314,21 @@ class StatAttributeTests(unittest.TestCase):
except TypeError:
pass
+ @unittest.skipUnless(hasattr(os, 'statvfs'),
+ "need os.statvfs()")
+ def test_statvfs_result_pickle(self):
+ try:
+ result = os.statvfs(self.fname)
+ except OSError as e:
+ # On AtheOS, glibc always returns ENOSYS
+ if e.errno == errno.ENOSYS:
+ self.skipTest('os.statvfs() failed with ENOSYS')
+
+ p = pickle.dumps(result)
+ self.assertIn(b'\x03cos\nstatvfs_result\n', p)
+ unpickled = pickle.loads(p)
+ self.assertEqual(result, unpickled)
+
def test_utime_dir(self):
delta = 1000000
st = os.stat(support.TESTFN)
@@ -478,9 +507,9 @@ class StatAttributeTests(unittest.TestCase):
# Verify that an open file can be stat'ed
try:
os.stat(r"c:\pagefile.sys")
- except WindowsError as e:
- if e.errno == 2: # file does not exist; cannot run test
- return
+ except FileNotFoundError:
+ pass # file does not exist; cannot run test
+ except OSError as e:
self.fail("Could not stat pagefile.sys")
@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests")
@@ -876,6 +905,17 @@ class MakedirTests(unittest.TestCase):
os.makedirs(path, mode=mode, exist_ok=True)
os.umask(old_mask)
+ @unittest.skipUnless(hasattr(os, 'chown'), 'test needs os.chown')
+ def test_chown_uid_gid_arguments_must_be_index(self):
+ stat = os.stat(support.TESTFN)
+ uid = stat.st_uid
+ gid = stat.st_gid
+ for value in (-1.0, -1j, decimal.Decimal(-1), fractions.Fraction(-2, 2)):
+ self.assertRaises(TypeError, os.chown, support.TESTFN, value, gid)
+ self.assertRaises(TypeError, os.chown, support.TESTFN, uid, value)
+ self.assertIsNone(os.chown(support.TESTFN, uid, gid))
+ self.assertIsNone(os.chown(support.TESTFN, -1, -1))
+
def test_exist_ok_s_isgid_directory(self):
path = os.path.join(support.TESTFN, 'dir1')
S_ISGID = stat.S_ISGID
@@ -1133,27 +1173,27 @@ class ExecTests(unittest.TestCase):
@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests")
class Win32ErrorTests(unittest.TestCase):
def test_rename(self):
- self.assertRaises(WindowsError, os.rename, support.TESTFN, support.TESTFN+".bak")
+ self.assertRaises(OSError, os.rename, support.TESTFN, support.TESTFN+".bak")
def test_remove(self):
- self.assertRaises(WindowsError, os.remove, support.TESTFN)
+ self.assertRaises(OSError, os.remove, support.TESTFN)
def test_chdir(self):
- self.assertRaises(WindowsError, os.chdir, support.TESTFN)
+ self.assertRaises(OSError, os.chdir, support.TESTFN)
def test_mkdir(self):
f = open(support.TESTFN, "w")
try:
- self.assertRaises(WindowsError, os.mkdir, support.TESTFN)
+ self.assertRaises(OSError, os.mkdir, support.TESTFN)
finally:
f.close()
os.unlink(support.TESTFN)
def test_utime(self):
- self.assertRaises(WindowsError, os.utime, support.TESTFN, None)
+ self.assertRaises(OSError, os.utime, support.TESTFN, None)
def test_chmod(self):
- self.assertRaises(WindowsError, os.chmod, support.TESTFN, 0)
+ self.assertRaises(OSError, os.chmod, support.TESTFN, 0)
class TestInvalidFD(unittest.TestCase):
singles = ["fchdir", "dup", "fdopen", "fdatasync", "fstat",
@@ -1278,41 +1318,57 @@ class PosixUidGidTests(unittest.TestCase):
@unittest.skipUnless(hasattr(os, 'setuid'), 'test needs os.setuid()')
def test_setuid(self):
if os.getuid() != 0:
- self.assertRaises(os.error, os.setuid, 0)
+ self.assertRaises(OSError, os.setuid, 0)
self.assertRaises(OverflowError, os.setuid, 1<<32)
@unittest.skipUnless(hasattr(os, 'setgid'), 'test needs os.setgid()')
def test_setgid(self):
if os.getuid() != 0 and not HAVE_WHEEL_GROUP:
- self.assertRaises(os.error, os.setgid, 0)
+ self.assertRaises(OSError, os.setgid, 0)
self.assertRaises(OverflowError, os.setgid, 1<<32)
@unittest.skipUnless(hasattr(os, 'seteuid'), 'test needs os.seteuid()')
def test_seteuid(self):
if os.getuid() != 0:
- self.assertRaises(os.error, os.seteuid, 0)
+ self.assertRaises(OSError, os.seteuid, 0)
self.assertRaises(OverflowError, os.seteuid, 1<<32)
@unittest.skipUnless(hasattr(os, 'setegid'), 'test needs os.setegid()')
def test_setegid(self):
if os.getuid() != 0 and not HAVE_WHEEL_GROUP:
- self.assertRaises(os.error, os.setegid, 0)
+ self.assertRaises(OSError, os.setegid, 0)
self.assertRaises(OverflowError, os.setegid, 1<<32)
@unittest.skipUnless(hasattr(os, 'setreuid'), 'test needs os.setreuid()')
def test_setreuid(self):
if os.getuid() != 0:
- self.assertRaises(os.error, os.setreuid, 0, 0)
+ self.assertRaises(OSError, os.setreuid, 0, 0)
self.assertRaises(OverflowError, os.setreuid, 1<<32, 0)
self.assertRaises(OverflowError, os.setreuid, 0, 1<<32)
+ @unittest.skipUnless(hasattr(os, 'setreuid'), 'test needs os.setreuid()')
+ def test_setreuid_neg1(self):
+ # Needs to accept -1. We run this in a subprocess to avoid
+ # altering the test runner's process state (issue8045).
+ subprocess.check_call([
+ sys.executable, '-c',
+ 'import os,sys;os.setreuid(-1,-1);sys.exit(0)'])
+
@unittest.skipUnless(hasattr(os, 'setregid'), 'test needs os.setregid()')
def test_setregid(self):
if os.getuid() != 0 and not HAVE_WHEEL_GROUP:
- self.assertRaises(os.error, os.setregid, 0, 0)
+ self.assertRaises(OSError, os.setregid, 0, 0)
self.assertRaises(OverflowError, os.setregid, 1<<32, 0)
self.assertRaises(OverflowError, os.setregid, 0, 1<<32)
+ @unittest.skipUnless(hasattr(os, 'setregid'), 'test needs os.setregid()')
+ def test_setregid_neg1(self):
+ # Needs to accept -1. We run this in a subprocess to avoid
+ # altering the test runner's process state (issue8045).
+ subprocess.check_call([
+ sys.executable, '-c',
+ 'import os,sys;os.setregid(-1,-1);sys.exit(0)'])
+
@unittest.skipIf(sys.platform == "win32", "Posix specific tests")
class Pep383Tests(unittest.TestCase):
def setUp(self):
@@ -1502,6 +1558,52 @@ class Win32KillTests(unittest.TestCase):
@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests")
+class Win32ListdirTests(unittest.TestCase):
+ """Test listdir on Windows."""
+
+ def setUp(self):
+ self.created_paths = []
+ for i in range(2):
+ dir_name = 'SUB%d' % i
+ dir_path = os.path.join(support.TESTFN, dir_name)
+ file_name = 'FILE%d' % i
+ file_path = os.path.join(support.TESTFN, file_name)
+ os.makedirs(dir_path)
+ with open(file_path, 'w') as f:
+ f.write("I'm %s and proud of it. Blame test_os.\n" % file_path)
+ self.created_paths.extend([dir_name, file_name])
+ self.created_paths.sort()
+
+ def tearDown(self):
+ shutil.rmtree(support.TESTFN)
+
+ def test_listdir_no_extended_path(self):
+ """Test when the path is not an "extended" path."""
+ # unicode
+ self.assertEqual(
+ sorted(os.listdir(support.TESTFN)),
+ self.created_paths)
+ # bytes
+ self.assertEqual(
+ sorted(os.listdir(os.fsencode(support.TESTFN))),
+ [os.fsencode(path) for path in self.created_paths])
+
+ def test_listdir_extended_path(self):
+ """Test when the path starts with '\\\\?\\'."""
+ # See: http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247(v=vs.85).aspx#maxpath
+ # unicode
+ path = '\\\\?\\' + os.path.abspath(support.TESTFN)
+ self.assertEqual(
+ sorted(os.listdir(path)),
+ self.created_paths)
+ # bytes
+ path = b'\\\\?\\' + os.fsencode(os.path.abspath(support.TESTFN))
+ self.assertEqual(
+ sorted(os.listdir(path)),
+ [os.fsencode(path) for path in self.created_paths])
+
+
+@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests")
@support.skip_unless_symlink
class Win32SymlinkTests(unittest.TestCase):
filelink = 'filelinktest'
@@ -2166,6 +2268,197 @@ class TermsizeTests(unittest.TestCase):
self.assertEqual(expected, actual)
+class OSErrorTests(unittest.TestCase):
+ def setUp(self):
+ class Str(str):
+ pass
+
+ self.bytes_filenames = []
+ self.unicode_filenames = []
+ if support.TESTFN_UNENCODABLE is not None:
+ decoded = support.TESTFN_UNENCODABLE
+ else:
+ decoded = support.TESTFN
+ self.unicode_filenames.append(decoded)
+ self.unicode_filenames.append(Str(decoded))
+ if support.TESTFN_UNDECODABLE is not None:
+ encoded = support.TESTFN_UNDECODABLE
+ else:
+ encoded = os.fsencode(support.TESTFN)
+ self.bytes_filenames.append(encoded)
+ self.bytes_filenames.append(memoryview(encoded))
+
+ self.filenames = self.bytes_filenames + self.unicode_filenames
+
+ def test_oserror_filename(self):
+ funcs = [
+ (self.filenames, os.chdir,),
+ (self.filenames, os.chmod, 0o777),
+ (self.filenames, os.lstat,),
+ (self.filenames, os.open, os.O_RDONLY),
+ (self.filenames, os.rmdir,),
+ (self.filenames, os.stat,),
+ (self.filenames, os.unlink,),
+ ]
+ if sys.platform == "win32":
+ funcs.extend((
+ (self.bytes_filenames, os.rename, b"dst"),
+ (self.bytes_filenames, os.replace, b"dst"),
+ (self.unicode_filenames, os.rename, "dst"),
+ (self.unicode_filenames, os.replace, "dst"),
+ # Issue #16414: Don't test undecodable names with listdir()
+ # because of a Windows bug.
+ #
+ # With the ANSI code page 932, os.listdir(b'\xe7') return an
+ # empty list (instead of failing), whereas os.listdir(b'\xff')
+ # raises a FileNotFoundError. It looks like a Windows bug:
+ # b'\xe7' directory does not exist, FindFirstFileA(b'\xe7')
+ # fails with ERROR_FILE_NOT_FOUND (2), instead of
+ # ERROR_PATH_NOT_FOUND (3).
+ (self.unicode_filenames, os.listdir,),
+ ))
+ else:
+ funcs.extend((
+ (self.filenames, os.listdir,),
+ (self.filenames, os.rename, "dst"),
+ (self.filenames, os.replace, "dst"),
+ ))
+ if hasattr(os, "chown"):
+ funcs.append((self.filenames, os.chown, 0, 0))
+ if hasattr(os, "lchown"):
+ funcs.append((self.filenames, os.lchown, 0, 0))
+ if hasattr(os, "truncate"):
+ funcs.append((self.filenames, os.truncate, 0))
+ if hasattr(os, "chflags"):
+ funcs.append((self.filenames, os.chflags, 0))
+ if hasattr(os, "lchflags"):
+ funcs.append((self.filenames, os.lchflags, 0))
+ if hasattr(os, "chroot"):
+ funcs.append((self.filenames, os.chroot,))
+ if hasattr(os, "link"):
+ if sys.platform == "win32":
+ funcs.append((self.bytes_filenames, os.link, b"dst"))
+ funcs.append((self.unicode_filenames, os.link, "dst"))
+ else:
+ funcs.append((self.filenames, os.link, "dst"))
+ if hasattr(os, "listxattr"):
+ funcs.extend((
+ (self.filenames, os.listxattr,),
+ (self.filenames, os.getxattr, "user.test"),
+ (self.filenames, os.setxattr, "user.test", b'user'),
+ (self.filenames, os.removexattr, "user.test"),
+ ))
+ if hasattr(os, "lchmod"):
+ funcs.append((self.filenames, os.lchmod, 0o777))
+ if hasattr(os, "readlink"):
+ if sys.platform == "win32":
+ funcs.append((self.unicode_filenames, os.readlink,))
+ else:
+ funcs.append((self.filenames, os.readlink,))
+
+ for filenames, func, *func_args in funcs:
+ for name in filenames:
+ try:
+ func(name, *func_args)
+ except OSError as err:
+ self.assertIs(err.filename, name)
+ else:
+ self.fail("No exception thrown by {}".format(func))
+
+class CPUCountTests(unittest.TestCase):
+ def test_cpu_count(self):
+ cpus = os.cpu_count()
+ if cpus is not None:
+ self.assertIsInstance(cpus, int)
+ self.assertGreater(cpus, 0)
+ else:
+ self.skipTest("Could not determine the number of CPUs")
+
+
+class FDInheritanceTests(unittest.TestCase):
+ def test_get_set_inheritable(self):
+ fd = os.open(__file__, os.O_RDONLY)
+ self.addCleanup(os.close, fd)
+ self.assertEqual(os.get_inheritable(fd), False)
+
+ os.set_inheritable(fd, True)
+ self.assertEqual(os.get_inheritable(fd), True)
+
+ @unittest.skipIf(fcntl is None, "need fcntl")
+ def test_get_inheritable_cloexec(self):
+ fd = os.open(__file__, os.O_RDONLY)
+ self.addCleanup(os.close, fd)
+ self.assertEqual(os.get_inheritable(fd), False)
+
+ # clear FD_CLOEXEC flag
+ flags = fcntl.fcntl(fd, fcntl.F_GETFD)
+ flags &= ~fcntl.FD_CLOEXEC
+ fcntl.fcntl(fd, fcntl.F_SETFD, flags)
+
+ self.assertEqual(os.get_inheritable(fd), True)
+
+ @unittest.skipIf(fcntl is None, "need fcntl")
+ def test_set_inheritable_cloexec(self):
+ fd = os.open(__file__, os.O_RDONLY)
+ self.addCleanup(os.close, fd)
+ self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
+ fcntl.FD_CLOEXEC)
+
+ os.set_inheritable(fd, True)
+ self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
+ 0)
+
+ def test_open(self):
+ fd = os.open(__file__, os.O_RDONLY)
+ self.addCleanup(os.close, fd)
+ self.assertEqual(os.get_inheritable(fd), False)
+
+ @unittest.skipUnless(hasattr(os, 'pipe'), "need os.pipe()")
+ def test_pipe(self):
+ rfd, wfd = os.pipe()
+ self.addCleanup(os.close, rfd)
+ self.addCleanup(os.close, wfd)
+ self.assertEqual(os.get_inheritable(rfd), False)
+ self.assertEqual(os.get_inheritable(wfd), False)
+
+ def test_dup(self):
+ fd1 = os.open(__file__, os.O_RDONLY)
+ self.addCleanup(os.close, fd1)
+
+ fd2 = os.dup(fd1)
+ self.addCleanup(os.close, fd2)
+ self.assertEqual(os.get_inheritable(fd2), False)
+
+ @unittest.skipUnless(hasattr(os, 'dup2'), "need os.dup2()")
+ def test_dup2(self):
+ fd = os.open(__file__, os.O_RDONLY)
+ self.addCleanup(os.close, fd)
+
+ # inheritable by default
+ fd2 = os.open(__file__, os.O_RDONLY)
+ try:
+ os.dup2(fd, fd2)
+ self.assertEqual(os.get_inheritable(fd2), True)
+ finally:
+ os.close(fd2)
+
+ # force non-inheritable
+ fd3 = os.open(__file__, os.O_RDONLY)
+ try:
+ os.dup2(fd, fd3, inheritable=False)
+ self.assertEqual(os.get_inheritable(fd3), False)
+ finally:
+ os.close(fd3)
+
+ @unittest.skipUnless(hasattr(os, 'openpty'), "need os.openpty()")
+ def test_openpty(self):
+ master_fd, slave_fd = os.openpty()
+ self.addCleanup(os.close, master_fd)
+ self.addCleanup(os.close, slave_fd)
+ self.assertEqual(os.get_inheritable(master_fd), False)
+ self.assertEqual(os.get_inheritable(slave_fd), False)
+
+
@support.reap_threads
def test_main():
support.run_unittest(
@@ -2183,6 +2476,7 @@ def test_main():
PosixUidGidTests,
Pep383Tests,
Win32KillTests,
+ Win32ListdirTests,
Win32SymlinkTests,
NonLocalSymlinkTests,
FSEncodingTests,
@@ -2195,7 +2489,10 @@ def test_main():
ExtendedAttributeTests,
Win32DeprecatedBytesAPI,
TermsizeTests,
+ OSErrorTests,
RemoveDirsTests,
+ CPUCountTests,
+ FDInheritanceTests,
)
if __name__ == "__main__":
diff --git a/Lib/test/test_ossaudiodev.py b/Lib/test/test_ossaudiodev.py
index 3908a0506e..c9e2a24767 100644
--- a/Lib/test/test_ossaudiodev.py
+++ b/Lib/test/test_ossaudiodev.py
@@ -44,7 +44,7 @@ class OSSAudioDevTests(unittest.TestCase):
def play_sound_file(self, data, rate, ssize, nchannels):
try:
dsp = ossaudiodev.open('w')
- except IOError as msg:
+ except OSError as msg:
if msg.args[0] in (errno.EACCES, errno.ENOENT,
errno.ENODEV, errno.EBUSY):
raise unittest.SkipTest(msg)
@@ -190,7 +190,7 @@ class OSSAudioDevTests(unittest.TestCase):
def test_main():
try:
dsp = ossaudiodev.open('w')
- except (ossaudiodev.error, IOError) as msg:
+ except (ossaudiodev.error, OSError) as msg:
if msg.args[0] in (errno.EACCES, errno.ENOENT,
errno.ENODEV, errno.EBUSY):
raise unittest.SkipTest(msg)
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
index 03084e4d12..7993d02ba1 100644
--- a/Lib/test/test_pdb.py
+++ b/Lib/test/test_pdb.py
@@ -1,9 +1,9 @@
# A test suite for pdb; not very comprehensive at the moment.
import doctest
-import imp
import pdb
import sys
+import types
import unittest
import subprocess
import textwrap
@@ -205,7 +205,8 @@ def test_pdb_breakpoint_commands():
... 'enable 1',
... 'clear 1',
... 'commands 2',
- ... 'print 42',
+ ... 'p "42"',
+ ... 'print("42", 7*6)', # Issue 18764 (not about breakpoints)
... 'end',
... 'continue', # will stop at breakpoint 2 (line 4)
... 'clear', # clear all!
@@ -252,11 +253,13 @@ def test_pdb_breakpoint_commands():
(Pdb) clear 1
Deleted breakpoint 1 at <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>:3
(Pdb) commands 2
- (com) print 42
+ (com) p "42"
+ (com) print("42", 7*6)
(com) end
(Pdb) continue
1
- 42
+ '42'
+ 42 42
> <doctest test.test_pdb.test_pdb_breakpoint_commands[0]>(4)test_function()
-> print(2)
(Pdb) clear
@@ -464,7 +467,7 @@ def test_pdb_skip_modules():
# Module for testing skipping of module that makes a callback
-mod = imp.new_module('module_to_skip')
+mod = types.ModuleType('module_to_skip')
exec('def foo_pony(callback): x = 1; callback(); return None', mod.__dict__)
@@ -617,6 +620,36 @@ class PdbTestCase(unittest.TestCase):
stderr = stderr and bytes.decode(stderr)
return stdout, stderr
+ def _assert_find_function(self, file_content, func_name, expected):
+ file_content = textwrap.dedent(file_content)
+
+ with open(support.TESTFN, 'w') as f:
+ f.write(file_content)
+
+ expected = None if not expected else (
+ expected[0], support.TESTFN, expected[1])
+ self.assertEqual(
+ expected, pdb.find_function(func_name, support.TESTFN))
+
+ def test_find_function_empty_file(self):
+ self._assert_find_function('', 'foo', None)
+
+ def test_find_function_found(self):
+ self._assert_find_function(
+ """\
+ def foo():
+ pass
+
+ def bar():
+ pass
+
+ def quux():
+ pass
+ """,
+ 'bar',
+ ('bar', 4),
+ )
+
def test_issue7964(self):
# open the file as binary so we can force \r\n newline
with open(support.TESTFN, 'wb') as f:
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index 1cacdea569..50257923fe 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -5,44 +5,28 @@ from io import StringIO
import unittest
from math import copysign
-def disassemble(func):
- f = StringIO()
- tmp = sys.stdout
- sys.stdout = f
- try:
- dis.dis(func)
- finally:
- sys.stdout = tmp
- result = f.getvalue()
- f.close()
- return result
+from test.bytecode_helper import BytecodeTestCase
-def dis_single(line):
- return disassemble(compile(line, '', 'single'))
-
-
-class TestTranforms(unittest.TestCase):
+class TestTranforms(BytecodeTestCase):
def test_unot(self):
# UNARY_NOT POP_JUMP_IF_FALSE --> POP_JUMP_IF_TRUE'
def unot(x):
if not x == 2:
del x
- asm = disassemble(unot)
- for elem in ('UNARY_NOT', 'POP_JUMP_IF_FALSE'):
- self.assertNotIn(elem, asm)
- for elem in ('POP_JUMP_IF_TRUE',):
- self.assertIn(elem, asm)
+ self.assertNotInBytecode(unot, 'UNARY_NOT')
+ self.assertNotInBytecode(unot, 'POP_JUMP_IF_FALSE')
+ self.assertInBytecode(unot, 'POP_JUMP_IF_TRUE')
def test_elim_inversion_of_is_or_in(self):
- for line, elem in (
- ('not a is b', '(is not)',),
- ('not a in b', '(not in)',),
- ('not a is not b', '(is)',),
- ('not a not in b', '(in)',),
+ for line, cmp_op in (
+ ('not a is b', 'is not',),
+ ('not a in b', 'not in',),
+ ('not a is not b', 'is',),
+ ('not a not in b', 'in',),
):
- asm = dis_single(line)
- self.assertIn(elem, asm)
+ code = compile(line, '', 'single')
+ self.assertInBytecode(code, 'COMPARE_OP', cmp_op)
def test_global_as_constant(self):
# LOAD_GLOBAL None/True/False --> LOAD_CONST None/True/False
@@ -56,17 +40,14 @@ class TestTranforms(unittest.TestCase):
def h(x):
False
return x
- for func, name in ((f, 'None'), (g, 'True'), (h, 'False')):
- asm = disassemble(func)
- for elem in ('LOAD_GLOBAL',):
- self.assertNotIn(elem, asm)
- for elem in ('LOAD_CONST', '('+name+')'):
- self.assertIn(elem, asm)
+ for func, elem in ((f, None), (g, True), (h, False)):
+ self.assertNotInBytecode(func, 'LOAD_GLOBAL')
+ self.assertInBytecode(func, 'LOAD_CONST', elem)
def f():
'Adding a docstring made this test fail in Py2.5.0'
return None
- self.assertIn('LOAD_CONST', disassemble(f))
- self.assertNotIn('LOAD_GLOBAL', disassemble(f))
+ self.assertNotInBytecode(f, 'LOAD_GLOBAL')
+ self.assertInBytecode(f, 'LOAD_CONST', None)
def test_while_one(self):
# Skip over: LOAD_CONST trueconst POP_JUMP_IF_FALSE xx
@@ -74,11 +55,10 @@ class TestTranforms(unittest.TestCase):
while 1:
pass
return list
- asm = disassemble(f)
for elem in ('LOAD_CONST', 'POP_JUMP_IF_FALSE'):
- self.assertNotIn(elem, asm)
+ self.assertNotInBytecode(f, elem)
for elem in ('JUMP_ABSOLUTE',):
- self.assertIn(elem, asm)
+ self.assertInBytecode(f, elem)
def test_pack_unpack(self):
for line, elem in (
@@ -86,28 +66,30 @@ class TestTranforms(unittest.TestCase):
('a, b = a, b', 'ROT_TWO',),
('a, b, c = a, b, c', 'ROT_THREE',),
):
- asm = dis_single(line)
- self.assertIn(elem, asm)
- self.assertNotIn('BUILD_TUPLE', asm)
- self.assertNotIn('UNPACK_TUPLE', asm)
+ code = compile(line,'','single')
+ self.assertInBytecode(code, elem)
+ self.assertNotInBytecode(code, 'BUILD_TUPLE')
+ self.assertNotInBytecode(code, 'UNPACK_TUPLE')
def test_folding_of_tuples_of_constants(self):
for line, elem in (
- ('a = 1,2,3', '((1, 2, 3))'),
- ('("a","b","c")', "(('a', 'b', 'c'))"),
- ('a,b,c = 1,2,3', '((1, 2, 3))'),
- ('(None, 1, None)', '((None, 1, None))'),
- ('((1, 2), 3, 4)', '(((1, 2), 3, 4))'),
+ ('a = 1,2,3', (1, 2, 3)),
+ ('("a","b","c")', ('a', 'b', 'c')),
+ ('a,b,c = 1,2,3', (1, 2, 3)),
+ ('(None, 1, None)', (None, 1, None)),
+ ('((1, 2), 3, 4)', ((1, 2), 3, 4)),
):
- asm = dis_single(line)
- self.assertIn(elem, asm)
- self.assertNotIn('BUILD_TUPLE', asm)
+ code = compile(line,'','single')
+ self.assertInBytecode(code, 'LOAD_CONST', elem)
+ self.assertNotInBytecode(code, 'BUILD_TUPLE')
# Long tuples should be folded too.
- asm = dis_single(repr(tuple(range(10000))))
+ code = compile(repr(tuple(range(10000))),'','single')
+ self.assertNotInBytecode(code, 'BUILD_TUPLE')
# One LOAD_CONST for the tuple, one for the None return value
- self.assertEqual(asm.count('LOAD_CONST'), 2)
- self.assertNotIn('BUILD_TUPLE', asm)
+ load_consts = [instr for instr in dis.get_instructions(code)
+ if instr.opname == 'LOAD_CONST']
+ self.assertEqual(len(load_consts), 2)
# Bug 1053819: Tuple of constants misidentified when presented with:
# . . . opcode_with_arg 100 unary_opcode BUILD_TUPLE 1 . . .
@@ -129,14 +111,14 @@ class TestTranforms(unittest.TestCase):
def test_folding_of_lists_of_constants(self):
for line, elem in (
# in/not in constants with BUILD_LIST should be folded to a tuple:
- ('a in [1,2,3]', '(1, 2, 3)'),
- ('a not in ["a","b","c"]', "(('a', 'b', 'c'))"),
- ('a in [None, 1, None]', '((None, 1, None))'),
- ('a not in [(1, 2), 3, 4]', '(((1, 2), 3, 4))'),
+ ('a in [1,2,3]', (1, 2, 3)),
+ ('a not in ["a","b","c"]', ('a', 'b', 'c')),
+ ('a in [None, 1, None]', (None, 1, None)),
+ ('a not in [(1, 2), 3, 4]', ((1, 2), 3, 4)),
):
- asm = dis_single(line)
- self.assertIn(elem, asm)
- self.assertNotIn('BUILD_LIST', asm)
+ code = compile(line, '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', elem)
+ self.assertNotInBytecode(code, 'BUILD_LIST')
def test_folding_of_sets_of_constants(self):
for line, elem in (
@@ -147,18 +129,9 @@ class TestTranforms(unittest.TestCase):
('a not in {(1, 2), 3, 4}', frozenset({(1, 2), 3, 4})),
('a in {1, 2, 3, 3, 2, 1}', frozenset({1, 2, 3})),
):
- asm = dis_single(line)
- self.assertNotIn('BUILD_SET', asm)
-
- # Verify that the frozenset 'elem' is in the disassembly
- # The ordering of the elements in repr( frozenset ) isn't
- # guaranteed, so we jump through some hoops to ensure that we have
- # the frozenset we expect:
- self.assertIn('frozenset', asm)
- # Extract the frozenset literal from the disassembly:
- m = re.match(r'.*(frozenset\({.*}\)).*', asm, re.DOTALL)
- self.assertTrue(m)
- self.assertEqual(eval(m.group(1)), elem)
+ code = compile(line, '', 'single')
+ self.assertNotInBytecode(code, 'BUILD_SET')
+ self.assertInBytecode(code, 'LOAD_CONST', elem)
# Ensure that the resulting code actually works:
def f(a):
@@ -176,98 +149,103 @@ class TestTranforms(unittest.TestCase):
def test_folding_of_binops_on_constants(self):
for line, elem in (
- ('a = 2+3+4', '(9)'), # chained fold
- ('"@"*4', "('@@@@')"), # check string ops
- ('a="abc" + "def"', "('abcdef')"), # check string ops
- ('a = 3**4', '(81)'), # binary power
- ('a = 3*4', '(12)'), # binary multiply
- ('a = 13//4', '(3)'), # binary floor divide
- ('a = 14%4', '(2)'), # binary modulo
- ('a = 2+3', '(5)'), # binary add
- ('a = 13-4', '(9)'), # binary subtract
- ('a = (12,13)[1]', '(13)'), # binary subscr
- ('a = 13 << 2', '(52)'), # binary lshift
- ('a = 13 >> 2', '(3)'), # binary rshift
- ('a = 13 & 7', '(5)'), # binary and
- ('a = 13 ^ 7', '(10)'), # binary xor
- ('a = 13 | 7', '(15)'), # binary or
+ ('a = 2+3+4', 9), # chained fold
+ ('"@"*4', '@@@@'), # check string ops
+ ('a="abc" + "def"', 'abcdef'), # check string ops
+ ('a = 3**4', 81), # binary power
+ ('a = 3*4', 12), # binary multiply
+ ('a = 13//4', 3), # binary floor divide
+ ('a = 14%4', 2), # binary modulo
+ ('a = 2+3', 5), # binary add
+ ('a = 13-4', 9), # binary subtract
+ ('a = (12,13)[1]', 13), # binary subscr
+ ('a = 13 << 2', 52), # binary lshift
+ ('a = 13 >> 2', 3), # binary rshift
+ ('a = 13 & 7', 5), # binary and
+ ('a = 13 ^ 7', 10), # binary xor
+ ('a = 13 | 7', 15), # binary or
):
- asm = dis_single(line)
- self.assertIn(elem, asm, asm)
- self.assertNotIn('BINARY_', asm)
+ code = compile(line, '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', elem)
+ for instr in dis.get_instructions(code):
+ self.assertFalse(instr.opname.startswith('BINARY_'))
# Verify that unfoldables are skipped
- asm = dis_single('a=2+"b"')
- self.assertIn('(2)', asm)
- self.assertIn("('b')", asm)
+ code = compile('a=2+"b"', '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', 2)
+ self.assertInBytecode(code, 'LOAD_CONST', 'b')
# Verify that large sequences do not result from folding
- asm = dis_single('a="x"*1000')
- self.assertIn('(1000)', asm)
+ code = compile('a="x"*1000', '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', 1000)
def test_binary_subscr_on_unicode(self):
# valid code get optimized
- asm = dis_single('"foo"[0]')
- self.assertIn("('f')", asm)
- self.assertNotIn('BINARY_SUBSCR', asm)
- asm = dis_single('"\u0061\uffff"[1]')
- self.assertIn("('\\uffff')", asm)
- self.assertNotIn('BINARY_SUBSCR', asm)
- asm = dis_single('"\U00012345abcdef"[3]')
- self.assertIn("('c')", asm)
- self.assertNotIn('BINARY_SUBSCR', asm)
+ code = compile('"foo"[0]', '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', 'f')
+ self.assertNotInBytecode(code, 'BINARY_SUBSCR')
+ code = compile('"\u0061\uffff"[1]', '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', '\uffff')
+ self.assertNotInBytecode(code,'BINARY_SUBSCR')
+
+ # With PEP 393, non-BMP char get optimized
+ code = compile('"\U00012345"[0]', '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', '\U00012345')
+ self.assertNotInBytecode(code, 'BINARY_SUBSCR')
# invalid code doesn't get optimized
# out of range
- asm = dis_single('"fuu"[10]')
- self.assertIn('BINARY_SUBSCR', asm)
+ code = compile('"fuu"[10]', '', 'single')
+ self.assertInBytecode(code, 'BINARY_SUBSCR')
def test_folding_of_unaryops_on_constants(self):
for line, elem in (
- ('-0.5', '(-0.5)'), # unary negative
- ('-0.0', '(-0.0)'), # -0.0
- ('-(1.0-1.0)','(-0.0)'), # -0.0 after folding
- ('-0', '(0)'), # -0
- ('~-2', '(1)'), # unary invert
- ('+1', '(1)'), # unary positive
+ ('-0.5', -0.5), # unary negative
+ ('-0.0', -0.0), # -0.0
+ ('-(1.0-1.0)', -0.0), # -0.0 after folding
+ ('-0', 0), # -0
+ ('~-2', 1), # unary invert
+ ('+1', 1), # unary positive
):
- asm = dis_single(line)
- self.assertIn(elem, asm, asm)
- self.assertNotIn('UNARY_', asm)
+ code = compile(line, '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', elem)
+ for instr in dis.get_instructions(code):
+ self.assertFalse(instr.opname.startswith('UNARY_'))
# Check that -0.0 works after marshaling
def negzero():
return -(1.0-1.0)
- self.assertNotIn('UNARY_', disassemble(negzero))
- self.assertTrue(copysign(1.0, negzero()) < 0)
+ for instr in dis.get_instructions(code):
+ self.assertFalse(instr.opname.startswith('UNARY_'))
# Verify that unfoldables are skipped
- for line, elem in (
- ('-"abc"', "('abc')"), # unary negative
- ('~"abc"', "('abc')"), # unary invert
+ for line, elem, opname in (
+ ('-"abc"', 'abc', 'UNARY_NEGATIVE'),
+ ('~"abc"', 'abc', 'UNARY_INVERT'),
):
- asm = dis_single(line)
- self.assertIn(elem, asm, asm)
- self.assertIn('UNARY_', asm)
+ code = compile(line, '', 'single')
+ self.assertInBytecode(code, 'LOAD_CONST', elem)
+ self.assertInBytecode(code, opname)
def test_elim_extra_return(self):
# RETURN LOAD_CONST None RETURN --> RETURN
def f(x):
return x
- asm = disassemble(f)
- self.assertNotIn('LOAD_CONST', asm)
- self.assertNotIn('(None)', asm)
- self.assertEqual(asm.split().count('RETURN_VALUE'), 1)
+ self.assertNotInBytecode(f, 'LOAD_CONST', None)
+ returns = [instr for instr in dis.get_instructions(f)
+ if instr.opname == 'RETURN_VALUE']
+ self.assertEqual(len(returns), 1)
def test_elim_jump_to_return(self):
# JUMP_FORWARD to RETURN --> RETURN
def f(cond, true_value, false_value):
return true_value if cond else false_value
- asm = disassemble(f)
- self.assertNotIn('JUMP_FORWARD', asm)
- self.assertNotIn('JUMP_ABSOLUTE', asm)
- self.assertEqual(asm.split().count('RETURN_VALUE'), 2)
+ self.assertNotInBytecode(f, 'JUMP_FORWARD')
+ self.assertNotInBytecode(f, 'JUMP_ABSOLUTE')
+ returns = [instr for instr in dis.get_instructions(f)
+ if instr.opname == 'RETURN_VALUE']
+ self.assertEqual(len(returns), 2)
def test_elim_jump_after_return1(self):
# Eliminate dead code: jumps immediately after returns can't be reached
@@ -280,48 +258,53 @@ class TestTranforms(unittest.TestCase):
if cond1: return 4
return 5
return 6
- asm = disassemble(f)
- self.assertNotIn('JUMP_FORWARD', asm)
- self.assertNotIn('JUMP_ABSOLUTE', asm)
- self.assertEqual(asm.split().count('RETURN_VALUE'), 6)
+ self.assertNotInBytecode(f, 'JUMP_FORWARD')
+ self.assertNotInBytecode(f, 'JUMP_ABSOLUTE')
+ returns = [instr for instr in dis.get_instructions(f)
+ if instr.opname == 'RETURN_VALUE']
+ self.assertEqual(len(returns), 6)
def test_elim_jump_after_return2(self):
# Eliminate dead code: jumps immediately after returns can't be reached
def f(cond1, cond2):
while 1:
if cond1: return 4
- asm = disassemble(f)
- self.assertNotIn('JUMP_FORWARD', asm)
+ self.assertNotInBytecode(f, 'JUMP_FORWARD')
# There should be one jump for the while loop.
- self.assertEqual(asm.split().count('JUMP_ABSOLUTE'), 1)
- self.assertEqual(asm.split().count('RETURN_VALUE'), 2)
+ returns = [instr for instr in dis.get_instructions(f)
+ if instr.opname == 'JUMP_ABSOLUTE']
+ self.assertEqual(len(returns), 1)
+ returns = [instr for instr in dis.get_instructions(f)
+ if instr.opname == 'RETURN_VALUE']
+ self.assertEqual(len(returns), 2)
def test_make_function_doesnt_bail(self):
def f():
def g()->1+1:
pass
return g
- asm = disassemble(f)
- self.assertNotIn('BINARY_ADD', asm)
+ self.assertNotInBytecode(f, 'BINARY_ADD')
def test_constant_folding(self):
# Issue #11244: aggressive constant folding.
exprs = [
- "3 * -5",
- "-3 * 5",
- "2 * (3 * 4)",
- "(2 * 3) * 4",
- "(-1, 2, 3)",
- "(1, -2, 3)",
- "(1, 2, -3)",
- "(1, 2, -3) * 6",
- "lambda x: x in {(3 * -5) + (-1 - 6), (1, -2, 3) * 2, None}",
+ '3 * -5',
+ '-3 * 5',
+ '2 * (3 * 4)',
+ '(2 * 3) * 4',
+ '(-1, 2, 3)',
+ '(1, -2, 3)',
+ '(1, 2, -3)',
+ '(1, 2, -3) * 6',
+ 'lambda x: x in {(3 * -5) + (-1 - 6), (1, -2, 3) * 2, None}',
]
for e in exprs:
- asm = dis_single(e)
- self.assertNotIn('UNARY_', asm, e)
- self.assertNotIn('BINARY_', asm, e)
- self.assertNotIn('BUILD_', asm, e)
+ code = compile(e, '', 'single')
+ for instr in dis.get_instructions(code):
+ self.assertFalse(instr.opname.startswith('UNARY_'))
+ self.assertFalse(instr.opname.startswith('BINARY_'))
+ self.assertFalse(instr.opname.startswith('BUILD_'))
+
class TestBuglets(unittest.TestCase):
@@ -343,7 +326,7 @@ def test_main(verbose=None):
support.run_unittest(*test_classes)
# verify reference counting
- if verbose and hasattr(sys, "gettotalrefcount"):
+ if verbose and hasattr(sys, 'gettotalrefcount'):
import gc
counts = [None] * 5
for i in range(len(counts)):
diff --git a/Lib/test/test_pep277.py b/Lib/test/test_pep277.py
index 4b16cbb383..9bae6dcad7 100644
--- a/Lib/test/test_pep277.py
+++ b/Lib/test/test_pep277.py
@@ -99,10 +99,6 @@ class UnicodeFileTests(unittest.TestCase):
with self.assertRaises(expected_exception) as c:
fn(filename)
exc_filename = c.exception.filename
- # listdir may append a wildcard to the filename
- if fn is os.listdir and sys.platform == 'win32':
- exc_filename, _, wildcard = exc_filename.rpartition(os.sep)
- self.assertEqual(wildcard, '*.*')
if check_filename:
self.assertEqual(exc_filename, filename, "Function '%s(%a) failed "
"with bad filename in the exception: %a" %
diff --git a/Lib/test/test_pep352.py b/Lib/test/test_pep352.py
index 558cdb56d2..7c98c460b9 100644
--- a/Lib/test/test_pep352.py
+++ b/Lib/test/test_pep352.py
@@ -1,7 +1,6 @@
import unittest
import builtins
import warnings
-from test.support import run_unittest
import os
from platform import system as platform_system
@@ -180,8 +179,6 @@ class UsageTests(unittest.TestCase):
# Catching a string is bad.
self.catch_fails("spam")
-def test_main():
- run_unittest(ExceptionClassTests, UsageTests)
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_pkgimport.py b/Lib/test/test_pkgimport.py
index a8426b5c9a..370b2aae28 100644
--- a/Lib/test/test_pkgimport.py
+++ b/Lib/test/test_pkgimport.py
@@ -6,7 +6,7 @@ import random
import tempfile
import unittest
-from imp import cache_from_source
+from importlib.util import cache_from_source
from test.support import run_unittest, create_empty_file
class TestImport(unittest.TestCase):
diff --git a/Lib/test/test_pkgutil.py b/Lib/test/test_pkgutil.py
index fd0661450a..1f488534d7 100644
--- a/Lib/test/test_pkgutil.py
+++ b/Lib/test/test_pkgutil.py
@@ -1,12 +1,12 @@
from test.support import run_unittest, unload, check_warnings
import unittest
import sys
-import imp
import importlib
import pkgutil
import os
import os.path
import tempfile
+import types
import shutil
import zipfile
@@ -105,7 +105,7 @@ class PkgutilPEP302Tests(unittest.TestCase):
class MyTestLoader(object):
def load_module(self, fullname):
# Create an empty module
- mod = sys.modules.setdefault(fullname, imp.new_module(fullname))
+ mod = sys.modules.setdefault(fullname, types.ModuleType(fullname))
mod.__file__ = "<%s>" % self.__class__.__name__
mod.__loader__ = self
# Make it a package
diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py
index f98a280e9a..a1e5c3dbf8 100644
--- a/Lib/test/test_poll.py
+++ b/Lib/test/test_poll.py
@@ -1,6 +1,7 @@
# Test case for the os.poll() function
import os
+import subprocess
import random
import select
import _testcapi
@@ -15,7 +16,7 @@ from test.support import TESTFN, run_unittest, reap_threads
try:
select.poll
except AttributeError:
- raise unittest.SkipTest("select.poll not defined -- skipping test_poll")
+ raise unittest.SkipTest("select.poll not defined")
def find_ready_matching(ready, flag):
@@ -76,13 +77,11 @@ class PollTests(unittest.TestCase):
self.assertEqual(bufs, [MSG] * NUM_PIPES)
- def poll_unit_tests(self):
+ def test_poll_unit_tests(self):
# returns NVAL for invalid file descriptor
- FD = 42
- try:
- os.close(FD)
- except OSError:
- pass
+ FD, w = os.pipe()
+ os.close(FD)
+ os.close(w)
p = select.poll()
p.register(FD)
r = p.poll()
@@ -125,7 +124,9 @@ class PollTests(unittest.TestCase):
def test_poll2(self):
cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
- p = os.popen(cmd, 'r')
+ proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
+ bufsize=0)
+ p = proc.stdout
pollster = select.poll()
pollster.register( p, select.POLLIN )
for tout in (0, 1000, 2000, 4000, 8000, 16000) + (-1,)*10:
@@ -135,7 +136,7 @@ class PollTests(unittest.TestCase):
fd, flags = fdlist[0]
if flags & select.POLLHUP:
line = p.readline()
- if line != "":
+ if line != b"":
self.fail('error: pipe seems to be closed, but still returns data')
continue
@@ -143,6 +144,7 @@ class PollTests(unittest.TestCase):
line = p.readline()
if not line:
break
+ self.assertEqual(line, b'testing...\n')
continue
else:
self.fail('Unexpected return value from select.poll: %s' % fdlist)
diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py
index cbce8524eb..70fe4265c5 100644
--- a/Lib/test/test_poplib.py
+++ b/Lib/test/test_poplib.py
@@ -18,6 +18,14 @@ threading = test_support.import_module('threading')
HOST = test_support.HOST
PORT = 0
+SUPPORTS_SSL = False
+if hasattr(poplib, 'POP3_SSL'):
+ import ssl
+
+ SUPPORTS_SSL = True
+ CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "keycert.pem")
+requires_ssl = skipUnless(SUPPORTS_SSL, 'SSL not supported')
+
# the dummy data returned by server when LIST and RETR commands are issued
LIST_RESP = b'1 1\r\n2 2\r\n3 3\r\n4 4\r\n5 5\r\n.\r\n'
RETR_RESP = b"""From: postmaster@python.org\
@@ -33,11 +41,15 @@ line3\r\n\
class DummyPOP3Handler(asynchat.async_chat):
+ CAPAS = {'UIDL': [], 'IMPLEMENTATION': ['python-testlib-pop-server']}
+
def __init__(self, conn):
asynchat.async_chat.__init__(self, conn)
self.set_terminator(b"\r\n")
self.in_buffer = []
self.push('+OK dummy pop3 server ready. <timestamp>')
+ self.tls_active = False
+ self.tls_starting = False
def collect_incoming_data(self, data):
self.in_buffer.append(data)
@@ -112,6 +124,65 @@ class DummyPOP3Handler(asynchat.async_chat):
self.push('+OK closing.')
self.close_when_done()
+ def _get_capas(self):
+ _capas = dict(self.CAPAS)
+ if not self.tls_active and SUPPORTS_SSL:
+ _capas['STLS'] = []
+ return _capas
+
+ def cmd_capa(self, arg):
+ self.push('+OK Capability list follows')
+ if self._get_capas():
+ for cap, params in self._get_capas().items():
+ _ln = [cap]
+ if params:
+ _ln.extend(params)
+ self.push(' '.join(_ln))
+ self.push('.')
+
+ if SUPPORTS_SSL:
+
+ def cmd_stls(self, arg):
+ if self.tls_active is False:
+ self.push('+OK Begin TLS negotiation')
+ tls_sock = ssl.wrap_socket(self.socket, certfile=CERTFILE,
+ server_side=True,
+ do_handshake_on_connect=False,
+ suppress_ragged_eofs=False)
+ self.del_channel()
+ self.set_socket(tls_sock)
+ self.tls_active = True
+ self.tls_starting = True
+ self.in_buffer = []
+ self._do_tls_handshake()
+ else:
+ self.push('-ERR Command not permitted when TLS active')
+
+ def _do_tls_handshake(self):
+ try:
+ self.socket.do_handshake()
+ except ssl.SSLError as err:
+ if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE):
+ return
+ elif err.args[0] == ssl.SSL_ERROR_EOF:
+ return self.handle_close()
+ raise
+ except OSError as err:
+ if err.args[0] == errno.ECONNABORTED:
+ return self.handle_close()
+ else:
+ self.tls_active = True
+ self.tls_starting = False
+
+ def handle_read(self):
+ if self.tls_starting:
+ self._do_tls_handshake()
+ else:
+ try:
+ asynchat.async_chat.handle_read(self)
+ except ssl.SSLEOFError:
+ self.handle_close()
class DummyPOP3Server(asyncore.dispatcher, threading.Thread):
@@ -236,19 +307,36 @@ class TestPOP3Class(TestCase):
self.client.uidl()
self.client.uidl('foo')
+ def test_capa(self):
+ capa = self.client.capa()
+ self.assertTrue('IMPLEMENTATION' in capa.keys())
+
def test_quit(self):
resp = self.client.quit()
self.assertTrue(resp)
self.assertIsNone(self.client.sock)
self.assertIsNone(self.client.file)
+ @requires_ssl
+ def test_stls_capa(self):
+ capa = self.client.capa()
+ self.assertTrue('STLS' in capa.keys())
-SUPPORTS_SSL = False
-if hasattr(poplib, 'POP3_SSL'):
- import ssl
+ @requires_ssl
+ def test_stls(self):
+ expected = b'+OK Begin TLS negotiation'
+ resp = self.client.stls()
+ self.assertEqual(resp, expected)
- SUPPORTS_SSL = True
- CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "keycert.pem")
+ @requires_ssl
+ def test_stls_context(self):
+ expected = b'+OK Begin TLS negotiation'
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ resp = self.client.stls(context=ctx)
+ self.assertEqual(resp, expected)
+
+
+if SUPPORTS_SSL:
class DummyPOP3_SSLHandler(DummyPOP3Handler):
@@ -260,35 +348,13 @@ if hasattr(poplib, 'POP3_SSL'):
self.del_channel()
self.set_socket(ssl_socket)
# Must try handshake before calling push()
- self._ssl_accepting = True
- self._do_ssl_handshake()
+ self.tls_active = True
+ self.tls_starting = True
+ self._do_tls_handshake()
self.set_terminator(b"\r\n")
self.in_buffer = []
self.push('+OK dummy pop3 server ready. <timestamp>')
- def _do_ssl_handshake(self):
- try:
- self.socket.do_handshake()
- except ssl.SSLError as err:
- if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE):
- return
- elif err.args[0] == ssl.SSL_ERROR_EOF:
- return self.handle_close()
- raise
- except socket.error as err:
- if err.args[0] == errno.ECONNABORTED:
- return self.handle_close()
- else:
- self._ssl_accepting = False
-
- def handle_read(self):
- if self._ssl_accepting:
- self._do_ssl_handshake()
- else:
- DummyPOP3Handler.handle_read(self)
-
-requires_ssl = skipUnless(SUPPORTS_SSL, 'SSL not supported')
@requires_ssl
class TestPOP3_SSLClass(TestPOP3Class):
@@ -306,20 +372,60 @@ class TestPOP3_SSLClass(TestPOP3Class):
def test_context(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host,
- self.server.port, keyfile=CERTFILE, context=ctx)
+ self.server.port, keyfile=CERTFILE, context=ctx)
self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host,
- self.server.port, certfile=CERTFILE, context=ctx)
+ self.server.port, certfile=CERTFILE, context=ctx)
self.assertRaises(ValueError, poplib.POP3_SSL, self.server.host,
- self.server.port, keyfile=CERTFILE,
- certfile=CERTFILE, context=ctx)
+ self.server.port, keyfile=CERTFILE,
+ certfile=CERTFILE, context=ctx)
self.client.quit()
self.client = poplib.POP3_SSL(self.server.host, self.server.port,
- context=ctx)
+ context=ctx)
self.assertIsInstance(self.client.sock, ssl.SSLSocket)
self.assertIs(self.client.sock.context, ctx)
self.assertTrue(self.client.noop().startswith(b'+OK'))
+ def test_stls(self):
+ self.assertRaises(poplib.error_proto, self.client.stls)
+
+ test_stls_context = test_stls
+
+ def test_stls_capa(self):
+ capa = self.client.capa()
+ self.assertFalse('STLS' in capa.keys())
+
+
+@requires_ssl
+class TestPOP3_TLSClass(TestPOP3Class):
+ # repeat previous tests by using poplib.POP3.stls()
+
+ def setUp(self):
+ self.server = DummyPOP3Server((HOST, PORT))
+ self.server.start()
+ self.client = poplib.POP3(self.server.host, self.server.port, timeout=3)
+ self.client.stls()
+
+ def tearDown(self):
+ if self.client.file is not None and self.client.sock is not None:
+ try:
+ self.client.quit()
+ except poplib.error_proto:
+ # happens in the test_too_long_lines case; the overlong
+ # response will be treated as response to QUIT and raise
+ # this exception
+ pass
+ self.server.stop()
+
+ def test_stls(self):
+ self.assertRaises(poplib.error_proto, self.client.stls)
+
+ test_stls_context = test_stls
+
+ def test_stls_capa(self):
+ capa = self.client.capa()
+ self.assertFalse(b'STLS' in capa.keys())
+
class TestTimeouts(TestCase):
@@ -377,7 +483,7 @@ class TestTimeouts(TestCase):
def test_main():
tests = [TestPOP3Class, TestTimeouts,
- TestPOP3_SSLClass]
+ TestPOP3_SSLClass, TestPOP3_TLSClass]
thread_info = test_support.threading_setup()
try:
test_support.run_unittest(*tests)
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index 02bb6acb82..1eceebe7ac 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -582,8 +582,10 @@ class PosixTester(unittest.TestCase):
r, w = os.pipe2(os.O_CLOEXEC|os.O_NONBLOCK)
self.addCleanup(os.close, r)
self.addCleanup(os.close, w)
- self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
- self.assertTrue(fcntl.fcntl(w, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
+ self.assertFalse(os.get_inheritable(r))
+ self.assertFalse(os.get_inheritable(w))
+ self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK)
+ self.assertTrue(fcntl.fcntl(w, fcntl.F_GETFL) & os.O_NONBLOCK)
# try reading from an empty pipe: this should fail, not block
self.assertRaises(OSError, os.read, r, 1)
# try a write big enough to fill-up the pipe: this should either
@@ -626,7 +628,7 @@ class PosixTester(unittest.TestCase):
self.assertEqual(st.st_flags | stat.UF_IMMUTABLE, new_st.st_flags)
try:
fd = open(target_file, 'w+')
- except IOError as e:
+ except OSError as e:
self.assertEqual(e.errno, errno.EPERM)
finally:
posix.chflags(target_file, st.st_flags)
@@ -684,41 +686,39 @@ class PosixTester(unittest.TestCase):
@unittest.skipUnless(hasattr(posix, 'getcwd'), 'test needs posix.getcwd()')
def test_getcwd_long_pathnames(self):
- if hasattr(posix, 'getcwd'):
- dirname = 'getcwd-test-directory-0123456789abcdef-01234567890abcdef'
- curdir = os.getcwd()
- base_path = os.path.abspath(support.TESTFN) + '.getcwd'
+ dirname = 'getcwd-test-directory-0123456789abcdef-01234567890abcdef'
+ curdir = os.getcwd()
+ base_path = os.path.abspath(support.TESTFN) + '.getcwd'
- try:
- os.mkdir(base_path)
- os.chdir(base_path)
- except:
-# Just returning nothing instead of the SkipTest exception,
-# because the test results in Error in that case.
-# Is that ok?
-# raise unittest.SkipTest("cannot create directory for testing")
- return
-
- def _create_and_do_getcwd(dirname, current_path_length = 0):
- try:
- os.mkdir(dirname)
- except:
- raise unittest.SkipTest("mkdir cannot create directory sufficiently deep for getcwd test")
-
- os.chdir(dirname)
- try:
- os.getcwd()
- if current_path_length < 1027:
- _create_and_do_getcwd(dirname, current_path_length + len(dirname) + 1)
- finally:
- os.chdir('..')
- os.rmdir(dirname)
-
- _create_and_do_getcwd(dirname)
+ try:
+ os.mkdir(base_path)
+ os.chdir(base_path)
+ except:
+ # Just returning nothing instead of the SkipTest exception, because
+ # the test results in Error in that case. Is that ok?
+ # raise unittest.SkipTest("cannot create directory for testing")
+ return
- finally:
- os.chdir(curdir)
- support.rmtree(base_path)
+ def _create_and_do_getcwd(dirname, current_path_length = 0):
+ try:
+ os.mkdir(dirname)
+ except:
+ raise unittest.SkipTest("mkdir cannot create directory sufficiently deep for getcwd test")
+
+ os.chdir(dirname)
+ try:
+ os.getcwd()
+ if current_path_length < 1027:
+ _create_and_do_getcwd(dirname, current_path_length + len(dirname) + 1)
+ finally:
+ os.chdir('..')
+ os.rmdir(dirname)
+
+ _create_and_do_getcwd(dirname)
+
+ finally:
+ os.chdir(curdir)
+ support.rmtree(base_path)
@unittest.skipUnless(hasattr(posix, 'getgrouplist'), "test needs posix.getgrouplist()")
@unittest.skipUnless(hasattr(pwd, 'getpwuid'), "test needs pwd.getpwuid()")
diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py
index 0e7d866485..412849cff3 100644
--- a/Lib/test/test_posixpath.py
+++ b/Lib/test/test_posixpath.py
@@ -186,63 +186,6 @@ class PosixPathTest(unittest.TestCase):
if not f.close():
f.close()
- @staticmethod
- def _create_file(filename):
- with open(filename, 'wb') as f:
- f.write(b'foo')
-
- def test_samefile(self):
- test_fn = support.TESTFN + "1"
- self._create_file(test_fn)
- self.assertTrue(posixpath.samefile(test_fn, test_fn))
- self.assertRaises(TypeError, posixpath.samefile)
-
- @unittest.skipIf(
- sys.platform.startswith('win'),
- "posixpath.samefile does not work on links in Windows")
- @unittest.skipUnless(hasattr(os, "symlink"),
- "Missing symlink implementation")
- def test_samefile_on_links(self):
- test_fn1 = support.TESTFN + "1"
- test_fn2 = support.TESTFN + "2"
- self._create_file(test_fn1)
-
- os.symlink(test_fn1, test_fn2)
- self.assertTrue(posixpath.samefile(test_fn1, test_fn2))
- os.remove(test_fn2)
-
- self._create_file(test_fn2)
- self.assertFalse(posixpath.samefile(test_fn1, test_fn2))
-
-
- def test_samestat(self):
- test_fn = support.TESTFN + "1"
- self._create_file(test_fn)
- test_fns = [test_fn]*2
- stats = map(os.stat, test_fns)
- self.assertTrue(posixpath.samestat(*stats))
-
- @unittest.skipIf(
- sys.platform.startswith('win'),
- "posixpath.samestat does not work on links in Windows")
- @unittest.skipUnless(hasattr(os, "symlink"),
- "Missing symlink implementation")
- def test_samestat_on_links(self):
- test_fn1 = support.TESTFN + "1"
- test_fn2 = support.TESTFN + "2"
- self._create_file(test_fn1)
- test_fns = (test_fn1, test_fn2)
- os.symlink(*test_fns)
- stats = map(os.stat, test_fns)
- self.assertTrue(posixpath.samestat(*stats))
- os.remove(test_fn2)
-
- self._create_file(test_fn2)
- stats = map(os.stat, test_fns)
- self.assertFalse(posixpath.samestat(*stats))
-
- self.assertRaises(TypeError, posixpath.samestat)
-
def test_ismount(self):
self.assertIs(posixpath.ismount("/"), True)
with warnings.catch_warnings():
@@ -595,11 +538,6 @@ class PosixPathTest(unittest.TestCase):
finally:
os.getcwdb = real_getcwdb
- def test_sameopenfile(self):
- fname = support.TESTFN + "1"
- with open(fname, "wb") as a, open(fname, "wb") as b:
- self.assertTrue(posixpath.sameopenfile(a.fileno(), b.fileno()))
-
class PosixCommonTest(test_genericpath.CommonTest, unittest.TestCase):
pathmodule = posixpath
diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py
index a85298e26f..3d364c4595 100644
--- a/Lib/test/test_pprint.py
+++ b/Lib/test/test_pprint.py
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+
import pprint
import test.support
import unittest
@@ -530,6 +532,54 @@ frozenset2({0,
self.assertEqual(pprint.pformat(dict.fromkeys(keys, 0)),
'{%r: 0, %r: 0}' % tuple(sorted(keys, key=id)))
+ def test_str_wrap(self):
+ # pprint tries to wrap strings intelligently
+ fox = 'the quick brown fox jumped over a lazy dog'
+ self.assertEqual(pprint.pformat(fox, width=20), """\
+'the quick brown '
+'fox jumped over '
+'a lazy dog'""")
+ self.assertEqual(pprint.pformat({'a': 1, 'b': fox, 'c': 2},
+ width=26), """\
+{'a': 1,
+ 'b': 'the quick brown '
+ 'fox jumped over '
+ 'a lazy dog',
+ 'c': 2}""")
+ # With some special characters
+ # - \n always triggers a new line in the pprint
+ # - \t and \n are escaped
+ # - non-ASCII is allowed
+ # - an apostrophe doesn't disrupt the pprint
+ special = "Portons dix bons \"whiskys\"\nà l'avocat goujat\t qui fumait au zoo"
+ self.assertEqual(pprint.pformat(special, width=20), """\
+'Portons dix bons '
+'"whiskys"\\n'
+"à l'avocat "
+'goujat\\t qui '
+'fumait au zoo'""")
+ # An unwrappable string is formatted as its repr
+ unwrappable = "x" * 100
+ self.assertEqual(pprint.pformat(unwrappable, width=80), repr(unwrappable))
+ self.assertEqual(pprint.pformat(''), "''")
+ # Check that the pprint is a usable repr
+ special *= 10
+ for width in range(3, 40):
+ formatted = pprint.pformat(special, width=width)
+ self.assertEqual(eval("(" + formatted + ")"), special)
+
+ def test_compact(self):
+ o = ([list(range(i * i)) for i in range(5)] +
+ [list(range(i)) for i in range(6)])
+ expected = """\
+[[], [0], [0, 1, 2, 3],
+ [0, 1, 2, 3, 4, 5, 6, 7, 8],
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15],
+ [], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3],
+ [0, 1, 2, 3, 4]]"""
+ self.assertEqual(pprint.pformat(o, width=48, compact=True), expected)
+
class DottedPrettyPrinter(pprint.PrettyPrinter):
diff --git a/Lib/test/test_print.py b/Lib/test/test_print.py
index 9d6dbea46b..7eea349115 100644
--- a/Lib/test/test_print.py
+++ b/Lib/test/test_print.py
@@ -1,62 +1,55 @@
-"""Test correct operation of the print function.
-"""
-
-# In 2.6, this gives us the behavior we want. In 3.0, it has
-# no function, but it still must parse correctly.
-from __future__ import print_function
-
import unittest
-from test import support
+from io import StringIO
-try:
- # 3.x
- from io import StringIO
-except ImportError:
- # 2.x
- from StringIO import StringIO
+from test import support
NotDefined = object()
# A dispatch table all 8 combinations of providing
-# sep, end, and file
+# sep, end, and file.
# I use this machinery so that I'm not just passing default
-# values to print, I'm either passing or not passing in the
-# arguments
+# values to print, I'm either passing or not passing in the
+# arguments.
dispatch = {
(False, False, False):
- lambda args, sep, end, file: print(*args),
+ lambda args, sep, end, file: print(*args),
(False, False, True):
- lambda args, sep, end, file: print(file=file, *args),
+ lambda args, sep, end, file: print(file=file, *args),
(False, True, False):
- lambda args, sep, end, file: print(end=end, *args),
+ lambda args, sep, end, file: print(end=end, *args),
(False, True, True):
- lambda args, sep, end, file: print(end=end, file=file, *args),
+ lambda args, sep, end, file: print(end=end, file=file, *args),
(True, False, False):
- lambda args, sep, end, file: print(sep=sep, *args),
+ lambda args, sep, end, file: print(sep=sep, *args),
(True, False, True):
- lambda args, sep, end, file: print(sep=sep, file=file, *args),
+ lambda args, sep, end, file: print(sep=sep, file=file, *args),
(True, True, False):
- lambda args, sep, end, file: print(sep=sep, end=end, *args),
+ lambda args, sep, end, file: print(sep=sep, end=end, *args),
(True, True, True):
- lambda args, sep, end, file: print(sep=sep, end=end, file=file, *args),
- }
+ lambda args, sep, end, file: print(sep=sep, end=end, file=file, *args),
+}
+
# Class used to test __str__ and print
class ClassWith__str__:
def __init__(self, x):
self.x = x
+
def __str__(self):
return self.x
+
class TestPrint(unittest.TestCase):
+ """Test correct operation of the print function."""
+
def check(self, expected, args,
- sep=NotDefined, end=NotDefined, file=NotDefined):
+ sep=NotDefined, end=NotDefined, file=NotDefined):
# Capture sys.stdout in a StringIO. Call print with args,
- # and with sep, end, and file, if they're defined. Result
- # must match expected.
+ # and with sep, end, and file, if they're defined. Result
+ # must match expected.
- # Look up the actual function to call, based on if sep, end, and file
- # are defined
+ # Look up the actual function to call, based on if sep, end,
+ # and file are defined.
fn = dispatch[(sep is not NotDefined,
end is not NotDefined,
file is not NotDefined)]
@@ -69,7 +62,7 @@ class TestPrint(unittest.TestCase):
def test_print(self):
def x(expected, args, sep=NotDefined, end=NotDefined):
# Run the test 2 ways: not using file, and using
- # file directed to a StringIO
+ # file directed to a StringIO.
self.check(expected, args, sep=sep, end=end)
@@ -101,11 +94,6 @@ class TestPrint(unittest.TestCase):
x('*\n', (ClassWith__str__('*'),))
x('abc 1\n', (ClassWith__str__('abc'), 1))
-# # 2.x unicode tests
-# x(u'1 2\n', ('1', u'2'))
-# x(u'u\1234\n', (u'u\1234',))
-# x(u' abc 1\n', (' ', ClassWith__str__(u'abc'), 1))
-
# errors
self.assertRaises(TypeError, print, '', sep=3)
self.assertRaises(TypeError, print, '', end=3)
@@ -113,12 +101,14 @@ class TestPrint(unittest.TestCase):
def test_print_flush(self):
# operation of the flush flag
- class filelike():
+ class filelike:
def __init__(self):
self.written = ''
self.flushed = 0
+
def write(self, str):
self.written += str
+
def flush(self):
self.flushed += 1
@@ -130,15 +120,13 @@ class TestPrint(unittest.TestCase):
self.assertEqual(f.flushed, 2)
# ensure exceptions from flush are passed through
- class noflush():
+ class noflush:
def write(self, str):
pass
+
def flush(self):
raise RuntimeError
self.assertRaises(RuntimeError, print, 1, file=noflush(), flush=True)
-def test_main():
- support.run_unittest(TestPrint)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_profile.py b/Lib/test/test_profile.py
index cd7ec58e23..1fc3c42669 100644
--- a/Lib/test/test_profile.py
+++ b/Lib/test/test_profile.py
@@ -3,9 +3,11 @@
import sys
import pstats
import unittest
+import os
from difflib import unified_diff
from io import StringIO
-from test.support import run_unittest
+from test.support import TESTFN, run_unittest, unlink
+from contextlib import contextmanager
import profile
from test.profilee import testfunc, timer
@@ -14,9 +16,13 @@ from test.profilee import testfunc, timer
class ProfileTest(unittest.TestCase):
profilerclass = profile.Profile
+ profilermodule = profile
methodnames = ['print_stats', 'print_callers', 'print_callees']
expected_max_output = ':0(max)'
+ def tearDown(self):
+ unlink(TESTFN)
+
def get_expected_output(self):
return _ProfileOutput
@@ -74,6 +80,19 @@ class ProfileTest(unittest.TestCase):
self.assertIn(self.expected_max_output, res,
"Profiling {0!r} didn't report max:\n{1}".format(stmt, res))
+ def test_run(self):
+ with silent():
+ self.profilermodule.run("int('1')")
+ self.profilermodule.run("int('1')", filename=TESTFN)
+ self.assertTrue(os.path.exists(TESTFN))
+
+ def test_runctx(self):
+ with silent():
+ self.profilermodule.runctx("testfunc()", globals(), locals())
+ self.profilermodule.runctx("testfunc()", globals(), locals(),
+ filename=TESTFN)
+ self.assertTrue(os.path.exists(TESTFN))
+
def regenerate_expected_output(filename, cls):
filename = filename.rstrip('co')
@@ -95,6 +114,14 @@ def regenerate_expected_output(filename, cls):
method, results[i+1]))
f.write('\nif __name__ == "__main__":\n main()\n')
+@contextmanager
+def silent():
+ stdout = sys.stdout
+ try:
+ sys.stdout = StringIO()
+ yield
+ finally:
+ sys.stdout = stdout
def test_main():
run_unittest(ProfileTest)
diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py
index 29297f8841..8916861f5b 100644
--- a/Lib/test/test_pty.py
+++ b/Lib/test/test_pty.py
@@ -187,7 +187,7 @@ class PtyTest(unittest.TestCase):
##debug("Reading from master_fd now that the child has exited")
##try:
## s1 = os.read(master_fd, 1024)
- ##except os.error:
+ ##except OSError:
## pass
##else:
## raise TestFailed("Read from master_fd did not raise exception")
diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py
index f3c1a6a44b..154c08a6a2 100644
--- a/Lib/test/test_py_compile.py
+++ b/Lib/test/test_py_compile.py
@@ -1,7 +1,9 @@
-import imp
+import importlib.util
import os
import py_compile
import shutil
+import stat
+import sys
import tempfile
import unittest
@@ -13,7 +15,7 @@ class PyCompileTests(unittest.TestCase):
self.directory = tempfile.mkdtemp()
self.source_path = os.path.join(self.directory, '_test.py')
self.pyc_path = self.source_path + 'c'
- self.cache_path = imp.cache_from_source(self.source_path)
+ self.cache_path = importlib.util.cache_from_source(self.source_path)
self.cwd_drive = os.path.splitdrive(os.getcwd())[0]
# In these tests we compute relative paths. When using Windows, the
# current working directory path and the 'self.source_path' might be
@@ -35,6 +37,26 @@ class PyCompileTests(unittest.TestCase):
self.assertTrue(os.path.exists(self.pyc_path))
self.assertFalse(os.path.exists(self.cache_path))
+ def test_do_not_overwrite_symlinks(self):
+ # In the face of a cfile argument being a symlink, bail out.
+ # Issue #17222
+ try:
+ os.symlink(self.pyc_path + '.actual', self.pyc_path)
+ except (NotImplementedError, OSError):
+ self.skipTest('need to be able to create a symlink for a file')
+ else:
+ assert os.path.islink(self.pyc_path)
+ with self.assertRaises(FileExistsError):
+ py_compile.compile(self.source_path, self.pyc_path)
+
+ @unittest.skipIf(not os.path.exists(os.devnull) or os.path.isfile(os.devnull),
+ 'requires os.devnull and for it to be a non-regular file')
+ def test_do_not_overwrite_nonregular_files(self):
+ # In the face of a cfile argument being a non-regular file, bail out.
+ # Issue #17222
+ with self.assertRaises(FileExistsError):
+ py_compile.compile(self.source_path, os.devnull)
+
def test_cache_path(self):
py_compile.compile(self.source_path)
self.assertTrue(os.path.exists(self.cache_path))
@@ -54,8 +76,22 @@ class PyCompileTests(unittest.TestCase):
self.assertTrue(os.path.exists(self.pyc_path))
self.assertFalse(os.path.exists(self.cache_path))
-def test_main():
- support.run_unittest(PyCompileTests)
+ @unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
+ 'non-root user required')
+ @unittest.skipIf(os.name == 'nt',
+ 'cannot control directory permissions on Windows')
+ def test_exceptions_propagate(self):
+ # Make sure that exceptions raised thanks to issues with writing
+ # bytecode.
+ # http://bugs.python.org/issue17244
+ mode = os.stat(self.directory)
+ os.chmod(self.directory, stat.S_IREAD)
+ try:
+ with self.assertRaises(IOError):
+ py_compile.compile(self.source_path, self.pyc_path)
+ finally:
+ os.chmod(self.directory, mode.st_mode)
+
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_pyclbr.py b/Lib/test/test_pyclbr.py
index e83989e2d8..88aff898d4 100644
--- a/Lib/test/test_pyclbr.py
+++ b/Lib/test/test_pyclbr.py
@@ -142,7 +142,7 @@ class PyclbrTest(TestCase):
self.checkModule('pyclbr')
self.checkModule('ast')
self.checkModule('doctest', ignore=("TestResults", "_SpoofOut",
- "DocTestCase"))
+ "DocTestCase", '_DocTestSuite'))
self.checkModule('difflib', ignore=("Match",))
def test_decorators(self):
@@ -158,7 +158,7 @@ class PyclbrTest(TestCase):
cm('random', ignore=('Random',)) # from _random import Random as CoreGenerator
cm('cgi', ignore=('log',)) # set with = in module
cm('pickle')
- cm('aifc', ignore=('openfp',)) # set with = in module
+ cm('aifc', ignore=('openfp', '_aifc_params')) # set with = in module
cm('sre_parse', ignore=('dump',)) # from sre_constants import *
cm('pdb')
cm('pydoc')
diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py
index 43f4163dae..cdf28a4d3d 100644
--- a/Lib/test/test_pydoc.py
+++ b/Lib/test/test_pydoc.py
@@ -11,6 +11,7 @@ import re
import string
import test.support
import time
+import types
import unittest
import xml.etree
import textwrap
@@ -29,10 +30,6 @@ try:
except ImportError:
threading = None
-# Just in case sys.modules["test"] has the optional attribute __loader__.
-if hasattr(pydoc_mod, "__loader__"):
- del pydoc_mod.__loader__
-
if test.support.HAVE_DOCSTRINGS:
expected_data_docstrings = (
'dictionary for instance variables (if defined)',
@@ -212,6 +209,75 @@ missing_pattern = "no Python documentation found for '%s'"
# output pattern for module with bad imports
badimport_pattern = "problem in %s - ImportError: No module named %r"
+expected_dynamicattribute_pattern = """
+Help on class DA in module %s:
+
+class DA(builtins.object)
+ | Data descriptors defined here:
+ |\x20\x20
+ | __dict__%s
+ |\x20\x20
+ | __weakref__%s
+ |\x20\x20
+ | ham
+ |\x20\x20
+ | ----------------------------------------------------------------------
+ | Data and other attributes inherited from Meta:
+ |\x20\x20
+ | ham = 'spam'
+""".strip()
+
+expected_virtualattribute_pattern1 = """
+Help on class Class in module %s:
+
+class Class(builtins.object)
+ | Data and other attributes inherited from Meta:
+ |\x20\x20
+ | LIFE = 42
+""".strip()
+
+expected_virtualattribute_pattern2 = """
+Help on class Class1 in module %s:
+
+class Class1(builtins.object)
+ | Data and other attributes inherited from Meta1:
+ |\x20\x20
+ | one = 1
+""".strip()
+
+expected_virtualattribute_pattern3 = """
+Help on class Class2 in module %s:
+
+class Class2(Class1)
+ | Method resolution order:
+ | Class2
+ | Class1
+ | builtins.object
+ |\x20\x20
+ | Data and other attributes inherited from Meta1:
+ |\x20\x20
+ | one = 1
+ |\x20\x20
+ | ----------------------------------------------------------------------
+ | Data and other attributes inherited from Meta3:
+ |\x20\x20
+ | three = 3
+ |\x20\x20
+ | ----------------------------------------------------------------------
+ | Data and other attributes inherited from Meta2:
+ |\x20\x20
+ | two = 2
+""".strip()
+
+expected_missingattribute_pattern = """
+Help on class C in module %s:
+
+class C(builtins.object)
+ | Data and other attributes defined here:
+ |\x20\x20
+ | here = 'present!'
+""".strip()
+
def run_pydoc(module_name, *args, **env):
"""
Runs pydoc on the specified module. Returns the stripped
@@ -421,6 +487,31 @@ class PydocDocTest(unittest.TestCase):
synopsis = pydoc.synopsis(TESTFN, {})
self.assertEqual(synopsis, 'line 1: h\xe9')
+ def test_splitdoc_with_description(self):
+ example_string = "I Am A Doc\n\n\nHere is my description"
+ self.assertEqual(pydoc.splitdoc(example_string),
+ ('I Am A Doc', '\nHere is my description'))
+
+ def test_is_object_or_method(self):
+ doc = pydoc.Doc()
+ # Bound Method
+ self.assertTrue(pydoc._is_some_method(doc.fail))
+ # Method Descriptor
+ self.assertTrue(pydoc._is_some_method(int.__add__))
+ # String
+ self.assertFalse(pydoc._is_some_method("I am not a method"))
+
+ def test_is_package_when_not_package(self):
+ with test.support.temp_cwd() as test_dir:
+ self.assertFalse(pydoc.ispackage(test_dir))
+
+ def test_is_package_when_is_package(self):
+ with test.support.temp_cwd() as test_dir:
+ init_path = os.path.join(test_dir, '__init__.py')
+ open(init_path, 'w').close()
+ self.assertTrue(pydoc.ispackage(test_dir))
+ os.remove(init_path)
+
def test_allmethods(self):
# issue 17476: allmethods was no longer returning unbound methods.
# This test is a bit fragile in the face of changes to object and type,
@@ -615,6 +706,127 @@ class TestHelper(unittest.TestCase):
self.assertEqual(sorted(pydoc.Helper.keywords),
sorted(keyword.kwlist))
+class PydocWithMetaClasses(unittest.TestCase):
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
+ def test_DynamicClassAttribute(self):
+ class Meta(type):
+ def __getattr__(self, name):
+ if name == 'ham':
+ return 'spam'
+ return super().__getattr__(name)
+ class DA(metaclass=Meta):
+ @types.DynamicClassAttribute
+ def ham(self):
+ return 'eggs'
+ expected_text_data_docstrings = tuple('\n | ' + s if s else ''
+ for s in expected_data_docstrings)
+ output = StringIO()
+ helper = pydoc.Helper(output=output)
+ helper(DA)
+ expected_text = expected_dynamicattribute_pattern % (
+ (__name__,) + expected_text_data_docstrings[:2])
+ result = output.getvalue().strip()
+ if result != expected_text:
+ print_diffs(expected_text, result)
+ self.fail("outputs are not equal, see diff above")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
+ def test_virtualClassAttributeWithOneMeta(self):
+ class Meta(type):
+ def __dir__(cls):
+ return ['__class__', '__module__', '__name__', 'LIFE']
+ def __getattr__(self, name):
+ if name =='LIFE':
+ return 42
+ return super().__getattr(name)
+ class Class(metaclass=Meta):
+ pass
+ output = StringIO()
+ helper = pydoc.Helper(output=output)
+ helper(Class)
+ expected_text = expected_virtualattribute_pattern1 % __name__
+ result = output.getvalue().strip()
+ if result != expected_text:
+ print_diffs(expected_text, result)
+ self.fail("outputs are not equal, see diff above")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
+ def test_virtualClassAttributeWithTwoMeta(self):
+ class Meta1(type):
+ def __dir__(cls):
+ return ['__class__', '__module__', '__name__', 'one']
+ def __getattr__(self, name):
+ if name =='one':
+ return 1
+ return super().__getattr__(name)
+ class Meta2(type):
+ def __dir__(cls):
+ return ['__class__', '__module__', '__name__', 'two']
+ def __getattr__(self, name):
+ if name =='two':
+ return 2
+ return super().__getattr__(name)
+ class Meta3(Meta1, Meta2):
+ def __dir__(cls):
+ return list(sorted(set(
+ ['__class__', '__module__', '__name__', 'three'] +
+ Meta1.__dir__(cls) + Meta2.__dir__(cls))))
+ def __getattr__(self, name):
+ if name =='three':
+ return 3
+ return super().__getattr__(name)
+ class Class1(metaclass=Meta1):
+ pass
+ class Class2(Class1, metaclass=Meta3):
+ pass
+ fail1 = fail2 = False
+ output = StringIO()
+ helper = pydoc.Helper(output=output)
+ helper(Class1)
+ expected_text1 = expected_virtualattribute_pattern2 % __name__
+ result1 = output.getvalue().strip()
+ if result1 != expected_text1:
+ print_diffs(expected_text1, result1)
+ fail1 = True
+ output = StringIO()
+ helper = pydoc.Helper(output=output)
+ helper(Class2)
+ expected_text2 = expected_virtualattribute_pattern3 % __name__
+ result2 = output.getvalue().strip()
+ if result2 != expected_text2:
+ print_diffs(expected_text2, result2)
+ fail2 = True
+ if fail1 or fail2:
+ self.fail("outputs are not equal, see diff above")
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ @unittest.skipIf(hasattr(sys, 'gettrace') and sys.gettrace(),
+ 'trace function introduces __locals__ unexpectedly')
+ def test_buggy_dir(self):
+ class M(type):
+ def __dir__(cls):
+ return ['__class__', '__name__', 'missing', 'here']
+ class C(metaclass=M):
+ here = 'present!'
+ output = StringIO()
+ helper = pydoc.Helper(output=output)
+ helper(C)
+ expected_text = expected_missingattribute_pattern % __name__
+ result = output.getvalue().strip()
+ if result != expected_text:
+ print_diffs(expected_text, result)
+ self.fail("outputs are not equal, see diff above")
+
@reap_threads
def test_main():
try:
@@ -624,6 +836,7 @@ def test_main():
PydocServerTest,
PydocUrlHandlerTest,
TestHelper,
+ PydocWithMetaClasses,
)
finally:
reap_children()
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index facddb1ebd..49a3f7bb00 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -1,10 +1,12 @@
#!/usr/bin/env python3
import unittest
+import unittest.mock
import random
import time
import pickle
import warnings
+from functools import partial
from math import log, exp, pi, fsum, sin
from test import support
@@ -46,6 +48,48 @@ class TestBasicOps:
self.assertRaises(TypeError, self.gen.seed, 1, 2, 3, 4)
self.assertRaises(TypeError, type(self.gen), [])
+ @unittest.mock.patch('random._urandom') # os.urandom
+ def test_seed_when_randomness_source_not_found(self, urandom_mock):
+ # Random.seed() uses time.time() when an operating system specific
+ # randomness source is not found. To test this on machines were it
+ # exists, run the above test, test_seedargs(), again after mocking
+ # os.urandom() so that it raises the exception expected when the
+ # randomness source is not available.
+ urandom_mock.side_effect = NotImplementedError
+ self.test_seedargs()
+
+ def test_shuffle(self):
+ shuffle = self.gen.shuffle
+ lst = []
+ shuffle(lst)
+ self.assertEqual(lst, [])
+ lst = [37]
+ shuffle(lst)
+ self.assertEqual(lst, [37])
+ seqs = [list(range(n)) for n in range(10)]
+ shuffled_seqs = [list(range(n)) for n in range(10)]
+ for shuffled_seq in shuffled_seqs:
+ shuffle(shuffled_seq)
+ for (seq, shuffled_seq) in zip(seqs, shuffled_seqs):
+ self.assertEqual(len(seq), len(shuffled_seq))
+ self.assertEqual(set(seq), set(shuffled_seq))
+ # The above tests all would pass if the shuffle was a
+ # no-op. The following non-deterministic test covers that. It
+ # asserts that the shuffled sequence of 1000 distinct elements
+ # must be different from the original one. Although there is
+ # mathematically a non-zero probability that this could
+ # actually happen in a genuinely random shuffle, it is
+ # completely negligible, given that the number of possible
+ # permutations of 1000 objects is 1000! (factorial of 1000),
+ # which is considerably larger than the number of atoms in the
+ # universe...
+ lst = list(range(1000))
+ shuffled_lst = list(range(1000))
+ shuffle(shuffled_lst)
+ self.assertTrue(lst != shuffled_lst)
+ shuffle(lst)
+ self.assertTrue(lst != shuffled_lst)
+
def test_choice(self):
choice = self.gen.choice
with self.assertRaises(IndexError):
@@ -65,6 +109,8 @@ class TestBasicOps:
self.assertEqual(len(uniq), k)
self.assertTrue(uniq <= set(population))
self.assertEqual(self.gen.sample([], 0), []) # test edge case N==k==0
+ # Exception raised if size of sample exceeds that of population
+ self.assertRaises(ValueError, self.gen.sample, population, N+1)
def test_sample_distribution(self):
# For the entire allowable range of 0 <= k <= N, validate that
@@ -205,6 +251,25 @@ class SystemRandom_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(set(range(start,stop)),
set([self.gen.randrange(start,stop) for i in range(100)]))
+ def test_randrange_nonunit_step(self):
+ rint = self.gen.randrange(0, 10, 2)
+ self.assertIn(rint, (0, 2, 4, 6, 8))
+ rint = self.gen.randrange(0, 2, 2)
+ self.assertEqual(rint, 0)
+
+ def test_randrange_errors(self):
+ raises = partial(self.assertRaises, ValueError, self.gen.randrange)
+ # Empty range
+ raises(3, 3)
+ raises(-721)
+ raises(0, 100, -12)
+ # Non-integer start/stop
+ raises(3.14159)
+ raises(0, 2.71828)
+ # Zero and non-integer step
+ raises(0, 42, 0)
+ raises(0, 42, 3.14159)
+
def test_genrandbits(self):
# Verify ranges
for k in range(1, 1000):
@@ -274,6 +339,16 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
# Last element s/b an int also
self.assertRaises(TypeError, self.gen.setstate, (2, (0,)*624+('a',), None))
+ # Little trick to make "tuple(x % (2**32) for x in internalstate)"
+ # raise ValueError. I cannot think of a simple way to achieve this, so
+ # I am opting for using a generator as the middle argument of setstate
+ # which attempts to cast a NaN to integer.
+ state_values = self.gen.getstate()[1]
+ state_values = list(state_values)
+ state_values[-1] = float('nan')
+ state = (int(x) for x in state_values)
+ self.assertRaises(TypeError, self.gen.setstate, (2, state, None))
+
def test_referenceImplementation(self):
# Compare the python implementation with results from the original
# code. Create 2000 53-bit precision random floats. Compare only
@@ -413,6 +488,38 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase):
self.assertEqual(k, numbits) # note the stronger assertion
self.assertTrue(2**k > n > 2**(k-1)) # note the stronger assertion
+ @unittest.mock.patch('random.Random.random')
+ def test_randbelow_overriden_random(self, random_mock):
+ # Random._randbelow() can only use random() when the built-in one
+ # has been overridden but no new getrandbits() method was supplied.
+ random_mock.side_effect = random.SystemRandom().random
+ maxsize = 1<<random.BPF
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning)
+ # Population range too large (n >= maxsize)
+ self.gen._randbelow(maxsize+1, maxsize = maxsize)
+ self.gen._randbelow(5640, maxsize = maxsize)
+
+ # This might be going too far to test a single line, but because of our
+ # noble aim of achieving 100% test coverage we need to write a case in
+ # which the following line in Random._randbelow() gets executed:
+ #
+ # rem = maxsize % n
+ # limit = (maxsize - rem) / maxsize
+ # r = random()
+ # while r >= limit:
+ # r = random() # <== *This line* <==<
+ #
+ # Therefore, to guarantee that the while loop is executed at least
+ # once, we need to mock random() so that it returns a number greater
+ # than 'limit' the first time it gets called.
+
+ n = 42
+ epsilon = 0.01
+ limit = (maxsize - (maxsize % n)) / maxsize
+ random_mock.side_effect = [limit + epsilon, limit - epsilon]
+ self.gen._randbelow(n, maxsize = maxsize)
+
def test_randrange_bug_1590891(self):
start = 1000000000000
stop = -100000000000000000000
@@ -530,6 +637,106 @@ class TestDistributions(unittest.TestCase):
random.vonmisesvariate(0, 1e15)
random.vonmisesvariate(0, 1e100)
+ def test_gammavariate_errors(self):
+ # Both alpha and beta must be > 0.0
+ self.assertRaises(ValueError, random.gammavariate, -1, 3)
+ self.assertRaises(ValueError, random.gammavariate, 0, 2)
+ self.assertRaises(ValueError, random.gammavariate, 2, 0)
+ self.assertRaises(ValueError, random.gammavariate, 1, -3)
+
+ @unittest.mock.patch('random.Random.random')
+ def test_gammavariate_full_code_coverage(self, random_mock):
+ # There are three different possibilities in the current implementation
+ # of random.gammavariate(), depending on the value of 'alpha'. What we
+ # are going to do here is to fix the values returned by random() to
+ # generate test cases that provide 100% line coverage of the method.
+
+ # #1: alpha > 1.0: we want the first random number to be outside the
+ # [1e-7, .9999999] range, so that the continue statement executes
+ # once. The values of u1 and u2 will be 0.5 and 0.3, respectively.
+ random_mock.side_effect = [1e-8, 0.5, 0.3]
+ returned_value = random.gammavariate(1.1, 2.3)
+ self.assertAlmostEqual(returned_value, 2.53)
+
+ # #2: alpha == 1: first random number less than 1e-7 to that the body
+ # of the while loop executes once. Then random.random() returns 0.45,
+ # which causes while to stop looping and the algorithm to terminate.
+ random_mock.side_effect = [1e-8, 0.45]
+ returned_value = random.gammavariate(1.0, 3.14)
+ self.assertAlmostEqual(returned_value, 2.507314166123803)
+
+ # #3: 0 < alpha < 1. This is the most complex region of code to cover,
+ # as there are multiple if-else statements. Let's take a look at the
+ # source code, and determine the values that we need accordingly:
+ #
+ # while 1:
+ # u = random()
+ # b = (_e + alpha)/_e
+ # p = b*u
+ # if p <= 1.0: # <=== (A)
+ # x = p ** (1.0/alpha)
+ # else: # <=== (B)
+ # x = -_log((b-p)/alpha)
+ # u1 = random()
+ # if p > 1.0: # <=== (C)
+ # if u1 <= x ** (alpha - 1.0): # <=== (D)
+ # break
+ # elif u1 <= _exp(-x): # <=== (E)
+ # break
+ # return x * beta
+ #
+ # First, we want (A) to be True. For that we need that:
+ # b*random() <= 1.0
+ # r1 = random() <= 1.0 / b
+ #
+ # We now get to the second if-else branch, and here, since p <= 1.0,
+ # (C) is False and we take the elif branch, (E). For it to be True,
+ # so that the break is executed, we need that:
+ # r2 = random() <= _exp(-x)
+ # r2 <= _exp(-(p ** (1.0/alpha)))
+ # r2 <= _exp(-((b*r1) ** (1.0/alpha)))
+
+ _e = random._e
+ _exp = random._exp
+ _log = random._log
+ alpha = 0.35
+ beta = 1.45
+ b = (_e + alpha)/_e
+ epsilon = 0.01
+
+ r1 = 0.8859296441566 # 1.0 / b
+ r2 = 0.3678794411714 # _exp(-((b*r1) ** (1.0/alpha)))
+
+ # These four "random" values result in the following trace:
+ # (A) True, (E) False --> [next iteration of while]
+ # (A) True, (E) True --> [while loop breaks]
+ random_mock.side_effect = [r1, r2 + epsilon, r1, r2]
+ returned_value = random.gammavariate(alpha, beta)
+ self.assertAlmostEqual(returned_value, 1.4499999999997544)
+
+ # Let's now make (A) be False. If this is the case, when we get to the
+ # second if-else 'p' is greater than 1, so (C) evaluates to True. We
+ # now encounter a second if statement, (D), which in order to execute
+ # must satisfy the following condition:
+ # r2 <= x ** (alpha - 1.0)
+ # r2 <= (-_log((b-p)/alpha)) ** (alpha - 1.0)
+ # r2 <= (-_log((b-(b*r1))/alpha)) ** (alpha - 1.0)
+ r1 = 0.8959296441566 # (1.0 / b) + epsilon -- so that (A) is False
+ r2 = 0.9445400408898141
+
+ # And these four values result in the following trace:
+ # (B) and (C) True, (D) False --> [next iteration of while]
+ # (B) and (C) True, (D) True [while loop breaks]
+ random_mock.side_effect = [r1, r2 + epsilon, r1, r2]
+ returned_value = random.gammavariate(alpha, beta)
+ self.assertAlmostEqual(returned_value, 1.5830349561760781)
+
+ @unittest.mock.patch('random.Random.gammavariate')
+ def test_betavariate_return_zero(self, gammavariate_mock):
+ # betavariate() returns zero when the Gamma distribution
+ # that it uses internally returns this same value.
+ gammavariate_mock.return_value = 0.0
+ self.assertEqual(0.0, random.betavariate(2.71828, 3.14159))
class TestModule(unittest.TestCase):
def testMagicConstants(self):
diff --git a/Lib/test/test_range.py b/Lib/test/test_range.py
index 2a13bfeabd..f088387c33 100644
--- a/Lib/test/test_range.py
+++ b/Lib/test/test_range.py
@@ -313,7 +313,7 @@ class RangeTest(unittest.TestCase):
self.assertRaises(TypeError, range, IN())
# Test use of user-defined classes in slice indices.
- self.assertEqual(list(range(10)[:I(5)]), list(range(5)))
+ self.assertEqual(range(10)[:I(5)], range(5))
with self.assertRaises(RuntimeError):
range(0, 10)[:IX()]
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index f093812442..8d63fac0be 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -3,10 +3,12 @@ from test.support import verbose, run_unittest, gc_collect, bigmemtest, _2G, \
import io
import re
from re import Scanner
+import sre_compile
import sre_constants
import sys
import string
import traceback
+import unittest
from weakref import proxy
# Misc tests from Tim Peters' re.doc
@@ -15,10 +17,26 @@ from weakref import proxy
# what you're doing. Some of these tests were carefully modeled to
# cover most of the code.
-import unittest
+class S(str):
+ def __getitem__(self, index):
+ return S(super().__getitem__(index))
+
+class B(bytes):
+ def __getitem__(self, index):
+ return B(super().__getitem__(index))
class ReTests(unittest.TestCase):
+ def assertTypedEqual(self, actual, expect, msg=None):
+ self.assertEqual(actual, expect, msg)
+ def recurse(actual, expect):
+ if isinstance(expect, (tuple, list)):
+ for x, y in zip(actual, expect):
+ recurse(x, y)
+ else:
+ self.assertIs(type(actual), type(expect), msg)
+ recurse(actual, expect)
+
def test_keep_buffer(self):
# See bug 14212
b = bytearray(b'x')
@@ -53,6 +71,15 @@ class ReTests(unittest.TestCase):
return str(int_value + 1)
def test_basic_re_sub(self):
+ self.assertTypedEqual(re.sub('y', 'a', 'xyz'), 'xaz')
+ self.assertTypedEqual(re.sub('y', S('a'), S('xyz')), 'xaz')
+ self.assertTypedEqual(re.sub(b'y', b'a', b'xyz'), b'xaz')
+ self.assertTypedEqual(re.sub(b'y', B(b'a'), B(b'xyz')), b'xaz')
+ self.assertTypedEqual(re.sub(b'y', bytearray(b'a'), bytearray(b'xyz')), b'xaz')
+ self.assertTypedEqual(re.sub(b'y', memoryview(b'a'), memoryview(b'xyz')), b'xaz')
+ for y in ("\xe0", "\u0430", "\U0001d49c"):
+ self.assertEqual(re.sub(y, 'a', 'x%sz' % y), 'xaz')
+
self.assertEqual(re.sub("(?i)b+", "x", "bbbb BBBB"), 'x x')
self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y'),
'9.3 -3 24x100y')
@@ -210,10 +237,29 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.subn("b*", "x", "xyz", 2), ('xxxyz', 2))
def test_re_split(self):
- self.assertEqual(re.split(":", ":a:b::c"), ['', 'a', 'b', '', 'c'])
- self.assertEqual(re.split(":*", ":a:b::c"), ['', 'a', 'b', 'c'])
- self.assertEqual(re.split("(:*)", ":a:b::c"),
- ['', ':', 'a', ':', 'b', '::', 'c'])
+ for string in ":a:b::c", S(":a:b::c"):
+ self.assertTypedEqual(re.split(":", string),
+ ['', 'a', 'b', '', 'c'])
+ self.assertTypedEqual(re.split(":*", string),
+ ['', 'a', 'b', 'c'])
+ self.assertTypedEqual(re.split("(:*)", string),
+ ['', ':', 'a', ':', 'b', '::', 'c'])
+ for string in (b":a:b::c", B(b":a:b::c"), bytearray(b":a:b::c"),
+ memoryview(b":a:b::c")):
+ self.assertTypedEqual(re.split(b":", string),
+ [b'', b'a', b'b', b'', b'c'])
+ self.assertTypedEqual(re.split(b":*", string),
+ [b'', b'a', b'b', b'c'])
+ self.assertTypedEqual(re.split(b"(:*)", string),
+ [b'', b':', b'a', b':', b'b', b'::', b'c'])
+ for a, b, c in ("\xe0\xdf\xe7", "\u0430\u0431\u0432",
+ "\U0001d49c\U0001d49e\U0001d4b5"):
+ string = ":%s:%s::%s" % (a, b, c)
+ self.assertEqual(re.split(":", string), ['', a, b, '', c])
+ self.assertEqual(re.split(":*", string), ['', a, b, c])
+ self.assertEqual(re.split("(:*)", string),
+ ['', ':', a, ':', b, '::', c])
+
self.assertEqual(re.split("(?::*)", ":a:b::c"), ['', 'a', 'b', 'c'])
self.assertEqual(re.split("(:)*", ":a:b::c"),
['', ':', 'a', ':', 'b', ':', 'c'])
@@ -235,22 +281,53 @@ class ReTests(unittest.TestCase):
def test_re_findall(self):
self.assertEqual(re.findall(":+", "abc"), [])
- self.assertEqual(re.findall(":+", "a:b::c:::d"), [":", "::", ":::"])
- self.assertEqual(re.findall("(:+)", "a:b::c:::d"), [":", "::", ":::"])
- self.assertEqual(re.findall("(:)(:*)", "a:b::c:::d"), [(":", ""),
- (":", ":"),
- (":", "::")])
+ for string in "a:b::c:::d", S("a:b::c:::d"):
+ self.assertTypedEqual(re.findall(":+", string),
+ [":", "::", ":::"])
+ self.assertTypedEqual(re.findall("(:+)", string),
+ [":", "::", ":::"])
+ self.assertTypedEqual(re.findall("(:)(:*)", string),
+ [(":", ""), (":", ":"), (":", "::")])
+ for string in (b"a:b::c:::d", B(b"a:b::c:::d"), bytearray(b"a:b::c:::d"),
+ memoryview(b"a:b::c:::d")):
+ self.assertTypedEqual(re.findall(b":+", string),
+ [b":", b"::", b":::"])
+ self.assertTypedEqual(re.findall(b"(:+)", string),
+ [b":", b"::", b":::"])
+ self.assertTypedEqual(re.findall(b"(:)(:*)", string),
+ [(b":", b""), (b":", b":"), (b":", b"::")])
+ for x in ("\xe0", "\u0430", "\U0001d49c"):
+ xx = x * 2
+ xxx = x * 3
+ string = "a%sb%sc%sd" % (x, xx, xxx)
+ self.assertEqual(re.findall("%s+" % x, string), [x, xx, xxx])
+ self.assertEqual(re.findall("(%s+)" % x, string), [x, xx, xxx])
+ self.assertEqual(re.findall("(%s)(%s*)" % (x, x), string),
+ [(x, ""), (x, x), (x, xx)])
def test_bug_117612(self):
self.assertEqual(re.findall(r"(a|(b))", "aba"),
[("a", ""),("b", "b"),("a", "")])
def test_re_match(self):
- self.assertEqual(re.match('a', 'a').groups(), ())
- self.assertEqual(re.match('(a)', 'a').groups(), ('a',))
- self.assertEqual(re.match(r'(a)', 'a').group(0), 'a')
- self.assertEqual(re.match(r'(a)', 'a').group(1), 'a')
- self.assertEqual(re.match(r'(a)', 'a').group(1, 1), ('a', 'a'))
+ for string in 'a', S('a'):
+ self.assertEqual(re.match('a', string).groups(), ())
+ self.assertEqual(re.match('(a)', string).groups(), ('a',))
+ self.assertEqual(re.match('(a)', string).group(0), 'a')
+ self.assertEqual(re.match('(a)', string).group(1), 'a')
+ self.assertEqual(re.match('(a)', string).group(1, 1), ('a', 'a'))
+ for string in b'a', B(b'a'), bytearray(b'a'), memoryview(b'a'):
+ self.assertEqual(re.match(b'a', string).groups(), ())
+ self.assertEqual(re.match(b'(a)', string).groups(), (b'a',))
+ self.assertEqual(re.match(b'(a)', string).group(0), b'a')
+ self.assertEqual(re.match(b'(a)', string).group(1), b'a')
+ self.assertEqual(re.match(b'(a)', string).group(1, 1), (b'a', b'a'))
+ for a in ("\xe0", "\u0430", "\U0001d49c"):
+ self.assertEqual(re.match(a, a).groups(), ())
+ self.assertEqual(re.match('(%s)' % a, a).groups(), (a,))
+ self.assertEqual(re.match('(%s)' % a, a).group(0), a)
+ self.assertEqual(re.match('(%s)' % a, a).group(1), a)
+ self.assertEqual(re.match('(%s)' % a, a).group(1, 1), (a, a))
pat = re.compile('((a)|(b))(c)?')
self.assertEqual(pat.match('a').groups(), ('a', 'a', None, None))
@@ -1053,6 +1130,28 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.compile(pattern, re.S).findall(b'xyz'),
[b'xyz'], msg=pattern)
+ def test_match_repr(self):
+ for string in '[abracadabra]', S('[abracadabra]'):
+ m = re.search(r'(.+)(.*?)\1', string)
+ self.assertEqual(repr(m), "<%s.%s object; "
+ "span=(1, 12), match='abracadabra'>" %
+ (type(m).__module__, type(m).__qualname__))
+ for string in (b'[abracadabra]', B(b'[abracadabra]'),
+ bytearray(b'[abracadabra]'),
+ memoryview(b'[abracadabra]')):
+ m = re.search(rb'(.+)(.*?)\1', string)
+ self.assertEqual(repr(m), "<%s.%s object; "
+ "span=(1, 12), match=b'abracadabra'>" %
+ (type(m).__module__, type(m).__qualname__))
+
+ first, second = list(re.finditer("(aa)|(bb)", "aa bb"))
+ self.assertEqual(repr(first), "<%s.%s object; "
+ "span=(0, 2), match='aa'>" %
+ (type(second).__module__, type(first).__qualname__))
+ self.assertEqual(repr(second), "<%s.%s object; "
+ "span=(3, 5), match='bb'>" %
+ (type(second).__module__, type(second).__qualname__))
+
def test_bug_2537(self):
# issue 2537: empty submatches
@@ -1064,6 +1163,22 @@ class ReTests(unittest.TestCase):
self.assertEqual(m.group(1), "")
self.assertEqual(m.group(2), "y")
+
+class ImplementationTest(unittest.TestCase):
+ """
+ Test implementation details of the re module.
+ """
+
+ def test_overlap_table(self):
+ f = sre_compile._generate_overlap_table
+ self.assertEqual(f(""), [])
+ self.assertEqual(f("a"), [0])
+ self.assertEqual(f("abcd"), [0, 0, 0, 0])
+ self.assertEqual(f("aaaa"), [0, 1, 2, 3])
+ self.assertEqual(f("ababba"), [0, 0, 1, 2, 0, 1])
+ self.assertEqual(f("abcabdac"), [0, 0, 0, 1, 2, 0, 1, 0])
+
+
def run_re_tests():
from test.re_tests import tests, SUCCEED, FAIL, SYNTAX_ERROR
if verbose:
@@ -1193,7 +1308,7 @@ def run_re_tests():
def test_main():
- run_unittest(ReTests)
+ run_unittest(__name__)
run_re_tests()
if __name__ == "__main__":
diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py
new file mode 100644
index 0000000000..a398a4f836
--- /dev/null
+++ b/Lib/test/test_regrtest.py
@@ -0,0 +1,275 @@
+"""
+Tests of regrtest.py.
+"""
+
+import argparse
+import faulthandler
+import getopt
+import os.path
+import unittest
+from test import regrtest, support
+
+class ParseArgsTestCase(unittest.TestCase):
+
+ """Test regrtest's argument parsing."""
+
+ def checkError(self, args, msg):
+ with support.captured_stderr() as err, self.assertRaises(SystemExit):
+ regrtest._parse_args(args)
+ self.assertIn(msg, err.getvalue())
+
+ def test_help(self):
+ for opt in '-h', '--help':
+ with self.subTest(opt=opt):
+ with support.captured_stdout() as out, \
+ self.assertRaises(SystemExit):
+ regrtest._parse_args([opt])
+ self.assertIn('Run Python regression tests.', out.getvalue())
+
+ @unittest.skipUnless(hasattr(faulthandler, 'dump_traceback_later'),
+ "faulthandler.dump_traceback_later() required")
+ def test_timeout(self):
+ ns = regrtest._parse_args(['--timeout', '4.2'])
+ self.assertEqual(ns.timeout, 4.2)
+ self.checkError(['--timeout'], 'expected one argument')
+ self.checkError(['--timeout', 'foo'], 'invalid float value')
+
+ def test_wait(self):
+ ns = regrtest._parse_args(['--wait'])
+ self.assertTrue(ns.wait)
+
+ def test_slaveargs(self):
+ ns = regrtest._parse_args(['--slaveargs', '[[], {}]'])
+ self.assertEqual(ns.slaveargs, '[[], {}]')
+ self.checkError(['--slaveargs'], 'expected one argument')
+
+ def test_start(self):
+ for opt in '-S', '--start':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, 'foo'])
+ self.assertEqual(ns.start, 'foo')
+ self.checkError([opt], 'expected one argument')
+
+ def test_verbose(self):
+ ns = regrtest._parse_args(['-v'])
+ self.assertEqual(ns.verbose, 1)
+ ns = regrtest._parse_args(['-vvv'])
+ self.assertEqual(ns.verbose, 3)
+ ns = regrtest._parse_args(['--verbose'])
+ self.assertEqual(ns.verbose, 1)
+ ns = regrtest._parse_args(['--verbose'] * 3)
+ self.assertEqual(ns.verbose, 3)
+ ns = regrtest._parse_args([])
+ self.assertEqual(ns.verbose, 0)
+
+ def test_verbose2(self):
+ for opt in '-w', '--verbose2':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.verbose2)
+
+ def test_verbose3(self):
+ for opt in '-W', '--verbose3':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.verbose3)
+
+ def test_quiet(self):
+ for opt in '-q', '--quiet':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.quiet)
+ self.assertEqual(ns.verbose, 0)
+
+ def test_slow(self):
+ for opt in '-o', '--slow':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.print_slow)
+
+ def test_header(self):
+ ns = regrtest._parse_args(['--header'])
+ self.assertTrue(ns.header)
+
+ def test_randomize(self):
+ for opt in '-r', '--randomize':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.randomize)
+
+ def test_randseed(self):
+ ns = regrtest._parse_args(['--randseed', '12345'])
+ self.assertEqual(ns.random_seed, 12345)
+ self.assertTrue(ns.randomize)
+ self.checkError(['--randseed'], 'expected one argument')
+ self.checkError(['--randseed', 'foo'], 'invalid int value')
+
+ def test_fromfile(self):
+ for opt in '-f', '--fromfile':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, 'foo'])
+ self.assertEqual(ns.fromfile, 'foo')
+ self.checkError([opt], 'expected one argument')
+ self.checkError([opt, 'foo', '-s'], "don't go together")
+
+ def test_exclude(self):
+ for opt in '-x', '--exclude':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.exclude)
+
+ def test_single(self):
+ for opt in '-s', '--single':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.single)
+ self.checkError([opt, '-f', 'foo'], "don't go together")
+
+ def test_match(self):
+ for opt in '-m', '--match':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, 'pattern'])
+ self.assertEqual(ns.match_tests, 'pattern')
+ self.checkError([opt], 'expected one argument')
+
+ def test_failfast(self):
+ for opt in '-G', '--failfast':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, '-v'])
+ self.assertTrue(ns.failfast)
+ ns = regrtest._parse_args([opt, '-W'])
+ self.assertTrue(ns.failfast)
+ self.checkError([opt], '-G/--failfast needs either -v or -W')
+
+ def test_use(self):
+ for opt in '-u', '--use':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, 'gui,network'])
+ self.assertEqual(ns.use_resources, ['gui', 'network'])
+ ns = regrtest._parse_args([opt, 'gui,none,network'])
+ self.assertEqual(ns.use_resources, ['network'])
+ expected = list(regrtest.RESOURCE_NAMES)
+ expected.remove('gui')
+ ns = regrtest._parse_args([opt, 'all,-gui'])
+ self.assertEqual(ns.use_resources, expected)
+ self.checkError([opt], 'expected one argument')
+ self.checkError([opt, 'foo'], 'invalid resource')
+
+ def test_memlimit(self):
+ for opt in '-M', '--memlimit':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, '4G'])
+ self.assertEqual(ns.memlimit, '4G')
+ self.checkError([opt], 'expected one argument')
+
+ def test_testdir(self):
+ ns = regrtest._parse_args(['--testdir', 'foo'])
+ self.assertEqual(ns.testdir, os.path.join(support.SAVEDCWD, 'foo'))
+ self.checkError(['--testdir'], 'expected one argument')
+
+ def test_runleaks(self):
+ for opt in '-L', '--runleaks':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.runleaks)
+
+ def test_huntrleaks(self):
+ for opt in '-R', '--huntrleaks':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, ':'])
+ self.assertEqual(ns.huntrleaks, (5, 4, 'reflog.txt'))
+ ns = regrtest._parse_args([opt, '6:'])
+ self.assertEqual(ns.huntrleaks, (6, 4, 'reflog.txt'))
+ ns = regrtest._parse_args([opt, ':3'])
+ self.assertEqual(ns.huntrleaks, (5, 3, 'reflog.txt'))
+ ns = regrtest._parse_args([opt, '6:3:leaks.log'])
+ self.assertEqual(ns.huntrleaks, (6, 3, 'leaks.log'))
+ self.checkError([opt], 'expected one argument')
+ self.checkError([opt, '6'],
+ 'needs 2 or 3 colon-separated arguments')
+ self.checkError([opt, 'foo:'], 'invalid huntrleaks value')
+ self.checkError([opt, '6:foo'], 'invalid huntrleaks value')
+
+ def test_multiprocess(self):
+ for opt in '-j', '--multiprocess':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, '2'])
+ self.assertEqual(ns.use_mp, 2)
+ self.checkError([opt], 'expected one argument')
+ self.checkError([opt, 'foo'], 'invalid int value')
+ self.checkError([opt, '2', '-T'], "don't go together")
+ self.checkError([opt, '2', '-l'], "don't go together")
+ self.checkError([opt, '2', '-M', '4G'], "don't go together")
+
+ def test_coverage(self):
+ for opt in '-T', '--coverage':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.trace)
+
+ def test_coverdir(self):
+ for opt in '-D', '--coverdir':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, 'foo'])
+ self.assertEqual(ns.coverdir,
+ os.path.join(support.SAVEDCWD, 'foo'))
+ self.checkError([opt], 'expected one argument')
+
+ def test_nocoverdir(self):
+ for opt in '-N', '--nocoverdir':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertIsNone(ns.coverdir)
+
+ def test_threshold(self):
+ for opt in '-t', '--threshold':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt, '1000'])
+ self.assertEqual(ns.threshold, 1000)
+ self.checkError([opt], 'expected one argument')
+ self.checkError([opt, 'foo'], 'invalid int value')
+
+ def test_nowindows(self):
+ for opt in '-n', '--nowindows':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.nowindows)
+
+ def test_forever(self):
+ for opt in '-F', '--forever':
+ with self.subTest(opt=opt):
+ ns = regrtest._parse_args([opt])
+ self.assertTrue(ns.forever)
+
+
+ def test_unrecognized_argument(self):
+ self.checkError(['--xxx'], 'usage:')
+
+ def test_long_option__partial(self):
+ ns = regrtest._parse_args(['--qui'])
+ self.assertTrue(ns.quiet)
+ self.assertEqual(ns.verbose, 0)
+
+ def test_two_options(self):
+ ns = regrtest._parse_args(['--quiet', '--exclude'])
+ self.assertTrue(ns.quiet)
+ self.assertEqual(ns.verbose, 0)
+ self.assertTrue(ns.exclude)
+
+ def test_option_with_empty_string_value(self):
+ ns = regrtest._parse_args(['--start', ''])
+ self.assertEqual(ns.start, '')
+
+ def test_arg(self):
+ ns = regrtest._parse_args(['foo'])
+ self.assertEqual(ns.args, ['foo'])
+
+ def test_option_and_arg(self):
+ ns = regrtest._parse_args(['--quiet', 'foo'])
+ self.assertTrue(ns.quiet)
+ self.assertEqual(ns.verbose, 0)
+ self.assertEqual(ns.args, ['foo'])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py
index 589ecdd4da..104e3b5705 100644
--- a/Lib/test/test_reprlib.py
+++ b/Lib/test/test_reprlib.py
@@ -3,11 +3,11 @@
Nick Mathewson
"""
-import imp
import sys
import os
import shutil
import importlib
+import importlib.util
import unittest
from test.support import run_unittest, create_empty_file, verbose
@@ -241,7 +241,8 @@ class LongReprTest(unittest.TestCase):
source_path_len += 2 * (len(self.longname) + 1)
# a path separator + `module_name` + ".py"
source_path_len += len(module_name) + 1 + len(".py")
- cached_path_len = source_path_len + len(imp.cache_from_source("x.py")) - len("x.py")
+ cached_path_len = (source_path_len +
+ len(importlib.util.cache_from_source("x.py")) - len("x.py"))
if os.name == 'nt' and cached_path_len >= 258:
# Under Windows, the max path len is 260 including C's terminating
# NUL character.
diff --git a/Lib/test/test_resource.py b/Lib/test/test_resource.py
index 0cf61cb17a..006198fc41 100644
--- a/Lib/test/test_resource.py
+++ b/Lib/test/test_resource.py
@@ -1,3 +1,6 @@
+import contextlib
+import sys
+import os
import unittest
from test import support
import time
@@ -61,7 +64,7 @@ class ResourceTest(unittest.TestCase):
for i in range(5):
time.sleep(.1)
f.flush()
- except IOError:
+ except OSError:
if not limit_set:
raise
if limit_set:
@@ -129,6 +132,27 @@ class ResourceTest(unittest.TestCase):
self.assertIsInstance(pagesize, int)
self.assertGreaterEqual(pagesize, 0)
+ @unittest.skipUnless(sys.platform == 'linux', 'test requires Linux')
+ def test_linux_constants(self):
+ for attr in ['MSGQUEUE', 'NICE', 'RTPRIO', 'RTTIME', 'SIGPENDING']:
+ with contextlib.suppress(AttributeError):
+ self.assertIsInstance(getattr(resource, 'RLIMIT_' + attr), int)
+
+ @unittest.skipUnless(hasattr(resource, 'prlimit'), 'no prlimit')
+ @support.requires_linux_version(2, 6, 36)
+ def test_prlimit(self):
+ self.assertRaises(TypeError, resource.prlimit)
+ if os.geteuid() != 0:
+ self.assertRaises(PermissionError, resource.prlimit,
+ 1, resource.RLIMIT_AS)
+ self.assertRaises(ProcessLookupError, resource.prlimit,
+ -1, resource.RLIMIT_AS)
+ limit = resource.getrlimit(resource.RLIMIT_AS)
+ self.assertEqual(resource.prlimit(0, resource.RLIMIT_AS), limit)
+ self.assertEqual(resource.prlimit(0, resource.RLIMIT_AS, limit),
+ limit)
+
+
def test_main(verbose=None):
support.run_unittest(ResourceTest)
diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py
index 2ddba3413a..16e2e12090 100644
--- a/Lib/test/test_runpy.py
+++ b/Lib/test/test_runpy.py
@@ -575,12 +575,5 @@ s = "non-ASCII: h\xe9"
self.assertEqual(result['s'], "non-ASCII: h\xe9")
-def test_main():
- run_unittest(
- ExecutionLayerTestCase,
- RunModuleTestCase,
- RunPathTestCase
- )
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py
index 070886d1ea..c6abf3d89a 100644
--- a/Lib/test/test_sched.py
+++ b/Lib/test/test_sched.py
@@ -197,8 +197,5 @@ class TestCase(unittest.TestCase):
self.assertEqual(l, [])
-def test_main():
- support.run_unittest(TestCase)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py
index 26ce0424f8..b325545f32 100644
--- a/Lib/test/test_scope.py
+++ b/Lib/test/test_scope.py
@@ -715,6 +715,19 @@ class ScopeTests(unittest.TestCase):
def b():
global a
+ def testClassNamespaceOverridesClosure(self):
+ # See #17853.
+ x = 42
+ class X:
+ locals()["x"] = 43
+ y = x
+ self.assertEqual(X.y, 43)
+ class X:
+ locals()["x"] = 43
+ del x
+ self.assertFalse(hasattr(X, "x"))
+ self.assertEqual(x, 42)
+
@cpython_only
def testCellLeak(self):
# Issue 17927.
@@ -743,10 +756,6 @@ class ScopeTests(unittest.TestCase):
del tester
self.assertIsNone(ref())
- def test__Class__Global(self):
- s = "class X:\n global __class__\n def f(self): super()"
- self.assertRaises(SyntaxError, exec, s)
-
def test_main():
run_unittest(ScopeTests)
diff --git a/Lib/test/test_select.py b/Lib/test/test_select.py
index ddb9a0f67e..8f9a1c9d88 100644
--- a/Lib/test/test_select.py
+++ b/Lib/test/test_select.py
@@ -5,7 +5,7 @@ import sys
import unittest
from test import support
-@unittest.skipIf(sys.platform[:3] in ('win', 'os2', 'riscos'),
+@unittest.skipIf((sys.platform[:3]=='win'),
"can't easily test on this system")
class SelectTestCase(unittest.TestCase):
@@ -32,7 +32,7 @@ class SelectTestCase(unittest.TestCase):
fp.close()
try:
select.select([fd], [], [], 0)
- except select.error as err:
+ except OSError as err:
self.assertEqual(err.errno, errno.EBADF)
else:
self.fail("exception not raised")
diff --git a/Lib/test/test_selectors.py b/Lib/test/test_selectors.py
new file mode 100644
index 0000000000..c64c87aa3f
--- /dev/null
+++ b/Lib/test/test_selectors.py
@@ -0,0 +1,429 @@
+import errno
+import random
+import selectors
+import signal
+import socket
+from test import support
+from time import sleep
+import unittest
+import unittest.mock
+try:
+ from time import monotonic as time
+except ImportError:
+ from time import time as time
+try:
+ import resource
+except ImportError:
+ resource = None
+
+
+if hasattr(socket, 'socketpair'):
+ socketpair = socket.socketpair
+else:
+ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
+ with socket.socket(family, type, proto) as l:
+ l.bind((support.HOST, 0))
+ l.listen(3)
+ c = socket.socket(family, type, proto)
+ try:
+ c.connect(l.getsockname())
+ caddr = c.getsockname()
+ while True:
+ a, addr = l.accept()
+ # check that we've got the correct client
+ if addr == caddr:
+ return c, a
+ a.close()
+ except OSError:
+ c.close()
+ raise
+
+
+def find_ready_matching(ready, flag):
+ match = []
+ for key, events in ready:
+ if events & flag:
+ match.append(key.fileobj)
+ return match
+
+
+class BaseSelectorTestCase(unittest.TestCase):
+
+ def test_register(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ key = s.register(rd, selectors.EVENT_READ, "data")
+ self.assertIsInstance(key, selectors.SelectorKey)
+ self.assertEqual(key.fileobj, rd)
+ self.assertEqual(key.fd, rd.fileno())
+ self.assertEqual(key.events, selectors.EVENT_READ)
+ self.assertEqual(key.data, "data")
+
+ # register an unknown event
+ self.assertRaises(ValueError, s.register, 0, 999999)
+
+ # register an invalid FD
+ self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ)
+
+ # register twice
+ self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ)
+
+ # register the same FD, but with a different object
+ self.assertRaises(KeyError, s.register, rd.fileno(),
+ selectors.EVENT_READ)
+
+ def test_unregister(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ s.register(rd, selectors.EVENT_READ)
+ s.unregister(rd)
+
+ # unregister an unknown file obj
+ self.assertRaises(KeyError, s.unregister, 999999)
+
+ # unregister twice
+ self.assertRaises(KeyError, s.unregister, rd)
+
+ def test_modify(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ key = s.register(rd, selectors.EVENT_READ)
+
+ # modify events
+ key2 = s.modify(rd, selectors.EVENT_WRITE)
+ self.assertNotEqual(key.events, key2.events)
+ self.assertEqual(key2, s.get_key(rd))
+
+ s.unregister(rd)
+
+ # modify data
+ d1 = object()
+ d2 = object()
+
+ key = s.register(rd, selectors.EVENT_READ, d1)
+ key2 = s.modify(rd, selectors.EVENT_READ, d2)
+ self.assertEqual(key.events, key2.events)
+ self.assertNotEqual(key.data, key2.data)
+ self.assertEqual(key2, s.get_key(rd))
+ self.assertEqual(key2.data, d2)
+
+ # modify unknown file obj
+ self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ)
+
+ # modify use a shortcut
+ d3 = object()
+ s.register = unittest.mock.Mock()
+ s.unregister = unittest.mock.Mock()
+
+ s.modify(rd, selectors.EVENT_READ, d3)
+ self.assertFalse(s.register.called)
+ self.assertFalse(s.unregister.called)
+
+ def test_close(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ s.register(rd, selectors.EVENT_READ)
+ s.register(wr, selectors.EVENT_WRITE)
+
+ s.close()
+ self.assertRaises(KeyError, s.get_key, rd)
+ self.assertRaises(KeyError, s.get_key, wr)
+
+ def test_get_key(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ key = s.register(rd, selectors.EVENT_READ, "data")
+ self.assertEqual(key, s.get_key(rd))
+
+ # unknown file obj
+ self.assertRaises(KeyError, s.get_key, 999999)
+
+ def test_get_map(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ keys = s.get_map()
+ self.assertFalse(keys)
+ self.assertEqual(len(keys), 0)
+ self.assertEqual(list(keys), [])
+ key = s.register(rd, selectors.EVENT_READ, "data")
+ self.assertIn(rd, keys)
+ self.assertEqual(key, keys[rd])
+ self.assertEqual(len(keys), 1)
+ self.assertEqual(list(keys), [rd.fileno()])
+ self.assertEqual(list(keys.values()), [key])
+
+ # unknown file obj
+ with self.assertRaises(KeyError):
+ keys[999999]
+
+ # Read-only mapping
+ with self.assertRaises(TypeError):
+ del keys[rd]
+
+ def test_select(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ s.register(rd, selectors.EVENT_READ)
+ wr_key = s.register(wr, selectors.EVENT_WRITE)
+
+ result = s.select()
+ for key, events in result:
+ self.assertTrue(isinstance(key, selectors.SelectorKey))
+ self.assertTrue(events)
+ self.assertFalse(events & ~(selectors.EVENT_READ |
+ selectors.EVENT_WRITE))
+
+ self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result)
+
+ def test_context_manager(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ with s as sel:
+ sel.register(rd, selectors.EVENT_READ)
+ sel.register(wr, selectors.EVENT_WRITE)
+
+ self.assertRaises(KeyError, s.get_key, rd)
+ self.assertRaises(KeyError, s.get_key, wr)
+
+ def test_fileno(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ if hasattr(s, 'fileno'):
+ fd = s.fileno()
+ self.assertTrue(isinstance(fd, int))
+ self.assertGreaterEqual(fd, 0)
+
+ def test_selector(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ NUM_SOCKETS = 12
+ MSG = b" This is a test."
+ MSG_LEN = len(MSG)
+ readers = []
+ writers = []
+ r2w = {}
+ w2r = {}
+
+ for i in range(NUM_SOCKETS):
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+ s.register(rd, selectors.EVENT_READ)
+ s.register(wr, selectors.EVENT_WRITE)
+ readers.append(rd)
+ writers.append(wr)
+ r2w[rd] = wr
+ w2r[wr] = rd
+
+ bufs = []
+
+ while writers:
+ ready = s.select()
+ ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE)
+ if not ready_writers:
+ self.fail("no sockets ready for writing")
+ wr = random.choice(ready_writers)
+ wr.send(MSG)
+
+ for i in range(10):
+ ready = s.select()
+ ready_readers = find_ready_matching(ready,
+ selectors.EVENT_READ)
+ if ready_readers:
+ break
+ # there might be a delay between the write to the write end and
+ # the read end is reported ready
+ sleep(0.1)
+ else:
+ self.fail("no sockets ready for reading")
+ self.assertEqual([w2r[wr]], ready_readers)
+ rd = ready_readers[0]
+ buf = rd.recv(MSG_LEN)
+ self.assertEqual(len(buf), MSG_LEN)
+ bufs.append(buf)
+ s.unregister(r2w[rd])
+ s.unregister(rd)
+ writers.remove(r2w[rd])
+
+ self.assertEqual(bufs, [MSG] * NUM_SOCKETS)
+
+ def test_timeout(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ s.register(wr, selectors.EVENT_WRITE)
+ t = time()
+ self.assertEqual(1, len(s.select(0)))
+ self.assertEqual(1, len(s.select(-1)))
+ self.assertLess(time() - t, 0.5)
+
+ s.unregister(wr)
+ s.register(rd, selectors.EVENT_READ)
+ t = time()
+ self.assertFalse(s.select(0))
+ self.assertFalse(s.select(-1))
+ self.assertLess(time() - t, 0.5)
+
+ t0 = time()
+ self.assertFalse(s.select(1))
+ t1 = time()
+ self.assertTrue(0.5 < t1 - t0 < 1.5, t1 - t0)
+
+ @unittest.skipUnless(hasattr(signal, "alarm"),
+ "signal.alarm() required for this test")
+ def test_select_interrupt(self):
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ rd, wr = socketpair()
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None)
+ self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
+ self.addCleanup(signal.alarm, 0)
+
+ signal.alarm(1)
+
+ s.register(rd, selectors.EVENT_READ)
+ t = time()
+ self.assertFalse(s.select(2))
+ self.assertLess(time() - t, 2.5)
+
+
+class ScalableSelectorMixIn:
+
+ # see issue #18963 for why it's skipped on older OS X versions
+ @support.requires_mac_ver(10, 5)
+ @unittest.skipUnless(resource, "Test needs resource module")
+ def test_above_fd_setsize(self):
+ # A scalable implementation should have no problem with more than
+ # FD_SETSIZE file descriptors. Since we don't know the value, we just
+ # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling.
+ soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+ try:
+ resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
+ self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE,
+ (soft, hard))
+ NUM_FDS = hard
+ except (OSError, ValueError):
+ NUM_FDS = soft
+
+ # guard for already allocated FDs (stdin, stdout...)
+ NUM_FDS -= 32
+
+ s = self.SELECTOR()
+ self.addCleanup(s.close)
+
+ for i in range(NUM_FDS // 2):
+ try:
+ rd, wr = socketpair()
+ except OSError:
+ # too many FDs, skip - note that we should only catch EMFILE
+ # here, but apparently *BSD and Solaris can fail upon connect()
+ # or bind() with EADDRNOTAVAIL, so let's be safe
+ self.skipTest("FD limit reached")
+
+ self.addCleanup(rd.close)
+ self.addCleanup(wr.close)
+
+ try:
+ s.register(rd, selectors.EVENT_READ)
+ s.register(wr, selectors.EVENT_WRITE)
+ except OSError as e:
+ if e.errno == errno.ENOSPC:
+ # this can be raised by epoll if we go over
+ # fs.epoll.max_user_watches sysctl
+ self.skipTest("FD limit reached")
+ raise
+
+ self.assertEqual(NUM_FDS // 2, len(s.select()))
+
+
+class DefaultSelectorTestCase(BaseSelectorTestCase):
+
+ SELECTOR = selectors.DefaultSelector
+
+
+class SelectSelectorTestCase(BaseSelectorTestCase):
+
+ SELECTOR = selectors.SelectSelector
+
+
+@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
+ "Test needs selectors.PollSelector")
+class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
+
+ SELECTOR = getattr(selectors, 'PollSelector', None)
+
+
+@unittest.skipUnless(hasattr(selectors, 'EpollSelector'),
+ "Test needs selectors.EpollSelector")
+class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
+
+ SELECTOR = getattr(selectors, 'EpollSelector', None)
+
+
+@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'),
+ "Test needs selectors.KqueueSelector)")
+class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn):
+
+ SELECTOR = getattr(selectors, 'KqueueSelector', None)
+
+
+def test_main():
+ tests = [DefaultSelectorTestCase, SelectSelectorTestCase,
+ PollSelectorTestCase, EpollSelectorTestCase,
+ KqueueSelectorTestCase]
+ support.run_unittest(*tests)
+ support.reap_children()
+
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index c0bca2fb88..bfef621093 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -848,8 +848,6 @@ class TestBasicOps:
for v in self.set:
self.assertIn(v, self.values)
setiter = iter(self.set)
- # note: __length_hint__ is an internal undocumented API,
- # don't rely on it in your own programs
self.assertEqual(setiter.__length_hint__(), len(self.set))
def test_pickling(self):
diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py
index 13c126566d..bd51d868fe 100644
--- a/Lib/test/test_shelve.py
+++ b/Lib/test/test_shelve.py
@@ -148,6 +148,19 @@ class TestCase(unittest.TestCase):
p2 = d[encodedkey]
self.assertNotEqual(p1, p2) # Write creates new object in store
+ def test_with(self):
+ d1 = {}
+ with shelve.Shelf(d1, protocol=2, writeback=False) as s:
+ s['key1'] = [1,2,3,4]
+ self.assertEqual(s['key1'], [1,2,3,4])
+ self.assertEqual(len(s), 1)
+ self.assertRaises(ValueError, len, s)
+ try:
+ s['key1']
+ except ValueError:
+ pass
+ else:
+ self.fail('Closed shelf should not find a key')
from test import mapping_tests
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index df95bd9a5c..deb1577611 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -18,7 +18,8 @@ from shutil import (_make_tarball, _make_zipfile, make_archive,
register_archive_format, unregister_archive_format,
get_archive_formats, Error, unpack_archive,
register_unpack_format, RegistryError,
- unregister_unpack_format, get_unpack_formats)
+ unregister_unpack_format, get_unpack_formats,
+ SameFileError)
import tarfile
import warnings
@@ -742,14 +743,14 @@ class TestShutil(unittest.TestCase):
os.chmod(restrictive_subdir, 0o600)
shutil.copytree(src_dir, dst_dir)
- self.assertEquals(os.stat(src_dir).st_mode, os.stat(dst_dir).st_mode)
- self.assertEquals(os.stat(os.path.join(src_dir, 'permissive.txt')).st_mode,
+ self.assertEqual(os.stat(src_dir).st_mode, os.stat(dst_dir).st_mode)
+ self.assertEqual(os.stat(os.path.join(src_dir, 'permissive.txt')).st_mode,
os.stat(os.path.join(dst_dir, 'permissive.txt')).st_mode)
- self.assertEquals(os.stat(os.path.join(src_dir, 'restrictive.txt')).st_mode,
+ self.assertEqual(os.stat(os.path.join(src_dir, 'restrictive.txt')).st_mode,
os.stat(os.path.join(dst_dir, 'restrictive.txt')).st_mode)
restrictive_subdir_dst = os.path.join(dst_dir,
os.path.split(restrictive_subdir)[1])
- self.assertEquals(os.stat(restrictive_subdir).st_mode,
+ self.assertEqual(os.stat(restrictive_subdir).st_mode,
os.stat(restrictive_subdir_dst).st_mode)
@unittest.skipUnless(hasattr(os, 'link'), 'requires os.link')
@@ -765,7 +766,7 @@ class TestShutil(unittest.TestCase):
with open(src, 'w') as f:
f.write('cheddar')
os.link(src, dst)
- self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
+ self.assertRaises(shutil.SameFileError, shutil.copyfile, src, dst)
with open(src, 'r') as f:
self.assertEqual(f.read(), 'cheddar')
os.remove(dst)
@@ -785,7 +786,7 @@ class TestShutil(unittest.TestCase):
# to TESTFN/TESTFN/cheese, while it should point at
# TESTFN/cheese.
os.symlink('cheese', dst)
- self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
+ self.assertRaises(shutil.SameFileError, shutil.copyfile, src, dst)
with open(src, 'r') as f:
self.assertEqual(f.read(), 'cheddar')
os.remove(dst)
@@ -1293,6 +1294,16 @@ class TestShutil(unittest.TestCase):
self.assertTrue(os.path.exists(rv))
self.assertEqual(read_file(src_file), read_file(dst_file))
+ def test_copyfile_same_file(self):
+ # copyfile() should raise SameFileError if the source and destination
+ # are the same.
+ src_dir = self.mkdtemp()
+ src_file = os.path.join(src_dir, 'foo')
+ write_file(src_file, 'foo')
+ self.assertRaises(SameFileError, shutil.copyfile, src_file, src_file)
+ # But Error should work too, to stay backward compatible.
+ self.assertRaises(Error, shutil.copyfile, src_file, src_file)
+
def test_copytree_return_value(self):
# copytree returns its destination path.
src_dir = self.mkdtemp()
@@ -1588,7 +1599,7 @@ class TestCopyFile(unittest.TestCase):
self._exited_with = exc_type, exc_val, exc_tb
if self._raise_in_exit:
self._raised = True
- raise IOError("Cannot close")
+ raise OSError("Cannot close")
return self._suppress_at_exit
def tearDown(self):
@@ -1602,12 +1613,12 @@ class TestCopyFile(unittest.TestCase):
def test_w_source_open_fails(self):
def _open(filename, mode='r'):
if filename == 'srcfile':
- raise IOError('Cannot open "srcfile"')
+ raise OSError('Cannot open "srcfile"')
assert 0 # shouldn't reach here.
self._set_shutil_open(_open)
- self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile')
+ self.assertRaises(OSError, shutil.copyfile, 'srcfile', 'destfile')
def test_w_dest_open_fails(self):
@@ -1617,14 +1628,14 @@ class TestCopyFile(unittest.TestCase):
if filename == 'srcfile':
return srcfile
if filename == 'destfile':
- raise IOError('Cannot open "destfile"')
+ raise OSError('Cannot open "destfile"')
assert 0 # shouldn't reach here.
self._set_shutil_open(_open)
shutil.copyfile('srcfile', 'destfile')
self.assertTrue(srcfile._entered)
- self.assertTrue(srcfile._exited_with[0] is IOError)
+ self.assertTrue(srcfile._exited_with[0] is OSError)
self.assertEqual(srcfile._exited_with[1].args,
('Cannot open "destfile"',))
@@ -1646,7 +1657,7 @@ class TestCopyFile(unittest.TestCase):
self.assertTrue(srcfile._entered)
self.assertTrue(destfile._entered)
self.assertTrue(destfile._raised)
- self.assertTrue(srcfile._exited_with[0] is IOError)
+ self.assertTrue(srcfile._exited_with[0] is OSError)
self.assertEqual(srcfile._exited_with[1].args,
('Cannot close',))
@@ -1664,7 +1675,7 @@ class TestCopyFile(unittest.TestCase):
self._set_shutil_open(_open)
- self.assertRaises(IOError,
+ self.assertRaises(OSError,
shutil.copyfile, 'srcfile', 'destfile')
self.assertTrue(srcfile._entered)
self.assertTrue(destfile._entered)
@@ -1735,9 +1746,5 @@ class TermsizeTests(unittest.TestCase):
self.assertEqual(expected, actual)
-def test_main():
- support.run_unittest(TestShutil, TestMove, TestCopyFile,
- TermsizeTests, TestWhich)
-
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_signal.py b/Lib/test/test_signal.py
index 9b4ba5016e..3530d8a14e 100644
--- a/Lib/test/test_signal.py
+++ b/Lib/test/test_signal.py
@@ -15,9 +15,6 @@ try:
except ImportError:
threading = None
-if sys.platform in ('os2', 'riscos'):
- raise unittest.SkipTest("Can't test signal on %s" % sys.platform)
-
class HandlerBCalled(Exception):
pass
@@ -36,7 +33,7 @@ def exit_subprocess():
def ignoring_eintr(__func, *args, **kwargs):
try:
return __func(*args, **kwargs)
- except EnvironmentError as e:
+ except OSError as e:
if e.errno != errno.EINTR:
raise
return None
@@ -278,6 +275,57 @@ class WakeupSignalTests(unittest.TestCase):
assert_python_ok('-c', code)
+ def test_wakeup_write_error(self):
+ # Issue #16105: write() errors in the C signal handler should not
+ # pass silently.
+ # Use a subprocess to have only one thread.
+ code = """if 1:
+ import errno
+ import fcntl
+ import os
+ import signal
+ import sys
+ import time
+ from test.support import captured_stderr
+
+ def handler(signum, frame):
+ 1/0
+
+ signal.signal(signal.SIGALRM, handler)
+ r, w = os.pipe()
+ flags = fcntl.fcntl(r, fcntl.F_GETFL, 0)
+ fcntl.fcntl(r, fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ # Set wakeup_fd a read-only file descriptor to trigger the error
+ signal.set_wakeup_fd(r)
+ try:
+ with captured_stderr() as err:
+ signal.alarm(1)
+ time.sleep(5.0)
+ except ZeroDivisionError:
+ # An ignored exception should have been printed out on stderr
+ err = err.getvalue()
+ if ('Exception ignored when trying to write to the signal wakeup fd'
+ not in err):
+ raise AssertionError(err)
+ if ('OSError: [Errno %d]' % errno.EBADF) not in err:
+ raise AssertionError(err)
+ else:
+ raise AssertionError("ZeroDivisionError not raised")
+ """
+ r, w = os.pipe()
+ try:
+ os.write(r, b'x')
+ except OSError:
+ pass
+ else:
+ self.skipTest("OS doesn't report write() error on the read end of a pipe")
+ finally:
+ os.close(r)
+ os.close(w)
+
+ assert_python_ok('-c', code)
+
def test_wakeup_fd_early(self):
self.check_wakeup("""def test():
import select
@@ -315,10 +363,10 @@ class WakeupSignalTests(unittest.TestCase):
# We attempt to get a signal during the select call
try:
select.select([read], [], [], TIMEOUT_FULL)
- except select.error:
+ except OSError:
pass
else:
- raise Exception("select.error not raised")
+ raise Exception("OSError not raised")
after_time = time.time()
dt = after_time - before_time
if dt >= TIMEOUT_HALF:
diff --git a/Lib/test/test_site.py b/Lib/test/test_site.py
index c294c65b58..cb7c393074 100644
--- a/Lib/test/test_site.py
+++ b/Lib/test/test_site.py
@@ -231,11 +231,7 @@ class HelperFunctionsTests(unittest.TestCase):
site.PREFIXES = ['xoxo']
dirs = site.getsitepackages()
- if sys.platform in ('os2emx', 'riscos'):
- self.assertEqual(len(dirs), 1)
- wanted = os.path.join('xoxo', 'Lib', 'site-packages')
- self.assertEqual(dirs[0], wanted)
- elif (sys.platform == "darwin" and
+ if (sys.platform == "darwin" and
sysconfig.get_config_var("PYTHONFRAMEWORK")):
# OS X framework builds
site.PREFIXES = ['Python.framework']
@@ -432,5 +428,38 @@ class ImportSideEffectTests(unittest.TestCase):
self.assertEqual(code, 200, msg="Can't find " + url)
+class StartupImportTests(unittest.TestCase):
+
+ def test_startup_imports(self):
+ # This tests checks which modules are loaded by Python when it
+ # initially starts upon startup.
+ popen = subprocess.Popen([sys.executable, '-I', '-v', '-c',
+ 'import sys; print(set(sys.modules))'],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ stdout, stderr = popen.communicate()
+ stdout = stdout.decode('utf-8')
+ stderr = stderr.decode('utf-8')
+ modules = eval(stdout)
+
+ self.assertIn('site', modules)
+
+ # http://bugs.python.org/issue19205
+ re_mods = {'re', '_sre', 'sre_compile', 'sre_constants', 'sre_parse'}
+ # _osx_support uses the re module in many placs
+ if sys.platform != 'darwin':
+ self.assertFalse(modules.intersection(re_mods), stderr)
+ # http://bugs.python.org/issue9548
+ self.assertNotIn('locale', modules, stderr)
+ if sys.platform != 'darwin':
+ # http://bugs.python.org/issue19209
+ self.assertNotIn('copyreg', modules, stderr)
+ # http://bugs.python.org/issue19218>
+ collection_mods = {'_collections', 'collections', 'functools',
+ 'heapq', 'itertools', 'keyword', 'operator',
+ 'reprlib', 'types', 'weakref'}
+ self.assertFalse(modules.intersection(collection_mods), stderr)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py
index 2df9271da7..9203d5eb19 100644
--- a/Lib/test/test_slice.py
+++ b/Lib/test/test_slice.py
@@ -4,8 +4,70 @@ import unittest
from test import support
from pickle import loads, dumps
+import itertools
+import operator
import sys
+
+def evaluate_slice_index(arg):
+ """
+ Helper function to convert a slice argument to an integer, and raise
+ TypeError with a suitable message on failure.
+
+ """
+ if hasattr(arg, '__index__'):
+ return operator.index(arg)
+ else:
+ raise TypeError(
+ "slice indices must be integers or "
+ "None or have an __index__ method")
+
+def slice_indices(slice, length):
+ """
+ Reference implementation for the slice.indices method.
+
+ """
+ # Compute step and length as integers.
+ length = operator.index(length)
+ step = 1 if slice.step is None else evaluate_slice_index(slice.step)
+
+ # Raise ValueError for negative length or zero step.
+ if length < 0:
+ raise ValueError("length should not be negative")
+ if step == 0:
+ raise ValueError("slice step cannot be zero")
+
+ # Find lower and upper bounds for start and stop.
+ lower = -1 if step < 0 else 0
+ upper = length - 1 if step < 0 else length
+
+ # Compute start.
+ if slice.start is None:
+ start = upper if step < 0 else lower
+ else:
+ start = evaluate_slice_index(slice.start)
+ start = max(start + length, lower) if start < 0 else min(start, upper)
+
+ # Compute stop.
+ if slice.stop is None:
+ stop = lower if step < 0 else upper
+ else:
+ stop = evaluate_slice_index(slice.stop)
+ stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
+
+ return start, stop, step
+
+
+# Class providing an __index__ method. Used for testing slice.indices.
+
+class MyIndexable(object):
+ def __init__(self, value):
+ self.value = value
+
+ def __index__(self):
+ return self.value
+
+
class SliceTest(unittest.TestCase):
def test_constructor(self):
@@ -75,6 +137,22 @@ class SliceTest(unittest.TestCase):
s = slice(obj)
self.assertTrue(s.stop is obj)
+ def check_indices(self, slice, length):
+ try:
+ actual = slice.indices(length)
+ except ValueError:
+ actual = "valueerror"
+ try:
+ expected = slice_indices(slice, length)
+ except ValueError:
+ expected = "valueerror"
+ self.assertEqual(actual, expected)
+
+ if length >= 0 and slice.step != 0:
+ actual = range(*slice.indices(length))
+ expected = range(length)[slice]
+ self.assertEqual(actual, expected)
+
def test_indices(self):
self.assertEqual(slice(None ).indices(10), (0, 10, 1))
self.assertEqual(slice(None, None, 2).indices(10), (0, 10, 2))
@@ -108,7 +186,41 @@ class SliceTest(unittest.TestCase):
self.assertEqual(list(range(10))[::sys.maxsize - 1], [0])
- self.assertRaises(OverflowError, slice(None).indices, 1<<100)
+ # Check a variety of start, stop, step and length values, including
+ # values exceeding sys.maxsize (see issue #14794).
+ vals = [None, -2**100, -2**30, -53, -7, -1, 0, 1, 7, 53, 2**30, 2**100]
+ lengths = [0, 1, 7, 53, 2**30, 2**100]
+ for slice_args in itertools.product(vals, repeat=3):
+ s = slice(*slice_args)
+ for length in lengths:
+ self.check_indices(s, length)
+ self.check_indices(slice(0, 10, 1), -3)
+
+ # Negative length should raise ValueError
+ with self.assertRaises(ValueError):
+ slice(None).indices(-1)
+
+ # Zero step should raise ValueError
+ with self.assertRaises(ValueError):
+ slice(0, 10, 0).indices(5)
+
+ # Using a start, stop or step or length that can't be interpreted as an
+ # integer should give a TypeError ...
+ with self.assertRaises(TypeError):
+ slice(0.0, 10, 1).indices(5)
+ with self.assertRaises(TypeError):
+ slice(0, 10.0, 1).indices(5)
+ with self.assertRaises(TypeError):
+ slice(0, 10, 1.0).indices(5)
+ with self.assertRaises(TypeError):
+ slice(0, 10, 1).indices(5.0)
+
+ # ... but it should be fine to use a custom class that provides index.
+ self.assertEqual(slice(0, 10, 1).indices(5), (0, 5, 1))
+ self.assertEqual(slice(MyIndexable(0), 10, 1).indices(5), (0, 5, 1))
+ self.assertEqual(slice(0, MyIndexable(10), 1).indices(5), (0, 5, 1))
+ self.assertEqual(slice(0, 10, MyIndexable(1)).indices(5), (0, 5, 1))
+ self.assertEqual(slice(0, 10, 1).indices(MyIndexable(5)), (0, 5, 1))
def test_setslice_without_getslice(self):
tmp = []
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
index 8d1dbbfc43..e6f39dec77 100644
--- a/Lib/test/test_smtplib.py
+++ b/Lib/test/test_smtplib.py
@@ -222,7 +222,7 @@ class DebuggingServerTests(unittest.TestCase):
self.assertEqual(smtp.source_address, ('127.0.0.1', port))
self.assertEqual(smtp.local_hostname, 'localhost')
smtp.quit()
- except IOError as e:
+ except OSError as e:
if e.errno == errno.EADDRINUSE:
self.skipTest("couldn't bind to port %d" % port)
raise
@@ -524,12 +524,6 @@ class DebuggingServerTests(unittest.TestCase):
class NonConnectingTests(unittest.TestCase):
- def setUp(self):
- smtplib.socket = mock_socket
-
- def tearDown(self):
- smtplib.socket = socket
-
def testNotConnected(self):
# Test various operations on an unconnected SMTP object that
# should raise exceptions (at present the attempt in SMTP.send
@@ -541,10 +535,10 @@ class NonConnectingTests(unittest.TestCase):
smtp.send, 'test msg')
def testNonnumericPort(self):
- # check that non-numeric port raises socket.error
- self.assertRaises(mock_socket.error, smtplib.SMTP,
+ # check that non-numeric port raises OSError
+ self.assertRaises(OSError, smtplib.SMTP,
"localhost", "bogus")
- self.assertRaises(mock_socket.error, smtplib.SMTP,
+ self.assertRaises(OSError, smtplib.SMTP,
"localhost:bogus")
@@ -825,6 +819,15 @@ class SMTPSimTests(unittest.TestCase):
self.assertIn(sim_auth_credentials['cram-md5'], str(err))
smtp.close()
+ def testAUTH_multiple(self):
+ # Test that multiple authentication methods are tried.
+ self.serv.add_feature("AUTH BOGUS PLAIN LOGIN CRAM-MD5")
+ smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
+ try: smtp.login(sim_auth[0], sim_auth[1])
+ except smtplib.SMTPAuthenticationError as err:
+ self.assertIn(sim_auth_login_password, str(err))
+ smtp.close()
+
def test_with_statement(self):
with smtplib.SMTP(HOST, self.port) as smtp:
code, message = smtp.noop()
diff --git a/Lib/test/test_sndhdr.py b/Lib/test/test_sndhdr.py
index 10046887d7..5e0abe0b36 100644
--- a/Lib/test/test_sndhdr.py
+++ b/Lib/test/test_sndhdr.py
@@ -12,7 +12,7 @@ class TestFormats(unittest.TestCase):
('sndhdr.hcom', ('hcom', 22050.0, 1, -1, 8)),
('sndhdr.sndt', ('sndt', 44100, 1, 5, 8)),
('sndhdr.voc', ('voc', 0, 1, -1, 8)),
- ('sndhdr.wav', ('wav', 44100, 2, -1, 16)),
+ ('sndhdr.wav', ('wav', 44100, 2, 5, 16)),
):
filename = findfile(filename, subdir="sndhdrdata")
what = sndhdr.what(filename)
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 12517ae95d..3c9fbf0603 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -2,7 +2,6 @@
import unittest
from test import support
-from unittest.case import _ExpectedFailure
import errno
import io
@@ -24,13 +23,13 @@ import math
import pickle
import struct
try:
- import fcntl
-except ImportError:
- fcntl = False
-try:
import multiprocessing
except ImportError:
multiprocessing = False
+try:
+ import fcntl
+except ImportError:
+ fcntl = None
HOST = support.HOST
MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return
@@ -46,7 +45,7 @@ def _have_socket_can():
"""Check whether CAN sockets are supported on this host."""
try:
s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
- except (AttributeError, socket.error, OSError):
+ except (AttributeError, OSError):
return False
else:
s.close()
@@ -121,12 +120,42 @@ class SocketCANTest(unittest.TestCase):
interface = 'vcan0'
bufsize = 128
+ """The CAN frame structure is defined in <linux/can.h>:
+
+ struct can_frame {
+ canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */
+ __u8 can_dlc; /* data length code: 0 .. 8 */
+ __u8 data[8] __attribute__((aligned(8)));
+ };
+ """
+ can_frame_fmt = "=IB3x8s"
+ can_frame_size = struct.calcsize(can_frame_fmt)
+
+ """The Broadcast Management Command frame structure is defined
+ in <linux/can/bcm.h>:
+
+ struct bcm_msg_head {
+ __u32 opcode;
+ __u32 flags;
+ __u32 count;
+ struct timeval ival1, ival2;
+ canid_t can_id;
+ __u32 nframes;
+ struct can_frame frames[0];
+ }
+
+ `bcm_msg_head` must be 8 bytes aligned because of the `frames` member (see
+ `struct can_frame` definition). Must use native not standard types for packing.
+ """
+ bcm_cmd_msg_fmt = "@3I4l2I"
+ bcm_cmd_msg_fmt += "x" * (struct.calcsize(bcm_cmd_msg_fmt) % 8)
+
def setUp(self):
self.s = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
self.addCleanup(self.s.close)
try:
self.s.bind((self.interface,))
- except socket.error:
+ except OSError:
self.skipTest('network interface `%s` does not exist' %
self.interface)
@@ -242,9 +271,6 @@ class ThreadableTest:
raise TypeError("test_func must be a callable function")
try:
test_func()
- except _ExpectedFailure:
- # We deliberately ignore expected failures
- pass
except BaseException as e:
self.queue.put(e)
finally:
@@ -295,7 +321,7 @@ class ThreadedCANSocketTest(SocketCANTest, ThreadableTest):
self.cli = socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW)
try:
self.cli.bind((self.interface,))
- except socket.error:
+ except OSError:
# skipTest should not be called here, and will be called in the
# server instead
pass
@@ -543,11 +569,7 @@ class SCTPStreamBase(InetTestBase):
class Inet6TestBase(InetTestBase):
"""Base class for IPv6 socket tests."""
- # Don't use "localhost" here - it may not have an IPv6 address
- # assigned to it by default (e.g. in /etc/hosts), and if someone
- # has assigned it an IPv4-mapped address, then it's unlikely to
- # work with the full IPv6 API.
- host = "::1"
+ host = support.HOSTv6
class UDP6TestBase(Inet6TestBase):
"""Base class for UDP-over-IPv6 tests."""
@@ -608,7 +630,7 @@ def requireSocket(*args):
for obj in args]
try:
s = socket.socket(*callargs)
- except socket.error as e:
+ except OSError as e:
# XXX: check errno?
err = str(e)
else:
@@ -626,8 +648,17 @@ class GeneralModuleTests(unittest.TestCase):
def test_repr(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.addCleanup(s.close)
- self.assertTrue(repr(s).startswith("<socket.socket object"))
+ with s:
+ self.assertIn('fd=%i' % s.fileno(), repr(s))
+ self.assertIn('family=%s' % socket.AF_INET, repr(s))
+ self.assertIn('type=%s' % socket.SOCK_STREAM, repr(s))
+ self.assertIn('proto=0', repr(s))
+ self.assertNotIn('raddr', repr(s))
+ s.bind(('127.0.0.1', 0))
+ self.assertIn('laddr', repr(s))
+ self.assertIn(str(s.getsockname()), repr(s))
+ self.assertIn('[closed]', repr(s))
+ self.assertNotIn('laddr', repr(s))
def test_weakref(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -645,11 +676,11 @@ class GeneralModuleTests(unittest.TestCase):
def testSocketError(self):
# Testing socket module exceptions
msg = "Error raising socket exception (%s)."
- with self.assertRaises(socket.error, msg=msg % 'socket.error'):
- raise socket.error
- with self.assertRaises(socket.error, msg=msg % 'socket.herror'):
+ with self.assertRaises(OSError, msg=msg % 'OSError'):
+ raise OSError
+ with self.assertRaises(OSError, msg=msg % 'socket.herror'):
raise socket.herror
- with self.assertRaises(socket.error, msg=msg % 'socket.gaierror'):
+ with self.assertRaises(OSError, msg=msg % 'socket.gaierror'):
raise socket.gaierror
def testSendtoErrors(self):
@@ -712,13 +743,13 @@ class GeneralModuleTests(unittest.TestCase):
hostname = socket.gethostname()
try:
ip = socket.gethostbyname(hostname)
- except socket.error:
+ except OSError:
# Probably name lookup wasn't set up right; skip this test
return
self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
try:
hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
- except socket.error:
+ except OSError:
# Probably a similar problem as above; skip this test
return
all_host_names = [hostname, hname] + aliases
@@ -726,13 +757,27 @@ class GeneralModuleTests(unittest.TestCase):
if not fqhn in all_host_names:
self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
+ def test_host_resolution(self):
+ for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2',
+ '1:1:1:1:1:1:1:1:1']:
+ self.assertRaises(OSError, socket.gethostbyname, addr)
+ self.assertRaises(OSError, socket.gethostbyaddr, addr)
+
+ for addr in [support.HOST, '10.0.0.1', '255.255.255.255']:
+ self.assertEqual(socket.gethostbyname(addr), addr)
+
+ # we don't test support.HOSTv6 because there's a chance it doesn't have
+ # a matching name entry (e.g. 'ip6-localhost')
+ for host in [support.HOST]:
+ self.assertIn(host, socket.gethostbyaddr(host)[2])
+
@unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()")
@unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()")
def test_sethostname(self):
oldhn = socket.gethostname()
try:
socket.sethostname('new')
- except socket.error as e:
+ except OSError as e:
if e.errno == errno.EPERM:
self.skipTest("test should be run as root")
else:
@@ -766,8 +811,8 @@ class GeneralModuleTests(unittest.TestCase):
'socket.if_nameindex() not available.')
def testInvalidInterfaceNameIndex(self):
# test nonexistent interface index/name
- self.assertRaises(socket.error, socket.if_indextoname, 0)
- self.assertRaises(socket.error, socket.if_nametoindex, '_DEADBEEF')
+ self.assertRaises(OSError, socket.if_indextoname, 0)
+ self.assertRaises(OSError, socket.if_nametoindex, '_DEADBEEF')
# test with invalid values
self.assertRaises(TypeError, socket.if_nametoindex, 0)
self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF')
@@ -789,7 +834,7 @@ class GeneralModuleTests(unittest.TestCase):
try:
# On some versions, this crashes the interpreter.
socket.getnameinfo(('x', 0, 0, 0), 0)
- except socket.error:
+ except OSError:
pass
def testNtoH(self):
@@ -836,17 +881,17 @@ class GeneralModuleTests(unittest.TestCase):
try:
port = socket.getservbyname(service, 'tcp')
break
- except socket.error:
+ except OSError:
pass
else:
- raise socket.error
+ raise OSError
# Try same call with optional protocol omitted
port2 = socket.getservbyname(service)
eq(port, port2)
# Try udp, but don't barf if it doesn't exist
try:
udpport = socket.getservbyname(service, 'udp')
- except socket.error:
+ except OSError:
udpport = None
else:
eq(udpport, port)
@@ -902,7 +947,7 @@ class GeneralModuleTests(unittest.TestCase):
g = lambda a: inet_pton(AF_INET, a)
assertInvalid = lambda func,a: self.assertRaises(
- (socket.error, ValueError), func, a
+ (OSError, ValueError), func, a
)
self.assertEqual(b'\x00\x00\x00\x00', f('0.0.0.0'))
@@ -935,9 +980,17 @@ class GeneralModuleTests(unittest.TestCase):
return
except ImportError:
return
+
+ if sys.platform == "win32":
+ try:
+ inet_pton(AF_INET6, '::')
+ except OSError as e:
+ if e.winerror == 10022:
+ return # IPv6 might not be installed on this PC
+
f = lambda a: inet_pton(AF_INET6, a)
assertInvalid = lambda a: self.assertRaises(
- (socket.error, ValueError), f, a
+ (OSError, ValueError), f, a
)
self.assertEqual(b'\x00' * 16, f('::'))
@@ -986,7 +1039,7 @@ class GeneralModuleTests(unittest.TestCase):
from socket import inet_ntoa as f, inet_ntop, AF_INET
g = lambda a: inet_ntop(AF_INET, a)
assertInvalid = lambda func,a: self.assertRaises(
- (socket.error, ValueError), func, a
+ (OSError, ValueError), func, a
)
self.assertEqual('1.0.1.0', f(b'\x01\x00\x01\x00'))
@@ -1013,9 +1066,17 @@ class GeneralModuleTests(unittest.TestCase):
return
except ImportError:
return
+
+ if sys.platform == "win32":
+ try:
+ inet_ntop(AF_INET6, b'\x00' * 16)
+ except OSError as e:
+ if e.winerror == 10022:
+ return # IPv6 might not be installed on this PC
+
f = lambda a: inet_ntop(AF_INET6, a)
assertInvalid = lambda a: self.assertRaises(
- (socket.error, ValueError), f, a
+ (OSError, ValueError), f, a
)
self.assertEqual('::', f(b'\x00' * 16))
@@ -1043,7 +1104,7 @@ class GeneralModuleTests(unittest.TestCase):
# At least for eCos. This is required for the S/390 to pass.
try:
my_ip_addr = socket.gethostbyname(socket.gethostname())
- except socket.error:
+ except OSError:
# Probably name lookup wasn't set up right; skip this test
return
self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
@@ -1070,13 +1131,19 @@ class GeneralModuleTests(unittest.TestCase):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
sock.close()
- self.assertRaises(socket.error, sock.send, b"spam")
+ self.assertRaises(OSError, sock.send, b"spam")
def testNewAttributes(self):
# testing .family, .type and .protocol
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.assertEqual(sock.family, socket.AF_INET)
- self.assertEqual(sock.type, socket.SOCK_STREAM)
+ if hasattr(socket, 'SOCK_CLOEXEC'):
+ self.assertIn(sock.type,
+ (socket.SOCK_STREAM | socket.SOCK_CLOEXEC,
+ socket.SOCK_STREAM))
+ else:
+ self.assertEqual(sock.type, socket.SOCK_STREAM)
self.assertEqual(sock.proto, 0)
sock.close()
@@ -1129,9 +1196,12 @@ class GeneralModuleTests(unittest.TestCase):
socket.getaddrinfo(HOST, 80)
socket.getaddrinfo(HOST, None)
# test family and socktype filters
- infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
- for family, _, _, _, _ in infos:
+ infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM)
+ for family, type, _, _, _ in infos:
self.assertEqual(family, socket.AF_INET)
+ self.assertEqual(str(family), 'AddressFamily.AF_INET')
+ self.assertEqual(type, socket.SOCK_STREAM)
+ self.assertEqual(str(type), 'SocketType.SOCK_STREAM')
infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
for _, socktype, _, _, _ in infos:
self.assertEqual(socktype, socket.SOCK_STREAM)
@@ -1173,7 +1243,7 @@ class GeneralModuleTests(unittest.TestCase):
def test_getnameinfo(self):
# only IP addresses are allowed
- self.assertRaises(socket.error, socket.getnameinfo, ('mail.python.org',0), 0)
+ self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0)
@unittest.skipUnless(support.is_resource_enabled('network'),
'network is not enabled')
@@ -1285,10 +1355,31 @@ class GeneralModuleTests(unittest.TestCase):
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
def test_flowinfo(self):
self.assertRaises(OverflowError, socket.getnameinfo,
- ('::1',0, 0xffffffff), 0)
+ (support.HOSTv6, 0, 0xffffffff), 0)
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
- self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
-
+ self.assertRaises(OverflowError, s.bind, (support.HOSTv6, 0, -10))
+
+ def test_str_for_enums(self):
+ # Make sure that the AF_* and SOCK_* constants have enum-like string
+ # reprs.
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ self.assertEqual(str(s.family), 'AddressFamily.AF_INET')
+ self.assertEqual(str(s.type), 'SocketType.SOCK_STREAM')
+
+ @unittest.skipIf(os.name == 'nt', 'Will not work on Windows')
+ def test_uknown_socket_family_repr(self):
+ # Test that when created with a family that's not one of the known
+ # AF_*/SOCK_* constants, socket.family just returns the number.
+ #
+ # To do this we fool socket.socket into believing it already has an
+ # open fd because on this path it doesn't actually verify the family and
+ # type and populates the socket object.
+ #
+ # On Windows this trick won't work, so the test is skipped.
+ fd, _ = tempfile.mkstemp()
+ with socket.socket(family=42424, type=13331, fileno=fd) as s:
+ self.assertEqual(s.family, 42424)
+ self.assertEqual(s.type, 13331)
@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
class BasicCANTest(unittest.TestCase):
@@ -1298,10 +1389,35 @@ class BasicCANTest(unittest.TestCase):
socket.PF_CAN
socket.CAN_RAW
+ @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
+ 'socket.CAN_BCM required for this test.')
+ def testBCMConstants(self):
+ socket.CAN_BCM
+
+ # opcodes
+ socket.CAN_BCM_TX_SETUP # create (cyclic) transmission task
+ socket.CAN_BCM_TX_DELETE # remove (cyclic) transmission task
+ socket.CAN_BCM_TX_READ # read properties of (cyclic) transmission task
+ socket.CAN_BCM_TX_SEND # send one CAN frame
+ socket.CAN_BCM_RX_SETUP # create RX content filter subscription
+ socket.CAN_BCM_RX_DELETE # remove RX content filter subscription
+ socket.CAN_BCM_RX_READ # read properties of RX content filter subscription
+ socket.CAN_BCM_TX_STATUS # reply to TX_READ request
+ socket.CAN_BCM_TX_EXPIRED # notification on performed transmissions (count=0)
+ socket.CAN_BCM_RX_STATUS # reply to RX_READ request
+ socket.CAN_BCM_RX_TIMEOUT # cyclic message is absent
+ socket.CAN_BCM_RX_CHANGED # updated CAN frame (detected content change)
+
def testCreateSocket(self):
with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
pass
+ @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
+ 'socket.CAN_BCM required for this test.')
+ def testCreateBCMSocket(self):
+ with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM) as s:
+ pass
+
def testBindAny(self):
with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
s.bind(('', ))
@@ -1309,7 +1425,7 @@ class BasicCANTest(unittest.TestCase):
def testTooLongInterfaceName(self):
# most systems limit IFNAMSIZ to 16, take 1024 to be sure
with socket.socket(socket.PF_CAN, socket.SOCK_RAW, socket.CAN_RAW) as s:
- self.assertRaisesRegex(socket.error, 'interface name too long',
+ self.assertRaisesRegex(OSError, 'interface name too long',
s.bind, ('x' * 1024,))
@unittest.skipUnless(hasattr(socket, "CAN_RAW_LOOPBACK"),
@@ -1334,19 +1450,8 @@ class BasicCANTest(unittest.TestCase):
@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
-@unittest.skipUnless(thread, 'Threading required for this test.')
class CANTest(ThreadedCANSocketTest):
- """The CAN frame structure is defined in <linux/can.h>:
-
- struct can_frame {
- canid_t can_id; /* 32 bit CAN_ID + EFF/RTR/ERR flags */
- __u8 can_dlc; /* data length code: 0 .. 8 */
- __u8 data[8] __attribute__((aligned(8)));
- };
- """
- can_frame_fmt = "=IB3x8s"
-
def __init__(self, methodName='runTest'):
ThreadedCANSocketTest.__init__(self, methodName=methodName)
@@ -1395,6 +1500,46 @@ class CANTest(ThreadedCANSocketTest):
self.cf2 = self.build_can_frame(0x12, b'\x99\x22\x33')
self.cli.send(self.cf2)
+ @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
+ 'socket.CAN_BCM required for this test.')
+ def _testBCM(self):
+ cf, addr = self.cli.recvfrom(self.bufsize)
+ self.assertEqual(self.cf, cf)
+ can_id, can_dlc, data = self.dissect_can_frame(cf)
+ self.assertEqual(self.can_id, can_id)
+ self.assertEqual(self.data, data)
+
+ @unittest.skipUnless(hasattr(socket, "CAN_BCM"),
+ 'socket.CAN_BCM required for this test.')
+ def testBCM(self):
+ bcm = socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_BCM)
+ self.addCleanup(bcm.close)
+ bcm.connect((self.interface,))
+ self.can_id = 0x123
+ self.data = bytes([0xc0, 0xff, 0xee])
+ self.cf = self.build_can_frame(self.can_id, self.data)
+ opcode = socket.CAN_BCM_TX_SEND
+ flags = 0
+ count = 0
+ ival1_seconds = ival1_usec = ival2_seconds = ival2_usec = 0
+ bcm_can_id = 0x0222
+ nframes = 1
+ assert len(self.cf) == 16
+ header = struct.pack(self.bcm_cmd_msg_fmt,
+ opcode,
+ flags,
+ count,
+ ival1_seconds,
+ ival1_usec,
+ ival2_seconds,
+ ival2_usec,
+ bcm_can_id,
+ nframes,
+ )
+ header_plus_frame = header + self.cf
+ bytes_sent = bcm.send(header_plus_frame)
+ self.assertEqual(bytes_sent, len(header_plus_frame))
+
@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
class BasicRDSTest(unittest.TestCase):
@@ -1609,7 +1754,7 @@ class BasicTCPTest(SocketConnectedTest):
self.assertEqual(f, fileno)
# cli_conn cannot be used anymore...
self.assertTrue(self.cli_conn._closed)
- self.assertRaises(socket.error, self.cli_conn.recv, 1024)
+ self.assertRaises(OSError, self.cli_conn.recv, 1024)
self.cli_conn.close()
# ...but we can create another socket using the (still open)
# file descriptor
@@ -1978,7 +2123,7 @@ class SendmsgTests(SendrecvmsgServerTimeoutBase):
def _testSendmsgExcessCmsgReject(self):
if not hasattr(socket, "CMSG_SPACE"):
# Can only send one item
- with self.assertRaises(socket.error) as cm:
+ with self.assertRaises(OSError) as cm:
self.sendmsgToServer([MSG], [(0, 0, b""), (0, 0, b"")])
self.assertIsNone(cm.exception.errno)
self.sendToServer(b"done")
@@ -1989,7 +2134,7 @@ class SendmsgTests(SendrecvmsgServerTimeoutBase):
def _testSendmsgAfterClose(self):
self.cli_sock.close()
- self.assertRaises(socket.error, self.sendmsgToServer, [MSG])
+ self.assertRaises(OSError, self.sendmsgToServer, [MSG])
class SendmsgStreamTests(SendmsgTests):
@@ -2033,7 +2178,7 @@ class SendmsgStreamTests(SendmsgTests):
@testSendmsgDontWait.client_skip
def _testSendmsgDontWait(self):
try:
- with self.assertRaises(socket.error) as cm:
+ with self.assertRaises(OSError) as cm:
while True:
self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
self.assertIn(cm.exception.errno,
@@ -2053,9 +2198,9 @@ class SendmsgConnectionlessTests(SendmsgTests):
pass
def _testSendmsgNoDestAddr(self):
- self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ self.assertRaises(OSError, self.cli_sock.sendmsg,
[MSG])
- self.assertRaises(socket.error, self.cli_sock.sendmsg,
+ self.assertRaises(OSError, self.cli_sock.sendmsg,
[MSG], [], 0, None)
@@ -2141,7 +2286,7 @@ class RecvmsgGenericTests(SendrecvmsgBase):
def testRecvmsgAfterClose(self):
# Check that recvmsg[_into]() fails on a closed socket.
self.serv_sock.close()
- self.assertRaises(socket.error, self.doRecvmsg, self.serv_sock, 1024)
+ self.assertRaises(OSError, self.doRecvmsg, self.serv_sock, 1024)
def _testRecvmsgAfterClose(self):
pass
@@ -2587,7 +2732,7 @@ class SCMRightsTest(SendrecvmsgServerTimeoutBase):
# call fails, just send msg with no ancillary data.
try:
nbytes = self.sendmsgToServer([msg], ancdata)
- except socket.error as e:
+ except OSError as e:
# Check that it was the system call that failed
self.assertIsInstance(e.errno, int)
nbytes = self.sendmsgToServer([msg])
@@ -2965,7 +3110,7 @@ class RFC3542AncillaryTest(SendrecvmsgServerTimeoutBase):
array.array("i", [self.traffic_class]).tobytes() + b"\x00"),
(socket.IPPROTO_IPV6, socket.IPV6_HOPLIMIT,
array.array("i", [self.hop_limit]))])
- except socket.error as e:
+ except OSError as e:
self.assertIsInstance(e.errno, int)
nbytes = self.sendmsgToServer(
[MSG],
@@ -3423,10 +3568,10 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
self.serv.settimeout(self.timeout)
def checkInterruptedRecv(self, func, *args, **kwargs):
- # Check that func(*args, **kwargs) raises socket.error with an
+ # Check that func(*args, **kwargs) raises OSError with an
# errno of EINTR when interrupted by a signal.
self.setAlarm(self.alarm_time)
- with self.assertRaises(socket.error) as cm:
+ with self.assertRaises(OSError) as cm:
func(*args, **kwargs)
self.assertNotIsInstance(cm.exception, socket.timeout)
self.assertEqual(cm.exception.errno, errno.EINTR)
@@ -3483,9 +3628,9 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
def checkInterruptedSend(self, func, *args, **kwargs):
# Check that func(*args, **kwargs), run in a loop, raises
- # socket.error with an errno of EINTR when interrupted by a
+ # OSError with an errno of EINTR when interrupted by a
# signal.
- with self.assertRaises(socket.error) as cm:
+ with self.assertRaises(OSError) as cm:
while True:
self.setAlarm(self.alarm_time)
func(*args, **kwargs)
@@ -3584,7 +3729,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
start = time.time()
try:
self.serv.accept()
- except socket.error:
+ except OSError:
pass
end = time.time()
self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.")
@@ -3610,7 +3755,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
start = time.time()
try:
self.serv.accept()
- except socket.error:
+ except OSError:
pass
end = time.time()
self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.")
@@ -3640,7 +3785,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
self.serv.setblocking(0)
try:
conn, addr = self.serv.accept()
- except socket.error:
+ except OSError:
pass
else:
self.fail("Error trying to do non-blocking accept.")
@@ -3670,7 +3815,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
conn.setblocking(0)
try:
msg = conn.recv(len(MSG))
- except socket.error:
+ except OSError:
pass
else:
self.fail("Error trying to do non-blocking recv.")
@@ -3753,7 +3898,7 @@ class FileObjectClassTestCase(SocketConnectedTest):
# First read raises a timeout
self.assertRaises(socket.timeout, self.read_file.read, 1)
# Second read is disallowed
- with self.assertRaises(IOError) as ctx:
+ with self.assertRaises(OSError) as ctx:
self.read_file.read(1)
self.assertIn("cannot read from timed out object", str(ctx.exception))
@@ -3845,7 +3990,7 @@ class FileObjectClassTestCase(SocketConnectedTest):
self.read_file.close()
self.assertRaises(ValueError, self.read_file.fileno)
self.cli_conn.close()
- self.assertRaises(socket.error, self.cli_conn.getsockname)
+ self.assertRaises(OSError, self.cli_conn.getsockname)
def _testRealClose(self):
pass
@@ -3882,7 +4027,7 @@ class FileObjectInterruptedTestCase(unittest.TestCase):
@staticmethod
def _raise_eintr():
- raise socket.error(errno.EINTR, "interrupted")
+ raise OSError(errno.EINTR, "interrupted")
def _textiowrap_mock_socket(self, mock, buffering=-1):
raw = socket.SocketIO(mock, "r")
@@ -3994,7 +4139,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
self.assertEqual(msg, self.read_msg)
# ...until the file is itself closed
self.read_file.close()
- self.assertRaises(socket.error, self.cli_conn.recv, 1024)
+ self.assertRaises(OSError, self.cli_conn.recv, 1024)
def _testMakefileClose(self):
self.write_file.write(self.write_msg)
@@ -4143,7 +4288,7 @@ class NetworkConnectionNoServer(unittest.TestCase):
port = support.find_unused_port()
cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.addCleanup(cli.close)
- with self.assertRaises(socket.error) as cm:
+ with self.assertRaises(OSError) as cm:
cli.connect((HOST, port))
self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
@@ -4151,7 +4296,7 @@ class NetworkConnectionNoServer(unittest.TestCase):
# Issue #9792: errors raised by create_connection() should have
# a proper errno attribute.
port = support.find_unused_port()
- with self.assertRaises(socket.error) as cm:
+ with self.assertRaises(OSError) as cm:
socket.create_connection((HOST, port))
# Issue #16257: create_connection() calls getaddrinfo() against
@@ -4299,7 +4444,7 @@ class TCPTimeoutTest(SocketTCPTest):
foo = self.serv.accept()
except socket.timeout:
self.fail("caught timeout instead of error (TCP)")
- except socket.error:
+ except OSError:
ok = True
except:
self.fail("caught unexpected exception (TCP)")
@@ -4356,7 +4501,7 @@ class UDPTimeoutTest(SocketUDPTest):
foo = self.serv.recv(1024)
except socket.timeout:
self.fail("caught timeout instead of error (UDP)")
- except socket.error:
+ except OSError:
ok = True
except:
self.fail("caught unexpected exception (UDP)")
@@ -4366,10 +4511,10 @@ class UDPTimeoutTest(SocketUDPTest):
class TestExceptions(unittest.TestCase):
def testExceptionTree(self):
- self.assertTrue(issubclass(socket.error, Exception))
- self.assertTrue(issubclass(socket.herror, socket.error))
- self.assertTrue(issubclass(socket.gaierror, socket.error))
- self.assertTrue(issubclass(socket.timeout, socket.error))
+ self.assertTrue(issubclass(OSError, Exception))
+ self.assertTrue(issubclass(socket.herror, OSError))
+ self.assertTrue(issubclass(socket.gaierror, OSError))
+ self.assertTrue(issubclass(socket.timeout, OSError))
@unittest.skipUnless(sys.platform == 'linux', 'Linux specific test')
class TestLinuxAbstractNamespace(unittest.TestCase):
@@ -4396,7 +4541,7 @@ class TestLinuxAbstractNamespace(unittest.TestCase):
def testNameOverflow(self):
address = "\x00" + "h" * self.UNIX_PATH_MAX
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
- self.assertRaises(socket.error, s.bind, address)
+ self.assertRaises(OSError, s.bind, address)
def testStrName(self):
# Check that an abstract name can be passed as a string.
@@ -4638,7 +4783,7 @@ class ContextManagersTest(ThreadedTCPSocketTest):
self.assertTrue(sock._closed)
# exception inside with block
with socket.socket() as sock:
- self.assertRaises(socket.error, sock.sendall, b'foo')
+ self.assertRaises(OSError, sock.sendall, b'foo')
self.assertTrue(sock._closed)
def testCreateConnectionBase(self):
@@ -4666,19 +4811,76 @@ class ContextManagersTest(ThreadedTCPSocketTest):
with socket.create_connection(address) as sock:
sock.close()
self.assertTrue(sock._closed)
- self.assertRaises(socket.error, sock.sendall, b'foo')
+ self.assertRaises(OSError, sock.sendall, b'foo')
-@unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"),
- "SOCK_CLOEXEC not defined")
-@unittest.skipUnless(fcntl, "module fcntl not available")
-class CloexecConstantTest(unittest.TestCase):
+class InheritanceTest(unittest.TestCase):
+ @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"),
+ "SOCK_CLOEXEC not defined")
@support.requires_linux_version(2, 6, 28)
def test_SOCK_CLOEXEC(self):
with socket.socket(socket.AF_INET,
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
self.assertTrue(s.type & socket.SOCK_CLOEXEC)
- self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC)
+ self.assertFalse(s.get_inheritable())
+
+ def test_default_inheritable(self):
+ sock = socket.socket()
+ with sock:
+ self.assertEqual(sock.get_inheritable(), False)
+
+ def test_dup(self):
+ sock = socket.socket()
+ with sock:
+ newsock = sock.dup()
+ sock.close()
+ with newsock:
+ self.assertEqual(newsock.get_inheritable(), False)
+
+ def test_set_inheritable(self):
+ sock = socket.socket()
+ with sock:
+ sock.set_inheritable(True)
+ self.assertEqual(sock.get_inheritable(), True)
+
+ sock.set_inheritable(False)
+ self.assertEqual(sock.get_inheritable(), False)
+
+ @unittest.skipIf(fcntl is None, "need fcntl")
+ def test_get_inheritable_cloexec(self):
+ sock = socket.socket()
+ with sock:
+ fd = sock.fileno()
+ self.assertEqual(sock.get_inheritable(), False)
+
+ # clear FD_CLOEXEC flag
+ flags = fcntl.fcntl(fd, fcntl.F_GETFD)
+ flags &= ~fcntl.FD_CLOEXEC
+ fcntl.fcntl(fd, fcntl.F_SETFD, flags)
+
+ self.assertEqual(sock.get_inheritable(), True)
+
+ @unittest.skipIf(fcntl is None, "need fcntl")
+ def test_set_inheritable_cloexec(self):
+ sock = socket.socket()
+ with sock:
+ fd = sock.fileno()
+ self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
+ fcntl.FD_CLOEXEC)
+
+ sock.set_inheritable(True)
+ self.assertEqual(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC,
+ 0)
+
+
+ @unittest.skipUnless(hasattr(socket, "socketpair"),
+ "need socket.socketpair()")
+ def test_socketpair(self):
+ s1, s2 = socket.socketpair()
+ self.addCleanup(s1.close)
+ self.addCleanup(s2.close)
+ self.assertEqual(s1.get_inheritable(), False)
+ self.assertEqual(s2.get_inheritable(), False)
@unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"),
@@ -4847,7 +5049,7 @@ def test_main():
NetworkConnectionAttributesTest,
NetworkConnectionBehaviourTest,
ContextManagersTest,
- CloexecConstantTest,
+ InheritanceTest,
NonblockConstantTest
])
tests.append(BasicSocketPairTest)
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 59d8e5dcb7..0617b30a69 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -2,8 +2,8 @@
Test suite for socketserver.
"""
+import _imp as imp
import contextlib
-import imp
import os
import select
import signal
@@ -29,7 +29,7 @@ HOST = test.support.HOST
HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
'requires Unix sockets')
-HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
+HAVE_FORKING = hasattr(os, "fork")
requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
def signal_alarm(n):
@@ -85,7 +85,7 @@ class SocketServerTest(unittest.TestCase):
for fn in self.test_files:
try:
os.remove(fn)
- except os.error:
+ except OSError:
pass
self.test_files[:] = []
@@ -96,21 +96,7 @@ class SocketServerTest(unittest.TestCase):
# XXX: We need a way to tell AF_UNIX to pick its own name
# like AF_INET provides port==0.
dir = None
- if os.name == 'os2':
- dir = '\socket'
fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
- if os.name == 'os2':
- # AF_UNIX socket names on OS/2 require a specific prefix
- # which can't include a drive letter and must also use
- # backslashes as directory separators
- if fn[1] == ':':
- fn = fn[2:]
- if fn[0] in (os.sep, os.altsep):
- fn = fn[1:]
- if os.sep == '/':
- fn = fn.replace(os.sep, os.altsep)
- else:
- fn = fn.replace(os.altsep, os.sep)
self.test_files.append(fn)
return fn
diff --git a/Lib/test/test_pep263.py b/Lib/test/test_source_encoding.py
index 324ae38619..cd9d2b374c 100644
--- a/Lib/test/test_pep263.py
+++ b/Lib/test/test_source_encoding.py
@@ -1,9 +1,12 @@
# -*- coding: koi8-r -*-
import unittest
-from test import support
+from test.support import TESTFN, unlink, unload
+import importlib
+import os
+import sys
-class PEP263Test(unittest.TestCase):
+class SourceEncodingTest(unittest.TestCase):
def test_pep263(self):
self.assertEqual(
@@ -72,9 +75,61 @@ class PEP263Test(unittest.TestCase):
with self.assertRaisesRegex(SyntaxError, 'BOM'):
compile(b'\xef\xbb\xbf# -*- coding: fake -*-\n', 'dummy', 'exec')
+ def test_bad_coding(self):
+ module_name = 'bad_coding'
+ self.verify_bad_module(module_name)
-def test_main():
- support.run_unittest(PEP263Test)
+ def test_bad_coding2(self):
+ module_name = 'bad_coding2'
+ self.verify_bad_module(module_name)
-if __name__=="__main__":
- test_main()
+ def verify_bad_module(self, module_name):
+ self.assertRaises(SyntaxError, __import__, 'test.' + module_name)
+
+ path = os.path.dirname(__file__)
+ filename = os.path.join(path, module_name + '.py')
+ with open(filename, "rb") as fp:
+ bytes = fp.read()
+ self.assertRaises(SyntaxError, compile, bytes, filename, 'exec')
+
+ def test_exec_valid_coding(self):
+ d = {}
+ exec(b'# coding: cp949\na = "\xaa\xa7"\n', d)
+ self.assertEqual(d['a'], '\u3047')
+
+ def test_file_parse(self):
+ # issue1134: all encodings outside latin-1 and utf-8 fail on
+ # multiline strings and long lines (>512 columns)
+ unload(TESTFN)
+ filename = TESTFN + ".py"
+ f = open(filename, "w", encoding="cp1252")
+ sys.path.insert(0, os.curdir)
+ try:
+ with f:
+ f.write("# -*- coding: cp1252 -*-\n")
+ f.write("'''A short string\n")
+ f.write("'''\n")
+ f.write("'A very long string %s'\n" % ("X" * 1000))
+
+ importlib.invalidate_caches()
+ __import__(TESTFN)
+ finally:
+ del sys.path[0]
+ unlink(filename)
+ unlink(filename + "c")
+ unlink(filename + "o")
+ unload(TESTFN)
+
+ def test_error_from_string(self):
+ # See http://bugs.python.org/issue6289
+ input = "# coding: ascii\n\N{SNOWMAN}".encode('utf-8')
+ with self.assertRaises(SyntaxError) as c:
+ compile(input, "<string>", "exec")
+ expected = "'ascii' codec can't decode byte 0xe2 in position 16: " \
+ "ordinal not in range(128)"
+ self.assertTrue(c.exception.args[0].startswith(expected),
+ msg=c.exception.args[0])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 06d4598a87..b1cb8c5424 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -6,6 +6,7 @@ from test import support
import socket
import select
import time
+import datetime
import gc
import os
import errno
@@ -17,16 +18,11 @@ import asyncore
import weakref
import platform
import functools
+from unittest import mock
ssl = support.import_module("ssl")
-PROTOCOLS = [
- ssl.PROTOCOL_SSLv3,
- ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1
-]
-if hasattr(ssl, 'PROTOCOL_SSLv2'):
- PROTOCOLS.append(ssl.PROTOCOL_SSLv2)
-
+PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = support.HOST
data_file = lambda name: os.path.join(os.path.dirname(__file__), name)
@@ -48,6 +44,11 @@ KEY_PASSWORD = "somepass"
CAPATH = data_file("capath")
BYTES_CAPATH = os.fsencode(CAPATH)
+# Two keys and certs signed by the same CA (for SNI tests)
+SIGNED_CERTFILE = data_file("keycert3.pem")
+SIGNED_CERTFILE2 = data_file("keycert4.pem")
+SIGNING_CA = data_file("pycacert.pem")
+
SVN_PYTHON_ORG_ROOT_CERT = data_file("https_svn_python_org_root.pem")
EMPTYCERT = data_file("nullcert.pem")
@@ -60,6 +61,7 @@ NULLBYTECERT = data_file("nullbytecert.pem")
DHFILE = data_file("dh512.pem")
BYTES_DHFILE = os.fsencode(DHFILE)
+
def handle_error(prefix):
exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
if support.verbose:
@@ -73,6 +75,19 @@ def no_sslv2_implies_sslv3_hello():
# 0.9.7h or higher
return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
+def asn1time(cert_time):
+ # Some versions of OpenSSL ignore seconds, see #18207
+ # 0.9.8.i
+ if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
+ fmt = "%b %d %H:%M:%S %Y GMT"
+ dt = datetime.datetime.strptime(cert_time, fmt)
+ dt = dt.replace(second=0)
+ cert_time = dt.strftime(fmt)
+ # %d adds leading zero but ASN1_TIME_print() uses leading space
+ if cert_time[4] == "0":
+ cert_time = cert_time[:4] + " " + cert_time[5:]
+
+ return cert_time
# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2
def skip_if_broken_ubuntu_ssl(func):
@@ -90,14 +105,12 @@ def skip_if_broken_ubuntu_ssl(func):
else:
return func
+needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
+
class BasicSocketTests(unittest.TestCase):
def test_constants(self):
- #ssl.PROTOCOL_SSLv2
- ssl.PROTOCOL_SSLv23
- ssl.PROTOCOL_SSLv3
- ssl.PROTOCOL_TLSv1
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED
@@ -175,8 +188,9 @@ class BasicSocketTests(unittest.TestCase):
(('organizationName', 'Python Software Foundation'),),
(('commonName', 'localhost'),))
)
- self.assertEqual(p['notAfter'], 'Oct 5 23:01:56 2020 GMT')
- self.assertEqual(p['notBefore'], 'Oct 8 23:01:56 2010 GMT')
+ # Note the next three asserts will fail if the keys are regenerated
+ self.assertEqual(p['notAfter'], asn1time('Oct 5 23:01:56 2020 GMT'))
+ self.assertEqual(p['notBefore'], asn1time('Oct 8 23:01:56 2010 GMT'))
self.assertEqual(p['serialNumber'], 'D7C7381919AFC24E')
self.assertEqual(p['subject'],
((('countryName', 'XY'),),
@@ -276,15 +290,15 @@ class BasicSocketTests(unittest.TestCase):
def test_wrapped_unconnected(self):
# Methods on an unconnected SSLSocket propagate the original
- # socket.error raise by the underlying socket object.
+ # OSError raise by the underlying socket object.
s = socket.socket(socket.AF_INET)
with ssl.wrap_socket(s) as ss:
- self.assertRaises(socket.error, ss.recv, 1)
- self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
- self.assertRaises(socket.error, ss.recvfrom, 1)
- self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
- self.assertRaises(socket.error, ss.send, b'x')
- self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
+ self.assertRaises(OSError, ss.recv, 1)
+ self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
+ self.assertRaises(OSError, ss.recvfrom, 1)
+ self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
+ self.assertRaises(OSError, ss.send, b'x')
+ self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the
@@ -309,15 +323,15 @@ class BasicSocketTests(unittest.TestCase):
with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
s.connect, (HOST, 8080))
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(OSError) as cm:
with socket.socket() as sock:
ssl.wrap_socket(sock, certfile=WRONGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(OSError) as cm:
with socket.socket() as sock:
ssl.wrap_socket(sock, certfile=CERTFILE, keyfile=WRONGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(OSError) as cm:
with socket.socket() as sock:
ssl.wrap_socket(sock, certfile=WRONGCERT, keyfile=WRONGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
@@ -489,15 +503,48 @@ class BasicSocketTests(unittest.TestCase):
support.gc_collect()
self.assertIn(r, str(cm.warning.args[0]))
+ def test_get_default_verify_paths(self):
+ paths = ssl.get_default_verify_paths()
+ self.assertEqual(len(paths), 6)
+ self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
+
+ with support.EnvironmentVarGuard() as env:
+ env["SSL_CERT_DIR"] = CAPATH
+ env["SSL_CERT_FILE"] = CERTFILE
+ paths = ssl.get_default_verify_paths()
+ self.assertEqual(paths.cafile, CERTFILE)
+ self.assertEqual(paths.capath, CAPATH)
+
+
+ @unittest.skipUnless(sys.platform == "win32", "Windows specific")
+ def test_enum_cert_store(self):
+ self.assertEqual(ssl.X509_ASN_ENCODING, 1)
+ self.assertEqual(ssl.PKCS_7_ASN_ENCODING, 0x00010000)
+
+ self.assertEqual(ssl.enum_cert_store("CA"),
+ ssl.enum_cert_store("CA", "certificate"))
+ ssl.enum_cert_store("CA", "crl")
+ self.assertEqual(ssl.enum_cert_store("ROOT"),
+ ssl.enum_cert_store("ROOT", "certificate"))
+ ssl.enum_cert_store("ROOT", "crl")
+
+ self.assertRaises(TypeError, ssl.enum_cert_store)
+ self.assertRaises(WindowsError, ssl.enum_cert_store, "")
+ self.assertRaises(ValueError, ssl.enum_cert_store, "CA", "wrong")
+
+ ca = ssl.enum_cert_store("CA")
+ self.assertIsInstance(ca, list)
+ self.assertIsInstance(ca[0], tuple)
+ self.assertEqual(len(ca[0]), 2)
+ self.assertIsInstance(ca[0][0], bytes)
+ self.assertIsInstance(ca[0][1], int)
+
class ContextTests(unittest.TestCase):
@skip_if_broken_ubuntu_ssl
def test_constructor(self):
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv2)
- ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv3)
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ for protocol in PROTOCOLS:
+ ssl.SSLContext(protocol)
self.assertRaises(TypeError, ssl.SSLContext)
self.assertRaises(ValueError, ssl.SSLContext, -1)
self.assertRaises(ValueError, ssl.SSLContext, 42)
@@ -557,7 +604,7 @@ class ContextTests(unittest.TestCase):
ctx.load_cert_chain(CERTFILE)
ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(OSError) as cm:
ctx.load_cert_chain(WRONGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
@@ -642,7 +689,7 @@ class ContextTests(unittest.TestCase):
ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
self.assertRaises(TypeError, ctx.load_verify_locations)
self.assertRaises(TypeError, ctx.load_verify_locations, None, None)
- with self.assertRaises(IOError) as cm:
+ with self.assertRaises(OSError) as cm:
ctx.load_verify_locations(WRONGCERT)
self.assertEqual(cm.exception.errno, errno.ENOENT)
with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
@@ -700,6 +747,75 @@ class ContextTests(unittest.TestCase):
self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
+ @needs_sni
+ def test_sni_callback(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+
+ # set_servername_callback expects a callable, or None
+ self.assertRaises(TypeError, ctx.set_servername_callback)
+ self.assertRaises(TypeError, ctx.set_servername_callback, 4)
+ self.assertRaises(TypeError, ctx.set_servername_callback, "")
+ self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
+
+ def dummycallback(sock, servername, ctx):
+ pass
+ ctx.set_servername_callback(None)
+ ctx.set_servername_callback(dummycallback)
+
+ @needs_sni
+ def test_sni_callback_refcycle(self):
+ # Reference cycles through the servername callback are detected
+ # and cleared.
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ def dummycallback(sock, servername, ctx, cycle=ctx):
+ pass
+ ctx.set_servername_callback(dummycallback)
+ wr = weakref.ref(ctx)
+ del ctx, dummycallback
+ gc.collect()
+ self.assertIs(wr(), None)
+
+ def test_cert_store_stats(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ self.assertEqual(ctx.cert_store_stats(),
+ {'x509_ca': 0, 'crl': 0, 'x509': 0})
+ ctx.load_cert_chain(CERTFILE)
+ self.assertEqual(ctx.cert_store_stats(),
+ {'x509_ca': 0, 'crl': 0, 'x509': 0})
+ ctx.load_verify_locations(CERTFILE)
+ self.assertEqual(ctx.cert_store_stats(),
+ {'x509_ca': 0, 'crl': 0, 'x509': 1})
+ ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT)
+ self.assertEqual(ctx.cert_store_stats(),
+ {'x509_ca': 1, 'crl': 0, 'x509': 2})
+
+ def test_get_ca_certs(self):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ self.assertEqual(ctx.get_ca_certs(), [])
+ # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
+ ctx.load_verify_locations(CERTFILE)
+ self.assertEqual(ctx.get_ca_certs(), [])
+ # but SVN_PYTHON_ORG_ROOT_CERT is a CA cert
+ ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT)
+ self.assertEqual(ctx.get_ca_certs(),
+ [{'issuer': ((('organizationName', 'Root CA'),),
+ (('organizationalUnitName', 'http://www.cacert.org'),),
+ (('commonName', 'CA Cert Signing Authority'),),
+ (('emailAddress', 'support@cacert.org'),)),
+ 'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
+ 'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
+ 'serialNumber': '00',
+ 'subject': ((('organizationName', 'Root CA'),),
+ (('organizationalUnitName', 'http://www.cacert.org'),),
+ (('commonName', 'CA Cert Signing Authority'),),
+ (('emailAddress', 'support@cacert.org'),)),
+ 'version': 3}])
+
+ with open(SVN_PYTHON_ORG_ROOT_CERT) as f:
+ pem = f.read()
+ der = ssl.PEM_cert_to_DER_cert(pem)
+ self.assertEqual(ctx.get_ca_certs(True), [der])
+
class SSLErrorTests(unittest.TestCase):
@@ -1015,6 +1131,22 @@ class NetworkedTests(unittest.TestCase):
finally:
s.close()
+ def test_get_ca_certs_capath(self):
+ # capath certs are loaded on request
+ with support.transient_internet("svn.python.org"):
+ ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ ctx.load_verify_locations(capath=CAPATH)
+ self.assertEqual(ctx.get_ca_certs(), [])
+ s = ctx.wrap_socket(socket.socket(socket.AF_INET))
+ s.connect(("svn.python.org", 443))
+ try:
+ cert = s.getpeercert()
+ self.assertTrue(cert)
+ finally:
+ s.close()
+ self.assertEqual(len(ctx.get_ca_certs()), 1)
+
try:
import threading
@@ -1142,7 +1274,7 @@ else:
sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
% (msg, ctype, msg.lower(), ctype))
self.write(msg.lower())
- except socket.error:
+ except OSError:
if self.server.chatty:
handle_error("Test server failure:\n")
self.close()
@@ -1252,7 +1384,7 @@ else:
return self.handle_close()
except ssl.SSLError:
raise
- except socket.error as err:
+ except OSError as err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
@@ -1356,19 +1488,19 @@ else:
except ssl.SSLError as x:
if support.verbose:
sys.stdout.write("\nSSLError is %s\n" % x.args[1])
- except socket.error as x:
+ except OSError as x:
if support.verbose:
- sys.stdout.write("\nsocket.error is %s\n" % x.args[1])
- except IOError as x:
+ sys.stdout.write("\nOSError is %s\n" % x.args[1])
+ except OSError as x:
if x.errno != errno.ENOENT:
raise
if support.verbose:
- sys.stdout.write("\IOError is %s\n" % str(x))
+ sys.stdout.write("\OSError is %s\n" % str(x))
else:
raise AssertionError("Use of invalid cert should have failed!")
def server_params_test(client_context, server_context, indata=b"FOO\n",
- chatty=True, connectionchatty=False):
+ chatty=True, connectionchatty=False, sni_name=None):
"""
Launch a server, connect a client to it and try various reads
and writes.
@@ -1378,7 +1510,8 @@ else:
chatty=chatty,
connectionchatty=False)
with server:
- with client_context.wrap_socket(socket.socket()) as s:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=sni_name) as s:
s.connect((HOST, server.port))
for arg in [indata, bytearray(indata), memoryview(indata)]:
if connectionchatty:
@@ -1402,6 +1535,7 @@ else:
stats.update({
'compression': s.compression(),
'cipher': s.cipher(),
+ 'peercert': s.getpeercert(),
'client_npn_protocol': s.selected_npn_protocol()
})
s.close()
@@ -1427,12 +1561,15 @@ else:
client_context.options = ssl.OP_ALL | client_options
server_context = ssl.SSLContext(server_protocol)
server_context.options = ssl.OP_ALL | server_options
+
+ # NOTE: we must enable "ALL" ciphers on the client, otherwise an
+ # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
+ # starting from OpenSSL 1.0.0 (see issue #8322).
+ if client_context.protocol == ssl.PROTOCOL_SSLv23:
+ client_context.set_ciphers("ALL")
+
for ctx in (client_context, server_context):
ctx.verify_mode = certsreqs
- # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
- # will send an SSLv3 hello (rather than SSLv2) starting from
- # OpenSSL 1.0.0 (see issue #8322).
- ctx.set_ciphers("ALL")
ctx.load_cert_chain(CERTFILE)
ctx.load_verify_locations(CERTFILE)
try:
@@ -1443,7 +1580,7 @@ else:
except ssl.SSLError:
if expect_success:
raise
- except socket.error as e:
+ except OSError as e:
if expect_success or e.errno != errno.ECONNRESET:
raise
else:
@@ -1462,10 +1599,11 @@ else:
if support.verbose:
sys.stdout.write("\n")
for protocol in PROTOCOLS:
- context = ssl.SSLContext(protocol)
- context.load_cert_chain(CERTFILE)
- server_params_test(context, context,
- chatty=True, connectionchatty=True)
+ with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
+ context = ssl.SSLContext(protocol)
+ context.load_cert_chain(CERTFILE)
+ server_params_test(context, context,
+ chatty=True, connectionchatty=True)
def test_getpeercert(self):
if support.verbose:
@@ -1476,8 +1614,14 @@ else:
context.load_cert_chain(CERTFILE)
server = ThreadedEchoServer(context=context, chatty=False)
with server:
- s = context.wrap_socket(socket.socket())
+ s = context.wrap_socket(socket.socket(),
+ do_handshake_on_connect=False)
s.connect((HOST, server.port))
+ # getpeercert() raise ValueError while the handshake isn't
+ # done.
+ with self.assertRaises(ValueError):
+ s.getpeercert()
+ s.do_handshake()
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
cipher = s.cipher()
@@ -1517,7 +1661,7 @@ else:
"badkey.pem"))
def test_rude_shutdown(self):
- """A brutal shutdown of an SSL server should raise an IOError
+ """A brutal shutdown of an SSL server should raise an OSError
in the client when attempting handshake.
"""
listener_ready = threading.Event()
@@ -1545,7 +1689,7 @@ else:
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(c)
- except IOError:
+ except OSError:
pass
else:
self.fail('connecting to closed SSL socket should have failed')
@@ -1588,7 +1732,7 @@ else:
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try:
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
- except (ssl.SSLError, socket.error) as x:
+ except OSError as x:
# this fails on some older versions of OpenSSL (0.9.7l, for instance)
if support.verbose:
sys.stdout.write(
@@ -1648,6 +1792,49 @@ else:
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_TLSv1)
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
+ "TLS version 1.1 not supported.")
+ def test_protocol_tlsv1_1(self):
+ """Connecting to a TLSv1.1 server with various client options.
+ Testing against older TLS versions."""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1_1)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
+
+
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
+ "TLS version 1.2 not supported.")
+ def test_protocol_tlsv1_2(self):
+ """Connecting to a TLSv1.2 server with various client options.
+ Testing against older TLS versions."""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True,
+ server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
+ client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1_2)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
+
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
@@ -1708,7 +1895,7 @@ else:
def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections."""
- server = make_https_server(self, CERTFILE)
+ server = make_https_server(self, certfile=CERTFILE)
# try to connect
if support.verbose:
sys.stdout.write('\n')
@@ -1964,6 +2151,20 @@ else:
self.assertIsInstance(remote, ssl.SSLSocket)
self.assertEqual(peer, client_addr)
+ def test_getpeercert_enotconn(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with context.wrap_socket(socket.socket()) as sock:
+ with self.assertRaises(OSError) as cm:
+ sock.getpeercert()
+ self.assertEqual(cm.exception.errno, errno.ENOTCONN)
+
+ def test_do_handshake_enotconn(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with context.wrap_socket(socket.socket()) as sock:
+ with self.assertRaises(OSError) as cm:
+ sock.do_handshake()
+ self.assertEqual(cm.exception.errno, errno.ENOTCONN)
+
def test_default_ciphers(self):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
try:
@@ -1975,7 +2176,7 @@ else:
ssl_version=ssl.PROTOCOL_SSLv23,
chatty=False) as server:
with context.wrap_socket(socket.socket()) as s:
- with self.assertRaises((OSError, ssl.SSLError)):
+ with self.assertRaises(OSError):
s.connect((HOST, server.port))
self.assertIn("no shared cipher", str(server.conn_errors[0]))
@@ -2108,6 +2309,124 @@ else:
if len(stats['server_npn_protocols']) else 'nothing'
self.assertEqual(server_result, expected, msg % (server_result, "server"))
+ def sni_contexts(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ other_context.load_cert_chain(SIGNED_CERTFILE2)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ return server_context, other_context, client_context
+
+ def check_common_name(self, stats, name):
+ cert = stats['peercert']
+ self.assertIn((('commonName', name),), cert['subject'])
+
+ @needs_sni
+ def test_sni_callback(self):
+ calls = []
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def servername_cb(ssl_sock, server_name, initial_context):
+ calls.append((server_name, initial_context))
+ if server_name is not None:
+ ssl_sock.context = other_context
+ server_context.set_servername_callback(servername_cb)
+
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name='supermessage')
+ # The hostname was fetched properly, and the certificate was
+ # changed for the connection.
+ self.assertEqual(calls, [("supermessage", server_context)])
+ # CERTFILE4 was selected
+ self.check_common_name(stats, 'fakehostname')
+
+ calls = []
+ # The callback is called with server_name=None
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name=None)
+ self.assertEqual(calls, [(None, server_context)])
+ self.check_common_name(stats, 'localhost')
+
+ # Check disabling the callback
+ calls = []
+ server_context.set_servername_callback(None)
+
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name='notfunny')
+ # Certificate didn't change
+ self.check_common_name(stats, 'localhost')
+ self.assertEqual(calls, [])
+
+ @needs_sni
+ def test_sni_callback_alert(self):
+ # Returning a TLS alert is reflected to the connecting client
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def cb_returning_alert(ssl_sock, server_name, initial_context):
+ return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
+ server_context.set_servername_callback(cb_returning_alert)
+
+ with self.assertRaises(ssl.SSLError) as cm:
+ stats = server_params_test(client_context, server_context,
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
+
+ @needs_sni
+ def test_sni_callback_raising(self):
+ # Raising fails the connection with a TLS handshake failure alert.
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def cb_raising(ssl_sock, server_name, initial_context):
+ 1/0
+ server_context.set_servername_callback(cb_raising)
+
+ with self.assertRaises(ssl.SSLError) as cm, \
+ support.captured_stderr() as stderr:
+ stats = server_params_test(client_context, server_context,
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
+ self.assertIn("ZeroDivisionError", stderr.getvalue())
+
+ @needs_sni
+ def test_sni_callback_wrong_return_type(self):
+ # Returning the wrong return type terminates the TLS connection
+ # with an internal error alert.
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def cb_wrong_return_type(ssl_sock, server_name, initial_context):
+ return "foo"
+ server_context.set_servername_callback(cb_wrong_return_type)
+
+ with self.assertRaises(ssl.SSLError) as cm, \
+ support.captured_stderr() as stderr:
+ stats = server_params_test(client_context, server_context,
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
+ self.assertIn("TypeError", stderr.getvalue())
+
+ def test_read_write_after_close_raises_valuerror(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+
+ with server:
+ s = context.wrap_socket(socket.socket())
+ s.connect((HOST, server.port))
+ s.close()
+
+ self.assertRaises(ValueError, s.read, 1024)
+ self.assertRaises(ValueError, s.write, b'hello')
+
def test_main(verbose=False):
if support.verbose:
@@ -2127,10 +2446,16 @@ def test_main(verbose=False):
(ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
print(" under %s" % plat)
print(" HAS_SNI = %r" % ssl.HAS_SNI)
+ print(" OP_ALL = 0x%8x" % ssl.OP_ALL)
+ try:
+ print(" OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
+ except AttributeError:
+ pass
for filename in [
CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, BYTES_CERTFILE,
ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
+ SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
BADCERT, BADKEY, EMPTYCERT]:
if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename)
diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py
index eb3f07a63d..af6ced4204 100644
--- a/Lib/test/test_stat.py
+++ b/Lib/test/test_stat.py
@@ -1,9 +1,13 @@
import unittest
import os
-from test.support import TESTFN, run_unittest, import_fresh_module
-import stat
+from test.support import TESTFN, import_fresh_module
+
+c_stat = import_fresh_module('stat', fresh=['_stat'])
+py_stat = import_fresh_module('stat', blocked=['_stat'])
+
+class TestFilemode:
+ statmod = None
-class TestFilemode(unittest.TestCase):
file_flags = {'SF_APPEND', 'SF_ARCHIVED', 'SF_IMMUTABLE', 'SF_NOUNLINK',
'SF_SNAPSHOT', 'UF_APPEND', 'UF_COMPRESSED', 'UF_HIDDEN',
'UF_IMMUTABLE', 'UF_NODUMP', 'UF_NOUNLINK', 'UF_OPAQUE'}
@@ -63,17 +67,17 @@ class TestFilemode(unittest.TestCase):
st_mode = os.lstat(fname).st_mode
else:
st_mode = os.stat(fname).st_mode
- modestr = stat.filemode(st_mode)
+ modestr = self.statmod.filemode(st_mode)
return st_mode, modestr
def assertS_IS(self, name, mode):
# test format, lstrip is for S_IFIFO
- fmt = getattr(stat, "S_IF" + name.lstrip("F"))
- self.assertEqual(stat.S_IFMT(mode), fmt)
+ fmt = getattr(self.statmod, "S_IF" + name.lstrip("F"))
+ self.assertEqual(self.statmod.S_IFMT(mode), fmt)
# test that just one function returns true
testname = "S_IS" + name
for funcname in self.format_funcs:
- func = getattr(stat, funcname, None)
+ func = getattr(self.statmod, funcname, None)
if func is None:
if funcname == testname:
raise ValueError(funcname)
@@ -91,35 +95,35 @@ class TestFilemode(unittest.TestCase):
st_mode, modestr = self.get_mode()
self.assertEqual(modestr, '-rwx------')
self.assertS_IS("REG", st_mode)
- self.assertEqual(stat.S_IMODE(st_mode),
- stat.S_IRWXU)
+ self.assertEqual(self.statmod.S_IMODE(st_mode),
+ self.statmod.S_IRWXU)
os.chmod(TESTFN, 0o070)
st_mode, modestr = self.get_mode()
self.assertEqual(modestr, '----rwx---')
self.assertS_IS("REG", st_mode)
- self.assertEqual(stat.S_IMODE(st_mode),
- stat.S_IRWXG)
+ self.assertEqual(self.statmod.S_IMODE(st_mode),
+ self.statmod.S_IRWXG)
os.chmod(TESTFN, 0o007)
st_mode, modestr = self.get_mode()
self.assertEqual(modestr, '-------rwx')
self.assertS_IS("REG", st_mode)
- self.assertEqual(stat.S_IMODE(st_mode),
- stat.S_IRWXO)
+ self.assertEqual(self.statmod.S_IMODE(st_mode),
+ self.statmod.S_IRWXO)
os.chmod(TESTFN, 0o444)
st_mode, modestr = self.get_mode()
self.assertS_IS("REG", st_mode)
self.assertEqual(modestr, '-r--r--r--')
- self.assertEqual(stat.S_IMODE(st_mode), 0o444)
+ self.assertEqual(self.statmod.S_IMODE(st_mode), 0o444)
else:
os.chmod(TESTFN, 0o700)
st_mode, modestr = self.get_mode()
self.assertEqual(modestr[:3], '-rw')
self.assertS_IS("REG", st_mode)
- self.assertEqual(stat.S_IFMT(st_mode),
- stat.S_IFREG)
+ self.assertEqual(self.statmod.S_IFMT(st_mode),
+ self.statmod.S_IFREG)
def test_directory(self):
os.mkdir(TESTFN)
@@ -165,25 +169,34 @@ class TestFilemode(unittest.TestCase):
def test_module_attributes(self):
for key, value in self.stat_struct.items():
- modvalue = getattr(stat, key)
+ modvalue = getattr(self.statmod, key)
self.assertEqual(value, modvalue, key)
for key, value in self.permission_bits.items():
- modvalue = getattr(stat, key)
+ modvalue = getattr(self.statmod, key)
self.assertEqual(value, modvalue, key)
for key in self.file_flags:
- modvalue = getattr(stat, key)
+ modvalue = getattr(self.statmod, key)
self.assertIsInstance(modvalue, int)
for key in self.formats:
- modvalue = getattr(stat, key)
+ modvalue = getattr(self.statmod, key)
self.assertIsInstance(modvalue, int)
for key in self.format_funcs:
- func = getattr(stat, key)
+ func = getattr(self.statmod, key)
self.assertTrue(callable(func))
self.assertEqual(func(0), 0)
-def test_main():
- run_unittest(TestFilemode)
+class TestFilemodeCStat(TestFilemode, unittest.TestCase):
+ statmod = c_stat
+
+ formats = TestFilemode.formats | {'S_IFDOOR', 'S_IFPORT', 'S_IFWHT'}
+ format_funcs = TestFilemode.format_funcs | {'S_ISDOOR', 'S_ISPORT',
+ 'S_ISWHT'}
+
+
+class TestFilemodePyStat(TestFilemode, unittest.TestCase):
+ statmod = py_stat
+
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
new file mode 100644
index 0000000000..71b3654526
--- /dev/null
+++ b/Lib/test/test_statistics.py
@@ -0,0 +1,1534 @@
+"""Test suite for statistics module, including helper NumericTestCase and
+approx_equal function.
+
+"""
+
+import collections
+import decimal
+import doctest
+import math
+import random
+import types
+import unittest
+
+from decimal import Decimal
+from fractions import Fraction
+
+
+# Module to be tested.
+import statistics
+
+
+# === Helper functions and class ===
+
+def _calc_errors(actual, expected):
+ """Return the absolute and relative errors between two numbers.
+
+ >>> _calc_errors(100, 75)
+ (25, 0.25)
+ >>> _calc_errors(100, 100)
+ (0, 0.0)
+
+ Returns the (absolute error, relative error) between the two arguments.
+ """
+ base = max(abs(actual), abs(expected))
+ abs_err = abs(actual - expected)
+ rel_err = abs_err/base if base else float('inf')
+ return (abs_err, rel_err)
+
+
+def approx_equal(x, y, tol=1e-12, rel=1e-7):
+ """approx_equal(x, y [, tol [, rel]]) => True|False
+
+ Return True if numbers x and y are approximately equal, to within some
+ margin of error, otherwise return False. Numbers which compare equal
+ will also compare approximately equal.
+
+ x is approximately equal to y if the difference between them is less than
+ an absolute error tol or a relative error rel, whichever is bigger.
+
+ If given, both tol and rel must be finite, non-negative numbers. If not
+ given, default values are tol=1e-12 and rel=1e-7.
+
+ >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
+ True
+ >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
+ False
+
+ Absolute error is defined as abs(x-y); if that is less than or equal to
+ tol, x and y are considered approximately equal.
+
+ Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
+ smaller, provided x or y are not zero. If that figure is less than or
+ equal to rel, x and y are considered approximately equal.
+
+ Complex numbers are not directly supported. If you wish to compare to
+ complex numbers, extract their real and imaginary parts and compare them
+ individually.
+
+ NANs always compare unequal, even with themselves. Infinities compare
+ approximately equal if they have the same sign (both positive or both
+ negative). Infinities with different signs compare unequal; so do
+ comparisons of infinities with finite numbers.
+ """
+ if tol < 0 or rel < 0:
+ raise ValueError('error tolerances must be non-negative')
+ # NANs are never equal to anything, approximately or otherwise.
+ if math.isnan(x) or math.isnan(y):
+ return False
+ # Numbers which compare equal also compare approximately equal.
+ if x == y:
+ # This includes the case of two infinities with the same sign.
+ return True
+ if math.isinf(x) or math.isinf(y):
+ # This includes the case of two infinities of opposite sign, or
+ # one infinity and one finite number.
+ return False
+ # Two finite numbers.
+ actual_error = abs(x - y)
+ allowed_error = max(tol, rel*max(abs(x), abs(y)))
+ return actual_error <= allowed_error
+
+
+# This class exists only as somewhere to stick a docstring containing
+# doctests. The following docstring and tests were originally in a separate
+# module. Now that it has been merged in here, I need somewhere to hang the.
+# docstring. Ultimately, this class will die, and the information below will
+# either become redundant, or be moved into more appropriate places.
+class _DoNothing:
+ """
+ When doing numeric work, especially with floats, exact equality is often
+ not what you want. Due to round-off error, it is often a bad idea to try
+ to compare floats with equality. Instead the usual procedure is to test
+ them with some (hopefully small!) allowance for error.
+
+ The ``approx_equal`` function allows you to specify either an absolute
+ error tolerance, or a relative error, or both.
+
+ Absolute error tolerances are simple, but you need to know the magnitude
+ of the quantities being compared:
+
+ >>> approx_equal(12.345, 12.346, tol=1e-3)
+ True
+ >>> approx_equal(12.345e6, 12.346e6, tol=1e-3) # tol is too small.
+ False
+
+ Relative errors are more suitable when the values you are comparing can
+ vary in magnitude:
+
+ >>> approx_equal(12.345, 12.346, rel=1e-4)
+ True
+ >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
+ True
+
+ but a naive implementation of relative error testing can run into trouble
+ around zero.
+
+ If you supply both an absolute tolerance and a relative error, the
+ comparison succeeds if either individual test succeeds:
+
+ >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
+ True
+
+ """
+ pass
+
+
+
+# We prefer this for testing numeric values that may not be exactly equal,
+# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
+
+class NumericTestCase(unittest.TestCase):
+ """Unit test class for numeric work.
+
+ This subclasses TestCase. In addition to the standard method
+ ``TestCase.assertAlmostEqual``, ``assertApproxEqual`` is provided.
+ """
+ # By default, we expect exact equality, unless overridden.
+ tol = rel = 0
+
+ def assertApproxEqual(
+ self, first, second, tol=None, rel=None, msg=None
+ ):
+ """Test passes if ``first`` and ``second`` are approximately equal.
+
+ This test passes if ``first`` and ``second`` are equal to
+ within ``tol``, an absolute error, or ``rel``, a relative error.
+
+ If either ``tol`` or ``rel`` are None or not given, they default to
+ test attributes of the same name (by default, 0).
+
+ The objects may be either numbers, or sequences of numbers. Sequences
+ are tested element-by-element.
+
+ >>> class MyTest(NumericTestCase):
+ ... def test_number(self):
+ ... x = 1.0/6
+ ... y = sum([x]*6)
+ ... self.assertApproxEqual(y, 1.0, tol=1e-15)
+ ... def test_sequence(self):
+ ... a = [1.001, 1.001e-10, 1.001e10]
+ ... b = [1.0, 1e-10, 1e10]
+ ... self.assertApproxEqual(a, b, rel=1e-3)
+ ...
+ >>> import unittest
+ >>> from io import StringIO # Suppress test runner output.
+ >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
+ >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
+ <unittest.runner.TextTestResult run=2 errors=0 failures=0>
+
+ """
+ if tol is None:
+ tol = self.tol
+ if rel is None:
+ rel = self.rel
+ if (
+ isinstance(first, collections.Sequence) and
+ isinstance(second, collections.Sequence)
+ ):
+ check = self._check_approx_seq
+ else:
+ check = self._check_approx_num
+ check(first, second, tol, rel, msg)
+
+ def _check_approx_seq(self, first, second, tol, rel, msg):
+ if len(first) != len(second):
+ standardMsg = (
+ "sequences differ in length: %d items != %d items"
+ % (len(first), len(second))
+ )
+ msg = self._formatMessage(msg, standardMsg)
+ raise self.failureException(msg)
+ for i, (a,e) in enumerate(zip(first, second)):
+ self._check_approx_num(a, e, tol, rel, msg, i)
+
+ def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
+ if approx_equal(first, second, tol, rel):
+ # Test passes. Return early, we are done.
+ return None
+ # Otherwise we failed.
+ standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
+ msg = self._formatMessage(msg, standardMsg)
+ raise self.failureException(msg)
+
+ @staticmethod
+ def _make_std_err_msg(first, second, tol, rel, idx):
+ # Create the standard error message for approx_equal failures.
+ assert first != second
+ template = (
+ ' %r != %r\n'
+ ' values differ by more than tol=%r and rel=%r\n'
+ ' -> absolute error = %r\n'
+ ' -> relative error = %r'
+ )
+ if idx is not None:
+ header = 'numeric sequences first differ at index %d.\n' % idx
+ template = header + template
+ # Calculate actual errors:
+ abs_err, rel_err = _calc_errors(first, second)
+ return template % (first, second, tol, rel, abs_err, rel_err)
+
+
+# ========================
+# === Test the helpers ===
+# ========================
+
+
+# --- Tests for approx_equal ---
+
+class ApproxEqualSymmetryTest(unittest.TestCase):
+ # Test symmetry of approx_equal.
+
+ def test_relative_symmetry(self):
+ # Check that approx_equal treats relative error symmetrically.
+ # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
+ # doesn't matter.
+ #
+ # Note: the reason for this test is that an early version
+ # of approx_equal was not symmetric. A relative error test
+ # would pass, or fail, depending on which value was passed
+ # as the first argument.
+ #
+ args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
+ args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
+ assert len(args1) == len(args2)
+ for a, b in zip(args1, args2):
+ self.do_relative_symmetry(a, b)
+
+ def do_relative_symmetry(self, a, b):
+ a, b = min(a, b), max(a, b)
+ assert a < b
+ delta = b - a # The absolute difference between the values.
+ rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
+ # Choose an error margin halfway between the two.
+ rel = (rel_err1 + rel_err2)/2
+ # Now see that values a and b compare approx equal regardless of
+ # which is given first.
+ self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
+ self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
+
+ def test_symmetry(self):
+ # Test that approx_equal(a, b) == approx_equal(b, a)
+ args = [-23, -2, 5, 107, 93568]
+ delta = 2
+ for x in args:
+ for type_ in (int, float, Decimal, Fraction):
+ x = type_(x)*100
+ y = x + delta
+ r = abs(delta/max(x, y))
+ # There are five cases to check:
+ # 1) actual error <= tol, <= rel
+ self.do_symmetry_test(x, y, tol=delta, rel=r)
+ self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
+ # 2) actual error > tol, > rel
+ self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
+ # 3) actual error <= tol, > rel
+ self.do_symmetry_test(x, y, tol=delta, rel=r/2)
+ # 4) actual error > tol, <= rel
+ self.do_symmetry_test(x, y, tol=delta-1, rel=r)
+ self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
+ # 5) exact equality test
+ self.do_symmetry_test(x, x, tol=0, rel=0)
+ self.do_symmetry_test(x, y, tol=0, rel=0)
+
+ def do_symmetry_test(self, a, b, tol, rel):
+ template = "approx_equal comparisons don't match for %r"
+ flag1 = approx_equal(a, b, tol, rel)
+ flag2 = approx_equal(b, a, tol, rel)
+ self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
+
+
+class ApproxEqualExactTest(unittest.TestCase):
+ # Test the approx_equal function with exactly equal values.
+ # Equal values should compare as approximately equal.
+ # Test cases for exactly equal values, which should compare approx
+ # equal regardless of the error tolerances given.
+
+ def do_exactly_equal_test(self, x, tol, rel):
+ result = approx_equal(x, x, tol=tol, rel=rel)
+ self.assertTrue(result, 'equality failure for x=%r' % x)
+ result = approx_equal(-x, -x, tol=tol, rel=rel)
+ self.assertTrue(result, 'equality failure for x=%r' % -x)
+
+ def test_exactly_equal_ints(self):
+ # Test that equal int values are exactly equal.
+ for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
+ self.do_exactly_equal_test(n, 0, 0)
+
+ def test_exactly_equal_floats(self):
+ # Test that equal float values are exactly equal.
+ for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
+ self.do_exactly_equal_test(x, 0, 0)
+
+ def test_exactly_equal_fractions(self):
+ # Test that equal Fraction values are exactly equal.
+ F = Fraction
+ for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
+ self.do_exactly_equal_test(f, 0, 0)
+
+ def test_exactly_equal_decimals(self):
+ # Test that equal Decimal values are exactly equal.
+ D = Decimal
+ for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
+ self.do_exactly_equal_test(d, 0, 0)
+
+ def test_exactly_equal_absolute(self):
+ # Test that equal values are exactly equal with an absolute error.
+ for n in [16, 1013, 1372, 1198, 971, 4]:
+ # Test as ints.
+ self.do_exactly_equal_test(n, 0.01, 0)
+ # Test as floats.
+ self.do_exactly_equal_test(n/10, 0.01, 0)
+ # Test as Fractions.
+ f = Fraction(n, 1234)
+ self.do_exactly_equal_test(f, 0.01, 0)
+
+ def test_exactly_equal_absolute_decimals(self):
+ # Test equal Decimal values are exactly equal with an absolute error.
+ self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
+ self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
+
+ def test_exactly_equal_relative(self):
+ # Test that equal values are exactly equal with a relative error.
+ for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
+ self.do_exactly_equal_test(x, 0, 0.01)
+ self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
+
+ def test_exactly_equal_both(self):
+ # Test that equal values are equal when both tol and rel are given.
+ for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
+ self.do_exactly_equal_test(x, 0.1, 0.01)
+ D = Decimal
+ self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
+
+
+class ApproxEqualUnequalTest(unittest.TestCase):
+ # Unequal values should compare unequal with zero error tolerances.
+ # Test cases for unequal values, with exact equality test.
+
+ def do_exactly_unequal_test(self, x):
+ for a in (x, -x):
+ result = approx_equal(a, a+1, tol=0, rel=0)
+ self.assertFalse(result, 'inequality failure for x=%r' % a)
+
+ def test_exactly_unequal_ints(self):
+ # Test unequal int values are unequal with zero error tolerance.
+ for n in [951, 572305, 478, 917, 17240]:
+ self.do_exactly_unequal_test(n)
+
+ def test_exactly_unequal_floats(self):
+ # Test unequal float values are unequal with zero error tolerance.
+ for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
+ self.do_exactly_unequal_test(x)
+
+ def test_exactly_unequal_fractions(self):
+ # Test that unequal Fractions are unequal with zero error tolerance.
+ F = Fraction
+ for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
+ self.do_exactly_unequal_test(f)
+
+ def test_exactly_unequal_decimals(self):
+ # Test that unequal Decimals are unequal with zero error tolerance.
+ for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
+ self.do_exactly_unequal_test(d)
+
+
+class ApproxEqualInexactTest(unittest.TestCase):
+ # Inexact test cases for approx_error.
+ # Test cases when comparing two values that are not exactly equal.
+
+ # === Absolute error tests ===
+
+ def do_approx_equal_abs_test(self, x, delta):
+ template = "Test failure for x={!r}, y={!r}"
+ for y in (x + delta, x - delta):
+ msg = template.format(x, y)
+ self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
+ self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
+
+ def test_approx_equal_absolute_ints(self):
+ # Test approximate equality of ints with an absolute error.
+ for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
+ self.do_approx_equal_abs_test(n, 10)
+ self.do_approx_equal_abs_test(n, 2)
+
+ def test_approx_equal_absolute_floats(self):
+ # Test approximate equality of floats with an absolute error.
+ for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
+ self.do_approx_equal_abs_test(x, 1.5)
+ self.do_approx_equal_abs_test(x, 0.01)
+ self.do_approx_equal_abs_test(x, 0.0001)
+
+ def test_approx_equal_absolute_fractions(self):
+ # Test approximate equality of Fractions with an absolute error.
+ delta = Fraction(1, 29)
+ numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
+ for f in (Fraction(n, 29) for n in numerators):
+ self.do_approx_equal_abs_test(f, delta)
+ self.do_approx_equal_abs_test(f, float(delta))
+
+ def test_approx_equal_absolute_decimals(self):
+ # Test approximate equality of Decimals with an absolute error.
+ delta = Decimal("0.01")
+ for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
+ self.do_approx_equal_abs_test(d, delta)
+ self.do_approx_equal_abs_test(-d, delta)
+
+ def test_cross_zero(self):
+ # Test for the case of the two values having opposite signs.
+ self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
+
+ # === Relative error tests ===
+
+ def do_approx_equal_rel_test(self, x, delta):
+ template = "Test failure for x={!r}, y={!r}"
+ for y in (x*(1+delta), x*(1-delta)):
+ msg = template.format(x, y)
+ self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
+ self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
+
+ def test_approx_equal_relative_ints(self):
+ # Test approximate equality of ints with a relative error.
+ self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
+ self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
+ # ---
+ self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
+ self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
+ self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
+
+ def test_approx_equal_relative_floats(self):
+ # Test approximate equality of floats with a relative error.
+ for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
+ self.do_approx_equal_rel_test(x, 0.02)
+ self.do_approx_equal_rel_test(x, 0.0001)
+
+ def test_approx_equal_relative_fractions(self):
+ # Test approximate equality of Fractions with a relative error.
+ F = Fraction
+ delta = Fraction(3, 8)
+ for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
+ for d in (delta, float(delta)):
+ self.do_approx_equal_rel_test(f, d)
+ self.do_approx_equal_rel_test(-f, d)
+
+ def test_approx_equal_relative_decimals(self):
+ # Test approximate equality of Decimals with a relative error.
+ for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
+ self.do_approx_equal_rel_test(d, Decimal("0.001"))
+ self.do_approx_equal_rel_test(-d, Decimal("0.05"))
+
+ # === Both absolute and relative error tests ===
+
+ # There are four cases to consider:
+ # 1) actual error <= both absolute and relative error
+ # 2) actual error <= absolute error but > relative error
+ # 3) actual error <= relative error but > absolute error
+ # 4) actual error > both absolute and relative error
+
+ def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
+ check = self.assertTrue if tol_flag else self.assertFalse
+ check(approx_equal(a, b, tol=tol, rel=0))
+ check = self.assertTrue if rel_flag else self.assertFalse
+ check(approx_equal(a, b, tol=0, rel=rel))
+ check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
+ check(approx_equal(a, b, tol=tol, rel=rel))
+
+ def test_approx_equal_both1(self):
+ # Test actual error <= both absolute and relative error.
+ self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
+ self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
+
+ def test_approx_equal_both2(self):
+ # Test actual error <= absolute error but > relative error.
+ self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
+
+ def test_approx_equal_both3(self):
+ # Test actual error <= relative error but > absolute error.
+ self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
+
+ def test_approx_equal_both4(self):
+ # Test actual error > both absolute and relative error.
+ self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
+ self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
+
+
+class ApproxEqualSpecialsTest(unittest.TestCase):
+ # Test approx_equal with NANs and INFs and zeroes.
+
+ def test_inf(self):
+ for type_ in (float, Decimal):
+ inf = type_('inf')
+ self.assertTrue(approx_equal(inf, inf))
+ self.assertTrue(approx_equal(inf, inf, 0, 0))
+ self.assertTrue(approx_equal(inf, inf, 1, 0.01))
+ self.assertTrue(approx_equal(-inf, -inf))
+ self.assertFalse(approx_equal(inf, -inf))
+ self.assertFalse(approx_equal(inf, 1000))
+
+ def test_nan(self):
+ for type_ in (float, Decimal):
+ nan = type_('nan')
+ for other in (nan, type_('inf'), 1000):
+ self.assertFalse(approx_equal(nan, other))
+
+ def test_float_zeroes(self):
+ nzero = math.copysign(0.0, -1)
+ self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
+
+ def test_decimal_zeroes(self):
+ nzero = Decimal("-0.0")
+ self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
+
+
+class TestApproxEqualErrors(unittest.TestCase):
+ # Test error conditions of approx_equal.
+
+ def test_bad_tol(self):
+ # Test negative tol raises.
+ self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
+
+ def test_bad_rel(self):
+ # Test negative rel raises.
+ self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
+
+
+# --- Tests for NumericTestCase ---
+
+# The formatting routine that generates the error messages is complex enough
+# that it too needs testing.
+
+class TestNumericTestCase(unittest.TestCase):
+ # The exact wording of NumericTestCase error messages is *not* guaranteed,
+ # but we need to give them some sort of test to ensure that they are
+ # generated correctly. As a compromise, we look for specific substrings
+ # that are expected to be found even if the overall error message changes.
+
+ def do_test(self, args):
+ actual_msg = NumericTestCase._make_std_err_msg(*args)
+ expected = self.generate_substrings(*args)
+ for substring in expected:
+ self.assertIn(substring, actual_msg)
+
+ def test_numerictestcase_is_testcase(self):
+ # Ensure that NumericTestCase actually is a TestCase.
+ self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
+
+ def test_error_msg_numeric(self):
+ # Test the error message generated for numeric comparisons.
+ args = (2.5, 4.0, 0.5, 0.25, None)
+ self.do_test(args)
+
+ def test_error_msg_sequence(self):
+ # Test the error message generated for sequence comparisons.
+ args = (3.75, 8.25, 1.25, 0.5, 7)
+ self.do_test(args)
+
+ def generate_substrings(self, first, second, tol, rel, idx):
+ """Return substrings we expect to see in error messages."""
+ abs_err, rel_err = _calc_errors(first, second)
+ substrings = [
+ 'tol=%r' % tol,
+ 'rel=%r' % rel,
+ 'absolute error = %r' % abs_err,
+ 'relative error = %r' % rel_err,
+ ]
+ if idx is not None:
+ substrings.append('differ at index %d' % idx)
+ return substrings
+
+
+# =======================================
+# === Tests for the statistics module ===
+# =======================================
+
+
+class GlobalsTest(unittest.TestCase):
+ module = statistics
+ expected_metadata = ["__doc__", "__all__"]
+
+ def test_meta(self):
+ # Test for the existence of metadata.
+ for meta in self.expected_metadata:
+ self.assertTrue(hasattr(self.module, meta),
+ "%s not present" % meta)
+
+ def test_check_all(self):
+ # Check everything in __all__ exists and is public.
+ module = self.module
+ for name in module.__all__:
+ # No private names in __all__:
+ self.assertFalse(name.startswith("_"),
+ 'private name "%s" in __all__' % name)
+ # And anything in __all__ must exist:
+ self.assertTrue(hasattr(module, name),
+ 'missing name "%s" in __all__' % name)
+
+
+class DocTests(unittest.TestCase):
+ def test_doc_tests(self):
+ failed, tried = doctest.testmod(statistics)
+ self.assertGreater(tried, 0)
+ self.assertEqual(failed, 0)
+
+class StatisticsErrorTest(unittest.TestCase):
+ def test_has_exception(self):
+ errmsg = (
+ "Expected StatisticsError to be a ValueError, but got a"
+ " subclass of %r instead."
+ )
+ self.assertTrue(hasattr(statistics, 'StatisticsError'))
+ self.assertTrue(
+ issubclass(statistics.StatisticsError, ValueError),
+ errmsg % statistics.StatisticsError.__base__
+ )
+
+
+# === Tests for private utility functions ===
+
+class ExactRatioTest(unittest.TestCase):
+ # Test _exact_ratio utility.
+
+ def test_int(self):
+ for i in (-20, -3, 0, 5, 99, 10**20):
+ self.assertEqual(statistics._exact_ratio(i), (i, 1))
+
+ def test_fraction(self):
+ numerators = (-5, 1, 12, 38)
+ for n in numerators:
+ f = Fraction(n, 37)
+ self.assertEqual(statistics._exact_ratio(f), (n, 37))
+
+ def test_float(self):
+ self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
+ self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
+ data = [random.uniform(-100, 100) for _ in range(100)]
+ for x in data:
+ num, den = statistics._exact_ratio(x)
+ self.assertEqual(x, num/den)
+
+ def test_decimal(self):
+ D = Decimal
+ _exact_ratio = statistics._exact_ratio
+ self.assertEqual(_exact_ratio(D("0.125")), (125, 1000))
+ self.assertEqual(_exact_ratio(D("12.345")), (12345, 1000))
+ self.assertEqual(_exact_ratio(D("-1.98")), (-198, 100))
+
+
+class DecimalToRatioTest(unittest.TestCase):
+ # Test _decimal_to_ratio private function.
+
+ def testSpecialsRaise(self):
+ # Test that NANs and INFs raise ValueError.
+ # Non-special values are covered by _exact_ratio above.
+ for d in (Decimal('NAN'), Decimal('sNAN'), Decimal('INF')):
+ self.assertRaises(ValueError, statistics._decimal_to_ratio, d)
+
+
+
+# === Tests for public functions ===
+
+class UnivariateCommonMixin:
+ # Common tests for most univariate functions that take a data argument.
+
+ def test_no_args(self):
+ # Fail if given no arguments.
+ self.assertRaises(TypeError, self.func)
+
+ def test_empty_data(self):
+ # Fail when the data argument (first argument) is empty.
+ for empty in ([], (), iter([])):
+ self.assertRaises(statistics.StatisticsError, self.func, empty)
+
+ def prepare_data(self):
+ """Return int data for various tests."""
+ data = list(range(10))
+ while data == sorted(data):
+ random.shuffle(data)
+ return data
+
+ def test_no_inplace_modifications(self):
+ # Test that the function does not modify its input data.
+ data = self.prepare_data()
+ assert len(data) != 1 # Necessary to avoid infinite loop.
+ assert data != sorted(data)
+ saved = data[:]
+ assert data is not saved
+ _ = self.func(data)
+ self.assertListEqual(data, saved, "data has been modified")
+
+ def test_order_doesnt_matter(self):
+ # Test that the order of data points doesn't change the result.
+
+ # CAUTION: due to floating point rounding errors, the result actually
+ # may depend on the order. Consider this test representing an ideal.
+ # To avoid this test failing, only test with exact values such as ints
+ # or Fractions.
+ data = [1, 2, 3, 3, 3, 4, 5, 6]*100
+ expected = self.func(data)
+ random.shuffle(data)
+ actual = self.func(data)
+ self.assertEqual(expected, actual)
+
+ def test_type_of_data_collection(self):
+ # Test that the type of iterable data doesn't effect the result.
+ class MyList(list):
+ pass
+ class MyTuple(tuple):
+ pass
+ def generator(data):
+ return (obj for obj in data)
+ data = self.prepare_data()
+ expected = self.func(data)
+ for kind in (list, tuple, iter, MyList, MyTuple, generator):
+ result = self.func(kind(data))
+ self.assertEqual(result, expected)
+
+ def test_range_data(self):
+ # Test that functions work with range objects.
+ data = range(20, 50, 3)
+ expected = self.func(list(data))
+ self.assertEqual(self.func(data), expected)
+
+ def test_bad_arg_types(self):
+ # Test that function raises when given data of the wrong type.
+
+ # Don't roll the following into a loop like this:
+ # for bad in list_of_bad:
+ # self.check_for_type_error(bad)
+ #
+ # Since assertRaises doesn't show the arguments that caused the test
+ # failure, it is very difficult to debug these test failures when the
+ # following are in a loop.
+ self.check_for_type_error(None)
+ self.check_for_type_error(23)
+ self.check_for_type_error(42.0)
+ self.check_for_type_error(object())
+
+ def check_for_type_error(self, *args):
+ self.assertRaises(TypeError, self.func, *args)
+
+ def test_type_of_data_element(self):
+ # Check the type of data elements doesn't affect the numeric result.
+ # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
+ # because it checks the numeric result by equality, but not by type.
+ class MyFloat(float):
+ def __truediv__(self, other):
+ return type(self)(super().__truediv__(other))
+ def __add__(self, other):
+ return type(self)(super().__add__(other))
+ __radd__ = __add__
+
+ raw = self.prepare_data()
+ expected = self.func(raw)
+ for kind in (float, MyFloat, Decimal, Fraction):
+ data = [kind(x) for x in raw]
+ result = type(expected)(self.func(data))
+ self.assertEqual(result, expected)
+
+
+class UnivariateTypeMixin:
+ """Mixin class for type-conserving functions.
+
+ This mixin class holds test(s) for functions which conserve the type of
+ individual data points. E.g. the mean of a list of Fractions should itself
+ be a Fraction.
+
+ Not all tests to do with types need go in this class. Only those that
+ rely on the function returning the same type as its input data.
+ """
+ def test_types_conserved(self):
+ # Test that functions keeps the same type as their data points.
+ # (Excludes mixed data types.) This only tests the type of the return
+ # result, not the value.
+ class MyFloat(float):
+ def __truediv__(self, other):
+ return type(self)(super().__truediv__(other))
+ def __sub__(self, other):
+ return type(self)(super().__sub__(other))
+ def __rsub__(self, other):
+ return type(self)(super().__rsub__(other))
+ def __pow__(self, other):
+ return type(self)(super().__pow__(other))
+ def __add__(self, other):
+ return type(self)(super().__add__(other))
+ __radd__ = __add__
+
+ data = self.prepare_data()
+ for kind in (float, Decimal, Fraction, MyFloat):
+ d = [kind(x) for x in data]
+ result = self.func(d)
+ self.assertIs(type(result), kind)
+
+
+class TestSum(NumericTestCase, UnivariateCommonMixin, UnivariateTypeMixin):
+ # Test cases for statistics._sum() function.
+
+ def setUp(self):
+ self.func = statistics._sum
+
+ def test_empty_data(self):
+ # Override test for empty data.
+ for data in ([], (), iter([])):
+ self.assertEqual(self.func(data), 0)
+ self.assertEqual(self.func(data, 23), 23)
+ self.assertEqual(self.func(data, 2.3), 2.3)
+
+ def test_ints(self):
+ self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]), 60)
+ self.assertEqual(self.func([4, 2, 3, -8, 7], 1000), 1008)
+
+ def test_floats(self):
+ self.assertEqual(self.func([0.25]*20), 5.0)
+ self.assertEqual(self.func([0.125, 0.25, 0.5, 0.75], 1.5), 3.125)
+
+ def test_fractions(self):
+ F = Fraction
+ self.assertEqual(self.func([Fraction(1, 1000)]*500), Fraction(1, 2))
+
+ def test_decimals(self):
+ D = Decimal
+ data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
+ D("3.974"), D("2.328"), D("4.617"), D("2.843"),
+ ]
+ self.assertEqual(self.func(data), Decimal("20.686"))
+
+ def test_compare_with_math_fsum(self):
+ # Compare with the math.fsum function.
+ # Ideally we ought to get the exact same result, but sometimes
+ # we differ by a very slight amount :-(
+ data = [random.uniform(-100, 1000) for _ in range(1000)]
+ self.assertApproxEqual(self.func(data), math.fsum(data), rel=2e-16)
+
+ def test_start_argument(self):
+ # Test that the optional start argument works correctly.
+ data = [random.uniform(1, 1000) for _ in range(100)]
+ t = self.func(data)
+ self.assertEqual(t+42, self.func(data, 42))
+ self.assertEqual(t-23, self.func(data, -23))
+ self.assertEqual(t+1e20, self.func(data, 1e20))
+
+ def test_strings_fail(self):
+ # Sum of strings should fail.
+ self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
+ self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
+
+ def test_bytes_fail(self):
+ # Sum of bytes should fail.
+ self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
+ self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
+
+ def test_mixed_sum(self):
+ # Mixed sums are allowed.
+
+ # Careful here: order matters. Can't mix Fraction and Decimal directly,
+ # only after they're converted to float.
+ data = [1, 2, Fraction(1, 2), 3.0, Decimal("0.25")]
+ self.assertEqual(self.func(data), 6.75)
+
+
+class SumInternalsTest(NumericTestCase):
+ # Test internals of the sum function.
+
+ def test_ignore_instance_float_method(self):
+ # Test that __float__ methods on data instances are ignored.
+
+ # Python typically calls __dunder__ methods on the class, not the
+ # instance. The ``sum`` implementation calls __float__ directly. To
+ # better match the behaviour of Python, we call it only on the class,
+ # not the instance. This test will fail if somebody "fixes" that code.
+
+ # Create a fake __float__ method.
+ def __float__(self):
+ raise AssertionError('test fails')
+
+ # Inject it into an instance.
+ class MyNumber(Fraction):
+ pass
+ x = MyNumber(3)
+ x.__float__ = types.MethodType(__float__, x)
+
+ # Check it works as expected.
+ self.assertRaises(AssertionError, x.__float__)
+ self.assertEqual(float(x), 3.0)
+ # And now test the function.
+ self.assertEqual(statistics._sum([1.0, 2.0, x, 4.0]), 10.0)
+
+
+class SumTortureTest(NumericTestCase):
+ def test_torture(self):
+ # Tim Peters' torture test for sum, and variants of same.
+ self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000), 20000.0)
+ self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000), 20000.0)
+ self.assertApproxEqual(
+ statistics._sum([1e-100, 1, 1e-100, -1]*10000), 2.0e-96, rel=5e-16
+ )
+
+
+class SumSpecialValues(NumericTestCase):
+ # Test that sum works correctly with IEEE-754 special values.
+
+ def test_nan(self):
+ for type_ in (float, Decimal):
+ nan = type_('nan')
+ result = statistics._sum([1, nan, 2])
+ self.assertIs(type(result), type_)
+ self.assertTrue(math.isnan(result))
+
+ def check_infinity(self, x, inf):
+ """Check x is an infinity of the same type and sign as inf."""
+ self.assertTrue(math.isinf(x))
+ self.assertIs(type(x), type(inf))
+ self.assertEqual(x > 0, inf > 0)
+ assert x == inf
+
+ def do_test_inf(self, inf):
+ # Adding a single infinity gives infinity.
+ result = statistics._sum([1, 2, inf, 3])
+ self.check_infinity(result, inf)
+ # Adding two infinities of the same sign also gives infinity.
+ result = statistics._sum([1, 2, inf, 3, inf, 4])
+ self.check_infinity(result, inf)
+
+ def test_float_inf(self):
+ inf = float('inf')
+ for sign in (+1, -1):
+ self.do_test_inf(sign*inf)
+
+ def test_decimal_inf(self):
+ inf = Decimal('inf')
+ for sign in (+1, -1):
+ self.do_test_inf(sign*inf)
+
+ def test_float_mismatched_infs(self):
+ # Test that adding two infinities of opposite sign gives a NAN.
+ inf = float('inf')
+ result = statistics._sum([1, 2, inf, 3, -inf, 4])
+ self.assertTrue(math.isnan(result))
+
+ def test_decimal_mismatched_infs_to_nan(self):
+ # Test adding Decimal INFs with opposite sign returns NAN.
+ inf = Decimal('inf')
+ data = [1, 2, inf, 3, -inf, 4]
+ with decimal.localcontext(decimal.ExtendedContext):
+ self.assertTrue(math.isnan(statistics._sum(data)))
+
+ def test_decimal_mismatched_infs_to_nan(self):
+ # Test adding Decimal INFs with opposite sign raises InvalidOperation.
+ inf = Decimal('inf')
+ data = [1, 2, inf, 3, -inf, 4]
+ with decimal.localcontext(decimal.BasicContext):
+ self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
+
+ def test_decimal_snan_raises(self):
+ # Adding sNAN should raise InvalidOperation.
+ sNAN = Decimal('sNAN')
+ data = [1, sNAN, 2]
+ self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
+
+
+# === Tests for averages ===
+
+class AverageMixin(UnivariateCommonMixin):
+ # Mixin class holding common tests for averages.
+
+ def test_single_value(self):
+ # Average of a single value is the value itself.
+ for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
+ self.assertEqual(self.func([x]), x)
+
+ def test_repeated_single_value(self):
+ # The average of a single repeated value is the value itself.
+ for x in (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712')):
+ for count in (2, 5, 10, 20):
+ data = [x]*count
+ self.assertEqual(self.func(data), x)
+
+
+class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
+ def setUp(self):
+ self.func = statistics.mean
+
+ def test_torture_pep(self):
+ # "Torture Test" from PEP-450.
+ self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
+
+ def test_ints(self):
+ # Test mean with ints.
+ data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 4.8125)
+
+ def test_floats(self):
+ # Test mean with floats.
+ data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 22.015625)
+
+ def test_decimals(self):
+ # Test mean with ints.
+ D = Decimal
+ data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
+ random.shuffle(data)
+ self.assertEqual(self.func(data), D("3.5896"))
+
+ def test_fractions(self):
+ # Test mean with Fractions.
+ F = Fraction
+ data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
+ random.shuffle(data)
+ self.assertEqual(self.func(data), F(1479, 1960))
+
+ def test_inf(self):
+ # Test mean with infinities.
+ raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later.
+ for kind in (float, Decimal):
+ for sign in (1, -1):
+ inf = kind("inf")*sign
+ data = raw + [inf]
+ result = self.func(data)
+ self.assertTrue(math.isinf(result))
+ self.assertEqual(result, inf)
+
+ def test_mismatched_infs(self):
+ # Test mean with infinities of opposite sign.
+ data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
+ result = self.func(data)
+ self.assertTrue(math.isnan(result))
+
+ def test_nan(self):
+ # Test mean with NANs.
+ raw = [1, 3, 5, 7, 9] # Use only ints, to avoid TypeError later.
+ for kind in (float, Decimal):
+ inf = kind("nan")
+ data = raw + [inf]
+ result = self.func(data)
+ self.assertTrue(math.isnan(result))
+
+ def test_big_data(self):
+ # Test adding a large constant to every data point.
+ c = 1e9
+ data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
+ expected = self.func(data) + c
+ assert expected != c
+ result = self.func([x+c for x in data])
+ self.assertEqual(result, expected)
+
+ def test_doubled_data(self):
+ # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
+ data = [random.uniform(-3, 5) for _ in range(1000)]
+ expected = self.func(data)
+ actual = self.func(data*2)
+ self.assertApproxEqual(actual, expected)
+
+
+class TestMedian(NumericTestCase, AverageMixin):
+ # Common tests for median and all median.* functions.
+ def setUp(self):
+ self.func = statistics.median
+
+ def prepare_data(self):
+ """Overload method from UnivariateCommonMixin."""
+ data = super().prepare_data()
+ if len(data)%2 != 1:
+ data.append(2)
+ return data
+
+ def test_even_ints(self):
+ # Test median with an even number of int data points.
+ data = [1, 2, 3, 4, 5, 6]
+ assert len(data)%2 == 0
+ self.assertEqual(self.func(data), 3.5)
+
+ def test_odd_ints(self):
+ # Test median with an odd number of int data points.
+ data = [1, 2, 3, 4, 5, 6, 9]
+ assert len(data)%2 == 1
+ self.assertEqual(self.func(data), 4)
+
+ def test_odd_fractions(self):
+ # Test median works with an odd number of Fractions.
+ F = Fraction
+ data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
+ assert len(data)%2 == 1
+ random.shuffle(data)
+ self.assertEqual(self.func(data), F(3, 7))
+
+ def test_even_fractions(self):
+ # Test median works with an even number of Fractions.
+ F = Fraction
+ data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), F(1, 2))
+
+ def test_odd_decimals(self):
+ # Test median works with an odd number of Decimals.
+ D = Decimal
+ data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
+ assert len(data)%2 == 1
+ random.shuffle(data)
+ self.assertEqual(self.func(data), D('4.2'))
+
+ def test_even_decimals(self):
+ # Test median works with an even number of Decimals.
+ D = Decimal
+ data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), D('3.65'))
+
+
+class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
+ # Test conservation of data element type for median.
+ def setUp(self):
+ self.func = statistics.median
+
+ def prepare_data(self):
+ data = list(range(15))
+ assert len(data)%2 == 1
+ while data == sorted(data):
+ random.shuffle(data)
+ return data
+
+
+class TestMedianLow(TestMedian, UnivariateTypeMixin):
+ def setUp(self):
+ self.func = statistics.median_low
+
+ def test_even_ints(self):
+ # Test median_low with an even number of ints.
+ data = [1, 2, 3, 4, 5, 6]
+ assert len(data)%2 == 0
+ self.assertEqual(self.func(data), 3)
+
+ def test_even_fractions(self):
+ # Test median_low works with an even number of Fractions.
+ F = Fraction
+ data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), F(3, 7))
+
+ def test_even_decimals(self):
+ # Test median_low works with an even number of Decimals.
+ D = Decimal
+ data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), D('3.3'))
+
+
+class TestMedianHigh(TestMedian, UnivariateTypeMixin):
+ def setUp(self):
+ self.func = statistics.median_high
+
+ def test_even_ints(self):
+ # Test median_high with an even number of ints.
+ data = [1, 2, 3, 4, 5, 6]
+ assert len(data)%2 == 0
+ self.assertEqual(self.func(data), 4)
+
+ def test_even_fractions(self):
+ # Test median_high works with an even number of Fractions.
+ F = Fraction
+ data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), F(4, 7))
+
+ def test_even_decimals(self):
+ # Test median_high works with an even number of Decimals.
+ D = Decimal
+ data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), D('4.4'))
+
+
+class TestMedianGrouped(TestMedian):
+ # Test median_grouped.
+ # Doesn't conserve data element types, so don't use TestMedianType.
+ def setUp(self):
+ self.func = statistics.median_grouped
+
+ def test_odd_number_repeated(self):
+ # Test median.grouped with repeated median values.
+ data = [12, 13, 14, 14, 14, 15, 15]
+ assert len(data)%2 == 1
+ self.assertEqual(self.func(data), 14)
+ #---
+ data = [12, 13, 14, 14, 14, 14, 15]
+ assert len(data)%2 == 1
+ self.assertEqual(self.func(data), 13.875)
+ #---
+ data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
+ assert len(data)%2 == 1
+ self.assertEqual(self.func(data, 5), 19.375)
+ #---
+ data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
+ assert len(data)%2 == 1
+ self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
+
+ def test_even_number_repeated(self):
+ # Test median.grouped with repeated median values.
+ data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
+ assert len(data)%2 == 0
+ self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
+ #---
+ data = [2, 3, 4, 4, 4, 5]
+ assert len(data)%2 == 0
+ self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
+ #---
+ data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
+ assert len(data)%2 == 0
+ self.assertEqual(self.func(data), 4.5)
+ #---
+ data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
+ assert len(data)%2 == 0
+ self.assertEqual(self.func(data), 4.75)
+
+ def test_repeated_single_value(self):
+ # Override method from AverageMixin.
+ # Yet again, failure of median_grouped to conserve the data type
+ # causes me headaches :-(
+ for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
+ for count in (2, 5, 10, 20):
+ data = [x]*count
+ self.assertEqual(self.func(data), float(x))
+
+ def test_odd_fractions(self):
+ # Test median_grouped works with an odd number of Fractions.
+ F = Fraction
+ data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
+ assert len(data)%2 == 1
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 3.0)
+
+ def test_even_fractions(self):
+ # Test median_grouped works with an even number of Fractions.
+ F = Fraction
+ data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 3.25)
+
+ def test_odd_decimals(self):
+ # Test median_grouped works with an odd number of Decimals.
+ D = Decimal
+ data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
+ assert len(data)%2 == 1
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 6.75)
+
+ def test_even_decimals(self):
+ # Test median_grouped works with an even number of Decimals.
+ D = Decimal
+ data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 6.5)
+ #---
+ data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
+ assert len(data)%2 == 0
+ random.shuffle(data)
+ self.assertEqual(self.func(data), 7.0)
+
+ def test_interval(self):
+ # Test median_grouped with interval argument.
+ data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
+ self.assertEqual(self.func(data, 0.25), 2.875)
+ data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
+ self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
+ data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
+ self.assertEqual(self.func(data, 20), 265.0)
+
+
+class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
+ # Test cases for the discrete version of mode.
+ def setUp(self):
+ self.func = statistics.mode
+
+ def prepare_data(self):
+ """Overload method from UnivariateCommonMixin."""
+ # Make sure test data has exactly one mode.
+ return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
+
+ def test_range_data(self):
+ # Override test from UnivariateCommonMixin.
+ data = range(20, 50, 3)
+ self.assertRaises(statistics.StatisticsError, self.func, data)
+
+ def test_nominal_data(self):
+ # Test mode with nominal data.
+ data = 'abcbdb'
+ self.assertEqual(self.func(data), 'b')
+ data = 'fe fi fo fum fi fi'.split()
+ self.assertEqual(self.func(data), 'fi')
+
+ def test_discrete_data(self):
+ # Test mode with discrete numeric data.
+ data = list(range(10))
+ for i in range(10):
+ d = data + [i]
+ random.shuffle(d)
+ self.assertEqual(self.func(d), i)
+
+ def test_bimodal_data(self):
+ # Test mode with bimodal data.
+ data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
+ assert data.count(2) == data.count(6) == 4
+ # Check for an exception.
+ self.assertRaises(statistics.StatisticsError, self.func, data)
+
+ def test_unique_data_failure(self):
+ # Test mode exception when data points are all unique.
+ data = list(range(10))
+ self.assertRaises(statistics.StatisticsError, self.func, data)
+
+ def test_none_data(self):
+ # Test that mode raises TypeError if given None as data.
+
+ # This test is necessary because the implementation of mode uses
+ # collections.Counter, which accepts None and returns an empty dict.
+ self.assertRaises(TypeError, self.func, None)
+
+
+# === Tests for variances and standard deviations ===
+
+class VarianceStdevMixin(UnivariateCommonMixin):
+ # Mixin class holding common tests for variance and std dev.
+
+ # Subclasses should inherit from this before NumericTestClass, in order
+ # to see the rel attribute below. See testShiftData for an explanation.
+
+ rel = 1e-12
+
+ def test_single_value(self):
+ # Deviation of a single value is zero.
+ for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
+ self.assertEqual(self.func([x]), 0)
+
+ def test_repeated_single_value(self):
+ # The deviation of a single repeated value is zero.
+ for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
+ for count in (2, 3, 5, 15):
+ data = [x]*count
+ self.assertEqual(self.func(data), 0)
+
+ def test_domain_error_regression(self):
+ # Regression test for a domain error exception.
+ # (Thanks to Geremy Condra.)
+ data = [0.123456789012345]*10000
+ # All the items are identical, so variance should be exactly zero.
+ # We allow some small round-off error, but not much.
+ result = self.func(data)
+ self.assertApproxEqual(result, 0.0, tol=5e-17)
+ self.assertGreaterEqual(result, 0) # A negative result must fail.
+
+ def test_shift_data(self):
+ # Test that shifting the data by a constant amount does not affect
+ # the variance or stdev. Or at least not much.
+
+ # Due to rounding, this test should be considered an ideal. We allow
+ # some tolerance away from "no change at all" by setting tol and/or rel
+ # attributes. Subclasses may set tighter or looser error tolerances.
+ raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
+ expected = self.func(raw)
+ # Don't set shift too high, the bigger it is, the more rounding error.
+ shift = 1e5
+ data = [x + shift for x in raw]
+ self.assertApproxEqual(self.func(data), expected)
+
+ def test_shift_data_exact(self):
+ # Like test_shift_data, but result is always exact.
+ raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
+ assert all(x==int(x) for x in raw)
+ expected = self.func(raw)
+ shift = 10**9
+ data = [x + shift for x in raw]
+ self.assertEqual(self.func(data), expected)
+
+ def test_iter_list_same(self):
+ # Test that iter data and list data give the same result.
+
+ # This is an explicit test that iterators and lists are treated the
+ # same; justification for this test over and above the similar test
+ # in UnivariateCommonMixin is that an earlier design had variance and
+ # friends swap between one- and two-pass algorithms, which would
+ # sometimes give different results.
+ data = [random.uniform(-3, 8) for _ in range(1000)]
+ expected = self.func(data)
+ self.assertEqual(self.func(iter(data)), expected)
+
+
+class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
+ # Tests for population variance.
+ def setUp(self):
+ self.func = statistics.pvariance
+
+ def test_exact_uniform(self):
+ # Test the variance against an exact result for uniform data.
+ data = list(range(10000))
+ random.shuffle(data)
+ expected = (10000**2 - 1)/12 # Exact value.
+ self.assertEqual(self.func(data), expected)
+
+ def test_ints(self):
+ # Test population variance with int data.
+ data = [4, 7, 13, 16]
+ exact = 22.5
+ self.assertEqual(self.func(data), exact)
+
+ def test_fractions(self):
+ # Test population variance with Fraction data.
+ F = Fraction
+ data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
+ exact = F(3, 8)
+ result = self.func(data)
+ self.assertEqual(result, exact)
+ self.assertIsInstance(result, Fraction)
+
+ def test_decimals(self):
+ # Test population variance with Decimal data.
+ D = Decimal
+ data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
+ exact = D('0.096875')
+ result = self.func(data)
+ self.assertEqual(result, exact)
+ self.assertIsInstance(result, Decimal)
+
+
+class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
+ # Tests for sample variance.
+ def setUp(self):
+ self.func = statistics.variance
+
+ def test_single_value(self):
+ # Override method from VarianceStdevMixin.
+ for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
+ self.assertRaises(statistics.StatisticsError, self.func, [x])
+
+ def test_ints(self):
+ # Test sample variance with int data.
+ data = [4, 7, 13, 16]
+ exact = 30
+ self.assertEqual(self.func(data), exact)
+
+ def test_fractions(self):
+ # Test sample variance with Fraction data.
+ F = Fraction
+ data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
+ exact = F(1, 2)
+ result = self.func(data)
+ self.assertEqual(result, exact)
+ self.assertIsInstance(result, Fraction)
+
+ def test_decimals(self):
+ # Test sample variance with Decimal data.
+ D = Decimal
+ data = [D(2), D(2), D(7), D(9)]
+ exact = 4*D('9.5')/D(3)
+ result = self.func(data)
+ self.assertEqual(result, exact)
+ self.assertIsInstance(result, Decimal)
+
+
+class TestPStdev(VarianceStdevMixin, NumericTestCase):
+ # Tests for population standard deviation.
+ def setUp(self):
+ self.func = statistics.pstdev
+
+ def test_compare_to_variance(self):
+ # Test that stdev is, in fact, the square root of variance.
+ data = [random.uniform(-17, 24) for _ in range(1000)]
+ expected = math.sqrt(statistics.pvariance(data))
+ self.assertEqual(self.func(data), expected)
+
+
+class TestStdev(VarianceStdevMixin, NumericTestCase):
+ # Tests for sample standard deviation.
+ def setUp(self):
+ self.func = statistics.stdev
+
+ def test_single_value(self):
+ # Override method from VarianceStdevMixin.
+ for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
+ self.assertRaises(statistics.StatisticsError, self.func, [x])
+
+ def test_compare_to_variance(self):
+ # Test that stdev is, in fact, the square root of variance.
+ data = [random.uniform(-2, 9) for _ in range(1000)]
+ expected = math.sqrt(statistics.variance(data))
+ self.assertEqual(self.func(data), expected)
+
+
+# === Run tests ===
+
+def load_tests(loader, tests, ignore):
+ """Used for doctest/unittest integration."""
+ tests.addTests(doctest.DocTestSuite())
+ return tests
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_strftime.py b/Lib/test/test_strftime.py
index 78c9c5bdb9..61215a15ba 100644
--- a/Lib/test/test_strftime.py
+++ b/Lib/test/test_strftime.py
@@ -189,19 +189,16 @@ class Y1900Tests(unittest.TestCase):
@unittest.skipIf(sys.platform == "win32", "Doesn't apply on Windows")
def test_y_before_1900_nonwin(self):
- self.assertEquals(
+ self.assertEqual(
time.strftime("%y", (1899, 1, 1, 0, 0, 0, 0, 0, 0)), "99")
def test_y_1900(self):
- self.assertEquals(
+ self.assertEqual(
time.strftime("%y", (1900, 1, 1, 0, 0, 0, 0, 0, 0)), "00")
def test_y_after_1900(self):
- self.assertEquals(
+ self.assertEqual(
time.strftime("%y", (2013, 1, 1, 0, 0, 0, 0, 0, 0)), "13")
-def test_main():
- support.run_unittest(StrftimeTest, Y1900Tests)
-
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py
index 50cae052f3..0107eebca6 100644
--- a/Lib/test/test_struct.py
+++ b/Lib/test/test_struct.py
@@ -1,4 +1,6 @@
+from collections import abc
import array
+import operator
import unittest
import struct
import sys
@@ -6,7 +8,6 @@ import sys
from test import support
ISBIGENDIAN = sys.byteorder == "big"
-IS32BIT = sys.maxsize == 0x7fffffff
integer_codes = 'b', 'B', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q', 'n', 'N'
byteorders = '', '@', '=', '<', '>', '!'
@@ -489,7 +490,7 @@ class StructTest(unittest.TestCase):
def test_bool(self):
class ExplodingBool(object):
def __bool__(self):
- raise IOError
+ raise OSError
for prefix in tuple("<>!=")+('',):
false = (), [], [], '', 0
true = [1], 'test', 5, -1, 0xffffffff+1, 0xffffffff/2
@@ -520,10 +521,10 @@ class StructTest(unittest.TestCase):
try:
struct.pack(prefix + '?', ExplodingBool())
- except IOError:
+ except OSError:
pass
else:
- self.fail("Expected IOError: struct.pack(%r, "
+ self.fail("Expected OSError: struct.pack(%r, "
"ExplodingBool())" % (prefix + '?'))
for c in [b'\x01', b'\x7f', b'\xff', b'\x0f', b'\xf0']:
@@ -536,10 +537,6 @@ class StructTest(unittest.TestCase):
hugecount2 = '{}b{}H'.format(sys.maxsize//2, sys.maxsize//2)
self.assertRaises(struct.error, struct.calcsize, hugecount2)
- if IS32BIT:
- def test_crasher(self):
- self.assertRaises(MemoryError, struct.pack, "357913941b", "a")
-
def test_trailing_counter(self):
store = array.array('b', b' '*100)
@@ -576,7 +573,7 @@ class StructTest(unittest.TestCase):
# The size of 'PyStructObject'
totalsize = support.calcobjsize('2n3P')
# The size taken up by the 'formatcode' dynamic array
- totalsize += struct.calcsize('P2n0P') * (number_of_codes + 1)
+ totalsize += struct.calcsize('P3n0P') * (number_of_codes + 1)
support.check_sizeof(self, struct.Struct(format_str), totalsize)
@support.cpython_only
@@ -587,14 +584,84 @@ class StructTest(unittest.TestCase):
self.check_sizeof('B' * 1234, 1234)
self.check_sizeof('fd', 2)
self.check_sizeof('xxxxxxxxxxxxxx', 0)
- self.check_sizeof('100H', 100)
+ self.check_sizeof('100H', 1)
self.check_sizeof('187s', 1)
self.check_sizeof('20p', 1)
self.check_sizeof('0s', 1)
self.check_sizeof('0c', 0)
+
+class UnpackIteratorTest(unittest.TestCase):
+ """
+ Tests for iterative unpacking (struct.Struct.iter_unpack).
+ """
+
+ def test_construct(self):
+ def _check_iterator(it):
+ self.assertIsInstance(it, abc.Iterator)
+ self.assertIsInstance(it, abc.Iterable)
+ s = struct.Struct('>ibcp')
+ it = s.iter_unpack(b"")
+ _check_iterator(it)
+ it = s.iter_unpack(b"1234567")
+ _check_iterator(it)
+ # Wrong bytes length
+ with self.assertRaises(struct.error):
+ s.iter_unpack(b"123456")
+ with self.assertRaises(struct.error):
+ s.iter_unpack(b"12345678")
+ # Zero-length struct
+ s = struct.Struct('>')
+ with self.assertRaises(struct.error):
+ s.iter_unpack(b"")
+ with self.assertRaises(struct.error):
+ s.iter_unpack(b"12")
+
+ def test_iterate(self):
+ s = struct.Struct('>IB')
+ b = bytes(range(1, 16))
+ it = s.iter_unpack(b)
+ self.assertEqual(next(it), (0x01020304, 5))
+ self.assertEqual(next(it), (0x06070809, 10))
+ self.assertEqual(next(it), (0x0b0c0d0e, 15))
+ self.assertRaises(StopIteration, next, it)
+ self.assertRaises(StopIteration, next, it)
+
+ def test_arbitrary_buffer(self):
+ s = struct.Struct('>IB')
+ b = bytes(range(1, 11))
+ it = s.iter_unpack(memoryview(b))
+ self.assertEqual(next(it), (0x01020304, 5))
+ self.assertEqual(next(it), (0x06070809, 10))
+ self.assertRaises(StopIteration, next, it)
+ self.assertRaises(StopIteration, next, it)
+
+ def test_length_hint(self):
+ lh = operator.length_hint
+ s = struct.Struct('>IB')
+ b = bytes(range(1, 16))
+ it = s.iter_unpack(b)
+ self.assertEqual(lh(it), 3)
+ next(it)
+ self.assertEqual(lh(it), 2)
+ next(it)
+ self.assertEqual(lh(it), 1)
+ next(it)
+ self.assertEqual(lh(it), 0)
+ self.assertRaises(StopIteration, next, it)
+ self.assertEqual(lh(it), 0)
+
+ def test_module_func(self):
+ # Sanity check for the global struct.iter_unpack()
+ it = struct.iter_unpack('>IB', bytes(range(1, 11)))
+ self.assertEqual(next(it), (0x01020304, 5))
+ self.assertEqual(next(it), (0x06070809, 10))
+ self.assertRaises(StopIteration, next, it)
+ self.assertRaises(StopIteration, next, it)
+
+
def test_main():
- support.run_unittest(StructTest)
+ support.run_unittest(__name__)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index a89e9556c1..353d0eadcc 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -38,7 +38,7 @@ class StructSeqTest(unittest.TestCase):
# os.stat() gives a complicated struct sequence.
st = os.stat(__file__)
rep = repr(st)
- self.assertTrue(rep.startswith(os.name + ".stat_result"))
+ self.assertTrue(rep.startswith("os.stat_result"))
self.assertIn("st_mode=", rep)
self.assertIn("st_ino=", rep)
self.assertIn("st_dev=", rep)
diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py
index 77f1ba3155..e12f593b19 100644
--- a/Lib/test/test_subprocess.py
+++ b/Lib/test/test_subprocess.py
@@ -11,6 +11,7 @@ import errno
import tempfile
import time
import re
+import selectors
import sysconfig
import warnings
import select
@@ -19,10 +20,6 @@ import gc
import textwrap
try:
- import resource
-except ImportError:
- resource = None
-try:
import threading
except ImportError:
threading = None
@@ -162,8 +159,28 @@ class ProcessTestCase(BaseTestCase):
stderr=subprocess.STDOUT)
self.assertIn(b'BDFL', output)
+ def test_check_output_stdin_arg(self):
+ # check_output() can be called with stdin set to a file
+ tf = tempfile.TemporaryFile()
+ self.addCleanup(tf.close)
+ tf.write(b'pear')
+ tf.seek(0)
+ output = subprocess.check_output(
+ [sys.executable, "-c",
+ "import sys; sys.stdout.write(sys.stdin.read().upper())"],
+ stdin=tf)
+ self.assertIn(b'PEAR', output)
+
+ def test_check_output_input_arg(self):
+ # check_output() can be called with input set to a string
+ output = subprocess.check_output(
+ [sys.executable, "-c",
+ "import sys; sys.stdout.write(sys.stdin.read().upper())"],
+ input=b'pear')
+ self.assertIn(b'PEAR', output)
+
def test_check_output_stdout_arg(self):
- # check_output() function stderr redirected to stdout
+ # check_output() refuses to accept 'stdout' argument
with self.assertRaises(ValueError) as c:
output = subprocess.check_output(
[sys.executable, "-c", "print('will not be run')"],
@@ -171,6 +188,20 @@ class ProcessTestCase(BaseTestCase):
self.fail("Expected ValueError when stdout arg supplied.")
self.assertIn('stdout', c.exception.args[0])
+ def test_check_output_stdin_with_input_arg(self):
+ # check_output() refuses to accept 'stdin' with 'input'
+ tf = tempfile.TemporaryFile()
+ self.addCleanup(tf.close)
+ tf.write(b'pear')
+ tf.seek(0)
+ with self.assertRaises(ValueError) as c:
+ output = subprocess.check_output(
+ [sys.executable, "-c", "print('will not be run')"],
+ stdin=tf, input=b'hare')
+ self.fail("Expected ValueError when stdin and input args supplied.")
+ self.assertIn('stdin', c.exception.args[0])
+ self.assertIn('input', c.exception.args[0])
+
def test_check_output_timeout(self):
# check_output() function with timeout arg
with self.assertRaises(subprocess.TimeoutExpired) as c:
@@ -853,8 +884,9 @@ class ProcessTestCase(BaseTestCase):
#
# UTF-16 and UTF-32-BE are sufficient to check both with BOM and
# without, and UTF-16 and UTF-32.
+ import _bootlocale
for encoding in ['utf-16', 'utf-32-be']:
- old_getpreferredencoding = locale.getpreferredencoding
+ old_getpreferredencoding = _bootlocale.getpreferredencoding
# Indirectly via io.TextIOWrapper, Popen() defaults to
# locale.getpreferredencoding(False) and earlier in Python 3.2 to
# locale.getpreferredencoding().
@@ -865,7 +897,7 @@ class ProcessTestCase(BaseTestCase):
encoding)
args = [sys.executable, '-c', code]
try:
- locale.getpreferredencoding = getpreferredencoding
+ _bootlocale.getpreferredencoding = getpreferredencoding
# We set stdin to be non-None because, as of this writing,
# a different code path is used when the number of pipes is
# zero or one.
@@ -874,7 +906,7 @@ class ProcessTestCase(BaseTestCase):
stdout=subprocess.PIPE)
stdout, stderr = popen.communicate(input='')
finally:
- locale.getpreferredencoding = old_getpreferredencoding
+ _bootlocale.getpreferredencoding = old_getpreferredencoding
self.assertEqual(stdout, '1\n2\n3\n4')
def test_no_leaking(self):
@@ -982,8 +1014,7 @@ class ProcessTestCase(BaseTestCase):
# value for that limit, but Windows has 2048, so we loop
# 1024 times (each call leaked two fds).
for i in range(1024):
- # Windows raises IOError. Others raise OSError.
- with self.assertRaises(EnvironmentError) as c:
+ with self.assertRaises(OSError) as c:
subprocess.Popen(['nonexisting_i_hope'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
@@ -1114,47 +1145,6 @@ class ProcessTestCase(BaseTestCase):
fds_after_exception = os.listdir(fd_directory)
self.assertEqual(fds_before_popen, fds_after_exception)
-
-# context manager
-class _SuppressCoreFiles(object):
- """Try to prevent core files from being created."""
- old_limit = None
-
- def __enter__(self):
- """Try to save previous ulimit, then set it to (0, 0)."""
- if resource is not None:
- try:
- self.old_limit = resource.getrlimit(resource.RLIMIT_CORE)
- resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
- except (ValueError, resource.error):
- pass
-
- if sys.platform == 'darwin':
- # Check if the 'Crash Reporter' on OSX was configured
- # in 'Developer' mode and warn that it will get triggered
- # when it is.
- #
- # This assumes that this context manager is used in tests
- # that might trigger the next manager.
- value = subprocess.Popen(['/usr/bin/defaults', 'read',
- 'com.apple.CrashReporter', 'DialogType'],
- stdout=subprocess.PIPE).communicate()[0]
- if value.strip() == b'developer':
- print("this tests triggers the Crash Reporter, "
- "that is intentional", end='')
- sys.stdout.flush()
-
- def __exit__(self, *args):
- """Return core file behavior to default."""
- if self.old_limit is None:
- return
- if resource is not None:
- try:
- resource.setrlimit(resource.RLIMIT_CORE, self.old_limit)
- except (ValueError, resource.error):
- pass
-
-
@unittest.skipIf(mswindows, "POSIX specific tests")
class POSIXProcessTestCase(BaseTestCase):
@@ -1243,7 +1233,7 @@ class POSIXProcessTestCase(BaseTestCase):
def test_run_abort(self):
# returncode handles signal termination
- with _SuppressCoreFiles():
+ with support.SuppressCrashReport():
p = subprocess.Popen([sys.executable, "-c",
'import os; os.abort()'])
p.wait()
@@ -1266,7 +1256,7 @@ class POSIXProcessTestCase(BaseTestCase):
try:
p = subprocess.Popen([sys.executable, "-c", ""],
preexec_fn=raise_it)
- except RuntimeError as e:
+ except subprocess.SubprocessError as e:
self.assertTrue(
subprocess._posixsubprocess,
"Expected a ValueError from the preexec_fn")
@@ -1306,9 +1296,10 @@ class POSIXProcessTestCase(BaseTestCase):
"""Issue16140: Don't double close pipes on preexec error."""
def raise_it():
- raise RuntimeError("force the _execute_child() errpipe_data path.")
+ raise subprocess.SubprocessError(
+ "force the _execute_child() errpipe_data path.")
- with self.assertRaises(RuntimeError):
+ with self.assertRaises(subprocess.SubprocessError):
self._TestExecuteChildPopen(
self, [sys.executable, "-c", "pass"],
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
@@ -1507,16 +1498,28 @@ class POSIXProcessTestCase(BaseTestCase):
# Terminating a dead process
self._kill_dead_process('terminate')
+ def _save_fds(self, save_fds):
+ fds = []
+ for fd in save_fds:
+ inheritable = os.get_inheritable(fd)
+ saved = os.dup(fd)
+ fds.append((fd, saved, inheritable))
+ return fds
+
+ def _restore_fds(self, fds):
+ for fd, saved, inheritable in fds:
+ os.dup2(saved, fd, inheritable=inheritable)
+ os.close(saved)
+
def check_close_std_fds(self, fds):
# Issue #9905: test that subprocess pipes still work properly with
# some standard fds closed
stdin = 0
- newfds = []
- for a in fds:
- b = os.dup(a)
- newfds.append(b)
- if a == 0:
- stdin = b
+ saved_fds = self._save_fds(fds)
+ for fd, saved, inheritable in saved_fds:
+ if fd == 0:
+ stdin = saved
+ break
try:
for fd in fds:
os.close(fd)
@@ -1531,10 +1534,7 @@ class POSIXProcessTestCase(BaseTestCase):
err = support.strip_python_stderr(err)
self.assertEqual((out, err), (b'apple', b'orange'))
finally:
- for b, a in zip(newfds, fds):
- os.dup2(b, a)
- for b in newfds:
- os.close(b)
+ self._restore_fds(saved_fds)
def test_close_fd_0(self):
self.check_close_std_fds([0])
@@ -1574,7 +1574,7 @@ class POSIXProcessTestCase(BaseTestCase):
os.lseek(temp_fds[1], 0, 0)
# move the standard file descriptors out of the way
- saved_fds = [os.dup(fd) for fd in range(3)]
+ saved_fds = self._save_fds(range(3))
try:
# duplicate the file objects over the standard fd's
for fd, temp_fd in enumerate(temp_fds):
@@ -1590,10 +1590,7 @@ class POSIXProcessTestCase(BaseTestCase):
stderr=temp_fds[0])
p.wait()
finally:
- # restore the original fd's underneath sys.stdin, etc.
- for std, saved in enumerate(saved_fds):
- os.dup2(saved, std)
- os.close(saved)
+ self._restore_fds(saved_fds)
for fd in temp_fds:
os.lseek(fd, 0, 0)
@@ -1617,7 +1614,7 @@ class POSIXProcessTestCase(BaseTestCase):
os.unlink(fname)
# save a copy of the standard file descriptors
- saved_fds = [os.dup(fd) for fd in range(3)]
+ saved_fds = self._save_fds(range(3))
try:
# duplicate the temp files over the standard fd's 0, 1, 2
for fd, temp_fd in enumerate(temp_fds):
@@ -1643,9 +1640,7 @@ class POSIXProcessTestCase(BaseTestCase):
out = os.read(stdout_no, 1024)
err = support.strip_python_stderr(os.read(stderr_no, 1024))
finally:
- for std, saved in enumerate(saved_fds):
- os.dup2(saved, std)
- os.close(saved)
+ self._restore_fds(saved_fds)
self.assertEqual(out, b"got STDIN")
self.assertEqual(err, b"err")
@@ -1677,12 +1672,12 @@ class POSIXProcessTestCase(BaseTestCase):
# Pure Python implementations keeps the message
self.assertIsNone(subprocess._posixsubprocess)
self.assertEqual(str(err), "surrogate:\uDCff")
- except RuntimeError as err:
+ except subprocess.SubprocessError as err:
# _posixsubprocess uses a default message
self.assertIsNotNone(subprocess._posixsubprocess)
self.assertEqual(str(err), "Exception occurred in preexec_fn.")
else:
- self.fail("Expected ValueError or RuntimeError")
+ self.fail("Expected ValueError or subprocess.SubprocessError")
def test_undecodable_env(self):
for key, value in (('test', 'abc\uDCFF'), ('test\uDCFF', '42')):
@@ -1816,6 +1811,9 @@ class POSIXProcessTestCase(BaseTestCase):
self.addCleanup(os.close, fd)
open_fds.add(fd)
+ for fd in open_fds:
+ os.set_inheritable(fd, True)
+
p = subprocess.Popen([sys.executable, fd_status],
stdout=subprocess.PIPE, close_fds=False)
output, ignored = p.communicate()
@@ -1860,6 +1858,8 @@ class POSIXProcessTestCase(BaseTestCase):
fds = os.pipe()
self.addCleanup(os.close, fds[0])
self.addCleanup(os.close, fds[1])
+ os.set_inheritable(fds[0], True)
+ os.set_inheritable(fds[1], True)
open_fds.update(fds)
for fd in open_fds:
@@ -1882,6 +1882,32 @@ class POSIXProcessTestCase(BaseTestCase):
close_fds=False, pass_fds=(fd, )))
self.assertIn('overriding close_fds', str(context.warning))
+ def test_pass_fds_inheritable(self):
+ script = support.findfile("fd_status.py", subdir="subprocessdata")
+
+ inheritable, non_inheritable = os.pipe()
+ self.addCleanup(os.close, inheritable)
+ self.addCleanup(os.close, non_inheritable)
+ os.set_inheritable(inheritable, True)
+ os.set_inheritable(non_inheritable, False)
+ pass_fds = (inheritable, non_inheritable)
+ args = [sys.executable, script]
+ args += list(map(str, pass_fds))
+
+ p = subprocess.Popen(args,
+ stdout=subprocess.PIPE, close_fds=True,
+ pass_fds=pass_fds)
+ output, ignored = p.communicate()
+ fds = set(map(int, output.split(b',')))
+
+ # the inheritable file descriptor must be inherited, so its inheritable
+ # flag must be set in the child process after fork() and before exec()
+ self.assertEqual(fds, set(pass_fds), "output=%a" % output)
+
+ # inheritable flag must not be changed in the parent process
+ self.assertEqual(os.get_inheritable(inheritable), True)
+ self.assertEqual(os.get_inheritable(non_inheritable), False)
+
def test_stdout_stdin_are_single_inout_fd(self):
with io.open(os.devnull, "r+") as inout:
p = subprocess.Popen([sys.executable, "-c", "import sys; sys.exit(0)"],
@@ -1969,7 +1995,7 @@ class POSIXProcessTestCase(BaseTestCase):
# let some time for the process to exit, and create a new Popen: this
# should trigger the wait() of p
time.sleep(0.2)
- with self.assertRaises(EnvironmentError) as c:
+ with self.assertRaises(OSError) as c:
with subprocess.Popen(['nonexisting_i_hope'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as proc:
@@ -2154,15 +2180,16 @@ class CommandTests(unittest.TestCase):
os.rmdir(dir)
-@unittest.skipUnless(getattr(subprocess, '_has_poll', False),
- "poll system call not supported")
+@unittest.skipUnless(hasattr(selectors, 'PollSelector'),
+ "Test needs selectors.PollSelector")
class ProcessTestCaseNoPoll(ProcessTestCase):
def setUp(self):
- subprocess._has_poll = False
+ self.orig_selector = subprocess._PopenSelector
+ subprocess._PopenSelector = selectors.SelectSelector
ProcessTestCase.setUp(self)
def tearDown(self):
- subprocess._has_poll = True
+ subprocess._PopenSelector = self.orig_selector
ProcessTestCase.tearDown(self)
diff --git a/Lib/test/test_sunau.py b/Lib/test/test_sunau.py
index 767314f5b2..81acd964c2 100644
--- a/Lib/test/test_sunau.py
+++ b/Lib/test/test_sunau.py
@@ -47,6 +47,34 @@ class SunauPCM16Test(audiotests.AudioWriteTests,
""")
+class SunauPCM24Test(audiotests.AudioWriteTests,
+ audiotests.AudioTestsWithSourceFile,
+ unittest.TestCase):
+ module = sunau
+ sndfilename = 'pluck-pcm24.au'
+ sndfilenframes = 3307
+ nchannels = 2
+ sampwidth = 3
+ framerate = 11025
+ nframes = 48
+ comptype = 'NONE'
+ compname = 'not compressed'
+ frames = bytes.fromhex("""\
+ 022D65FFEB9D 4B5A0F00FA54 3113C304EE2B 80DCD6084303 \
+ CBDEC006B261 48A99803F2F8 BFE82401B07D 036BFBFE7B5D \
+ B85756FA3EC9 B4B055F3502B 299830EBCB62 1A5CA7E6D99A \
+ EDFA3EE491BD C625EBE27884 0E05A9E0B6CF EF2929E02922 \
+ 5758D8E27067 FB3557E83E16 1377BFEF8402 D82C5BF7272A \
+ 978F16FB7745 F5F865FC1013 086635FB9C4E DF30FCFB40EE \
+ 117FE0FA3438 3EE6B8FB5AC3 BC77A3FCB2F4 66D6DAFF5F32 \
+ CF13B9041275 431D69097A8C C1BB600EC74E 5120B912A2BA \
+ EEDF641754C0 8207001664B7 7FFFFF14453F 8000001294E6 \
+ 499C1B0EB3B2 52B73E0DBCA0 EFB2B20F5FD8 CE3CDB0FBE12 \
+ E4B49C0CEA2D 6344A80A5A7C 08C8FE0A1FFE 2BB9860B0A0E \
+ 51486F0E44E1 8BCC64113B05 B6F4EC0EEB36 4413170A5B48 \
+ """)
+
+
class SunauPCM32Test(audiotests.AudioWriteTests,
audiotests.AudioTestsWithSourceFile,
unittest.TestCase):
diff --git a/Lib/test/test_sundry.py b/Lib/test/test_sundry.py
index e08cf0179d..2da3ac05dc 100644
--- a/Lib/test/test_sundry.py
+++ b/Lib/test/test_sundry.py
@@ -1,19 +1,26 @@
"""Do a minimal test of all the modules that aren't otherwise tested."""
-
-from test import support
+import importlib
import sys
+from test import support
import unittest
class TestUntestedModules(unittest.TestCase):
- def test_at_least_import_untested_modules(self):
+ def test_untested_modules_can_be_imported(self):
+ untested = ('bdb', 'encodings', 'formatter', 'imghdr',
+ 'nturl2path', 'tabnanny')
with support.check_warnings(quiet=True):
- import bdb
- import cgitb
+ for name in untested:
+ try:
+ support.import_module('test.test_{}'.format(name))
+ except unittest.SkipTest:
+ importlib.import_module(name)
+ else:
+ self.fail('{} has tests even though test_sundry claims '
+ 'otherwise'.format(name))
import distutils.bcppcompiler
import distutils.ccompiler
import distutils.cygwinccompiler
- import distutils.emxccompiler
import distutils.filelist
if sys.platform.startswith('win'):
import distutils.msvccompiler
@@ -39,28 +46,14 @@ class TestUntestedModules(unittest.TestCase):
import distutils.command.sdist
import distutils.command.upload
- import encodings
- import formatter
- import getpass
import html.entities
- import imghdr
- import keyword
- import mailcap
- import nturl2path
- import os2emxpath
- import pstats
- import py_compile
- import sndhdr
- import tabnanny
+
try:
- import tty # not available on Windows
+ import tty # Not available on Windows
except ImportError:
if support.verbose:
print("skipping tty")
-def test_main():
- support.run_unittest(TestUntestedModules)
-
if __name__ == "__main__":
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py
index 1e272ee2cf..37fc2d9134 100644
--- a/Lib/test/test_super.py
+++ b/Lib/test/test_super.py
@@ -44,6 +44,11 @@ class G(A):
class TestSuper(unittest.TestCase):
+ def tearDown(self):
+ # This fixes the damage that test_various___class___pathologies does.
+ nonlocal __class__
+ __class__ = TestSuper
+
def test_basics_working(self):
self.assertEqual(D().f(), 'ABCD')
@@ -81,8 +86,7 @@ class TestSuper(unittest.TestCase):
self.assertEqual(E().f(), 'AE')
- @unittest.expectedFailure
- def test___class___set(self):
+ def test_various___class___pathologies(self):
# See issue #12370
class X(A):
def f(self):
@@ -91,6 +95,31 @@ class TestSuper(unittest.TestCase):
x = X()
self.assertEqual(x.f(), 'A')
self.assertEqual(x.__class__, 413)
+ class X:
+ x = __class__
+ def f():
+ __class__
+ self.assertIs(X.x, type(self))
+ with self.assertRaises(NameError) as e:
+ exec("""class X:
+ __class__
+ def f():
+ __class__""", globals(), {})
+ self.assertIs(type(e.exception), NameError) # Not UnboundLocalError
+ class X:
+ global __class__
+ __class__ = 42
+ def f():
+ __class__
+ self.assertEqual(globals()["__class__"], 42)
+ del globals()["__class__"]
+ self.assertNotIn("__class__", X.__dict__)
+ class X:
+ nonlocal __class__
+ __class__ = 42
+ def f():
+ __class__
+ self.assertEqual(__class__, 42)
def test___class___instancemethod(self):
# See issue #14857
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 4edb1a8e01..16b660bce1 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -306,6 +306,7 @@ class TestSupport(unittest.TestCase):
# args_from_interpreter_flags
# can_symlink
# skip_unless_symlink
+ # SuppressCrashReport
def test_main():
diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py
index 5926b69c93..a9d3628819 100644
--- a/Lib/test/test_syntax.py
+++ b/Lib/test/test_syntax.py
@@ -33,7 +33,7 @@ SyntaxError: invalid syntax
>>> None = 1
Traceback (most recent call last):
-SyntaxError: assignment to keyword
+SyntaxError: can't assign to keyword
It's a syntax error to assign to the empty tuple. Why isn't it an
error to assign to the empty list? It will always raise some error at
@@ -233,7 +233,7 @@ Traceback (most recent call last):
SyntaxError: can't assign to generator expression
>>> None += 1
Traceback (most recent call last):
-SyntaxError: assignment to keyword
+SyntaxError: can't assign to keyword
>>> f() += 1
Traceback (most recent call last):
SyntaxError: can't assign to function call
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index acf6d364f2..0565f39e2c 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -6,6 +6,8 @@ import textwrap
import warnings
import operator
import codecs
+import gc
+import sysconfig
# count the number of test runs, used to create unique
# strings to intern in test_intern()
@@ -248,7 +250,7 @@ class SysModuleTest(unittest.TestCase):
sys.setrecursionlimit(%d)
f()""")
- with test.support.suppress_crash_popup():
+ with test.support.SuppressCrashReport():
for i in (50, 1000):
sub = subprocess.Popen([sys.executable, '-c', code % i],
stderr=subprocess.PIPE)
@@ -481,7 +483,7 @@ class SysModuleTest(unittest.TestCase):
def test_thread_info(self):
info = sys.thread_info
self.assertEqual(len(info), 3)
- self.assertIn(info.name, ('nt', 'os2', 'pthread', 'solaris', None))
+ self.assertIn(info.name, ('nt', 'pthread', 'solaris', None))
self.assertIn(info.lock, ('semaphore', 'mutex+cond', None))
def test_43581(self):
@@ -514,7 +516,7 @@ class SysModuleTest(unittest.TestCase):
attrs = ("debug",
"inspect", "interactive", "optimize", "dont_write_bytecode",
"no_user_site", "no_site", "ignore_environment", "verbose",
- "bytes_warning", "quiet", "hash_randomization")
+ "bytes_warning", "quiet", "hash_randomization", "isolated")
for attr in attrs:
self.assertTrue(hasattr(sys.flags, attr), attr)
self.assertEqual(type(getattr(sys.flags, attr)), int, attr)
@@ -543,6 +545,42 @@ class SysModuleTest(unittest.TestCase):
out = p.communicate()[0].strip()
self.assertEqual(out, b'?')
+ env["PYTHONIOENCODING"] = "ascii"
+ p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ env=env)
+ out, err = p.communicate()
+ self.assertEqual(out, b'')
+ self.assertIn(b'UnicodeEncodeError:', err)
+ self.assertIn(rb"'\xa2'", err)
+
+ env["PYTHONIOENCODING"] = "ascii:"
+ p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xa2))'],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ env=env)
+ out, err = p.communicate()
+ self.assertEqual(out, b'')
+ self.assertIn(b'UnicodeEncodeError:', err)
+ self.assertIn(rb"'\xa2'", err)
+
+ env["PYTHONIOENCODING"] = ":surrogateescape"
+ p = subprocess.Popen([sys.executable, "-c", 'print(chr(0xdcbd))'],
+ stdout=subprocess.PIPE, env=env)
+ out = p.communicate()[0].strip()
+ self.assertEqual(out, b'\xbd')
+
+ @unittest.skipUnless(test.support.FS_NONASCII,
+ 'requires OS support of non-ASCII encodings')
+ def test_ioencoding_nonascii(self):
+ env = dict(os.environ)
+
+ env["PYTHONIOENCODING"] = ""
+ p = subprocess.Popen([sys.executable, "-c",
+ 'print(%a)' % test.support.FS_NONASCII],
+ stdout=subprocess.PIPE, env=env)
+ out = p.communicate()[0].strip()
+ self.assertEqual(out, os.fsencode(test.support.FS_NONASCII))
+
@unittest.skipIf(sys.base_prefix != sys.prefix,
'Test is not venv-compatible')
def test_executable(self):
@@ -610,6 +648,36 @@ class SysModuleTest(unittest.TestCase):
ret, out, err = assert_python_ok(*args)
self.assertIn(b"free PyDictObjects", err)
+ @unittest.skipUnless(hasattr(sys, "getallocatedblocks"),
+ "sys.getallocatedblocks unavailable on this build")
+ def test_getallocatedblocks(self):
+ # Some sanity checks
+ with_pymalloc = sysconfig.get_config_var('WITH_PYMALLOC')
+ a = sys.getallocatedblocks()
+ self.assertIs(type(a), int)
+ if with_pymalloc:
+ self.assertGreater(a, 0)
+ else:
+ # When WITH_PYMALLOC isn't available, we don't know anything
+ # about the underlying implementation: the function might
+ # return 0 or something greater.
+ self.assertGreaterEqual(a, 0)
+ try:
+ # While we could imagine a Python session where the number of
+ # multiple buffer objects would exceed the sharing of references,
+ # it is unlikely to happen in a normal test run.
+ self.assertLess(a, sys.gettotalrefcount())
+ except AttributeError:
+ # gettotalrefcount() not available
+ pass
+ gc.collect()
+ b = sys.getallocatedblocks()
+ self.assertLessEqual(b, a)
+ gc.collect()
+ c = sys.getallocatedblocks()
+ self.assertIn(c, range(b - 50, b + 50))
+
+
class SizeofTest(unittest.TestCase):
def setUp(self):
@@ -654,7 +722,7 @@ class SizeofTest(unittest.TestCase):
samples = [b'', b'u'*100000]
for sample in samples:
x = bytearray(sample)
- check(x, vsize('inP') + x.__alloc__())
+ check(x, vsize('n2Pi') + x.__alloc__())
# bytearray_iterator
check(iter(bytearray()), size('nP'))
# cell
@@ -733,7 +801,7 @@ class SizeofTest(unittest.TestCase):
nfrees = len(x.f_code.co_freevars)
extras = x.f_code.co_stacksize + x.f_code.co_nlocals +\
ncells + nfrees - 1
- check(x, vsize('12P3i' + CO_MAXBLOCKS*'3i' + 'P' + extras*'P'))
+ check(x, vsize('13P3ic' + CO_MAXBLOCKS*'3i' + 'P' + extras*'P'))
# function
def func(): pass
check(func, size('12P'))
@@ -779,7 +847,7 @@ class SizeofTest(unittest.TestCase):
# memoryview
check(memoryview(b''), size('Pnin 2P2n2i5P 3cPn'))
# module
- check(unittest, size('PnP'))
+ check(unittest, size('PnPPP'))
# None
check(None, size(''))
# NotImplementedType
@@ -833,11 +901,11 @@ class SizeofTest(unittest.TestCase):
check((1,2,3), vsize('') + 3*self.P)
# type
# static type: PyTypeObject
- s = vsize('P2n15Pl4Pn9Pn11PI')
+ s = vsize('P2n15Pl4Pn9Pn11PIP')
check(int, s)
# (PyTypeObject + PyNumberMethods + PyMappingMethods +
# PySequenceMethods + PyBufferProcs + 4P)
- s = vsize('P2n15Pl4Pn9Pn11PI') + struct.calcsize('34P 3P 10P 2P 4P')
+ s = vsize('P2n15Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P')
# Separate block for PyDictKeysObject with 4 entries
s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P")
# class
diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py
index 03f67fd476..29686040ea 100644
--- a/Lib/test/test_sysconfig.py
+++ b/Lib/test/test_sysconfig.py
@@ -234,7 +234,7 @@ class TestSysConfig(unittest.TestCase):
self.assertTrue(os.path.isfile(config_h), config_h)
def test_get_scheme_names(self):
- wanted = ('nt', 'nt_user', 'os2', 'os2_home', 'osx_framework_user',
+ wanted = ('nt', 'nt_user', 'osx_framework_user',
'posix_home', 'posix_prefix', 'posix_user')
self.assertEqual(get_scheme_names(), wanted)
@@ -305,14 +305,13 @@ class TestSysConfig(unittest.TestCase):
if 'MACOSX_DEPLOYMENT_TARGET' in env:
del env['MACOSX_DEPLOYMENT_TARGET']
- with open('/dev/null', 'w') as devnull_fp:
- p = subprocess.Popen([
- sys.executable, '-c',
- 'import sysconfig; print(sysconfig.get_platform())',
- ],
- stdout=subprocess.PIPE,
- stderr=devnull_fp,
- env=env)
+ p = subprocess.Popen([
+ sys.executable, '-c',
+ 'import sysconfig; print(sysconfig.get_platform())',
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ env=env)
test_platform = p.communicate()[0].strip()
test_platform = test_platform.decode('utf-8')
status = p.wait()
@@ -325,20 +324,19 @@ class TestSysConfig(unittest.TestCase):
env = os.environ.copy()
env['MACOSX_DEPLOYMENT_TARGET'] = '10.1'
- with open('/dev/null') as dev_null:
- p = subprocess.Popen([
- sys.executable, '-c',
- 'import sysconfig; print(sysconfig.get_platform())',
- ],
- stdout=subprocess.PIPE,
- stderr=dev_null,
- env=env)
- test_platform = p.communicate()[0].strip()
- test_platform = test_platform.decode('utf-8')
- status = p.wait()
-
- self.assertEqual(status, 0)
- self.assertEqual(my_platform, test_platform)
+ p = subprocess.Popen([
+ sys.executable, '-c',
+ 'import sysconfig; print(sysconfig.get_platform())',
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ env=env)
+ test_platform = p.communicate()[0].strip()
+ test_platform = test_platform.decode('utf-8')
+ status = p.wait()
+
+ self.assertEqual(status, 0)
+ self.assertEqual(my_platform, test_platform)
def test_srcdir(self):
# See Issues #15322, #15364.
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index 238175ff3c..2f2bf6b8be 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -1732,20 +1732,20 @@ class ContextManagerTest(unittest.TestCase):
self.assertTrue(tar.closed, "context manager failed")
def test_closed(self):
- # The __enter__() method is supposed to raise IOError
+ # The __enter__() method is supposed to raise OSError
# if the TarFile object is already closed.
tar = tarfile.open(tarname)
tar.close()
- with self.assertRaises(IOError):
+ with self.assertRaises(OSError):
with tar:
pass
def test_exception(self):
- # Test if the IOError exception is passed through properly.
+ # Test if the OSError exception is passed through properly.
with self.assertRaises(Exception) as exc:
with tarfile.open(tarname) as tar:
- raise IOError
- self.assertIsInstance(exc.exception, IOError,
+ raise OSError
+ self.assertIsInstance(exc.exception, OSError,
"wrong exception raised in context manager")
self.assertTrue(tar.closed, "context manager failed")
diff --git a/Lib/test/test_tcl.py b/Lib/test/test_tcl.py
index cf1fcf9812..cf717d8ce5 100644
--- a/Lib/test/test_tcl.py
+++ b/Lib/test/test_tcl.py
@@ -254,40 +254,6 @@ class TclTest(unittest.TestCase):
for arg, res in testcases:
self.assertEqual(split(arg), res, msg=arg)
- def test_merge(self):
- with support.check_warnings(('merge is deprecated',
- DeprecationWarning)):
- merge = self.interp.tk.merge
- call = self.interp.tk.call
- testcases = [
- ((), ''),
- (('a',), 'a'),
- ((2,), '2'),
- (('',), '{}'),
- ('{', '\\{'),
- (('a', 'b', 'c'), 'a b c'),
- ((' ', '\t', '\r', '\n'), '{ } {\t} {\r} {\n}'),
- (('a', ' ', 'c'), 'a { } c'),
- (('a', '€'), 'a €'),
- (('a', '\U000104a2'), 'a \U000104a2'),
- (('a', b'\xe2\x82\xac'), 'a €'),
- (('a', ('b', 'c')), 'a {b c}'),
- (('a', 2), 'a 2'),
- (('a', 3.4), 'a 3.4'),
- (('a', (2, 3.4)), 'a {2 3.4}'),
- ((), ''),
- ((call('list', 1, '2', (3.4,)),), '{1 2 3.4}'),
- ]
- if tcl_version >= (8, 5):
- testcases += [
- ((call('dict', 'create', 12, '\u20ac', b'\xe2\x82\xac', (3.4,)),),
- '{12 € € 3.4}'),
- ]
- for args, res in testcases:
- self.assertEqual(merge(*args), res, msg=args)
- self.assertRaises(UnicodeDecodeError, merge, b'\x80')
- self.assertRaises(UnicodeEncodeError, merge, '\udc80')
-
class BigmemTclTest(unittest.TestCase):
diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py
index c9f2ccb462..ba33064623 100644
--- a/Lib/test/test_telnetlib.py
+++ b/Lib/test/test_telnetlib.py
@@ -1,10 +1,9 @@
import socket
-import select
+import selectors
import telnetlib
import time
import contextlib
-import unittest
from unittest import TestCase
from test import support
threading = support.import_module('threading')
@@ -112,40 +111,32 @@ class TelnetAlike(telnetlib.Telnet):
self._messages += out.getvalue()
return
-def mock_select(*s_args):
- block = False
- for l in s_args:
- for fob in l:
- if isinstance(fob, TelnetAlike):
- block = fob.sock.block
- if block:
- return [[], [], []]
- else:
- return s_args
-
-class MockPoller(object):
- test_case = None # Set during TestCase setUp.
+class MockSelector(selectors.BaseSelector):
def __init__(self):
- self._file_objs = []
+ super().__init__()
+ self.keys = {}
+
+ def register(self, fileobj, events, data=None):
+ key = selectors.SelectorKey(fileobj, 0, events, data)
+ self.keys[fileobj] = key
+ return key
- def register(self, fd, eventmask):
- self.test_case.assertTrue(hasattr(fd, 'fileno'), fd)
- self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI)
- self._file_objs.append(fd)
+ def unregister(self, fileobj):
+ key = self.keys.pop(fileobj)
+ return key
- def poll(self, timeout=None):
+ def select(self, timeout=None):
block = False
- for fob in self._file_objs:
- if isinstance(fob, TelnetAlike):
- block = fob.sock.block
+ for fileobj in self.keys:
+ if isinstance(fileobj, TelnetAlike):
+ block = fileobj.sock.block
+ break
if block:
return []
else:
- return zip(self._file_objs, [select.POLLIN]*len(self._file_objs))
+ return [(key, key.events) for key in self.keys.values()]
- def unregister(self, fd):
- self._file_objs.remove(fd)
@contextlib.contextmanager
def test_socket(reads):
@@ -159,7 +150,7 @@ def test_socket(reads):
socket.create_connection = old_conn
return
-def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
+def test_telnet(reads=(), cls=TelnetAlike):
''' return a telnetlib.Telnet object that uses a SocketStub with
reads queued up to be read '''
for x in reads:
@@ -167,29 +158,14 @@ def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
with test_socket(reads):
telnet = cls('dummy', 0)
telnet._messages = '' # debuglevel output
- if use_poll is not None:
- if use_poll and not telnet._has_poll:
- raise unittest.SkipTest('select.poll() required.')
- telnet._has_poll = use_poll
return telnet
-
class ExpectAndReadTestCase(TestCase):
def setUp(self):
- self.old_select = select.select
- select.select = mock_select
- self.old_poll = False
- if hasattr(select, 'poll'):
- self.old_poll = select.poll
- select.poll = MockPoller
- MockPoller.test_case = self
-
+ self.old_selector = telnetlib._TelnetSelector
+ telnetlib._TelnetSelector = MockSelector
def tearDown(self):
- if self.old_poll:
- MockPoller.test_case = None
- select.poll = self.old_poll
- select.select = self.old_select
-
+ telnetlib._TelnetSelector = self.old_selector
class ReadTests(ExpectAndReadTestCase):
def test_read_until(self):
@@ -208,22 +184,6 @@ class ReadTests(ExpectAndReadTestCase):
data = telnet.read_until(b'match')
self.assertEqual(data, expect)
- def test_read_until_with_poll(self):
- """Use select.poll() to implement telnet.read_until()."""
- want = [b'x' * 10, b'match', b'y' * 10]
- telnet = test_telnet(want, use_poll=True)
- select.select = lambda *_: self.fail('unexpected select() call.')
- data = telnet.read_until(b'match')
- self.assertEqual(data, b''.join(want[:-1]))
-
- def test_read_until_with_select(self):
- """Use select.select() to implement telnet.read_until()."""
- want = [b'x' * 10, b'match', b'y' * 10]
- telnet = test_telnet(want, use_poll=False)
- if self.old_poll:
- select.poll = lambda *_: self.fail('unexpected poll() call.')
- data = telnet.read_until(b'match')
- self.assertEqual(data, b''.join(want[:-1]))
def test_read_all(self):
"""
@@ -427,23 +387,6 @@ class ExpectTests(ExpectAndReadTestCase):
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
- def test_expect_with_poll(self):
- """Use select.poll() to implement telnet.expect()."""
- want = [b'x' * 10, b'match', b'y' * 10]
- telnet = test_telnet(want, use_poll=True)
- select.select = lambda *_: self.fail('unexpected select() call.')
- (_,_,data) = telnet.expect([b'match'])
- self.assertEqual(data, b''.join(want[:-1]))
-
- def test_expect_with_select(self):
- """Use select.select() to implement telnet.expect()."""
- want = [b'x' * 10, b'match', b'y' * 10]
- telnet = test_telnet(want, use_poll=False)
- if self.old_poll:
- select.poll = lambda *_: self.fail('unexpected poll() call.')
- (_,_,data) = telnet.expect([b'match'])
- self.assertEqual(data, b''.join(want[:-1]))
-
def test_main(verbose=None):
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index 4ff56b8bf6..ac4d8609df 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -36,7 +36,7 @@ else:
# Common functionality.
class BaseTestCase(unittest.TestCase):
- str_check = re.compile(r"[a-zA-Z0-9_-]{6}$")
+ str_check = re.compile(r"^[a-z0-9_-]{8}$")
def setUp(self):
self._warnings_manager = support.check_warnings()
@@ -63,7 +63,7 @@ class BaseTestCase(unittest.TestCase):
nbase = nbase[len(pre):len(nbase)-len(suf)]
self.assertTrue(self.str_check.match(nbase),
- "random string '%s' does not match /^[a-zA-Z0-9_-]{6}$/"
+ "random string '%s' does not match ^[a-z0-9_-]{8}$"
% nbase)
@@ -152,7 +152,7 @@ class TestRandomNameSequence(BaseTestCase):
# via any bugs above
try:
os.kill(pid, signal.SIGKILL)
- except EnvironmentError:
+ except OSError:
pass
os.close(read_fd)
os.close(write_fd)
@@ -191,7 +191,7 @@ class TestCandidateTempdirList(BaseTestCase):
try:
dirname = os.getcwd()
- except (AttributeError, os.error):
+ except (AttributeError, OSError):
dirname = os.curdir
self.assertIn(dirname, cand)
@@ -332,7 +332,7 @@ class TestMkstempInner(BaseTestCase):
file = self.do_create()
mode = stat.S_IMODE(os.stat(file.name).st_mode)
expected = 0o600
- if sys.platform in ('win32', 'os2emx'):
+ if sys.platform == 'win32':
# There's no distinction among 'user', 'group' and 'world';
# replicate the 'user' bits.
user = expected >> 6
@@ -350,6 +350,7 @@ class TestMkstempInner(BaseTestCase):
v="q"
file = self.do_create()
+ self.assertEqual(os.get_inheritable(file.fd), False)
fd = "%d" % file.fd
try:
@@ -366,7 +367,7 @@ class TestMkstempInner(BaseTestCase):
# On Windows a spawn* /path/ with embedded spaces shouldn't be quoted,
# but an arg with embedded spaces should be decorated with double
# quotes on each end
- if sys.platform in ('win32',):
+ if sys.platform == 'win32':
decorated = '"%s"' % sys.executable
tester = '"%s"' % tester
else:
@@ -477,6 +478,20 @@ class TestGetTempDir(BaseTestCase):
self.assertTrue(a is b)
+ def test_case_sensitive(self):
+ # gettempdir should not flatten its case
+ # even on a case-insensitive file system
+ case_sensitive_tempdir = tempfile.mkdtemp("-Temp")
+ _tempdir, tempfile.tempdir = tempfile.tempdir, None
+ try:
+ with support.EnvironmentVarGuard() as env:
+ # Fake the first env var which is checked as a candidate
+ env["TMPDIR"] = case_sensitive_tempdir
+ self.assertEqual(tempfile.gettempdir(), case_sensitive_tempdir)
+ finally:
+ tempfile.tempdir = _tempdir
+ support.rmdir(case_sensitive_tempdir)
+
class TestMkstemp(BaseTestCase):
"""Test mkstemp()."""
@@ -566,7 +581,7 @@ class TestMkdtemp(BaseTestCase):
mode = stat.S_IMODE(os.stat(dir).st_mode)
mode &= 0o777 # Mask off sticky bits inherited from /tmp
expected = 0o700
- if sys.platform in ('win32', 'os2emx'):
+ if sys.platform == 'win32':
# There's no distinction among 'user', 'group' and 'world';
# replicate the 'user' bits.
user = expected >> 6
diff --git a/Lib/test/test_textwrap.py b/Lib/test/test_textwrap.py
index c86f5cfae8..1bba77eb9e 100644
--- a/Lib/test/test_textwrap.py
+++ b/Lib/test/test_textwrap.py
@@ -9,9 +9,8 @@
#
import unittest
-from test import support
-from textwrap import TextWrapper, wrap, fill, dedent, indent
+from textwrap import TextWrapper, wrap, fill, dedent, indent, shorten
class BaseTestCase(unittest.TestCase):
@@ -430,6 +429,90 @@ What a mess!
self.check_wrap(text, 7, ["aa \xe4\xe4-", "\xe4\xe4"])
+class MaxLinesTestCase(BaseTestCase):
+ text = "Hello there, how are you this fine day? I'm glad to hear it!"
+
+ def test_simple(self):
+ self.check_wrap(self.text, 12,
+ ["Hello [...]"],
+ max_lines=0)
+ self.check_wrap(self.text, 12,
+ ["Hello [...]"],
+ max_lines=1)
+ self.check_wrap(self.text, 12,
+ ["Hello there,",
+ "how [...]"],
+ max_lines=2)
+ self.check_wrap(self.text, 13,
+ ["Hello there,",
+ "how are [...]"],
+ max_lines=2)
+ self.check_wrap(self.text, 80, [self.text], max_lines=1)
+ self.check_wrap(self.text, 12,
+ ["Hello there,",
+ "how are you",
+ "this fine",
+ "day? I'm",
+ "glad to hear",
+ "it!"],
+ max_lines=6)
+
+ def test_spaces(self):
+ # strip spaces before placeholder
+ self.check_wrap(self.text, 12,
+ ["Hello there,",
+ "how are you",
+ "this fine",
+ "day? [...]"],
+ max_lines=4)
+ # placeholder at the start of line
+ self.check_wrap(self.text, 6,
+ ["Hello",
+ "[...]"],
+ max_lines=2)
+ # final spaces
+ self.check_wrap(self.text + ' ' * 10, 12,
+ ["Hello there,",
+ "how are you",
+ "this fine",
+ "day? I'm",
+ "glad to hear",
+ "it!"],
+ max_lines=6)
+
+ def test_placeholder(self):
+ self.check_wrap(self.text, 12,
+ ["Hello..."],
+ max_lines=1,
+ placeholder='...')
+ self.check_wrap(self.text, 12,
+ ["Hello there,",
+ "how are..."],
+ max_lines=2,
+ placeholder='...')
+ # long placeholder and indentation
+ with self.assertRaises(ValueError):
+ wrap(self.text, 16, initial_indent=' ',
+ max_lines=1, placeholder=' [truncated]...')
+ with self.assertRaises(ValueError):
+ wrap(self.text, 16, subsequent_indent=' ',
+ max_lines=2, placeholder=' [truncated]...')
+ self.check_wrap(self.text, 16,
+ [" Hello there,",
+ " [truncated]..."],
+ max_lines=2,
+ initial_indent=' ',
+ subsequent_indent=' ',
+ placeholder=' [truncated]...')
+ self.check_wrap(self.text, 16,
+ [" [truncated]..."],
+ max_lines=1,
+ initial_indent=' ',
+ subsequent_indent=' ',
+ placeholder=' [truncated]...')
+ self.check_wrap(self.text, 80, [self.text], placeholder='.' * 1000)
+
+
class LongWordTestCase (BaseTestCase):
def setUp(self):
self.wrapper = TextWrapper()
@@ -490,6 +573,14 @@ How *do* you spell that odd word, anyways?
result = wrap(self.text, width=30, break_long_words=0)
self.check(result, expect)
+ def test_max_lines_long(self):
+ self.check_wrap(self.text, 12,
+ ['Did you say ',
+ '"supercalifr',
+ 'agilisticexp',
+ '[...]'],
+ max_lines=4)
+
class IndentTestCases(BaseTestCase):
@@ -777,12 +868,62 @@ class IndentTestCase(unittest.TestCase):
self.assertEqual(indent(text, prefix, predicate), expect)
-def test_main():
- support.run_unittest(WrapTestCase,
- LongWordTestCase,
- IndentTestCases,
- DedentTestCase,
- IndentTestCase)
+class ShortenTestCase(BaseTestCase):
+
+ def check_shorten(self, text, width, expect, **kwargs):
+ result = shorten(text, width, **kwargs)
+ self.check(result, expect)
+
+ def test_simple(self):
+ # Simple case: just words, spaces, and a bit of punctuation
+ text = "Hello there, how are you this fine day? I'm glad to hear it!"
+
+ self.check_shorten(text, 18, "Hello there, [...]")
+ self.check_shorten(text, len(text), text)
+ self.check_shorten(text, len(text) - 1,
+ "Hello there, how are you this fine day? "
+ "I'm glad to [...]")
+
+ def test_placeholder(self):
+ text = "Hello there, how are you this fine day? I'm glad to hear it!"
+
+ self.check_shorten(text, 17, "Hello there,$$", placeholder='$$')
+ self.check_shorten(text, 18, "Hello there, how$$", placeholder='$$')
+ self.check_shorten(text, 18, "Hello there, $$", placeholder=' $$')
+ self.check_shorten(text, len(text), text, placeholder='$$')
+ self.check_shorten(text, len(text) - 1,
+ "Hello there, how are you this fine day? "
+ "I'm glad to hear$$", placeholder='$$')
+
+ def test_empty_string(self):
+ self.check_shorten("", 6, "")
+
+ def test_whitespace(self):
+ # Whitespace collapsing
+ text = """
+ This is a paragraph that already has
+ line breaks and \t tabs too."""
+ self.check_shorten(text, 62,
+ "This is a paragraph that already has line "
+ "breaks and tabs too.")
+ self.check_shorten(text, 61,
+ "This is a paragraph that already has line "
+ "breaks and [...]")
+
+ self.check_shorten("hello world! ", 12, "hello world!")
+ self.check_shorten("hello world! ", 11, "hello [...]")
+ # The leading space is trimmed from the placeholder
+ # (it would be ugly otherwise).
+ self.check_shorten("hello world! ", 10, "[...]")
+
+ def test_width_too_small_for_placeholder(self):
+ shorten("x" * 20, width=8, placeholder="(......)")
+ with self.assertRaises(ValueError):
+ shorten("x" * 20, width=8, placeholder="(.......)")
+
+ def test_first_word_too_long_but_placeholder_fits(self):
+ self.check_shorten("Helloo", 5, "[...]")
+
if __name__ == '__main__':
- test_main()
+ unittest.main()
diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py
index a191e157bc..f9a721b03b 100644
--- a/Lib/test/test_thread.py
+++ b/Lib/test/test_thread.py
@@ -68,7 +68,7 @@ class ThreadRunningTests(BasicThreadTest):
thread.stack_size(0)
self.assertEqual(thread.stack_size(), 0, "stack_size not reset to default")
- if os.name not in ("nt", "os2", "posix"):
+ if os.name not in ("nt", "posix"):
return
tss_supported = True
diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py
index 6c2965ba6e..3d961b5ee2 100644
--- a/Lib/test/test_threaded_import.py
+++ b/Lib/test/test_threaded_import.py
@@ -5,8 +5,8 @@
# complains several times about module random having no attribute
# randrange, and then Python hangs.
+import _imp as imp
import os
-import imp
import importlib
import sys
import time
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index 0ebeb39cbd..66eace021e 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -11,6 +11,7 @@ import re
import sys
_thread = import_module('_thread')
threading = import_module('threading')
+import _testcapi
import time
import unittest
import weakref
@@ -20,6 +21,15 @@ import subprocess
from test import lock_tests
+
+# Between fork() and exec(), only async-safe functions are allowed (issues
+# #12316 and #11870), and fork() from a worker thread is known to trigger
+# problems with some operating systems (issue #3863): skip problematic tests
+# on platforms known to behave badly.
+platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5',
+ 'hp-ux11')
+
+
# A trivial mutable counter.
class Counter(object):
def __init__(self):
@@ -99,7 +109,7 @@ class ThreadTests(BaseTestCase):
if verbose:
print('waiting for all tasks to complete')
for t in threads:
- t.join(NUMTASKS)
+ t.join()
self.assertTrue(not t.is_alive())
self.assertNotEqual(t.ident, 0)
self.assertFalse(t.ident is None)
@@ -467,6 +477,127 @@ class ThreadTests(BaseTestCase):
pid, status = os.waitpid(pid, 0)
self.assertEqual(0, status)
+ def test_main_thread(self):
+ main = threading.main_thread()
+ self.assertEqual(main.name, 'MainThread')
+ self.assertEqual(main.ident, threading.current_thread().ident)
+ self.assertEqual(main.ident, threading.get_ident())
+
+ def f():
+ self.assertNotEqual(threading.main_thread().ident,
+ threading.current_thread().ident)
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
+ @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()")
+ def test_main_thread_after_fork(self):
+ code = """if 1:
+ import os, threading
+
+ pid = os.fork()
+ if pid == 0:
+ main = threading.main_thread()
+ print(main.name)
+ print(main.ident == threading.current_thread().ident)
+ print(main.ident == threading.get_ident())
+ else:
+ os.waitpid(pid, 0)
+ """
+ _, out, err = assert_python_ok("-c", code)
+ data = out.decode().replace('\r', '')
+ self.assertEqual(err, b"")
+ self.assertEqual(data, "MainThread\nTrue\nTrue\n")
+
+ @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
+ @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
+ @unittest.skipUnless(hasattr(os, 'waitpid'), "test needs os.waitpid()")
+ def test_main_thread_after_fork_from_nonmain_thread(self):
+ code = """if 1:
+ import os, threading, sys
+
+ def f():
+ pid = os.fork()
+ if pid == 0:
+ main = threading.main_thread()
+ print(main.name)
+ print(main.ident == threading.current_thread().ident)
+ print(main.ident == threading.get_ident())
+ # stdout is fully buffered because not a tty,
+ # we have to flush before exit.
+ sys.stdout.flush()
+ else:
+ os.waitpid(pid, 0)
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+ """
+ _, out, err = assert_python_ok("-c", code)
+ data = out.decode().replace('\r', '')
+ self.assertEqual(err, b"")
+ self.assertEqual(data, "Thread-1\nTrue\nTrue\n")
+
+ def test_tstate_lock(self):
+ # Test an implementation detail of Thread objects.
+ started = _thread.allocate_lock()
+ finish = _thread.allocate_lock()
+ started.acquire()
+ finish.acquire()
+ def f():
+ started.release()
+ finish.acquire()
+ time.sleep(0.01)
+ # The tstate lock is None until the thread is started
+ t = threading.Thread(target=f)
+ self.assertIs(t._tstate_lock, None)
+ t.start()
+ started.acquire()
+ self.assertTrue(t.is_alive())
+ # The tstate lock can't be acquired when the thread is running
+ # (or suspended).
+ tstate_lock = t._tstate_lock
+ self.assertFalse(tstate_lock.acquire(timeout=0), False)
+ finish.release()
+ # When the thread ends, the state_lock can be successfully
+ # acquired.
+ self.assertTrue(tstate_lock.acquire(timeout=5), False)
+ # But is_alive() is still True: we hold _tstate_lock now, which
+ # prevents is_alive() from knowing the thread's end-of-life C code
+ # is done.
+ self.assertTrue(t.is_alive())
+ # Let is_alive() find out the C code is done.
+ tstate_lock.release()
+ self.assertFalse(t.is_alive())
+ # And verify the thread disposed of _tstate_lock.
+ self.assertTrue(t._tstate_lock is None)
+
+ def test_repr_stopped(self):
+ # Verify that "stopped" shows up in repr(Thread) appropriately.
+ started = _thread.allocate_lock()
+ finish = _thread.allocate_lock()
+ started.acquire()
+ finish.acquire()
+ def f():
+ started.release()
+ finish.acquire()
+ t = threading.Thread(target=f)
+ t.start()
+ started.acquire()
+ self.assertIn("started", repr(t))
+ finish.release()
+ # "stopped" should appear in the repr in a reasonable amount of time.
+ # Implementation detail: as of this writing, that's trivially true
+ # if .join() is called, and almost trivially true if .is_alive() is
+ # called. The detail we're testing here is that "stopped" shows up
+ # "all on its own".
+ LOOKING_FOR = "stopped"
+ for i in range(500):
+ if LOOKING_FOR in repr(t):
+ break
+ time.sleep(0.01)
+ self.assertIn(LOOKING_FOR, repr(t)) # we waited at least 5 seconds
def test_BoundedSemaphore_limit(self):
# BoundedSemaphore should raise ValueError if released too often.
@@ -486,14 +617,53 @@ class ThreadTests(BaseTestCase):
t.join()
self.assertRaises(ValueError, bs.release)
-class ThreadJoinOnShutdown(BaseTestCase):
+ def test_locals_at_exit(self):
+ # Issue #19466: thread locals must not be deleted before destructors
+ # are called
+ rc, out, err = assert_python_ok("-c", """if 1:
+ import threading
+
+ class Atexit:
+ def __del__(self):
+ print("thread_dict.atexit = %r" % thread_dict.atexit)
+
+ thread_dict = threading.local()
+ thread_dict.atexit = "atexit"
+
+ atexit = Atexit()
+ """)
+ self.assertEqual(out.rstrip(), b"thread_dict.atexit = 'atexit'")
+
+ def test_warnings_at_exit(self):
+ # Issue #19466: try to call most destructors at Python shutdown before
+ # destroying Python thread states
+ filename = __file__
+ rc, out, err = assert_python_ok("-Wd", "-c", """if 1:
+ import time
+ import threading
- # Between fork() and exec(), only async-safe functions are allowed (issues
- # #12316 and #11870), and fork() from a worker thread is known to trigger
- # problems with some operating systems (issue #3863): skip problematic tests
- # on platforms known to behave badly.
- platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5',
- 'os2emx', 'hp-ux11')
+ def open_sleep():
+ # a warning will be emitted when the open file will be
+ # destroyed (without being explicitly closed) while the daemon
+ # thread is destroyed
+ fileobj = open(%a, 'rb')
+ start_event.set()
+ time.sleep(60.0)
+
+ start_event = threading.Event()
+
+ thread = threading.Thread(target=open_sleep)
+ thread.daemon = True
+ thread.start()
+
+ # wait until the thread started
+ start_event.wait()
+ """ % filename)
+ self.assertRegex(err.rstrip(),
+ b"^sys:1: ResourceWarning: unclosed file ")
+
+
+class ThreadJoinOnShutdown(BaseTestCase):
def _run_and_join(self, script):
script = """if 1:
@@ -566,144 +736,8 @@ class ThreadJoinOnShutdown(BaseTestCase):
"""
self._run_and_join(script)
- def assertScriptHasOutput(self, script, expected_output):
- rc, out, err = assert_python_ok("-c", script)
- data = out.decode().replace('\r', '')
- self.assertEqual(data, expected_output)
-
- @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
- def test_4_joining_across_fork_in_worker_thread(self):
- # There used to be a possible deadlock when forking from a child
- # thread. See http://bugs.python.org/issue6643.
-
- # The script takes the following steps:
- # - The main thread in the parent process starts a new thread and then
- # tries to join it.
- # - The join operation acquires the Lock inside the thread's _block
- # Condition. (See threading.py:Thread.join().)
- # - We stub out the acquire method on the condition to force it to wait
- # until the child thread forks. (See LOCK ACQUIRED HERE)
- # - The child thread forks. (See LOCK HELD and WORKER THREAD FORKS
- # HERE)
- # - The main thread of the parent process enters Condition.wait(),
- # which releases the lock on the child thread.
- # - The child process returns. Without the necessary fix, when the
- # main thread of the child process (which used to be the child thread
- # in the parent process) attempts to exit, it will try to acquire the
- # lock in the Thread._block Condition object and hang, because the
- # lock was held across the fork.
-
- script = """if 1:
- import os, time, threading
-
- finish_join = False
- start_fork = False
-
- def worker():
- # Wait until this thread's lock is acquired before forking to
- # create the deadlock.
- global finish_join
- while not start_fork:
- time.sleep(0.01)
- # LOCK HELD: Main thread holds lock across this call.
- childpid = os.fork()
- finish_join = True
- if childpid != 0:
- # Parent process just waits for child.
- os.waitpid(childpid, 0)
- # Child process should just return.
-
- w = threading.Thread(target=worker)
-
- # Stub out the private condition variable's lock acquire method.
- # This acquires the lock and then waits until the child has forked
- # before returning, which will release the lock soon after. If
- # someone else tries to fix this test case by acquiring this lock
- # before forking instead of resetting it, the test case will
- # deadlock when it shouldn't.
- condition = w._block
- orig_acquire = condition.acquire
- call_count_lock = threading.Lock()
- call_count = 0
- def my_acquire():
- global call_count
- global start_fork
- orig_acquire() # LOCK ACQUIRED HERE
- start_fork = True
- if call_count == 0:
- while not finish_join:
- time.sleep(0.01) # WORKER THREAD FORKS HERE
- with call_count_lock:
- call_count += 1
- condition.acquire = my_acquire
-
- w.start()
- w.join()
- print('end of main')
- """
- self.assertScriptHasOutput(script, "end of main\n")
-
- @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
- @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
- def test_5_clear_waiter_locks_to_avoid_crash(self):
- # Check that a spawned thread that forks doesn't segfault on certain
- # platforms, namely OS X. This used to happen if there was a waiter
- # lock in the thread's condition variable's waiters list. Even though
- # we know the lock will be held across the fork, it is not safe to
- # release locks held across forks on all platforms, so releasing the
- # waiter lock caused a segfault on OS X. Furthermore, since locks on
- # OS X are (as of this writing) implemented with a mutex + condition
- # variable instead of a semaphore, while we know that the Python-level
- # lock will be acquired, we can't know if the internal mutex will be
- # acquired at the time of the fork.
-
- script = """if True:
- import os, time, threading
-
- start_fork = False
-
- def worker():
- # Wait until the main thread has attempted to join this thread
- # before continuing.
- while not start_fork:
- time.sleep(0.01)
- childpid = os.fork()
- if childpid != 0:
- # Parent process just waits for child.
- (cpid, rc) = os.waitpid(childpid, 0)
- assert cpid == childpid
- assert rc == 0
- print('end of worker thread')
- else:
- # Child process should just return.
- pass
-
- w = threading.Thread(target=worker)
-
- # Stub out the private condition variable's _release_save method.
- # This releases the condition's lock and flips the global that
- # causes the worker to fork. At this point, the problematic waiter
- # lock has been acquired once by the waiter and has been put onto
- # the waiters list.
- condition = w._block
- orig_release_save = condition._release_save
- def my_release_save():
- global start_fork
- orig_release_save()
- # Waiter lock held here, condition lock released.
- start_fork = True
- condition._release_save = my_release_save
-
- w.start()
- w.join()
- print('end of main thread')
- """
- output = "end of worker thread\nend of main thread\n"
- self.assertScriptHasOutput(script, output)
-
- @unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
- def test_6_daemon_threads(self):
+ def test_4_daemon_threads(self):
# Check that a daemon thread cannot crash the interpreter on shutdown
# by manipulating internal structures that are being disposed of in
# the main thread.
@@ -713,6 +747,10 @@ class ThreadJoinOnShutdown(BaseTestCase):
import sys
import time
import threading
+ import warnings
+
+ # ignore "unclosed file ..." warnings
+ warnings.filterwarnings('ignore', '', ResourceWarning)
thread_has_run = set()
@@ -769,6 +807,111 @@ class ThreadJoinOnShutdown(BaseTestCase):
for t in threads:
t.join()
+ @unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
+ def test_clear_threads_states_after_fork(self):
+ # Issue #17094: check that threads states are cleared after fork()
+
+ # start a bunch of threads
+ threads = []
+ for i in range(16):
+ t = threading.Thread(target=lambda : time.sleep(0.3))
+ threads.append(t)
+ t.start()
+
+ pid = os.fork()
+ if pid == 0:
+ # check that threads states have been cleared
+ if len(sys._current_frames()) == 1:
+ os._exit(0)
+ else:
+ os._exit(1)
+ else:
+ _, status = os.waitpid(pid, 0)
+ self.assertEqual(0, status)
+
+ for t in threads:
+ t.join()
+
+
+class SubinterpThreadingTests(BaseTestCase):
+
+ def test_threads_join(self):
+ # Non-daemon threads should be joined at subinterpreter shutdown
+ # (issue #18808)
+ r, w = os.pipe()
+ self.addCleanup(os.close, r)
+ self.addCleanup(os.close, w)
+ code = r"""if 1:
+ import os
+ import threading
+ import time
+
+ def f():
+ # Sleep a bit so that the thread is still running when
+ # Py_EndInterpreter is called.
+ time.sleep(0.05)
+ os.write(%d, b"x")
+ threading.Thread(target=f).start()
+ """ % (w,)
+ ret = _testcapi.run_in_subinterp(code)
+ self.assertEqual(ret, 0)
+ # The thread was joined properly.
+ self.assertEqual(os.read(r, 1), b"x")
+
+ def test_threads_join_2(self):
+ # Same as above, but a delay gets introduced after the thread's
+ # Python code returned but before the thread state is deleted.
+ # To achieve this, we register a thread-local object which sleeps
+ # a bit when deallocated.
+ r, w = os.pipe()
+ self.addCleanup(os.close, r)
+ self.addCleanup(os.close, w)
+ code = r"""if 1:
+ import os
+ import threading
+ import time
+
+ class Sleeper:
+ def __del__(self):
+ time.sleep(0.05)
+
+ tls = threading.local()
+
+ def f():
+ # Sleep a bit so that the thread is still running when
+ # Py_EndInterpreter is called.
+ time.sleep(0.05)
+ tls.x = Sleeper()
+ os.write(%d, b"x")
+ threading.Thread(target=f).start()
+ """ % (w,)
+ ret = _testcapi.run_in_subinterp(code)
+ self.assertEqual(ret, 0)
+ # The thread was joined properly.
+ self.assertEqual(os.read(r, 1), b"x")
+
+ def test_daemon_threads_fatal_error(self):
+ subinterp_code = r"""if 1:
+ import os
+ import threading
+ import time
+
+ def f():
+ # Make sure the daemon thread is still running when
+ # Py_EndInterpreter is called.
+ time.sleep(10)
+ threading.Thread(target=f, daemon=True).start()
+ """
+ script = r"""if 1:
+ import _testcapi
+
+ _testcapi.run_in_subinterp(%r)
+ """ % (subinterp_code,)
+ with test.support.SuppressCrashReport():
+ rc, out, err = assert_python_failure("-c", script)
+ self.assertIn("Fatal Python error: Py_EndInterpreter: "
+ "not the last thread", err.decode())
+
class ThreadingExceptionTests(BaseTestCase):
# A RuntimeError should be raised if Thread.start() is called
diff --git a/Lib/test/test_threadsignals.py b/Lib/test/test_threadsignals.py
index f975a75e85..b1004e6689 100644
--- a/Lib/test/test_threadsignals.py
+++ b/Lib/test/test_threadsignals.py
@@ -8,7 +8,7 @@ from test.support import run_unittest, import_module
thread = import_module('_thread')
import time
-if sys.platform[:3] in ('win', 'os2') or sys.platform=='riscos':
+if (sys.platform[:3] == 'win'):
raise unittest.SkipTest("Can't test signal on %s" % sys.platform)
process_pid = os.getpid()
diff --git a/Lib/test/test_timeout.py b/Lib/test/test_timeout.py
index dcf201b507..703c43ab2b 100644
--- a/Lib/test/test_timeout.py
+++ b/Lib/test/test_timeout.py
@@ -207,7 +207,7 @@ class TCPTimeoutTestCase(TimeoutTestCase):
sock.connect((whitehole))
except socket.timeout:
pass
- except IOError as err:
+ except OSError as err:
if err.errno == errno.ECONNREFUSED:
skip = False
finally:
diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py
index 5bce2af68a..96ef951139 100644
--- a/Lib/test/test_traceback.py
+++ b/Lib/test/test_traceback.py
@@ -150,30 +150,74 @@ class SyntaxTracebackCases(unittest.TestCase):
class TracebackFormatTests(unittest.TestCase):
- def test_traceback_format(self):
+ def some_exception(self):
+ raise KeyError('blah')
+
+ def check_traceback_format(self, cleanup_func=None):
try:
- raise KeyError('blah')
+ self.some_exception()
except KeyError:
type_, value, tb = sys.exc_info()
+ if cleanup_func is not None:
+ # Clear the inner frames, not this one
+ cleanup_func(tb.tb_next)
traceback_fmt = 'Traceback (most recent call last):\n' + \
''.join(traceback.format_tb(tb))
file_ = StringIO()
traceback_print(tb, file_)
python_fmt = file_.getvalue()
+ # Call all _tb and _exc functions
+ with captured_output("stderr") as tbstderr:
+ traceback.print_tb(tb)
+ tbfile = StringIO()
+ traceback.print_tb(tb, file=tbfile)
+ with captured_output("stderr") as excstderr:
+ traceback.print_exc()
+ excfmt = traceback.format_exc()
+ excfile = StringIO()
+ traceback.print_exc(file=excfile)
else:
raise Error("unable to create test traceback string")
# Make sure that Python and the traceback module format the same thing
self.assertEqual(traceback_fmt, python_fmt)
+ # Now verify the _tb func output
+ self.assertEqual(tbstderr.getvalue(), tbfile.getvalue())
+ # Now verify the _exc func output
+ self.assertEqual(excstderr.getvalue(), excfile.getvalue())
+ self.assertEqual(excfmt, excfile.getvalue())
# Make sure that the traceback is properly indented.
tb_lines = python_fmt.splitlines()
- self.assertEqual(len(tb_lines), 3)
- banner, location, source_line = tb_lines
+ self.assertEqual(len(tb_lines), 5)
+ banner = tb_lines[0]
+ location, source_line = tb_lines[-2:]
self.assertTrue(banner.startswith('Traceback'))
self.assertTrue(location.startswith(' File'))
self.assertTrue(source_line.startswith(' raise'))
+ def test_traceback_format(self):
+ self.check_traceback_format()
+
+ def test_traceback_format_with_cleared_frames(self):
+ # Check that traceback formatting also works with a clear()ed frame
+ def cleanup_tb(tb):
+ tb.tb_frame.clear()
+ self.check_traceback_format(cleanup_tb)
+
+ def test_stack_format(self):
+ # Verify _stack functions. Note we have to use _getframe(1) to
+ # compare them without this frame appearing in the output
+ with captured_output("stderr") as ststderr:
+ traceback.print_stack(sys._getframe(1))
+ stfile = StringIO()
+ traceback.print_stack(sys._getframe(1), file=stfile)
+ self.assertEqual(ststderr.getvalue(), stfile.getvalue())
+
+ stfmt = traceback.format_stack(sys._getframe(1))
+
+ self.assertEqual(ststderr.getvalue(), "".join(stfmt))
+
cause_message = (
"\nThe above exception was the direct cause "
@@ -344,6 +388,36 @@ class CExcReportingTests(BaseExceptionReportingTests, unittest.TestCase):
return s.getvalue()
+class MiscTracebackCases(unittest.TestCase):
+ #
+ # Check non-printing functions in traceback module
+ #
+
+ def test_clear(self):
+ def outer():
+ middle()
+ def middle():
+ inner()
+ def inner():
+ i = 1
+ 1/0
+
+ try:
+ outer()
+ except:
+ type_, value, tb = sys.exc_info()
+
+ # Initial assertion: there's one local in the inner frame.
+ inner_frame = tb.tb_next.tb_next.tb_next.tb_frame
+ self.assertEqual(len(inner_frame.f_locals), 1)
+
+ # Clear traceback frames
+ traceback.clear_frames(tb)
+
+ # Local variable dict should now be empty.
+ self.assertEqual(len(inner_frame.f_locals), 0)
+
+
def test_main():
run_unittest(__name__)
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index 3ee4c6bc5f..ec10752e6a 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -2,6 +2,7 @@
from test.support import run_unittest, run_with_locale
import collections
+import pickle
import locale
import sys
import types
@@ -1077,9 +1078,19 @@ class SimpleNamespaceTests(unittest.TestCase):
ns2 = types.SimpleNamespace()
ns2.x = "spam"
ns2._y = 5
+ name = "namespace"
- self.assertEqual(repr(ns1), "namespace(w=3, x=1, y=2)")
- self.assertEqual(repr(ns2), "namespace(_y=5, x='spam')")
+ self.assertEqual(repr(ns1), "{name}(w=3, x=1, y=2)".format(name=name))
+ self.assertEqual(repr(ns2), "{name}(_y=5, x='spam')".format(name=name))
+
+ def test_equal(self):
+ ns1 = types.SimpleNamespace(x=1)
+ ns2 = types.SimpleNamespace()
+ ns2.x = 1
+
+ self.assertEqual(types.SimpleNamespace(), types.SimpleNamespace())
+ self.assertEqual(ns1, ns2)
+ self.assertNotEqual(ns2, types.SimpleNamespace())
def test_nested(self):
ns1 = types.SimpleNamespace(a=1, b=2)
@@ -1117,11 +1128,12 @@ class SimpleNamespaceTests(unittest.TestCase):
ns1.spam = ns1
ns2.spam = ns3
ns3.spam = ns2
+ name = "namespace"
+ repr1 = "{name}(c='cookie', spam={name}(...))".format(name=name)
+ repr2 = "{name}(spam={name}(spam={name}(...), x=1))".format(name=name)
- self.assertEqual(repr(ns1),
- "namespace(c='cookie', spam=namespace(...))")
- self.assertEqual(repr(ns2),
- "namespace(spam=namespace(spam=namespace(...), x=1))")
+ self.assertEqual(repr(ns1), repr1)
+ self.assertEqual(repr(ns2), repr2)
def test_as_dict(self):
ns = types.SimpleNamespace(spam='spamspamspam')
@@ -1144,6 +1156,19 @@ class SimpleNamespaceTests(unittest.TestCase):
self.assertIs(type(spam), Spam)
self.assertEqual(vars(spam), {'ham': 8, 'eggs': 9})
+ def test_pickle(self):
+ ns = types.SimpleNamespace(breakfast="spam", lunch="spam")
+
+ for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+ pname = "protocol {}".format(protocol)
+ try:
+ ns_pickled = pickle.dumps(ns, protocol)
+ except TypeError as e:
+ raise TypeError(pname) from e
+ ns_roundtrip = pickle.loads(ns_pickled)
+
+ self.assertEqual(ns, ns_roundtrip, pname)
+
def test_main():
run_unittest(TypesTests, MappingProxyTests, ClassCreationTests,
diff --git a/Lib/test/test_ucn.py b/Lib/test/test_ucn.py
index 2e6374561f..59bde7475a 100644
--- a/Lib/test/test_ucn.py
+++ b/Lib/test/test_ucn.py
@@ -173,7 +173,7 @@ class UnicodeNamesTest(unittest.TestCase):
try:
testdata = support.open_urlresource(url, encoding="utf-8",
check=check_version)
- except (IOError, HTTPException):
+ except (OSError, HTTPException):
self.skipTest("Could not retrieve " + url)
self.addCleanup(testdata.close)
for line in testdata:
diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py
index 9dc3438bea..727897e53e 100644
--- a/Lib/test/test_unicode.py
+++ b/Lib/test/test_unicode.py
@@ -7,6 +7,7 @@ Written by Marc-Andre Lemburg (mal@lemburg.com).
"""#"
import _string
import codecs
+import itertools
import struct
import sys
import unittest
@@ -31,6 +32,16 @@ def search_function(encoding):
return None
codecs.register(search_function)
+def duplicate_string(text):
+ """
+ Try to get a fresh clone of the specified text:
+ new object with a reference count of 1.
+
+ This is a best-effort: latin1 single letters and the empty
+ string ('') are singletons and cannot be cloned.
+ """
+ return text.encode().decode()
+
class UnicodeTest(string_tests.CommonTest,
string_tests.MixinStrUnicodeUserStringTest,
string_tests.MixinStrUnicodeTest,
@@ -863,11 +874,9 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual('{0:d}'.format(G('data')), 'G(data)')
self.assertEqual('{0!s}'.format(G('data')), 'string is data')
- msg = 'object.__format__ with a non-empty format string is deprecated'
- with support.check_warnings((msg, DeprecationWarning)):
- self.assertEqual('{0:^10}'.format(E('data')), ' E(data) ')
- self.assertEqual('{0:^10s}'.format(E('data')), ' E(data) ')
- self.assertEqual('{0:>15s}'.format(G('data')), ' string is data')
+ self.assertRaises(TypeError, '{0:^10}'.format, E('data'))
+ self.assertRaises(TypeError, '{0:^10s}'.format, E('data'))
+ self.assertRaises(TypeError, '{0:>15s}'.format, G('data'))
self.assertEqual("{0:date: %Y-%m-%d}".format(I(year=2007,
month=8,
@@ -903,7 +912,7 @@ class UnicodeTest(string_tests.CommonTest,
self.assertRaises(ValueError, "{0".format)
self.assertRaises(IndexError, "{0.}".format)
self.assertRaises(ValueError, "{0.}".format, 0)
- self.assertRaises(IndexError, "{0[}".format)
+ self.assertRaises(ValueError, "{0[}".format)
self.assertRaises(ValueError, "{0[}".format, [])
self.assertRaises(KeyError, "{0]}".format)
self.assertRaises(ValueError, "{0.[]}".format, 0)
@@ -955,6 +964,14 @@ class UnicodeTest(string_tests.CommonTest,
'')
self.assertEqual("{[{}]}".format({"{}": 5}), "5")
+ self.assertEqual("{[{}]}".format({"{}" : "a"}), "a")
+ self.assertEqual("{[{]}".format({"{" : "a"}), "a")
+ self.assertEqual("{[}]}".format({"}" : "a"}), "a")
+ self.assertEqual("{[[]}".format({"[" : "a"}), "a")
+ self.assertEqual("{[!]}".format({"!" : "a"}), "a")
+ self.assertRaises(ValueError, "{a{}b}".format, 42)
+ self.assertRaises(ValueError, "{a{b}".format, 42)
+ self.assertRaises(ValueError, "{[}".format, 42)
def test_format_map(self):
self.assertEqual(''.format_map({}), '')
@@ -1107,6 +1124,38 @@ class UnicodeTest(string_tests.CommonTest,
self.assertEqual('%.1s' % "a\xe9\u20ac", 'a')
self.assertEqual('%.2s' % "a\xe9\u20ac", 'a\xe9')
+ def test_formatting_with_enum(self):
+ # issue18780
+ import enum
+ class Float(float, enum.Enum):
+ PI = 3.1415926
+ class Int(enum.IntEnum):
+ IDES = 15
+ class Str(str, enum.Enum):
+ ABC = 'abc'
+ # Testing Unicode formatting strings...
+ self.assertEqual("%s, %s" % (Str.ABC, Str.ABC),
+ 'Str.ABC, Str.ABC')
+ self.assertEqual("%s, %s, %d, %i, %u, %f, %5.2f" %
+ (Str.ABC, Str.ABC,
+ Int.IDES, Int.IDES, Int.IDES,
+ Float.PI, Float.PI),
+ 'Str.ABC, Str.ABC, 15, 15, 15, 3.141593, 3.14')
+
+ # formatting jobs delegated from the string implementation:
+ self.assertEqual('...%(foo)s...' % {'foo':Str.ABC},
+ '...Str.ABC...')
+ self.assertEqual('...%(foo)s...' % {'foo':Int.IDES},
+ '...Int.IDES...')
+ self.assertEqual('...%(foo)i...' % {'foo':Int.IDES},
+ '...15...')
+ self.assertEqual('...%(foo)d...' % {'foo':Int.IDES},
+ '...15...')
+ self.assertEqual('...%(foo)u...' % {'foo':Int.IDES, 'def':Float.PI},
+ '...15...')
+ self.assertEqual('...%(foo)f...' % {'foo':Float.PI,'def':123},
+ '...3.141593...')
+
@support.cpython_only
def test_formatting_huge_precision(self):
from _testcapi import INT_MAX
@@ -1782,7 +1831,7 @@ class UnicodeTest(string_tests.CommonTest,
# 0-127
s = bytes(range(128))
for encoding in (
- 'cp037', 'cp1026',
+ 'cp037', 'cp1026', 'cp273',
'cp437', 'cp500', 'cp720', 'cp737', 'cp775', 'cp850',
'cp852', 'cp855', 'cp858', 'cp860', 'cp861', 'cp862',
'cp863', 'cp865', 'cp866',
@@ -1810,7 +1859,7 @@ class UnicodeTest(string_tests.CommonTest,
# 128-255
s = bytes(range(128, 256))
for encoding in (
- 'cp037', 'cp1026',
+ 'cp037', 'cp1026', 'cp273',
'cp437', 'cp500', 'cp720', 'cp737', 'cp775', 'cp850',
'cp852', 'cp855', 'cp858', 'cp860', 'cp861', 'cp862',
'cp863', 'cp865', 'cp866',
@@ -2004,9 +2053,10 @@ class UnicodeTest(string_tests.CommonTest,
# Test PyUnicode_FromFormat()
def test_from_format(self):
support.import_module('ctypes')
- from ctypes import (pythonapi, py_object,
+ from ctypes import (
+ pythonapi, py_object, sizeof,
c_int, c_long, c_longlong, c_ssize_t,
- c_uint, c_ulong, c_ulonglong, c_size_t)
+ c_uint, c_ulong, c_ulonglong, c_size_t, c_void_p)
name = "PyUnicode_FromFormat"
_PyUnicode_FromFormat = getattr(pythonapi, name)
_PyUnicode_FromFormat.restype = py_object
@@ -2017,9 +2067,13 @@ class UnicodeTest(string_tests.CommonTest,
for arg in args)
return _PyUnicode_FromFormat(format, *cargs)
+ def check_format(expected, format, *args):
+ text = PyUnicode_FromFormat(format, *args)
+ self.assertEqual(expected, text)
+
# ascii format, non-ascii argument
- text = PyUnicode_FromFormat(b'ascii\x7f=%U', 'unicode\xe9')
- self.assertEqual(text, 'ascii\x7f=unicode\xe9')
+ check_format('ascii\x7f=unicode\xe9',
+ b'ascii\x7f=%U', 'unicode\xe9')
# non-ascii format, ascii argument: ensure that PyUnicode_FromFormatV()
# raises an error
@@ -2029,64 +2083,205 @@ class UnicodeTest(string_tests.CommonTest,
PyUnicode_FromFormat, b'unicode\xe9=%s', 'ascii')
# test "%c"
- self.assertEqual(PyUnicode_FromFormat(b'%c', c_int(0xabcd)), '\uabcd')
- self.assertEqual(PyUnicode_FromFormat(b'%c', c_int(0x10ffff)), '\U0010ffff')
+ check_format('\uabcd',
+ b'%c', c_int(0xabcd))
+ check_format('\U0010ffff',
+ b'%c', c_int(0x10ffff))
with self.assertRaises(OverflowError):
PyUnicode_FromFormat(b'%c', c_int(0x110000))
# Issue #18183
- self.assertEqual(
- PyUnicode_FromFormat(b'%c%c', c_int(0x10000), c_int(0x100000)),
- '\U00010000\U00100000')
+ check_format('\U00010000\U00100000',
+ b'%c%c', c_int(0x10000), c_int(0x100000))
# test "%"
- self.assertEqual(PyUnicode_FromFormat(b'%'), '%')
- self.assertEqual(PyUnicode_FromFormat(b'%%'), '%')
- self.assertEqual(PyUnicode_FromFormat(b'%%s'), '%s')
- self.assertEqual(PyUnicode_FromFormat(b'[%%]'), '[%]')
- self.assertEqual(PyUnicode_FromFormat(b'%%%s', b'abc'), '%abc')
+ check_format('%',
+ b'%')
+ check_format('%',
+ b'%%')
+ check_format('%s',
+ b'%%s')
+ check_format('[%]',
+ b'[%%]')
+ check_format('%abc',
+ b'%%%s', b'abc')
+
+ # truncated string
+ check_format('abc',
+ b'%.3s', b'abcdef')
+ check_format('abc[\ufffd',
+ b'%.5s', 'abc[\u20ac]'.encode('utf8'))
+ check_format("'\\u20acABC'",
+ b'%A', '\u20acABC')
+ check_format("'\\u20",
+ b'%.5A', '\u20acABCDEF')
+ check_format("'\u20acABC'",
+ b'%R', '\u20acABC')
+ check_format("'\u20acA",
+ b'%.3R', '\u20acABCDEF')
+ check_format('\u20acAB',
+ b'%.3S', '\u20acABCDEF')
+ check_format('\u20acAB',
+ b'%.3U', '\u20acABCDEF')
+ check_format('\u20acAB',
+ b'%.3V', '\u20acABCDEF', None)
+ check_format('abc[\ufffd',
+ b'%.5V', None, 'abc[\u20ac]'.encode('utf8'))
+
+ # following tests comes from #7330
+ # test width modifier and precision modifier with %S
+ check_format("repr= abc",
+ b'repr=%5S', 'abc')
+ check_format("repr=ab",
+ b'repr=%.2S', 'abc')
+ check_format("repr= ab",
+ b'repr=%5.2S', 'abc')
+
+ # test width modifier and precision modifier with %R
+ check_format("repr= 'abc'",
+ b'repr=%8R', 'abc')
+ check_format("repr='ab",
+ b'repr=%.3R', 'abc')
+ check_format("repr= 'ab",
+ b'repr=%5.3R', 'abc')
+
+ # test width modifier and precision modifier with %A
+ check_format("repr= 'abc'",
+ b'repr=%8A', 'abc')
+ check_format("repr='ab",
+ b'repr=%.3A', 'abc')
+ check_format("repr= 'ab",
+ b'repr=%5.3A', 'abc')
+
+ # test width modifier and precision modifier with %s
+ check_format("repr= abc",
+ b'repr=%5s', b'abc')
+ check_format("repr=ab",
+ b'repr=%.2s', b'abc')
+ check_format("repr= ab",
+ b'repr=%5.2s', b'abc')
+
+ # test width modifier and precision modifier with %U
+ check_format("repr= abc",
+ b'repr=%5U', 'abc')
+ check_format("repr=ab",
+ b'repr=%.2U', 'abc')
+ check_format("repr= ab",
+ b'repr=%5.2U', 'abc')
+
+ # test width modifier and precision modifier with %V
+ check_format("repr= abc",
+ b'repr=%5V', 'abc', b'123')
+ check_format("repr=ab",
+ b'repr=%.2V', 'abc', b'123')
+ check_format("repr= ab",
+ b'repr=%5.2V', 'abc', b'123')
+ check_format("repr= 123",
+ b'repr=%5V', None, b'123')
+ check_format("repr=12",
+ b'repr=%.2V', None, b'123')
+ check_format("repr= 12",
+ b'repr=%5.2V', None, b'123')
# test integer formats (%i, %d, %u)
- self.assertEqual(PyUnicode_FromFormat(b'%03i', c_int(10)), '010')
- self.assertEqual(PyUnicode_FromFormat(b'%0.4i', c_int(10)), '0010')
- self.assertEqual(PyUnicode_FromFormat(b'%i', c_int(-123)), '-123')
- self.assertEqual(PyUnicode_FromFormat(b'%li', c_long(-123)), '-123')
- self.assertEqual(PyUnicode_FromFormat(b'%lli', c_longlong(-123)), '-123')
- self.assertEqual(PyUnicode_FromFormat(b'%zi', c_ssize_t(-123)), '-123')
-
- self.assertEqual(PyUnicode_FromFormat(b'%d', c_int(-123)), '-123')
- self.assertEqual(PyUnicode_FromFormat(b'%ld', c_long(-123)), '-123')
- self.assertEqual(PyUnicode_FromFormat(b'%lld', c_longlong(-123)), '-123')
- self.assertEqual(PyUnicode_FromFormat(b'%zd', c_ssize_t(-123)), '-123')
-
- self.assertEqual(PyUnicode_FromFormat(b'%u', c_uint(123)), '123')
- self.assertEqual(PyUnicode_FromFormat(b'%lu', c_ulong(123)), '123')
- self.assertEqual(PyUnicode_FromFormat(b'%llu', c_ulonglong(123)), '123')
- self.assertEqual(PyUnicode_FromFormat(b'%zu', c_size_t(123)), '123')
+ check_format('010',
+ b'%03i', c_int(10))
+ check_format('0010',
+ b'%0.4i', c_int(10))
+ check_format('-123',
+ b'%i', c_int(-123))
+ check_format('-123',
+ b'%li', c_long(-123))
+ check_format('-123',
+ b'%lli', c_longlong(-123))
+ check_format('-123',
+ b'%zi', c_ssize_t(-123))
+
+ check_format('-123',
+ b'%d', c_int(-123))
+ check_format('-123',
+ b'%ld', c_long(-123))
+ check_format('-123',
+ b'%lld', c_longlong(-123))
+ check_format('-123',
+ b'%zd', c_ssize_t(-123))
+
+ check_format('123',
+ b'%u', c_uint(123))
+ check_format('123',
+ b'%lu', c_ulong(123))
+ check_format('123',
+ b'%llu', c_ulonglong(123))
+ check_format('123',
+ b'%zu', c_size_t(123))
+
+ # test long output
+ min_longlong = -(2 ** (8 * sizeof(c_longlong) - 1))
+ max_longlong = -min_longlong - 1
+ check_format(str(min_longlong),
+ b'%lld', c_longlong(min_longlong))
+ check_format(str(max_longlong),
+ b'%lld', c_longlong(max_longlong))
+ max_ulonglong = 2 ** (8 * sizeof(c_ulonglong)) - 1
+ check_format(str(max_ulonglong),
+ b'%llu', c_ulonglong(max_ulonglong))
+ PyUnicode_FromFormat(b'%p', c_void_p(-1))
+
+ # test padding (width and/or precision)
+ check_format('123'.rjust(10, '0'),
+ b'%010i', c_int(123))
+ check_format('123'.rjust(100),
+ b'%100i', c_int(123))
+ check_format('123'.rjust(100, '0'),
+ b'%.100i', c_int(123))
+ check_format('123'.rjust(80, '0').rjust(100),
+ b'%100.80i', c_int(123))
+
+ check_format('123'.rjust(10, '0'),
+ b'%010u', c_uint(123))
+ check_format('123'.rjust(100),
+ b'%100u', c_uint(123))
+ check_format('123'.rjust(100, '0'),
+ b'%.100u', c_uint(123))
+ check_format('123'.rjust(80, '0').rjust(100),
+ b'%100.80u', c_uint(123))
+
+ check_format('123'.rjust(10, '0'),
+ b'%010x', c_int(0x123))
+ check_format('123'.rjust(100),
+ b'%100x', c_int(0x123))
+ check_format('123'.rjust(100, '0'),
+ b'%.100x', c_int(0x123))
+ check_format('123'.rjust(80, '0').rjust(100),
+ b'%100.80x', c_int(0x123))
# test %A
- text = PyUnicode_FromFormat(b'%%A:%A', 'abc\xe9\uabcd\U0010ffff')
- self.assertEqual(text, r"%A:'abc\xe9\uabcd\U0010ffff'")
+ check_format(r"%A:'abc\xe9\uabcd\U0010ffff'",
+ b'%%A:%A', 'abc\xe9\uabcd\U0010ffff')
# test %V
- text = PyUnicode_FromFormat(b'repr=%V', 'abc', b'xyz')
- self.assertEqual(text, 'repr=abc')
+ check_format('repr=abc',
+ b'repr=%V', 'abc', b'xyz')
# Test string decode from parameter of %s using utf-8.
# b'\xe4\xba\xba\xe6\xb0\x91' is utf-8 encoded byte sequence of
# '\u4eba\u6c11'
- text = PyUnicode_FromFormat(b'repr=%V', None, b'\xe4\xba\xba\xe6\xb0\x91')
- self.assertEqual(text, 'repr=\u4eba\u6c11')
+ check_format('repr=\u4eba\u6c11',
+ b'repr=%V', None, b'\xe4\xba\xba\xe6\xb0\x91')
#Test replace error handler.
- text = PyUnicode_FromFormat(b'repr=%V', None, b'abc\xff')
- self.assertEqual(text, 'repr=abc\ufffd')
+ check_format('repr=abc\ufffd',
+ b'repr=%V', None, b'abc\xff')
# not supported: copy the raw format string. these tests are just here
# to check for crashs and should not be considered as specifications
- self.assertEqual(PyUnicode_FromFormat(b'%1%s', b'abc'), '%s')
- self.assertEqual(PyUnicode_FromFormat(b'%1abc'), '%1abc')
- self.assertEqual(PyUnicode_FromFormat(b'%+i', c_int(10)), '%+i')
- self.assertEqual(PyUnicode_FromFormat(b'%.%s', b'abc'), '%.%s')
+ check_format('%s',
+ b'%1%s', b'abc')
+ check_format('%1abc',
+ b'%1abc')
+ check_format('%+i',
+ b'%+i', c_int(10))
+ check_format('%.%s',
+ b'%.%s', b'abc')
# Test PyUnicode_AsWideChar()
def test_aswidechar(self):
@@ -2210,6 +2405,80 @@ class UnicodeTest(string_tests.CommonTest,
self.assertNotEqual(abc, abcdef)
self.assertEqual(abcdef.decode('unicode_internal'), text)
+ def test_compare(self):
+ # Issue #17615
+ N = 10
+ ascii = 'a' * N
+ ascii2 = 'z' * N
+ latin = '\x80' * N
+ latin2 = '\xff' * N
+ bmp = '\u0100' * N
+ bmp2 = '\uffff' * N
+ astral = '\U00100000' * N
+ astral2 = '\U0010ffff' * N
+ strings = (
+ ascii, ascii2,
+ latin, latin2,
+ bmp, bmp2,
+ astral, astral2)
+ for text1, text2 in itertools.combinations(strings, 2):
+ equal = (text1 is text2)
+ self.assertEqual(text1 == text2, equal)
+ self.assertEqual(text1 != text2, not equal)
+
+ if equal:
+ self.assertTrue(text1 <= text2)
+ self.assertTrue(text1 >= text2)
+
+ # text1 is text2: duplicate strings to skip the "str1 == str2"
+ # optimization in unicode_compare_eq() and really compare
+ # character per character
+ copy1 = duplicate_string(text1)
+ copy2 = duplicate_string(text2)
+ self.assertIsNot(copy1, copy2)
+
+ self.assertTrue(copy1 == copy2)
+ self.assertFalse(copy1 != copy2)
+
+ self.assertTrue(copy1 <= copy2)
+ self.assertTrue(copy2 >= copy2)
+
+ self.assertTrue(ascii < ascii2)
+ self.assertTrue(ascii < latin)
+ self.assertTrue(ascii < bmp)
+ self.assertTrue(ascii < astral)
+ self.assertFalse(ascii >= ascii2)
+ self.assertFalse(ascii >= latin)
+ self.assertFalse(ascii >= bmp)
+ self.assertFalse(ascii >= astral)
+
+ self.assertFalse(latin < ascii)
+ self.assertTrue(latin < latin2)
+ self.assertTrue(latin < bmp)
+ self.assertTrue(latin < astral)
+ self.assertTrue(latin >= ascii)
+ self.assertFalse(latin >= latin2)
+ self.assertFalse(latin >= bmp)
+ self.assertFalse(latin >= astral)
+
+ self.assertFalse(bmp < ascii)
+ self.assertFalse(bmp < latin)
+ self.assertTrue(bmp < bmp2)
+ self.assertTrue(bmp < astral)
+ self.assertTrue(bmp >= ascii)
+ self.assertTrue(bmp >= latin)
+ self.assertFalse(bmp >= bmp2)
+ self.assertFalse(bmp >= astral)
+
+ self.assertFalse(astral < ascii)
+ self.assertFalse(astral < latin)
+ self.assertFalse(astral < bmp2)
+ self.assertTrue(astral < astral2)
+ self.assertTrue(astral >= ascii)
+ self.assertTrue(astral >= latin)
+ self.assertTrue(astral >= bmp2)
+ self.assertFalse(astral >= astral2)
+
class StringModuleTest(unittest.TestCase):
def test_formatter_parser(self):
diff --git a/Lib/test/test_unicodedata.py b/Lib/test/test_unicodedata.py
index 99aa0033cd..707b30e50c 100644
--- a/Lib/test/test_unicodedata.py
+++ b/Lib/test/test_unicodedata.py
@@ -21,7 +21,7 @@ errors = 'surrogatepass'
class UnicodeMethodsTest(unittest.TestCase):
# update this, if the database changes
- expectedchecksum = 'bf7a78f1a532421b5033600102e23a92044dbba9'
+ expectedchecksum = 'e74e878de71b6e780ffac271785c3cb58f6251f3'
def test_method_checksum(self):
h = hashlib.sha1()
@@ -80,7 +80,7 @@ class UnicodeDatabaseTest(unittest.TestCase):
class UnicodeFunctionsTest(UnicodeDatabaseTest):
# update this, if the database changes
- expectedchecksum = '17fe2f12b788e4fff5479b469c4404bb6ecf841f'
+ expectedchecksum = 'f0b74d26776331cc7bdc3a4698f037d73f2cee2b'
def test_function_checksum(self):
data = []
h = hashlib.sha1()
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 7a34a05d05..94f640b923 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -16,6 +16,7 @@ from nturl2path import url2pathname, pathname2url
from base64 import b64encode
import collections
+
def hexescape(char):
"""Escape char as RFC 2396 specifies"""
hex_repr = hex(ord(char))[2:].upper()
@@ -238,7 +239,7 @@ class urlopen_HttpTests(unittest.TestCase, FakeHTTPMixin):
self.check_read(b"1.1")
def test_read_bogus(self):
- # urlopen() should raise IOError for many error codes.
+ # urlopen() should raise OSError for many error codes.
self.fakehttp(b'''HTTP/1.1 401 Authentication Required
Date: Wed, 02 Jan 2008 03:03:54 GMT
Server: Apache/1.3.33 (Debian GNU/Linux) mod_ssl/2.8.22 OpenSSL/0.9.7e
@@ -246,12 +247,12 @@ Connection: close
Content-Type: text/html; charset=iso-8859-1
''')
try:
- self.assertRaises(IOError, urlopen, "http://python.org/")
+ self.assertRaises(OSError, urlopen, "http://python.org/")
finally:
self.unfakehttp()
def test_invalid_redirect(self):
- # urlopen() should raise IOError for many error codes.
+ # urlopen() should raise OSError for many error codes.
self.fakehttp(b'''HTTP/1.1 302 Found
Date: Wed, 02 Jan 2008 03:03:54 GMT
Server: Apache/1.3.33 (Debian GNU/Linux) mod_ssl/2.8.22 OpenSSL/0.9.7e
@@ -266,19 +267,20 @@ Content-Type: text/html; charset=iso-8859-1
self.unfakehttp()
def test_empty_socket(self):
- # urlopen() raises IOError if the underlying socket does not send any
+ # urlopen() raises OSError if the underlying socket does not send any
# data. (#1680230)
self.fakehttp(b'')
try:
- self.assertRaises(IOError, urlopen, "http://something")
+ self.assertRaises(OSError, urlopen, "http://something")
finally:
self.unfakehttp()
def test_missing_localfile(self):
# Test for #10836
- # 3.3 - URLError is not captured, explicit IOError is raised.
- with self.assertRaises(IOError):
+ with self.assertRaises(urllib.error.URLError) as e:
urlopen('file://localhost/a/file/which/doesnot/exists.py')
+ self.assertTrue(e.exception.filename)
+ self.assertTrue(e.exception.reason)
def test_file_notexists(self):
fd, tmp_file = tempfile.mkstemp()
@@ -291,20 +293,21 @@ Content-Type: text/html; charset=iso-8859-1
os.close(fd)
os.unlink(tmp_file)
self.assertFalse(os.path.exists(tmp_file))
- # 3.3 - IOError instead of URLError
- with self.assertRaises(IOError):
+ with self.assertRaises(urllib.error.URLError):
urlopen(tmp_fileurl)
def test_ftp_nohost(self):
test_ftp_url = 'ftp:///path'
- # 3.3 - IOError instead of URLError
- with self.assertRaises(IOError):
+ with self.assertRaises(urllib.error.URLError) as e:
urlopen(test_ftp_url)
+ self.assertFalse(e.exception.filename)
+ self.assertTrue(e.exception.reason)
def test_ftp_nonexisting(self):
- # 3.3 - IOError instead of URLError
- with self.assertRaises(IOError):
+ with self.assertRaises(urllib.error.URLError) as e:
urlopen('ftp://localhost/a/file/which/doesnot/exists.py')
+ self.assertFalse(e.exception.filename)
+ self.assertTrue(e.exception.reason)
def test_userpass_inurl(self):
@@ -341,6 +344,79 @@ Content-Type: text/html; charset=iso-8859-1
with support.check_warnings(('',DeprecationWarning)):
urllib.request.URLopener()
+class urlopen_DataTests(unittest.TestCase):
+ """Test urlopen() opening a data URL."""
+
+ def setUp(self):
+ # text containing URL special- and unicode-characters
+ self.text = "test data URLs :;,%=& \u00f6 \u00c4 "
+ # 2x1 pixel RGB PNG image with one black and one white pixel
+ self.image = (
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x02\x00\x00\x00'
+ b'\x01\x08\x02\x00\x00\x00{@\xe8\xdd\x00\x00\x00\x01sRGB\x00\xae'
+ b'\xce\x1c\xe9\x00\x00\x00\x0fIDAT\x08\xd7c```\xf8\xff\xff?\x00'
+ b'\x06\x01\x02\xfe\no/\x1e\x00\x00\x00\x00IEND\xaeB`\x82')
+
+ self.text_url = (
+ "data:text/plain;charset=UTF-8,test%20data%20URLs%20%3A%3B%2C%25%3"
+ "D%26%20%C3%B6%20%C3%84%20")
+ self.text_url_base64 = (
+ "data:text/plain;charset=ISO-8859-1;base64,dGVzdCBkYXRhIFVSTHMgOjs"
+ "sJT0mIPYgxCA%3D")
+ # base64 encoded data URL that contains ignorable spaces,
+ # such as "\n", " ", "%0A", and "%20".
+ self.image_url = (
+ "\n"
+ "QOjdAAAAAXNSR0IArs4c6QAAAA9JREFUCNdj%0AYGBg%2BP//PwAGAQL%2BCm8 "
+ "vHgAAAABJRU5ErkJggg%3D%3D%0A%20")
+
+ self.text_url_resp = urllib.request.urlopen(self.text_url)
+ self.text_url_base64_resp = urllib.request.urlopen(
+ self.text_url_base64)
+ self.image_url_resp = urllib.request.urlopen(self.image_url)
+
+ def test_interface(self):
+ # Make sure object returned by urlopen() has the specified methods
+ for attr in ("read", "readline", "readlines",
+ "close", "info", "geturl", "getcode", "__iter__"):
+ self.assertTrue(hasattr(self.text_url_resp, attr),
+ "object returned by urlopen() lacks %s attribute" %
+ attr)
+
+ def test_info(self):
+ self.assertIsInstance(self.text_url_resp.info(), email.message.Message)
+ self.assertEqual(self.text_url_base64_resp.info().get_params(),
+ [('text/plain', ''), ('charset', 'ISO-8859-1')])
+ self.assertEqual(self.image_url_resp.info()['content-length'],
+ str(len(self.image)))
+ self.assertEqual(urllib.request.urlopen("data:,").info().get_params(),
+ [('text/plain', ''), ('charset', 'US-ASCII')])
+
+ def test_geturl(self):
+ self.assertEqual(self.text_url_resp.geturl(), self.text_url)
+ self.assertEqual(self.text_url_base64_resp.geturl(),
+ self.text_url_base64)
+ self.assertEqual(self.image_url_resp.geturl(), self.image_url)
+
+ def test_read_text(self):
+ self.assertEqual(self.text_url_resp.read().decode(
+ dict(self.text_url_resp.info().get_params())['charset']), self.text)
+
+ def test_read_text_base64(self):
+ self.assertEqual(self.text_url_base64_resp.read().decode(
+ dict(self.text_url_base64_resp.info().get_params())['charset']),
+ self.text)
+
+ def test_read_image(self):
+ self.assertEqual(self.image_url_resp.read(), self.image)
+
+ def test_missing_comma(self):
+ self.assertRaises(ValueError,urllib.request.urlopen,'data:text/plain')
+
+ def test_invalid_base64_data(self):
+ # missing padding character
+ self.assertRaises(ValueError,urllib.request.urlopen,'data:;base64,Cg=')
+
class urlretrieve_FileTests(unittest.TestCase):
"""Test urllib.urlretrieve() on local files"""
@@ -1291,6 +1367,7 @@ class URLopener_Tests(unittest.TestCase):
# self.assertEqual(ftp.ftp.sock.gettimeout(), 30)
# ftp.close()
+
class RequestTests(unittest.TestCase):
"""Unit tests for urllib.request.Request."""
diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py
index 33f90f48ca..dbd1c60ec2 100644
--- a/Lib/test/test_urllib2.py
+++ b/Lib/test/test_urllib2.py
@@ -11,6 +11,7 @@ import urllib.request
# The proxy bypass method imported below has logic specific to the OSX
# proxy config data structure but is testable on all platforms.
from urllib.request import Request, OpenerDirector, _proxy_bypass_macosx_sysconf
+from urllib.parse import urlparse
import urllib.error
# XXX
@@ -118,6 +119,15 @@ class RequestHdrsTests(unittest.TestCase):
self.assertIsNone(req.get_header("Not-there"))
self.assertEqual(req.get_header("Not-there", "default"), "default")
+ req.remove_header("Spam-eggs")
+ self.assertFalse(req.has_header("Spam-eggs"))
+
+ req.add_unredirected_header("Unredirected-spam", "Eggs")
+ self.assertTrue(req.has_header("Unredirected-spam"))
+
+ req.remove_header("Unredirected-spam")
+ self.assertFalse(req.has_header("Unredirected-spam"))
+
def test_password_manager(self):
mgr = urllib.request.HTTPPasswordMgr()
@@ -283,6 +293,7 @@ class MockHTTPClass:
self.req_headers = []
self.data = None
self.raise_on_endheaders = False
+ self.sock = None
self._tunnel_headers = {}
def __call__(self, host, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
@@ -310,8 +321,7 @@ class MockHTTPClass:
if body:
self.data = body
if self.raise_on_endheaders:
- import socket
- raise socket.error()
+ raise OSError()
def getresponse(self):
return MockHTTPResponse(MockFile(), {}, 200, "OK")
@@ -596,27 +606,6 @@ class OpenerDirectorTests(unittest.TestCase):
self.assertTrue(args[1] is None or
isinstance(args[1], MockResponse))
- def test_method_deprecations(self):
- req = Request("http://www.example.com")
-
- with self.assertWarns(DeprecationWarning):
- req.add_data("data")
- with self.assertWarns(DeprecationWarning):
- req.get_data()
- with self.assertWarns(DeprecationWarning):
- req.has_data()
- with self.assertWarns(DeprecationWarning):
- req.get_host()
- with self.assertWarns(DeprecationWarning):
- req.get_selector()
- with self.assertWarns(DeprecationWarning):
- req.is_unverifiable()
- with self.assertWarns(DeprecationWarning):
- req.get_origin_req_host()
- with self.assertWarns(DeprecationWarning):
- req.get_type()
-
-
def sanepathname2url(path):
try:
path.encode("utf-8")
@@ -811,7 +800,7 @@ class HandlerTests(unittest.TestCase):
("Foo", "bar"), ("Spam", "eggs")])
self.assertEqual(http.data, data)
- # check socket.error converted to URLError
+ # check OSError converted to URLError
http.raise_on_endheaders = True
self.assertRaises(urllib.error.URLError, h.do_open, http, req)
@@ -916,6 +905,36 @@ class HandlerTests(unittest.TestCase):
p_ds_req = h.do_request_(ds_req)
self.assertEqual(p_ds_req.unredirected_hdrs["Host"],"example.com")
+ def test_full_url_setter(self):
+ # Checks to ensure that components are set correctly after setting the
+ # full_url of a Request object
+
+ urls = [
+ 'http://example.com?foo=bar#baz',
+ 'http://example.com?foo=bar&spam=eggs#bash',
+ 'http://example.com',
+ ]
+
+ # testing a reusable request instance, but the url parameter is
+ # required, so just use a dummy one to instantiate
+ r = Request('http://example.com')
+ for url in urls:
+ r.full_url = url
+ parsed = urlparse(url)
+
+ self.assertEqual(r.get_full_url(), url)
+ # full_url setter uses splittag to split into components.
+ # splittag sets the fragment as None while urlparse sets it to ''
+ self.assertEqual(r.fragment or '', parsed.fragment)
+ self.assertEqual(urlparse(r.get_full_url()).query, parsed.query)
+
+ def test_full_url_deleter(self):
+ r = Request('http://www.example.com')
+ del r.full_url
+ self.assertIsNone(r.full_url)
+ self.assertIsNone(r.fragment)
+ self.assertEqual(r.selector, '')
+
def test_fixpath_in_weirdurls(self):
# Issue4493: urllib2 to supply '/' when to urls where path does not
# start with'/'
@@ -1415,6 +1434,21 @@ class MiscTests(unittest.TestCase):
self.opener_has_handler(o, MyHTTPHandler)
self.opener_has_handler(o, MyOtherHTTPHandler)
+ @unittest.skipUnless(support.is_resource_enabled('network'),
+ 'test requires network access')
+ def test_issue16464(self):
+ opener = urllib.request.build_opener()
+ request = urllib.request.Request("http://www.python.org/~jeremy/")
+ self.assertEqual(None, request.data)
+
+ opener.open(request, "1".encode("us-ascii"))
+ self.assertEqual(b"1", request.data)
+ self.assertEqual("1", request.get_header("Content-length"))
+
+ opener.open(request, "1234567890".encode("us-ascii"))
+ self.assertEqual(b"1234567890", request.data)
+ self.assertEqual("10", request.get_header("Content-length"))
+
def test_HTTPError_interface(self):
"""
Issue 13211 reveals that HTTPError didn't implement the URLError
@@ -1426,23 +1460,31 @@ class MiscTests(unittest.TestCase):
err = urllib.error.HTTPError(url, code, msg, hdrs, fp)
self.assertTrue(hasattr(err, 'reason'))
self.assertEqual(err.reason, 'something bad happened')
- self.assertTrue(hasattr(err, 'hdrs'))
- self.assertEqual(err.hdrs, 'Content-Length: 42')
+ self.assertTrue(hasattr(err, 'headers'))
+ self.assertEqual(err.headers, 'Content-Length: 42')
expected_errmsg = 'HTTP Error %s: %s' % (err.code, err.msg)
self.assertEqual(str(err), expected_errmsg)
-
class RequestTests(unittest.TestCase):
+ class PutRequest(Request):
+ method='PUT'
def setUp(self):
self.get = Request("http://www.python.org/~jeremy/")
self.post = Request("http://www.python.org/~jeremy/",
"data",
headers={"X-Test": "test"})
+ self.head = Request("http://www.python.org/~jeremy/", method='HEAD')
+ self.put = self.PutRequest("http://www.python.org/~jeremy/")
+ self.force_post = self.PutRequest("http://www.python.org/~jeremy/",
+ method="POST")
def test_method(self):
self.assertEqual("POST", self.post.get_method())
self.assertEqual("GET", self.get.get_method())
+ self.assertEqual("HEAD", self.head.get_method())
+ self.assertEqual("PUT", self.put.get_method())
+ self.assertEqual("POST", self.force_post.get_method())
def test_data(self):
self.assertFalse(self.get.data)
@@ -1451,6 +1493,25 @@ class RequestTests(unittest.TestCase):
self.assertTrue(self.get.data)
self.assertEqual("POST", self.get.get_method())
+ # issue 16464
+ # if we change data we need to remove content-length header
+ # (cause it's most probably calculated for previous value)
+ def test_setting_data_should_remove_content_length(self):
+ self.assertNotIn("Content-length", self.get.unredirected_hdrs)
+ self.get.add_unredirected_header("Content-length", 42)
+ self.assertEqual(42, self.get.unredirected_hdrs["Content-length"])
+ self.get.data = "spam"
+ self.assertNotIn("Content-length", self.get.unredirected_hdrs)
+
+ # issue 17485 same for deleting data.
+ def test_deleting_data_should_remove_content_length(self):
+ self.assertNotIn("Content-length", self.get.unredirected_hdrs)
+ self.get.data = 'foo'
+ self.get.add_unredirected_header("Content-length", 3)
+ self.assertEqual(3, self.get.unredirected_hdrs["Content-length"])
+ del self.get.data
+ self.assertNotIn("Content-length", self.get.unredirected_hdrs)
+
def test_get_full_url(self):
self.assertEqual("http://www.python.org/~jeremy/",
self.get.get_full_url())
@@ -1492,21 +1553,13 @@ class RequestTests(unittest.TestCase):
req = Request(url)
self.assertEqual(req.get_full_url(), url)
- def test_HTTPError_interface_call(self):
- """
- Issue 15701 - HTTPError interface has info method available from URLError
- """
- err = urllib.request.HTTPError(msg="something bad happened", url=None,
- code=None, hdrs='Content-Length:42', fp=None)
- self.assertTrue(hasattr(err, 'reason'))
- assert hasattr(err, 'reason')
- assert hasattr(err, 'info')
- assert callable(err.info)
- try:
- err.info()
- except AttributeError:
- self.fail('err.info call failed.')
- self.assertEqual(err.info(), "Content-Length:42")
+ def test_url_fullurl_get_full_url(self):
+ urls = ['http://docs.python.org',
+ 'http://docs.python.org/library/urllib2.html#OK',
+ 'http://www.python.org/?qs=query#fragment=true' ]
+ for url in urls:
+ req = Request(url)
+ self.assertEqual(req.get_full_url(), req.full_url)
def test_main(verbose=None):
from test import test_urllib2
diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py
index b1aa158455..08250c36e5 100644
--- a/Lib/test/test_urllib2_localnet.py
+++ b/Lib/test/test_urllib2_localnet.py
@@ -9,7 +9,10 @@ import unittest
import hashlib
from test import support
threading = support.import_module('threading')
-
+try:
+ import ssl
+except ImportError:
+ ssl = None
here = os.path.dirname(__file__)
# Self-signed cert file for 'localhost'
@@ -17,6 +20,7 @@ CERT_localhost = os.path.join(here, 'keycert.pem')
# Self-signed cert file for 'fakehostname'
CERT_fakehostname = os.path.join(here, 'keycert2.pem')
+
# Loopback http server infrastructure
class LoopbackHttpServer(http.server.HTTPServer):
@@ -353,12 +357,15 @@ class TestUrlopen(unittest.TestCase):
def setUp(self):
super(TestUrlopen, self).setUp()
# Ignore proxies for localhost tests.
+ self.old_environ = os.environ.copy()
os.environ['NO_PROXY'] = '*'
self.server = None
def tearDown(self):
if self.server is not None:
self.server.stop()
+ os.environ.clear()
+ os.environ.update(self.old_environ)
super(TestUrlopen, self).tearDown()
def urlopen(self, url, data=None, **kwargs):
@@ -386,14 +393,14 @@ class TestUrlopen(unittest.TestCase):
handler.port = port
return handler
- def start_https_server(self, responses=None, certfile=CERT_localhost):
+ def start_https_server(self, responses=None, **kwargs):
if not hasattr(urllib.request, 'HTTPSHandler'):
self.skipTest('ssl support required')
from test.ssl_servers import make_https_server
if responses is None:
responses = [(200, [], b"we care a bit")]
handler = GetRequestHandler(responses)
- server = make_https_server(self, certfile=certfile, handler_class=handler)
+ server = make_https_server(self, handler_class=handler, **kwargs)
handler.port = server.port
return handler
@@ -483,6 +490,21 @@ class TestUrlopen(unittest.TestCase):
self.urlopen("https://localhost:%s/bizarre" % handler.port,
cadefault=True)
+ def test_https_sni(self):
+ if ssl is None:
+ self.skipTest("ssl module required")
+ if not ssl.HAS_SNI:
+ self.skipTest("SNI support required in OpenSSL")
+ sni_name = None
+ def cb_sni(ssl_sock, server_name, initial_context):
+ nonlocal sni_name
+ sni_name = server_name
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.set_servername_callback(cb_sni)
+ handler = self.start_https_server(context=context, certfile=CERT_localhost)
+ self.urlopen("https://localhost:%s" % handler.port)
+ self.assertEqual(sni_name, "localhost")
+
def test_sending_headers(self):
handler = self.start_server()
req = urllib.request.Request("http://localhost:%s/" % handler.port,
@@ -529,7 +551,7 @@ class TestUrlopen(unittest.TestCase):
# so we run the test only when -unetwork/-uall is specified to
# mitigate the problem a bit (see #17564)
support.requires('network')
- self.assertRaises(IOError,
+ self.assertRaises(OSError,
# Given that both VeriSign and various ISPs have in
# the past or are presently hijacking various invalid
# domain name requests in an attempt to boost traffic
diff --git a/Lib/test/test_urllib2net.py b/Lib/test/test_urllib2net.py
index 7f3c93adaa..fba3ceac83 100644
--- a/Lib/test/test_urllib2net.py
+++ b/Lib/test/test_urllib2net.py
@@ -164,6 +164,14 @@ class OtherNetworkTests(unittest.TestCase):
self.assertEqual(res.geturl(),
"http://docs.python.org/2/glossary.html#glossary")
+ def test_redirect_url_withfrag(self):
+ redirect_url_with_frag = "http://bitly.com/urllibredirecttest"
+ with support.transient_internet(redirect_url_with_frag):
+ req = urllib.request.Request(redirect_url_with_frag)
+ res = urllib.request.urlopen(req)
+ self.assertEqual(res.geturl(),
+ "http://docs.python.org/3.4/glossary.html#term-global-interpreter-lock")
+
def test_custom_headers(self):
url = "http://www.example.com"
with support.transient_internet(url):
@@ -216,7 +224,7 @@ class OtherNetworkTests(unittest.TestCase):
debug(url)
try:
f = urlopen(url, req, TIMEOUT)
- except EnvironmentError as err:
+ except OSError as err:
debug(err)
if expected_err:
msg = ("Didn't get expected error(s) %s for %s %s, got %s: %s" %
@@ -330,35 +338,6 @@ class TimeoutTest(unittest.TestCase):
self.assertEqual(u.fp.fp.raw._sock.gettimeout(), 60)
-@unittest.skipUnless(ssl, "requires SSL support")
-class HTTPSTests(unittest.TestCase):
-
- def test_sni(self):
- self.skipTest("test disabled - test server needed")
- # Checks that Server Name Indication works, if supported by the
- # OpenSSL linked to.
- # The ssl module itself doesn't have server-side support for SNI,
- # so we rely on a third-party test site.
- expect_sni = ssl.HAS_SNI
- with support.transient_internet("XXX"):
- u = urllib.request.urlopen("XXX")
- contents = u.readall()
- if expect_sni:
- self.assertIn(b"Great", contents)
- self.assertNotIn(b"Unfortunately", contents)
- else:
- self.assertNotIn(b"Great", contents)
- self.assertIn(b"Unfortunately", contents)
-
-
-def test_main():
- support.requires("network")
- support.run_unittest(AuthTests,
- HTTPSTests,
- OtherNetworkTests,
- CloseSocketTest,
- TimeoutTest,
- )
-
if __name__ == "__main__":
- test_main()
+ support.requires("network")
+ unittest.main()
diff --git a/Lib/test/test_urllibnet.py b/Lib/test/test_urllibnet.py
index 20efca6adb..b6888f3d7b 100644
--- a/Lib/test/test_urllibnet.py
+++ b/Lib/test/test_urllibnet.py
@@ -124,16 +124,15 @@ class urlopenNetworkTests(unittest.TestCase):
else:
# This happens with some overzealous DNS providers such as OpenDNS
self.skipTest("%r should not resolve for test to work" % bogus_domain)
- self.assertRaises(IOError,
- # SF patch 809915: In Sep 2003, VeriSign started
- # highjacking invalid .com and .net addresses to
- # boost traffic to their own site. This test
- # started failing then. One hopes the .invalid
- # domain will be spared to serve its defined
- # purpose.
- # urllib.urlopen, "http://www.sadflkjsasadf.com/")
- urllib.request.urlopen,
- "http://sadflkjsasf.i.nvali.d/")
+ failure_explanation = ('opening an invalid URL did not raise OSError; '
+ 'can be caused by a broken DNS server '
+ '(e.g. returns 404 or hijacks page)')
+ with self.assertRaises(OSError, msg=failure_explanation):
+ # SF patch 809915: In Sep 2003, VeriSign started highjacking
+ # invalid .com and .net addresses to boost traffic to their own
+ # site. This test started failing then. One hopes the .invalid
+ # domain will be spared to serve its defined purpose.
+ urllib.request.urlopen("http://sadflkjsasf.i.nvali.d/")
class urlretrieveNetworkTests(unittest.TestCase):
diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py
index 378a427bc5..c938f09951 100755
--- a/Lib/test/test_urlparse.py
+++ b/Lib/test/test_urlparse.py
@@ -847,6 +847,14 @@ class UrlParseTestCase(unittest.TestCase):
self.assertEqual(p1.path, '863-1234')
self.assertEqual(p1.params, 'phone-context=+1-914-555')
+ def test_unwrap(self):
+ url = urllib.parse.unwrap('<URL:type://host/path>')
+ self.assertEqual(url, 'type://host/path')
+
+ def test_Quoter_repr(self):
+ quoter = urllib.parse.Quoter(urllib.parse._ALWAYS_SAFE)
+ self.assertIn('Quoter', repr(quoter))
+
def test_main():
support.run_unittest(UrlParseTestCase)
diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py
index 0fae88bcf9..dbbe1570c9 100644
--- a/Lib/test/test_venv.py
+++ b/Lib/test/test_venv.py
@@ -109,13 +109,68 @@ class BasicTest(BaseTest):
out, err = p.communicate()
self.assertEqual(out.strip(), expected.encode())
+ if sys.platform == 'win32':
+ ENV_SUBDIRS = (
+ ('Scripts',),
+ ('Include',),
+ ('Lib',),
+ ('Lib', 'site-packages'),
+ )
+ else:
+ ENV_SUBDIRS = (
+ ('bin',),
+ ('include',),
+ ('lib',),
+ ('lib', 'python%d.%d' % sys.version_info[:2]),
+ ('lib', 'python%d.%d' % sys.version_info[:2], 'site-packages'),
+ )
+
+ def create_contents(self, paths, filename):
+ """
+ Create some files in the environment which are unrelated
+ to the virtual environment.
+ """
+ for subdirs in paths:
+ d = os.path.join(self.env_dir, *subdirs)
+ os.mkdir(d)
+ fn = os.path.join(d, filename)
+ with open(fn, 'wb') as f:
+ f.write(b'Still here?')
+
def test_overwrite_existing(self):
"""
- Test control of overwriting an existing environment directory.
+ Test creating environment in an existing directory.
"""
- self.assertRaises(ValueError, venv.create, self.env_dir)
+ self.create_contents(self.ENV_SUBDIRS, 'foo')
+ venv.create(self.env_dir)
+ for subdirs in self.ENV_SUBDIRS:
+ fn = os.path.join(self.env_dir, *(subdirs + ('foo',)))
+ self.assertTrue(os.path.exists(fn))
+ with open(fn, 'rb') as f:
+ self.assertEqual(f.read(), b'Still here?')
+
builder = venv.EnvBuilder(clear=True)
builder.create(self.env_dir)
+ for subdirs in self.ENV_SUBDIRS:
+ fn = os.path.join(self.env_dir, *(subdirs + ('foo',)))
+ self.assertFalse(os.path.exists(fn))
+
+ def clear_directory(self, path):
+ for fn in os.listdir(path):
+ fn = os.path.join(path, fn)
+ if os.path.islink(fn) or os.path.isfile(fn):
+ os.remove(fn)
+ elif os.path.isdir(fn):
+ shutil.rmtree(fn)
+
+ def test_unoverwritable_fails(self):
+ #create a file clashing with directories in the env dir
+ for paths in self.ENV_SUBDIRS[:3]:
+ fn = os.path.join(self.env_dir, *paths)
+ with open(fn, 'wb') as f:
+ f.write(b'')
+ self.assertRaises((ValueError, OSError), venv.create, self.env_dir)
+ self.clear_directory(self.env_dir)
def test_upgrade(self):
"""
diff --git a/Lib/test/test_wait3.py b/Lib/test/test_wait3.py
index bd06c8d8bd..f6a065d850 100644
--- a/Lib/test/test_wait3.py
+++ b/Lib/test/test_wait3.py
@@ -7,15 +7,11 @@ import unittest
from test.fork_wait import ForkWait
from test.support import run_unittest, reap_children
-try:
- os.fork
-except AttributeError:
- raise unittest.SkipTest("os.fork not defined -- skipping test_wait3")
+if not hasattr(os, 'fork'):
+ raise unittest.SkipTest("os.fork not defined")
-try:
- os.wait3
-except AttributeError:
- raise unittest.SkipTest("os.wait3 not defined -- skipping test_wait3")
+if not hasattr(os, 'wait3'):
+ raise unittest.SkipTest("os.wait3 not defined")
class Wait3Test(ForkWait):
def wait_impl(self, cpid):
diff --git a/Lib/test/test_warnings.py b/Lib/test/test_warnings.py
index 9f6d775c59..87463ac1a3 100644
--- a/Lib/test/test_warnings.py
+++ b/Lib/test/test_warnings.py
@@ -330,6 +330,19 @@ class WarnTests(BaseTest):
warning_tests.__name__ = module_name
sys.argv = argv
+ def test_warn_explicit_non_ascii_filename(self):
+ with original_warnings.catch_warnings(record=True,
+ module=self.module) as w:
+ self.module.resetwarnings()
+ self.module.filterwarnings("always", category=UserWarning)
+ for filename in ("nonascii\xe9\u20ac", "surrogate\udc80"):
+ try:
+ os.fsencode(filename)
+ except UnicodeEncodeError:
+ continue
+ self.module.warn_explicit("text", UserWarning, filename, 1)
+ self.assertEqual(w[-1].filename, filename)
+
def test_warn_explicit_type_errors(self):
# warn_explicit() should error out gracefully if it is given objects
# of the wrong types.
@@ -787,6 +800,25 @@ class BootstrapTest(unittest.TestCase):
env=env)
self.assertEqual(retcode, 0)
+class FinalizationTest(unittest.TestCase):
+ def test_finalization(self):
+ # Issue #19421: warnings.warn() should not crash
+ # during Python finalization
+ code = """
+import warnings
+warn = warnings.warn
+
+class A:
+ def __del__(self):
+ warn("test")
+
+a=A()
+ """
+ rc, out, err = assert_python_ok("-c", code)
+ # note: "__main__" filename is not correct, it should be the name
+ # of the script
+ self.assertEqual(err, b'__main__:7: UserWarning: test')
+
def setUpModule():
py_warnings.onceregistry.clear()
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index 571e33f492..551d95cb91 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -7,11 +7,15 @@ import operator
import contextlib
import copy
-from test import support
+from test import support, script_helper
# Used in ReferencesTestCase.test_ref_created_during_del() .
ref_from_del = None
+# Used by FinalizeTestCase as a global that may be replaced by None
+# when the interpreter shuts down.
+_global_var = 'foobar'
+
class C:
def method(self):
pass
@@ -47,6 +51,11 @@ class Object:
return NotImplemented
def __hash__(self):
return hash(self.arg)
+ def some_method(self):
+ return 4
+ def other_method(self):
+ return 5
+
class RefCycle:
def __init__(self):
@@ -797,6 +806,30 @@ class ReferencesTestCase(TestBase):
del root
gc.collect()
+ def test_callback_attribute(self):
+ x = Object(1)
+ callback = lambda ref: None
+ ref1 = weakref.ref(x, callback)
+ self.assertIs(ref1.__callback__, callback)
+
+ ref2 = weakref.ref(x)
+ self.assertIsNone(ref2.__callback__)
+
+ def test_callback_attribute_after_deletion(self):
+ x = Object(1)
+ ref = weakref.ref(x, self.callback)
+ self.assertIsNotNone(ref.__callback__)
+ del x
+ support.gc_collect()
+ self.assertIsNone(ref.__callback__)
+
+ def test_set_callback_attribute(self):
+ x = Object(1)
+ callback = lambda ref: None
+ ref1 = weakref.ref(x, callback)
+ with self.assertRaises(AttributeError):
+ ref1.__callback__ = lambda ref: None
+
class SubclassableWeakrefTestCase(TestBase):
@@ -901,6 +934,140 @@ class SubclassableWeakrefTestCase(TestBase):
self.assertEqual(self.cbcalled, 0)
+class WeakMethodTestCase(unittest.TestCase):
+
+ def _subclass(self):
+ """Return a Object subclass overriding `some_method`."""
+ class C(Object):
+ def some_method(self):
+ return 6
+ return C
+
+ def test_alive(self):
+ o = Object(1)
+ r = weakref.WeakMethod(o.some_method)
+ self.assertIsInstance(r, weakref.ReferenceType)
+ self.assertIsInstance(r(), type(o.some_method))
+ self.assertIs(r().__self__, o)
+ self.assertIs(r().__func__, o.some_method.__func__)
+ self.assertEqual(r()(), 4)
+
+ def test_object_dead(self):
+ o = Object(1)
+ r = weakref.WeakMethod(o.some_method)
+ del o
+ gc.collect()
+ self.assertIs(r(), None)
+
+ def test_method_dead(self):
+ C = self._subclass()
+ o = C(1)
+ r = weakref.WeakMethod(o.some_method)
+ del C.some_method
+ gc.collect()
+ self.assertIs(r(), None)
+
+ def test_callback_when_object_dead(self):
+ # Test callback behaviour when object dies first.
+ C = self._subclass()
+ calls = []
+ def cb(arg):
+ calls.append(arg)
+ o = C(1)
+ r = weakref.WeakMethod(o.some_method, cb)
+ del o
+ gc.collect()
+ self.assertEqual(calls, [r])
+ # Callback is only called once.
+ C.some_method = Object.some_method
+ gc.collect()
+ self.assertEqual(calls, [r])
+
+ def test_callback_when_method_dead(self):
+ # Test callback behaviour when method dies first.
+ C = self._subclass()
+ calls = []
+ def cb(arg):
+ calls.append(arg)
+ o = C(1)
+ r = weakref.WeakMethod(o.some_method, cb)
+ del C.some_method
+ gc.collect()
+ self.assertEqual(calls, [r])
+ # Callback is only called once.
+ del o
+ gc.collect()
+ self.assertEqual(calls, [r])
+
+ @support.cpython_only
+ def test_no_cycles(self):
+ # A WeakMethod doesn't create any reference cycle to itself.
+ o = Object(1)
+ def cb(_):
+ pass
+ r = weakref.WeakMethod(o.some_method, cb)
+ wr = weakref.ref(r)
+ del r
+ self.assertIs(wr(), None)
+
+ def test_equality(self):
+ def _eq(a, b):
+ self.assertTrue(a == b)
+ self.assertFalse(a != b)
+ def _ne(a, b):
+ self.assertTrue(a != b)
+ self.assertFalse(a == b)
+ x = Object(1)
+ y = Object(1)
+ a = weakref.WeakMethod(x.some_method)
+ b = weakref.WeakMethod(y.some_method)
+ c = weakref.WeakMethod(x.other_method)
+ d = weakref.WeakMethod(y.other_method)
+ # Objects equal, same method
+ _eq(a, b)
+ _eq(c, d)
+ # Objects equal, different method
+ _ne(a, c)
+ _ne(a, d)
+ _ne(b, c)
+ _ne(b, d)
+ # Objects unequal, same or different method
+ z = Object(2)
+ e = weakref.WeakMethod(z.some_method)
+ f = weakref.WeakMethod(z.other_method)
+ _ne(a, e)
+ _ne(a, f)
+ _ne(b, e)
+ _ne(b, f)
+ del x, y, z
+ gc.collect()
+ # Dead WeakMethods compare by identity
+ refs = a, b, c, d, e, f
+ for q in refs:
+ for r in refs:
+ self.assertEqual(q == r, q is r)
+ self.assertEqual(q != r, q is not r)
+
+ def test_hashing(self):
+ # Alive WeakMethods are hashable if the underlying object is
+ # hashable.
+ x = Object(1)
+ y = Object(1)
+ a = weakref.WeakMethod(x.some_method)
+ b = weakref.WeakMethod(y.some_method)
+ c = weakref.WeakMethod(y.other_method)
+ # Since WeakMethod objects are equal, the hashes should be equal.
+ self.assertEqual(hash(a), hash(b))
+ ha = hash(a)
+ # Dead WeakMethods retain their old hash value
+ del x, y
+ gc.collect()
+ self.assertEqual(hash(a), ha)
+ self.assertEqual(hash(b), ha)
+ # If it wasn't hashed when alive, a dead WeakMethod cannot be hashed.
+ self.assertRaises(TypeError, hash, c)
+
+
class MappingTestCase(TestBase):
COUNT = 10
@@ -1388,6 +1555,151 @@ class WeakKeyDictionaryTestCase(mapping_tests.BasicTestMappingProtocol):
def _reference(self):
return self.__ref.copy()
+
+class FinalizeTestCase(unittest.TestCase):
+
+ class A:
+ pass
+
+ def _collect_if_necessary(self):
+ # we create no ref-cycles so in CPython no gc should be needed
+ if sys.implementation.name != 'cpython':
+ support.gc_collect()
+
+ def test_finalize(self):
+ def add(x,y,z):
+ res.append(x + y + z)
+ return x + y + z
+
+ a = self.A()
+
+ res = []
+ f = weakref.finalize(a, add, 67, 43, z=89)
+ self.assertEqual(f.alive, True)
+ self.assertEqual(f.peek(), (a, add, (67,43), {'z':89}))
+ self.assertEqual(f(), 199)
+ self.assertEqual(f(), None)
+ self.assertEqual(f(), None)
+ self.assertEqual(f.peek(), None)
+ self.assertEqual(f.detach(), None)
+ self.assertEqual(f.alive, False)
+ self.assertEqual(res, [199])
+
+ res = []
+ f = weakref.finalize(a, add, 67, 43, 89)
+ self.assertEqual(f.peek(), (a, add, (67,43,89), {}))
+ self.assertEqual(f.detach(), (a, add, (67,43,89), {}))
+ self.assertEqual(f(), None)
+ self.assertEqual(f(), None)
+ self.assertEqual(f.peek(), None)
+ self.assertEqual(f.detach(), None)
+ self.assertEqual(f.alive, False)
+ self.assertEqual(res, [])
+
+ res = []
+ f = weakref.finalize(a, add, x=67, y=43, z=89)
+ del a
+ self._collect_if_necessary()
+ self.assertEqual(f(), None)
+ self.assertEqual(f(), None)
+ self.assertEqual(f.peek(), None)
+ self.assertEqual(f.detach(), None)
+ self.assertEqual(f.alive, False)
+ self.assertEqual(res, [199])
+
+ def test_order(self):
+ a = self.A()
+ res = []
+
+ f1 = weakref.finalize(a, res.append, 'f1')
+ f2 = weakref.finalize(a, res.append, 'f2')
+ f3 = weakref.finalize(a, res.append, 'f3')
+ f4 = weakref.finalize(a, res.append, 'f4')
+ f5 = weakref.finalize(a, res.append, 'f5')
+
+ # make sure finalizers can keep themselves alive
+ del f1, f4
+
+ self.assertTrue(f2.alive)
+ self.assertTrue(f3.alive)
+ self.assertTrue(f5.alive)
+
+ self.assertTrue(f5.detach())
+ self.assertFalse(f5.alive)
+
+ f5() # nothing because previously unregistered
+ res.append('A')
+ f3() # => res.append('f3')
+ self.assertFalse(f3.alive)
+ res.append('B')
+ f3() # nothing because previously called
+ res.append('C')
+ del a
+ self._collect_if_necessary()
+ # => res.append('f4')
+ # => res.append('f2')
+ # => res.append('f1')
+ self.assertFalse(f2.alive)
+ res.append('D')
+ f2() # nothing because previously called by gc
+
+ expected = ['A', 'f3', 'B', 'C', 'f4', 'f2', 'f1', 'D']
+ self.assertEqual(res, expected)
+
+ def test_all_freed(self):
+ # we want a weakrefable subclass of weakref.finalize
+ class MyFinalizer(weakref.finalize):
+ pass
+
+ a = self.A()
+ res = []
+ def callback():
+ res.append(123)
+ f = MyFinalizer(a, callback)
+
+ wr_callback = weakref.ref(callback)
+ wr_f = weakref.ref(f)
+ del callback, f
+
+ self.assertIsNotNone(wr_callback())
+ self.assertIsNotNone(wr_f())
+
+ del a
+ self._collect_if_necessary()
+
+ self.assertIsNone(wr_callback())
+ self.assertIsNone(wr_f())
+ self.assertEqual(res, [123])
+
+ @classmethod
+ def run_in_child(cls):
+ def error():
+ # Create an atexit finalizer from inside a finalizer called
+ # at exit. This should be the next to be run.
+ g1 = weakref.finalize(cls, print, 'g1')
+ print('f3 error')
+ 1/0
+
+ # cls should stay alive till atexit callbacks run
+ f1 = weakref.finalize(cls, print, 'f1', _global_var)
+ f2 = weakref.finalize(cls, print, 'f2', _global_var)
+ f3 = weakref.finalize(cls, error)
+ f4 = weakref.finalize(cls, print, 'f4', _global_var)
+
+ assert f1.atexit == True
+ f2.atexit = False
+ assert f3.atexit == True
+ assert f4.atexit == True
+
+ def test_atexit(self):
+ prog = ('from test.test_weakref import FinalizeTestCase;'+
+ 'FinalizeTestCase.run_in_child()')
+ rc, out, err = script_helper.assert_python_ok('-c', prog)
+ out = out.decode('ascii').splitlines()
+ self.assertEqual(out, ['f4 foobar', 'f3 error', 'g1', 'f1 foobar'])
+ self.assertTrue(b'ZeroDivisionError' in err)
+
+
libreftest = """ Doctest for examples in the library reference: weakref.rst
>>> import weakref
@@ -1476,10 +1788,12 @@ __test__ = {'libreftest' : libreftest}
def test_main():
support.run_unittest(
ReferencesTestCase,
+ WeakMethodTestCase,
MappingTestCase,
WeakValueDictionaryTestCase,
WeakKeyDictionaryTestCase,
SubclassableWeakrefTestCase,
+ FinalizeTestCase,
)
support.run_doctest(sys.modules[__name__])
diff --git a/Lib/test/test_winreg.py b/Lib/test/test_winreg.py
index cb4cde9bc6..ef4ce552f1 100644
--- a/Lib/test/test_winreg.py
+++ b/Lib/test/test_winreg.py
@@ -8,7 +8,7 @@ threading = support.import_module("threading")
from platform import machine
# Do this first so test will be skipped if module doesn't exist
-support.import_module('winreg')
+support.import_module('winreg', required_on=['win'])
# Now import everything
from winreg import *
@@ -57,13 +57,13 @@ class BaseWinregTests(unittest.TestCase):
def delete_tree(self, root, subkey):
try:
hkey = OpenKey(root, subkey, KEY_ALL_ACCESS)
- except WindowsError:
+ except OSError:
# subkey does not exist
return
while True:
try:
subsubkey = EnumKey(hkey, 0)
- except WindowsError:
+ except OSError:
# no more subkeys
break
self.delete_tree(hkey, subsubkey)
@@ -100,7 +100,7 @@ class BaseWinregTests(unittest.TestCase):
QueryInfoKey(int_sub_key)
self.fail("It appears the CloseKey() function does "
"not close the actual key!")
- except EnvironmentError:
+ except OSError:
pass
# ... and close that key that way :-)
int_key = int(key)
@@ -109,7 +109,7 @@ class BaseWinregTests(unittest.TestCase):
QueryInfoKey(int_key)
self.fail("It appears the key.Close() function "
"does not close the actual key!")
- except EnvironmentError:
+ except OSError:
pass
def _read_test_data(self, root_key, subkeystr="sub_key", OpenKey=OpenKey):
@@ -126,7 +126,7 @@ class BaseWinregTests(unittest.TestCase):
while 1:
try:
data = EnumValue(sub_key, index)
- except EnvironmentError:
+ except OSError:
break
self.assertEqual(data in test_data, True,
"Didn't read back the correct test data")
@@ -147,7 +147,7 @@ class BaseWinregTests(unittest.TestCase):
try:
EnumKey(key, 1)
self.fail("Was able to get a second key when I only have one!")
- except EnvironmentError:
+ except OSError:
pass
key.Close()
@@ -171,7 +171,7 @@ class BaseWinregTests(unittest.TestCase):
# Shouldnt be able to delete it twice!
DeleteKey(key, subkeystr)
self.fail("Deleting the key twice succeeded")
- except EnvironmentError:
+ except OSError:
pass
key.Close()
DeleteKey(root_key, test_key_name)
@@ -179,7 +179,7 @@ class BaseWinregTests(unittest.TestCase):
try:
key = OpenKey(root_key, test_key_name)
self.fail("Could open the non-existent key")
- except WindowsError: # Use this error name this time
+ except OSError: # Use this error name this time
pass
def _test_all(self, root_key, subkeystr="sub_key"):
@@ -230,7 +230,7 @@ class LocalWinregTests(BaseWinregTests):
def test_inexistant_remote_registry(self):
connect = lambda: ConnectRegistry("abcdefghijkl", HKEY_CURRENT_USER)
- self.assertRaises(WindowsError, connect)
+ self.assertRaises(OSError, connect)
def testExpandEnvironmentStrings(self):
r = ExpandEnvironmentStrings("%windir%\\test")
@@ -242,8 +242,8 @@ class LocalWinregTests(BaseWinregTests):
try:
with ConnectRegistry(None, HKEY_LOCAL_MACHINE) as h:
self.assertNotEqual(h.handle, 0)
- raise WindowsError
- except WindowsError:
+ raise OSError
+ except OSError:
self.assertEqual(h.handle, 0)
def test_changing_value(self):
@@ -407,7 +407,7 @@ class Win64WinregTests(BaseWinregTests):
open_fail = lambda: OpenKey(HKEY_CURRENT_USER,
test_reflect_key_name, 0,
KEY_READ | KEY_WOW64_64KEY)
- self.assertRaises(WindowsError, open_fail)
+ self.assertRaises(OSError, open_fail)
# Now explicitly open the 64-bit version of the key
with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0,
@@ -447,7 +447,7 @@ class Win64WinregTests(BaseWinregTests):
open_fail = lambda: OpenKeyEx(HKEY_CURRENT_USER,
test_reflect_key_name, 0,
KEY_READ | KEY_WOW64_64KEY)
- self.assertRaises(WindowsError, open_fail)
+ self.assertRaises(OSError, open_fail)
# Make sure the 32-bit key is actually there
with OpenKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0,
diff --git a/Lib/test/test_winsound.py b/Lib/test/test_winsound.py
index eb7f75f066..61d864a648 100644
--- a/Lib/test/test_winsound.py
+++ b/Lib/test/test_winsound.py
@@ -22,7 +22,7 @@ def has_sound(sound):
key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER,
"AppEvents\Schemes\Apps\.Default\{0}\.Default".format(sound))
return winreg.EnumValue(key, 0)[1] != ""
- except WindowsError:
+ except OSError:
return False
class BeepTest(unittest.TestCase):
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 54965345c7..614e598f6c 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -10,6 +10,7 @@ import io
import operator
import pickle
import sys
+import types
import unittest
import weakref
@@ -240,7 +241,6 @@ class ElementTreeTest(unittest.TestCase):
self.assertEqual(ET.XML, ET.fromstring)
self.assertEqual(ET.PI, ET.ProcessingInstruction)
- self.assertEqual(ET.XMLParser, ET.XMLTreeBuilder)
def test_simpleops(self):
# Basic method sanity checks.
@@ -433,15 +433,6 @@ class ElementTreeTest(unittest.TestCase):
' <empty-element />\n'
'</root>')
- parser = ET.XMLTreeBuilder() # 1.2 compatibility
- parser.feed(data)
- self.serialize_check(parser.close(),
- '<root>\n'
- ' <element key="value">text</element>\n'
- ' <element>text</element>tail\n'
- ' <empty-element />\n'
- '</root>')
-
target = ET.TreeBuilder()
parser = ET.XMLParser(target=target)
parser.feed(data)
@@ -959,6 +950,160 @@ class ElementTreeTest(unittest.TestCase):
self.assertEqual(serialized, expected)
+class XMLPullParserTest(unittest.TestCase):
+
+ def _feed(self, parser, data, chunk_size=None):
+ if chunk_size is None:
+ parser.feed(data)
+ else:
+ for i in range(0, len(data), chunk_size):
+ parser.feed(data[i:i+chunk_size])
+
+ def assert_event_tags(self, parser, expected):
+ events = parser.read_events()
+ self.assertEqual([(action, elem.tag) for action, elem in events],
+ expected)
+
+ def test_simple_xml(self):
+ for chunk_size in (None, 1, 5):
+ with self.subTest(chunk_size=chunk_size):
+ parser = ET.XMLPullParser()
+ self.assert_event_tags(parser, [])
+ self._feed(parser, "<!-- comment -->\n", chunk_size)
+ self.assert_event_tags(parser, [])
+ self._feed(parser,
+ "<root>\n <element key='value'>text</element",
+ chunk_size)
+ self.assert_event_tags(parser, [])
+ self._feed(parser, ">\n", chunk_size)
+ self.assert_event_tags(parser, [('end', 'element')])
+ self._feed(parser, "<element>text</element>tail\n", chunk_size)
+ self._feed(parser, "<empty-element/>\n", chunk_size)
+ self.assert_event_tags(parser, [
+ ('end', 'element'),
+ ('end', 'empty-element'),
+ ])
+ self._feed(parser, "</root>\n", chunk_size)
+ self.assert_event_tags(parser, [('end', 'root')])
+ self.assertIsNone(parser.close())
+
+ def test_feed_while_iterating(self):
+ parser = ET.XMLPullParser()
+ it = parser.read_events()
+ self._feed(parser, "<root>\n <element key='value'>text</element>\n")
+ action, elem = next(it)
+ self.assertEqual((action, elem.tag), ('end', 'element'))
+ self._feed(parser, "</root>\n")
+ action, elem = next(it)
+ self.assertEqual((action, elem.tag), ('end', 'root'))
+ with self.assertRaises(StopIteration):
+ next(it)
+
+ def test_simple_xml_with_ns(self):
+ parser = ET.XMLPullParser()
+ self.assert_event_tags(parser, [])
+ self._feed(parser, "<!-- comment -->\n")
+ self.assert_event_tags(parser, [])
+ self._feed(parser, "<root xmlns='namespace'>\n")
+ self.assert_event_tags(parser, [])
+ self._feed(parser, "<element key='value'>text</element")
+ self.assert_event_tags(parser, [])
+ self._feed(parser, ">\n")
+ self.assert_event_tags(parser, [('end', '{namespace}element')])
+ self._feed(parser, "<element>text</element>tail\n")
+ self._feed(parser, "<empty-element/>\n")
+ self.assert_event_tags(parser, [
+ ('end', '{namespace}element'),
+ ('end', '{namespace}empty-element'),
+ ])
+ self._feed(parser, "</root>\n")
+ self.assert_event_tags(parser, [('end', '{namespace}root')])
+ self.assertIsNone(parser.close())
+
+ def test_ns_events(self):
+ parser = ET.XMLPullParser(events=('start-ns', 'end-ns'))
+ self._feed(parser, "<!-- comment -->\n")
+ self._feed(parser, "<root xmlns='namespace'>\n")
+ self.assertEqual(
+ list(parser.read_events()),
+ [('start-ns', ('', 'namespace'))])
+ self._feed(parser, "<element key='value'>text</element")
+ self._feed(parser, ">\n")
+ self._feed(parser, "<element>text</element>tail\n")
+ self._feed(parser, "<empty-element/>\n")
+ self._feed(parser, "</root>\n")
+ self.assertEqual(list(parser.read_events()), [('end-ns', None)])
+ self.assertIsNone(parser.close())
+
+ def test_events(self):
+ parser = ET.XMLPullParser(events=())
+ self._feed(parser, "<root/>\n")
+ self.assert_event_tags(parser, [])
+
+ parser = ET.XMLPullParser(events=('start', 'end'))
+ self._feed(parser, "<!-- comment -->\n")
+ self.assert_event_tags(parser, [])
+ self._feed(parser, "<root>\n")
+ self.assert_event_tags(parser, [('start', 'root')])
+ self._feed(parser, "<element key='value'>text</element")
+ self.assert_event_tags(parser, [('start', 'element')])
+ self._feed(parser, ">\n")
+ self.assert_event_tags(parser, [('end', 'element')])
+ self._feed(parser,
+ "<element xmlns='foo'>text<empty-element/></element>tail\n")
+ self.assert_event_tags(parser, [
+ ('start', '{foo}element'),
+ ('start', '{foo}empty-element'),
+ ('end', '{foo}empty-element'),
+ ('end', '{foo}element'),
+ ])
+ self._feed(parser, "</root>")
+ self.assertIsNone(parser.close())
+ self.assert_event_tags(parser, [('end', 'root')])
+
+ parser = ET.XMLPullParser(events=('start',))
+ self._feed(parser, "<!-- comment -->\n")
+ self.assert_event_tags(parser, [])
+ self._feed(parser, "<root>\n")
+ self.assert_event_tags(parser, [('start', 'root')])
+ self._feed(parser, "<element key='value'>text</element")
+ self.assert_event_tags(parser, [('start', 'element')])
+ self._feed(parser, ">\n")
+ self.assert_event_tags(parser, [])
+ self._feed(parser,
+ "<element xmlns='foo'>text<empty-element/></element>tail\n")
+ self.assert_event_tags(parser, [
+ ('start', '{foo}element'),
+ ('start', '{foo}empty-element'),
+ ])
+ self._feed(parser, "</root>")
+ self.assertIsNone(parser.close())
+
+ def test_events_sequence(self):
+ # Test that events can be some sequence that's not just a tuple or list
+ eventset = {'end', 'start'}
+ parser = ET.XMLPullParser(events=eventset)
+ self._feed(parser, "<foo>bar</foo>")
+ self.assert_event_tags(parser, [('start', 'foo'), ('end', 'foo')])
+
+ class DummyIter:
+ def __init__(self):
+ self.events = iter(['start', 'end', 'start-ns'])
+ def __iter__(self):
+ return self
+ def __next__(self):
+ return next(self.events)
+
+ parser = ET.XMLPullParser(events=DummyIter())
+ self._feed(parser, "<foo>bar</foo>")
+ self.assert_event_tags(parser, [('start', 'foo'), ('end', 'foo')])
+
+
+ def test_unknown_event(self):
+ with self.assertRaises(ValueError):
+ ET.XMLPullParser(events=('start', 'end', 'bogus'))
+
+
#
# xinclude tests (samples from appendix C of the xinclude specification)
@@ -1300,7 +1445,7 @@ class BugsTest(unittest.TestCase):
# Don't crash when using custom entities.
ENTITIES = {'rsquo': '\u2019', 'lsquo': '\u2018'}
- parser = ET.XMLTreeBuilder()
+ parser = ET.XMLParser()
parser.entity.update(ENTITIES)
parser.feed("""<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE patent-application-publication SYSTEM "pap-v15-2001-01-31.dtd" []>
@@ -1638,6 +1783,11 @@ class ElementFindTest(unittest.TestCase):
self.assertEqual(e.find('./tag[last()-1]').attrib['class'], 'c')
self.assertEqual(e.find('./tag[last()-2]').attrib['class'], 'b')
+ self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[0]')
+ self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[-1]')
+ self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[last()-0]')
+ self.assertRaisesRegex(SyntaxError, 'XPath', e.find, './tag[last()+1]')
+
def test_findall(self):
e = ET.XML(SAMPLE_XML)
e[2] = ET.XML(SAMPLE_SECTION)
@@ -1897,7 +2047,7 @@ class TreeBuilderTest(unittest.TestCase):
# Mimick SimpleTAL's behaviour (issue #16089): both versions of
# TreeBuilder should be able to cope with a subclass of the
# pure Python Element class.
- base = ET._Element
+ base = ET._Element_Py
# Not from a C extension
self.assertEqual(base.__module__, 'xml.etree.ElementTree')
# Force some multiple inheritance with a C class to make things
@@ -2256,6 +2406,18 @@ class IOTest(unittest.TestCase):
ET.tostring(root, 'utf-16'),
b''.join(ET.tostringlist(root, 'utf-16')))
+ def test_short_empty_elements(self):
+ root = ET.fromstring('<tag>a<x />b<y></y>c</tag>')
+ self.assertEqual(
+ ET.tostring(root, 'unicode'),
+ '<tag>a<x />b<y />c</tag>')
+ self.assertEqual(
+ ET.tostring(root, 'unicode', short_empty_elements=True),
+ '<tag>a<x />b<y />c</tag>')
+ self.assertEqual(
+ ET.tostring(root, 'unicode', short_empty_elements=False),
+ '<tag>a<x></x>b<y></y>c</tag>')
+
class ParseErrorTest(unittest.TestCase):
def test_subclass(self):
@@ -2321,8 +2483,11 @@ class NoAcceleratorTest(unittest.TestCase):
# Test that the C accelerator was not imported for pyET
def test_correct_import_pyET(self):
- self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree')
- self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree')
+ # The type of methods defined in Python code is types.FunctionType,
+ # while the type of methods defined inside _elementtree is
+ # <class 'wrapper_descriptor'>
+ self.assertIsInstance(pyET.Element.__init__, types.FunctionType)
+ self.assertIsInstance(pyET.XMLParser.__init__, types.FunctionType)
# --------------------------------------------------------------------
@@ -2392,6 +2557,7 @@ def test_main(module=None):
ElementIterTest,
TreeBuilderTest,
XMLParserTest,
+ XMLPullParserTest,
BugsTest,
]
diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py
index bcaa724344..b3ff7ae37d 100644
--- a/Lib/test/test_xml_etree_c.py
+++ b/Lib/test/test_xml_etree_c.py
@@ -2,6 +2,7 @@
import sys, struct
from test import support
from test.support import import_fresh_module
+import types
import unittest
cET = import_fresh_module('xml.etree.ElementTree',
@@ -33,14 +34,22 @@ class TestAliasWorking(unittest.TestCase):
@unittest.skipUnless(cET, 'requires _elementtree')
+@support.cpython_only
class TestAcceleratorImported(unittest.TestCase):
# Test that the C accelerator was imported, as expected
def test_correct_import_cET(self):
+ # SubElement is a function so it retains _elementtree as its module.
self.assertEqual(cET.SubElement.__module__, '_elementtree')
def test_correct_import_cET_alias(self):
self.assertEqual(cET_alias.SubElement.__module__, '_elementtree')
+ def test_parser_comes_from_C(self):
+ # The type of methods defined in Python code is types.FunctionType,
+ # while the type of methods defined inside _elementtree is
+ # <class 'wrapper_descriptor'>
+ self.assertNotIsInstance(cET.Element.__init__, types.FunctionType)
+
@unittest.skipUnless(cET, 'requires _elementtree')
@support.cpython_only
diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py
index cc52259c82..99b3eda13e 100644
--- a/Lib/test/test_xmlrpc.py
+++ b/Lib/test/test_xmlrpc.py
@@ -15,6 +15,10 @@ import contextlib
from test import support
try:
+ import gzip
+except ImportError:
+ gzip = None
+try:
import threading
except ImportError:
threading = None
@@ -216,7 +220,7 @@ class XMLRPCTestCase(unittest.TestCase):
xmlrpc.client.ServerProxy('https://localhost:9999').bad_function()
except NotImplementedError:
self.assertFalse(has_ssl, "xmlrpc client's error with SSL support")
- except socket.error:
+ except OSError:
self.assertTrue(has_ssl)
class HelperTestCase(unittest.TestCase):
@@ -376,6 +380,11 @@ def http_server(evt, numrequests, requestHandler=None):
if name == 'div':
return 'This is the div function'
+ class Fixture:
+ @staticmethod
+ def getData():
+ return '42'
+
def my_function():
'''This is my function'''
return True
@@ -407,7 +416,8 @@ def http_server(evt, numrequests, requestHandler=None):
serv.register_function(pow)
serv.register_function(lambda x,y: x+y, 'add')
serv.register_function(my_function)
- serv.register_instance(TestInstanceClass())
+ testInstance = TestInstanceClass()
+ serv.register_instance(testInstance, allow_dotted_names=True)
evt.set()
# handle up to 'numrequests' requests
@@ -500,7 +510,7 @@ def is_unavailable_exception(e):
return True
exc_mess = e.headers.get('X-exception')
except AttributeError:
- # Ignore socket.errors here.
+ # Ignore OSErrors here.
exc_mess = str(e)
if exc_mess and 'temporarily unavailable' in exc_mess.lower():
@@ -515,7 +525,7 @@ def make_request_and_skipIf(condition, reason):
def make_request_and_skip(self):
try:
xmlrpclib.ServerProxy(URL).my_function()
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
if not is_unavailable_exception(e):
raise
raise unittest.SkipTest(reason)
@@ -553,7 +563,7 @@ class SimpleServerTestCase(BaseServerTestCase):
try:
p = xmlrpclib.ServerProxy(URL)
self.assertEqual(p.pow(6,8), 6**8)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -566,7 +576,7 @@ class SimpleServerTestCase(BaseServerTestCase):
p = xmlrpclib.ServerProxy(URL)
self.assertEqual(p.add(start_string, end_string),
start_string + end_string)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -587,12 +597,13 @@ class SimpleServerTestCase(BaseServerTestCase):
def test_introspection1(self):
expected_methods = set(['pow', 'div', 'my_function', 'add',
'system.listMethods', 'system.methodHelp',
- 'system.methodSignature', 'system.multicall'])
+ 'system.methodSignature', 'system.multicall',
+ 'Fixture'])
try:
p = xmlrpclib.ServerProxy(URL)
meth = p.system.listMethods()
self.assertEqual(set(meth), expected_methods)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -605,7 +616,7 @@ class SimpleServerTestCase(BaseServerTestCase):
p = xmlrpclib.ServerProxy(URL)
divhelp = p.system.methodHelp('div')
self.assertEqual(divhelp, 'This is the div function')
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -619,7 +630,7 @@ class SimpleServerTestCase(BaseServerTestCase):
p = xmlrpclib.ServerProxy(URL)
myfunction = p.system.methodHelp('my_function')
self.assertEqual(myfunction, 'This is my function')
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -632,7 +643,7 @@ class SimpleServerTestCase(BaseServerTestCase):
p = xmlrpclib.ServerProxy(URL)
divsig = p.system.methodSignature('div')
self.assertEqual(divsig, 'signatures not supported')
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -649,7 +660,7 @@ class SimpleServerTestCase(BaseServerTestCase):
self.assertEqual(add_result, 2+3)
self.assertEqual(pow_result, 6**8)
self.assertEqual(div_result, 127//42)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -670,7 +681,7 @@ class SimpleServerTestCase(BaseServerTestCase):
self.assertEqual(result.results[0]['faultString'],
'<class \'Exception\'>:method "this_is_not_exists" '
'is not supported')
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -686,6 +697,12 @@ class SimpleServerTestCase(BaseServerTestCase):
# This avoids waiting for the socket timeout.
self.test_simple1()
+ def test_allow_dotted_names_true(self):
+ # XXX also need allow_dotted_names_false test.
+ server = xmlrpclib.ServerProxy("http://%s:%d/RPC2" % (ADDR, PORT))
+ data = server.Fixture.getData()
+ self.assertEqual(data, '42')
+
def test_unicode_host(self):
server = xmlrpclib.ServerProxy("http://%s:%d/RPC2" % (ADDR, PORT))
self.assertEqual(server.add("a", "\xe9"), "a\xe9")
@@ -793,6 +810,7 @@ class KeepaliveServerTestCase2(BaseKeepaliveServerTestCase):
#A test case that verifies that gzip encoding works in both directions
#(for a request and the response)
+@unittest.skipIf(gzip is None, 'requires gzip')
class GzipServerTestCase(BaseServerTestCase):
#a request handler that supports keep-alive and logs requests into a
#class variable
@@ -923,7 +941,7 @@ class FailingServerTestCase(unittest.TestCase):
try:
p = xmlrpclib.ServerProxy(URL)
self.assertEqual(p.pow(6,8), 6**8)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e):
# protocol error; provide additional information in test output
@@ -936,7 +954,7 @@ class FailingServerTestCase(unittest.TestCase):
try:
p = xmlrpclib.ServerProxy(URL)
p.pow(6,8)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e) and hasattr(e, "headers"):
# The two server-side error headers shouldn't be sent back in this case
@@ -956,7 +974,7 @@ class FailingServerTestCase(unittest.TestCase):
try:
p = xmlrpclib.ServerProxy(URL)
p.pow(6,8)
- except (xmlrpclib.ProtocolError, socket.error) as e:
+ except (xmlrpclib.ProtocolError, OSError) as e:
# ignore failures due to non-blocking socket 'unavailable' errors
if not is_unavailable_exception(e) and hasattr(e, "headers"):
# We should get error info in the response
@@ -1084,23 +1102,13 @@ class UseBuiltinTypesTestCase(unittest.TestCase):
@support.reap_threads
def test_main():
- xmlrpc_tests = [XMLRPCTestCase, HelperTestCase, DateTimeTestCase,
- BinaryTestCase, FaultTestCase]
- xmlrpc_tests.append(UseBuiltinTypesTestCase)
- xmlrpc_tests.append(SimpleServerTestCase)
- xmlrpc_tests.append(KeepaliveServerTestCase1)
- xmlrpc_tests.append(KeepaliveServerTestCase2)
- try:
- import gzip
- xmlrpc_tests.append(GzipServerTestCase)
- except ImportError:
- pass #gzip not supported in this build
- xmlrpc_tests.append(MultiPathServerTestCase)
- xmlrpc_tests.append(ServerProxyTestCase)
- xmlrpc_tests.append(FailingServerTestCase)
- xmlrpc_tests.append(CGIHandlerTestCase)
-
- support.run_unittest(*xmlrpc_tests)
+ support.run_unittest(XMLRPCTestCase, HelperTestCase, DateTimeTestCase,
+ BinaryTestCase, FaultTestCase, UseBuiltinTypesTestCase,
+ SimpleServerTestCase, KeepaliveServerTestCase1,
+ KeepaliveServerTestCase2, GzipServerTestCase,
+ MultiPathServerTestCase, ServerProxyTestCase, FailingServerTestCase,
+ CGIHandlerTestCase)
+
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_xmlrpc_net.py b/Lib/test/test_xmlrpc_net.py
index dfb5f9aa3d..00aca199df 100644
--- a/Lib/test/test_xmlrpc_net.py
+++ b/Lib/test/test_xmlrpc_net.py
@@ -9,32 +9,7 @@ from test import support
import xmlrpc.client as xmlrpclib
-class CurrentTimeTest(unittest.TestCase):
-
- def test_current_time(self):
- # Get the current time from xmlrpc.com. This code exercises
- # the minimal HTTP functionality in xmlrpclib.
- self.skipTest("time.xmlrpc.com is unreliable")
- server = xmlrpclib.ServerProxy("http://time.xmlrpc.com/RPC2")
- try:
- t0 = server.currentTime.getCurrentTime()
- except socket.error as e:
- self.skipTest("network error: %s" % e)
- return
-
- # Perform a minimal sanity check on the result, just to be sure
- # the request means what we think it means.
- t1 = xmlrpclib.DateTime()
-
- dt0 = xmlrpclib._datetime_type(t0.value)
- dt1 = xmlrpclib._datetime_type(t1.value)
- if dt0 > dt1:
- delta = dt0 - dt1
- else:
- delta = dt1 - dt0
- # The difference between the system time here and the system
- # time on the server should not be too big.
- self.assertTrue(delta.days <= 1)
+class PythonBuildersTest(unittest.TestCase):
def test_python_builders(self):
# Get the list of builders from the XMLRPC buildbot interface at
@@ -42,7 +17,7 @@ class CurrentTimeTest(unittest.TestCase):
server = xmlrpclib.ServerProxy("http://buildbot.python.org/all/xmlrpc/")
try:
builders = server.getAllBuilders()
- except socket.error as e:
+ except OSError as e:
self.skipTest("network error: %s" % e)
return
self.addCleanup(lambda: server('close')())
@@ -55,7 +30,7 @@ class CurrentTimeTest(unittest.TestCase):
def test_main():
support.requires("network")
- support.run_unittest(CurrentTimeTest)
+ support.run_unittest(PythonBuildersTest)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py
index ad0c0b7b41..7249b13d4c 100644
--- a/Lib/test/test_zipfile.py
+++ b/Lib/test/test_zipfile.py
@@ -1,7 +1,7 @@
import io
import os
import sys
-import imp
+import importlib.util
import time
import shutil
import struct
@@ -557,7 +557,7 @@ class PyZipFileTests(unittest.TestCase):
if os.altsep is not None:
path_split.extend(fn.split(os.altsep))
if '__pycache__' in path_split:
- fn = imp.source_from_cache(fn)
+ fn = importlib.util.source_from_cache(fn)
else:
fn = fn[:-1]
@@ -591,6 +591,34 @@ class PyZipFileTests(unittest.TestCase):
self.assertCompiledIn('email/__init__.py', names)
self.assertCompiledIn('email/mime/text.py', names)
+ def test_write_filtered_python_package(self):
+ import test
+ packagedir = os.path.dirname(test.__file__)
+
+ with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp:
+
+ # first make sure that the test folder gives error messages
+ # (on the badsyntax_... files)
+ with captured_stdout() as reportSIO:
+ zipfp.writepy(packagedir)
+ reportStr = reportSIO.getvalue()
+ self.assertTrue('SyntaxError' in reportStr)
+
+ # then check that the filter works on the whole package
+ with captured_stdout() as reportSIO:
+ zipfp.writepy(packagedir, filterfunc=lambda whatever: False)
+ reportStr = reportSIO.getvalue()
+ self.assertTrue('SyntaxError' not in reportStr)
+
+ # then check that the filter works on individual files
+ with captured_stdout() as reportSIO:
+ zipfp.writepy(packagedir, filterfunc=lambda fn:
+ 'bad' not in fn)
+ reportStr = reportSIO.getvalue()
+ if reportStr:
+ print(reportStr)
+ self.assertTrue('SyntaxError' not in reportStr)
+
def test_write_with_optimization(self):
import email
packagedir = os.path.dirname(email.__file__)
@@ -600,7 +628,7 @@ class PyZipFileTests(unittest.TestCase):
ext = '.pyo' if optlevel == 1 else '.pyc'
with TemporaryFile() as t, \
- zipfile.PyZipFile(t, "w", optimize=optlevel) as zipfp:
+ zipfile.PyZipFile(t, "w", optimize=optlevel) as zipfp:
zipfp.writepy(packagedir)
names = zipfp.namelist()
@@ -630,6 +658,26 @@ class PyZipFileTests(unittest.TestCase):
finally:
shutil.rmtree(TESTFN2)
+ def test_write_python_directory_filtered(self):
+ os.mkdir(TESTFN2)
+ try:
+ with open(os.path.join(TESTFN2, "mod1.py"), "w") as fp:
+ fp.write("print(42)\n")
+
+ with open(os.path.join(TESTFN2, "mod2.py"), "w") as fp:
+ fp.write("print(42 * 42)\n")
+
+ with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp:
+ zipfp.writepy(TESTFN2, filterfunc=lambda fn:
+ not fn.endswith('mod2.py'))
+
+ names = zipfp.namelist()
+ self.assertCompiledIn('mod1.py', names)
+ self.assertNotIn('mod2.py', names)
+
+ finally:
+ shutil.rmtree(TESTFN2)
+
def test_write_non_pyfile(self):
with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp:
with open(TESTFN, 'w') as f:
@@ -733,25 +781,25 @@ class ExtractTests(unittest.TestCase):
def test_extract_hackers_arcnames_windows_only(self):
"""Test combination of path fixing and windows name sanitization."""
windows_hacknames = [
- (r'..\foo\bar', 'foo/bar'),
- (r'..\/foo\/bar', 'foo/bar'),
- (r'foo/\..\/bar', 'foo/bar'),
- (r'foo\/../\bar', 'foo/bar'),
- (r'C:foo/bar', 'foo/bar'),
- (r'C:/foo/bar', 'foo/bar'),
- (r'C://foo/bar', 'foo/bar'),
- (r'C:\foo\bar', 'foo/bar'),
- (r'//conky/mountpoint/foo/bar', 'foo/bar'),
- (r'\\conky\mountpoint\foo\bar', 'foo/bar'),
- (r'///conky/mountpoint/foo/bar', 'conky/mountpoint/foo/bar'),
- (r'\\\conky\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'),
- (r'//conky//mountpoint/foo/bar', 'conky/mountpoint/foo/bar'),
- (r'\\conky\\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'),
- (r'//?/C:/foo/bar', 'foo/bar'),
- (r'\\?\C:\foo\bar', 'foo/bar'),
- (r'C:/../C:/foo/bar', 'C_/foo/bar'),
- (r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'),
- ('../../foo../../ba..r', 'foo/ba..r'),
+ (r'..\foo\bar', 'foo/bar'),
+ (r'..\/foo\/bar', 'foo/bar'),
+ (r'foo/\..\/bar', 'foo/bar'),
+ (r'foo\/../\bar', 'foo/bar'),
+ (r'C:foo/bar', 'foo/bar'),
+ (r'C:/foo/bar', 'foo/bar'),
+ (r'C://foo/bar', 'foo/bar'),
+ (r'C:\foo\bar', 'foo/bar'),
+ (r'//conky/mountpoint/foo/bar', 'foo/bar'),
+ (r'\\conky\mountpoint\foo\bar', 'foo/bar'),
+ (r'///conky/mountpoint/foo/bar', 'conky/mountpoint/foo/bar'),
+ (r'\\\conky\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'),
+ (r'//conky//mountpoint/foo/bar', 'conky/mountpoint/foo/bar'),
+ (r'\\conky\\mountpoint\foo\bar', 'conky/mountpoint/foo/bar'),
+ (r'//?/C:/foo/bar', 'foo/bar'),
+ (r'\\?\C:\foo\bar', 'foo/bar'),
+ (r'C:/../C:/foo/bar', 'C_/foo/bar'),
+ (r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'),
+ ('../../foo../../ba..r', 'foo/ba..r'),
]
self._test_extract_hackers_arcnames(windows_hacknames)
@@ -877,10 +925,10 @@ class OtherTests(unittest.TestCase):
def test_unsupported_version(self):
# File has an extract_version of 120
data = (b'PK\x03\x04x\x00\x00\x00\x00\x00!p\xa1@\x00\x00\x00\x00\x00\x00'
- b'\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00xPK\x01\x02x\x03x\x00\x00\x00\x00'
- b'\x00!p\xa1@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00'
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00xPK\x05\x06'
- b'\x00\x00\x00\x00\x01\x00\x01\x00/\x00\x00\x00\x1f\x00\x00\x00\x00\x00')
+ b'\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00xPK\x01\x02x\x03x\x00\x00\x00\x00'
+ b'\x00!p\xa1@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00'
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00\x00xPK\x05\x06'
+ b'\x00\x00\x00\x00\x01\x00\x01\x00/\x00\x00\x00\x1f\x00\x00\x00\x00\x00')
self.assertRaises(NotImplementedError, zipfile.ZipFile,
io.BytesIO(data), 'r')
@@ -913,7 +961,7 @@ class OtherTests(unittest.TestCase):
try:
with zipfile.ZipFile(TESTFN, 'a') as zf:
zf.writestr(filename, content)
- except IOError:
+ except OSError:
self.fail('Could not append data to a non-existent zip file.')
self.assertTrue(os.path.exists(TESTFN))
@@ -985,7 +1033,7 @@ class OtherTests(unittest.TestCase):
fp.seek(0, 0)
self.assertTrue(zipfile.is_zipfile(fp))
- def test_non_existent_file_raises_IOError(self):
+ def test_non_existent_file_raises_OSError(self):
# make sure we don't raise an AttributeError when a partially-constructed
# ZipFile instance is finalized; this tests for regression on SF tracker
# bug #403871.
@@ -997,7 +1045,7 @@ class OtherTests(unittest.TestCase):
# it is ignored, but the user should be sufficiently annoyed by
# the message on the output that regression will be noticed
# quickly.
- self.assertRaises(IOError, zipfile.ZipFile, TESTFN)
+ self.assertRaises(OSError, zipfile.ZipFile, TESTFN)
def test_empty_file_raises_BadZipFile(self):
f = open(TESTFN, 'w')
@@ -1066,11 +1114,11 @@ class OtherTests(unittest.TestCase):
def test_unsupported_compression(self):
# data is declared as shrunk, but actually deflated
data = (b'PK\x03\x04.\x00\x00\x00\x01\x00\xe4C\xa1@\x00\x00\x00'
- b'\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00x\x03\x00PK\x01'
- b'\x02.\x03.\x00\x00\x00\x01\x00\xe4C\xa1@\x00\x00\x00\x00\x02\x00\x00'
- b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
- b'\x80\x01\x00\x00\x00\x00xPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x00'
- b'/\x00\x00\x00!\x00\x00\x00\x00\x00')
+ b'\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00x\x03\x00PK\x01'
+ b'\x02.\x03.\x00\x00\x00\x01\x00\xe4C\xa1@\x00\x00\x00\x00\x02\x00\x00'
+ b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
+ b'\x80\x01\x00\x00\x00\x00xPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x00'
+ b'/\x00\x00\x00!\x00\x00\x00\x00\x00')
with zipfile.ZipFile(io.BytesIO(data), 'r') as zipf:
self.assertRaises(NotImplementedError, zipf.open, 'x')
@@ -1184,7 +1232,7 @@ class OtherTests(unittest.TestCase):
def test_open_empty_file(self):
# Issue 1710703: Check that opening a file with less than 22 bytes
# raises a BadZipFile exception (rather than the previously unhelpful
- # IOError)
+ # OSError)
f = open(TESTFN, 'w')
f.close()
self.assertRaises(zipfile.BadZipFile, zipfile.ZipFile, TESTFN, 'r')
@@ -1232,57 +1280,57 @@ class AbstractBadCrcTests:
class StoredBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
compression = zipfile.ZIP_STORED
zip_with_bad_crc = (
- b'PK\003\004\024\0\0\0\0\0 \213\212;:r'
- b'\253\377\f\0\0\0\f\0\0\0\005\0\0\000af'
- b'ilehello,AworldP'
- b'K\001\002\024\003\024\0\0\0\0\0 \213\212;:'
- b'r\253\377\f\0\0\0\f\0\0\0\005\0\0\0\0'
- b'\0\0\0\0\0\0\0\200\001\0\0\0\000afi'
- b'lePK\005\006\0\0\0\0\001\0\001\0003\000'
- b'\0\0/\0\0\0\0\0')
+ b'PK\003\004\024\0\0\0\0\0 \213\212;:r'
+ b'\253\377\f\0\0\0\f\0\0\0\005\0\0\000af'
+ b'ilehello,AworldP'
+ b'K\001\002\024\003\024\0\0\0\0\0 \213\212;:'
+ b'r\253\377\f\0\0\0\f\0\0\0\005\0\0\0\0'
+ b'\0\0\0\0\0\0\0\200\001\0\0\0\000afi'
+ b'lePK\005\006\0\0\0\0\001\0\001\0003\000'
+ b'\0\0/\0\0\0\0\0')
@requires_zlib
class DeflateBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
compression = zipfile.ZIP_DEFLATED
zip_with_bad_crc = (
- b'PK\x03\x04\x14\x00\x00\x00\x08\x00n}\x0c=FA'
- b'KE\x10\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
- b'ile\xcbH\xcd\xc9\xc9W(\xcf/\xcaI\xc9\xa0'
- b'=\x13\x00PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00n'
- b'}\x0c=FAKE\x10\x00\x00\x00n\x00\x00\x00\x05'
- b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00'
- b'\x00afilePK\x05\x06\x00\x00\x00\x00\x01\x00'
- b'\x01\x003\x00\x00\x003\x00\x00\x00\x00\x00')
+ b'PK\x03\x04\x14\x00\x00\x00\x08\x00n}\x0c=FA'
+ b'KE\x10\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
+ b'ile\xcbH\xcd\xc9\xc9W(\xcf/\xcaI\xc9\xa0'
+ b'=\x13\x00PK\x01\x02\x14\x03\x14\x00\x00\x00\x08\x00n'
+ b'}\x0c=FAKE\x10\x00\x00\x00n\x00\x00\x00\x05'
+ b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\x01\x00\x00\x00'
+ b'\x00afilePK\x05\x06\x00\x00\x00\x00\x01\x00'
+ b'\x01\x003\x00\x00\x003\x00\x00\x00\x00\x00')
@requires_bz2
class Bzip2BadCrcTests(AbstractBadCrcTests, unittest.TestCase):
compression = zipfile.ZIP_BZIP2
zip_with_bad_crc = (
- b'PK\x03\x04\x14\x03\x00\x00\x0c\x00nu\x0c=FA'
- b'KE8\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
- b'ileBZh91AY&SY\xd4\xa8\xca'
- b'\x7f\x00\x00\x0f\x11\x80@\x00\x06D\x90\x80 \x00 \xa5'
- b'P\xd9!\x03\x03\x13\x13\x13\x89\xa9\xa9\xc2u5:\x9f'
- b'\x8b\xb9"\x9c(HjTe?\x80PK\x01\x02\x14'
- b'\x03\x14\x03\x00\x00\x0c\x00nu\x0c=FAKE8'
- b'\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00'
- b'\x00 \x80\x80\x81\x00\x00\x00\x00afilePK'
- b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00[\x00'
- b'\x00\x00\x00\x00')
+ b'PK\x03\x04\x14\x03\x00\x00\x0c\x00nu\x0c=FA'
+ b'KE8\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
+ b'ileBZh91AY&SY\xd4\xa8\xca'
+ b'\x7f\x00\x00\x0f\x11\x80@\x00\x06D\x90\x80 \x00 \xa5'
+ b'P\xd9!\x03\x03\x13\x13\x13\x89\xa9\xa9\xc2u5:\x9f'
+ b'\x8b\xb9"\x9c(HjTe?\x80PK\x01\x02\x14'
+ b'\x03\x14\x03\x00\x00\x0c\x00nu\x0c=FAKE8'
+ b'\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00'
+ b'\x00 \x80\x80\x81\x00\x00\x00\x00afilePK'
+ b'\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00\x00[\x00'
+ b'\x00\x00\x00\x00')
@requires_lzma
class LzmaBadCrcTests(AbstractBadCrcTests, unittest.TestCase):
compression = zipfile.ZIP_LZMA
zip_with_bad_crc = (
- b'PK\x03\x04\x14\x03\x00\x00\x0e\x00nu\x0c=FA'
- b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
- b'ile\t\x04\x05\x00]\x00\x00\x00\x04\x004\x19I'
- b'\xee\x8d\xe9\x17\x89:3`\tq!.8\x00PK'
- b'\x01\x02\x14\x03\x14\x03\x00\x00\x0e\x00nu\x0c=FA'
- b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00'
- b'\x00\x00\x00\x00 \x80\x80\x81\x00\x00\x00\x00afil'
- b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
- b'\x00>\x00\x00\x00\x00\x00')
+ b'PK\x03\x04\x14\x03\x00\x00\x0e\x00nu\x0c=FA'
+ b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00af'
+ b'ile\t\x04\x05\x00]\x00\x00\x00\x04\x004\x19I'
+ b'\xee\x8d\xe9\x17\x89:3`\tq!.8\x00PK'
+ b'\x01\x02\x14\x03\x14\x03\x00\x00\x0e\x00nu\x0c=FA'
+ b'KE\x1b\x00\x00\x00n\x00\x00\x00\x05\x00\x00\x00\x00\x00'
+ b'\x00\x00\x00\x00 \x80\x80\x81\x00\x00\x00\x00afil'
+ b'ePK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x003\x00\x00'
+ b'\x00>\x00\x00\x00\x00\x00')
class DecryptionTests(unittest.TestCase):
@@ -1291,22 +1339,22 @@ class DecryptionTests(unittest.TestCase):
ZIP file."""
data = (
- b'PK\x03\x04\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00\x1a\x00'
- b'\x00\x00\x08\x00\x00\x00test.txt\xfa\x10\xa0gly|\xfa-\xc5\xc0=\xf9y'
- b'\x18\xe0\xa8r\xb3Z}Lg\xbc\xae\xf9|\x9b\x19\xe4\x8b\xba\xbb)\x8c\xb0\xdbl'
- b'PK\x01\x02\x14\x00\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00'
- b'\x1a\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x01\x00 \x00\xb6\x81'
- b'\x00\x00\x00\x00test.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x006\x00'
- b'\x00\x00L\x00\x00\x00\x00\x00' )
+ b'PK\x03\x04\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00\x1a\x00'
+ b'\x00\x00\x08\x00\x00\x00test.txt\xfa\x10\xa0gly|\xfa-\xc5\xc0=\xf9y'
+ b'\x18\xe0\xa8r\xb3Z}Lg\xbc\xae\xf9|\x9b\x19\xe4\x8b\xba\xbb)\x8c\xb0\xdbl'
+ b'PK\x01\x02\x14\x00\x14\x00\x01\x00\x00\x00n\x92i.#y\xef?&\x00\x00\x00'
+ b'\x1a\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x01\x00 \x00\xb6\x81'
+ b'\x00\x00\x00\x00test.txtPK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x006\x00'
+ b'\x00\x00L\x00\x00\x00\x00\x00' )
data2 = (
- b'PK\x03\x04\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02'
- b'\x00\x00\x04\x00\x15\x00zeroUT\t\x00\x03\xd6\x8b\x92G\xda\x8b\x92GUx\x04'
- b'\x00\xe8\x03\xe8\x03\xc7<M\xb5a\xceX\xa3Y&\x8b{oE\xd7\x9d\x8c\x98\x02\xc0'
- b'PK\x07\x08xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00PK\x01\x02\x17\x03'
- b'\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00'
- b'\x04\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x81\x00\x00\x00\x00ze'
- b'roUT\x05\x00\x03\xd6\x8b\x92GUx\x00\x00PK\x05\x06\x00\x00\x00\x00\x01'
- b'\x00\x01\x00?\x00\x00\x00[\x00\x00\x00\x00\x00' )
+ b'PK\x03\x04\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02'
+ b'\x00\x00\x04\x00\x15\x00zeroUT\t\x00\x03\xd6\x8b\x92G\xda\x8b\x92GUx\x04'
+ b'\x00\xe8\x03\xe8\x03\xc7<M\xb5a\xceX\xa3Y&\x8b{oE\xd7\x9d\x8c\x98\x02\xc0'
+ b'PK\x07\x08xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00PK\x01\x02\x17\x03'
+ b'\x14\x00\t\x00\x08\x00\xcf}38xu\xaa\xb2\x14\x00\x00\x00\x00\x02\x00\x00'
+ b'\x04\x00\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x81\x00\x00\x00\x00ze'
+ b'roUT\x05\x00\x03\xd6\x8b\x92GUx\x00\x00PK\x05\x06\x00\x00\x00\x00\x01'
+ b'\x00\x01\x00?\x00\x00\x00[\x00\x00\x00\x00\x00' )
plain = b'zipfile.py encryption test'
plain2 = b'\x00'*512
@@ -1668,6 +1716,5 @@ class LzmaUniversalNewlineTests(AbstractUniversalNewlineTests,
unittest.TestCase):
compression = zipfile.ZIP_LZMA
-
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py
index f7cb8b977f..3f16041f51 100644
--- a/Lib/test/test_zipimport.py
+++ b/Lib/test/test_zipimport.py
@@ -1,13 +1,12 @@
import sys
import os
import marshal
-import imp
+import importlib.util
import struct
import time
import unittest
from test import support
-from test.test_importhooks import ImportHooksBaseTestCase, test_src, test_co
from zipfile import ZipFile, ZipInfo, ZIP_STORED, ZIP_DEFLATED
@@ -17,6 +16,14 @@ import doctest
import inspect
import io
from traceback import extract_tb, extract_stack, print_tb
+
+test_src = """\
+def get_name():
+ return __name__
+def get_file():
+ return __file__
+"""
+test_co = compile(test_src, "<???>", "exec")
raise_src = 'def do_raise(): raise TypeError\n'
def make_pyc(co, mtime, size):
@@ -27,7 +34,8 @@ def make_pyc(co, mtime, size):
mtime = int(mtime)
else:
mtime = int(-0x100000000 + int(mtime))
- pyc = imp.get_magic() + struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data
+ pyc = (importlib.util.MAGIC_NUMBER +
+ struct.pack("<ii", int(mtime), size & 0xFFFFFFFF) + data)
return pyc
def module_path_to_dotted_name(path):
@@ -42,10 +50,27 @@ TESTPACK = "ziptestpackage"
TESTPACK2 = "ziptestpackage2"
TEMP_ZIP = os.path.abspath("junk95142.zip")
-pyc_file = imp.cache_from_source(TESTMOD + '.py')
+pyc_file = importlib.util.cache_from_source(TESTMOD + '.py')
pyc_ext = ('.pyc' if __debug__ else '.pyo')
+class ImportHooksBaseTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self.path = sys.path[:]
+ self.meta_path = sys.meta_path[:]
+ self.path_hooks = sys.path_hooks[:]
+ sys.path_importer_cache.clear()
+ self.modules_before = support.modules_setup()
+
+ def tearDown(self):
+ sys.path[:] = self.path
+ sys.meta_path[:] = self.meta_path
+ sys.path_hooks[:] = self.path_hooks
+ sys.path_importer_cache.clear()
+ support.modules_cleanup(*self.modules_before)
+
+
class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
compression = ZIP_STORED
@@ -196,6 +221,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
for name, (mtime, data) in files.items():
zinfo = ZipInfo(name, time.localtime(mtime))
zinfo.compress_type = self.compression
+ zinfo.comment = b"spam"
z.writestr(zinfo, data)
z.close()
@@ -245,6 +271,7 @@ class UncompressedZipImportTestCase(ImportHooksBaseTestCase):
for name, (mtime, data) in files.items():
zinfo = ZipInfo(name, time.localtime(mtime))
zinfo.compress_type = self.compression
+ zinfo.comment = b"eggs"
z.writestr(zinfo, data)
z.close()
@@ -459,7 +486,7 @@ class BadFileZipImportTestCase(unittest.TestCase):
self.assertRaises(error, z.load_module, 'abc')
self.assertRaises(error, z.get_code, 'abc')
- self.assertRaises(IOError, z.get_data, 'abc')
+ self.assertRaises(OSError, z.get_data, 'abc')
self.assertRaises(error, z.get_source, 'abc')
self.assertRaises(error, z.is_package, 'abc')
finally:
diff --git a/Lib/test/test_zipimport_support.py b/Lib/test/test_zipimport_support.py
index f7f3398015..84ba5e047a 100644
--- a/Lib/test/test_zipimport_support.py
+++ b/Lib/test/test_zipimport_support.py
@@ -227,13 +227,15 @@ class ZipSupportTests(unittest.TestCase):
p = spawn_python(script_name)
p.stdin.write(b'l\n')
data = kill_python(p)
- self.assertIn(script_name.encode('utf-8'), data)
+ # bdb/pdb applies normcase to its filename before displaying
+ self.assertIn(os.path.normcase(script_name.encode('utf-8')), data)
zip_name, run_name = make_zip_script(d, "test_zip",
script_name, '__main__.py')
p = spawn_python(zip_name)
p.stdin.write(b'l\n')
data = kill_python(p)
- self.assertIn(run_name.encode('utf-8'), data)
+ # bdb/pdb applies normcase to its filename before displaying
+ self.assertIn(os.path.normcase(run_name.encode('utf-8')), data)
def test_main():
diff --git a/Lib/test/tf_inherit_check.py b/Lib/test/tf_inherit_check.py
index 92ebd95e52..afe50d2325 100644
--- a/Lib/test/tf_inherit_check.py
+++ b/Lib/test/tf_inherit_check.py
@@ -11,7 +11,7 @@ try:
try:
os.write(fd, b"blat")
- except os.error:
+ except OSError:
# Success -- could not write to fd.
sys.exit(0)
else: