summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNed Batchelder <ned@nedbatchelder.com>2016-01-10 12:37:36 -0500
committerNed Batchelder <ned@nedbatchelder.com>2016-01-10 12:37:36 -0500
commit7865e5067d30c98f280ad06175b26c93508783c5 (patch)
treefeb82c87c38795f13f2da6746848d2b9ebb882e7
parentf6d343cf2259491ce9556758c5398e0db76d804c (diff)
downloadpython-coveragepy-git-7865e5067d30c98f280ad06175b26c93508783c5.tar.gz
Make multiprocessing support work with spawned processes, which is what Windows uses.
-rw-r--r--coverage/monkey.py30
-rw-r--r--igor.py2
-rw-r--r--tests/test_concurrency.py30
3 files changed, 51 insertions, 11 deletions
diff --git a/coverage/monkey.py b/coverage/monkey.py
index b896dbf5..3f78d7dc 100644
--- a/coverage/monkey.py
+++ b/coverage/monkey.py
@@ -12,7 +12,6 @@ import sys
PATCHED_MARKER = "_coverage$patched"
if sys.version_info >= (3, 4):
-
klass = multiprocessing.process.BaseProcess
else:
klass = multiprocessing.Process
@@ -49,4 +48,33 @@ def patch_multiprocessing():
else:
multiprocessing.Process = ProcessWithCoverage
+ # When spawning processes rather than forking them, we have no state in the
+ # new process. We sneak in there with a Stowaway: we stuff one of our own
+ # objects into the data that gets pickled and sent to the sub-process. When
+ # the Stowaway is unpickled, it's __setstate__ method is called, which
+ # re-applies the monkey-patch.
+ # Windows only spawns, so this is needed to keep Windows working.
+ try:
+ from multiprocessing import spawn
+ original_get_preparation_data = spawn.get_preparation_data
+ except (ImportError, AttributeError):
+ pass
+ else:
+ def get_preparation_data_with_stowaway(name):
+ """Get the original preparation data, and also insert our stowaway."""
+ d = original_get_preparation_data(name)
+ d['stowaway'] = Stowaway()
+ return d
+
+ spawn.get_preparation_data = get_preparation_data_with_stowaway
+
setattr(multiprocessing, PATCHED_MARKER, True)
+
+
+class Stowaway(object):
+ """An object to pickle, so when it is unpickled, it can apply the monkey-patch."""
+ def __getstate__(self):
+ return {}
+
+ def __setstate__(self, state):
+ patch_multiprocessing()
diff --git a/igor.py b/igor.py
index b857fcc3..54b8da16 100644
--- a/igor.py
+++ b/igor.py
@@ -332,7 +332,7 @@ def print_banner(label):
which_python = os.path.relpath(sys.executable)
except ValueError:
# On Windows having a python executable on a different drives
- # than the sources cannot be relative
+ # than the sources cannot be relative.
which_python = sys.executable
print('=== %s %s %s (%s) ===' % (impl, version, label, which_python))
sys.stdout.flush()
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
index 0f5ffe95..04eb9853 100644
--- a/tests/test_concurrency.py
+++ b/tests/test_concurrency.py
@@ -3,6 +3,7 @@
"""Tests for concurrency libraries."""
+import multiprocessing
import threading
import coverage
@@ -227,6 +228,7 @@ class MultiprocessingTest(CoverageTest):
import multiprocessing
import os
import time
+ import sys
def func(x):
# Need to pause, or the tasks go too quick, and some processes
@@ -240,6 +242,7 @@ class MultiprocessingTest(CoverageTest):
return os.getpid(), y
if __name__ == "__main__":
+ if len(sys.argv) > 1: multiprocessing.set_start_method(sys.argv[1])
pool = multiprocessing.Pool(3)
inputs = range(30)
outputs = pool.imap_unordered(func, inputs)
@@ -253,16 +256,25 @@ class MultiprocessingTest(CoverageTest):
pool.join()
""")
- out = self.run_command(
- "coverage run --concurrency=multiprocessing multi.py"
- )
- total = sum(x*x if x%2 else x*x*x for x in range(30))
- self.assertEqual(out.rstrip(), "3 pids, total = %d" % total)
+ if env.PYVERSION >= (3, 4):
+ start_methods = ['fork', 'spawn']
+ else:
+ start_methods = ['']
+
+ for start_method in start_methods:
+ if start_method and start_method not in multiprocessing.get_all_start_methods():
+ continue
+
+ out = self.run_command(
+ "coverage run --concurrency=multiprocessing multi.py %s" % start_method
+ )
+ total = sum(x*x if x%2 else x*x*x for x in range(30))
+ self.assertEqual(out.rstrip(), "3 pids, total = %d" % total)
- self.run_command("coverage combine")
- out = self.run_command("coverage report -m")
- last_line = self.squeezed_lines(out)[-1]
- self.assertEqual(last_line, "multi.py 21 0 100%")
+ self.run_command("coverage combine")
+ out = self.run_command("coverage report -m")
+ last_line = self.squeezed_lines(out)[-1]
+ self.assertEqual(last_line, "multi.py 23 0 100%")
def print_simple_annotation(code, linenos):