summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNed Batchelder <ned@nedbatchelder.com>2016-01-02 10:18:04 -0500
committerNed Batchelder <ned@nedbatchelder.com>2016-01-02 10:18:04 -0500
commitbaf18bed45cbd943f379f9ca4e7747fb607552c8 (patch)
tree8dd7986ce95861ebde3a2a95837a7ad20c01a96e
parent82dae969e9318e35bccfc08c0e652cbb931403c6 (diff)
downloadpython-coveragepy-baf18bed45cbd943f379f9ca4e7747fb607552c8.tar.gz
Handle yield-from and await. All tests pass
-rw-r--r--coverage/parser.py88
-rw-r--r--coverage/test_helpers.py12
-rw-r--r--tests/test_arcs.py20
3 files changed, 77 insertions, 43 deletions
diff --git a/coverage/parser.py b/coverage/parser.py
index 2396fb8..0462802 100644
--- a/coverage/parser.py
+++ b/coverage/parser.py
@@ -327,11 +327,17 @@ class AstArcAnalyzer(object):
def __init__(self, text):
self.root_node = ast.parse(text)
if int(os.environ.get("COVERAGE_ASTDUMP", 0)):
+ # Dump the AST so that failing tests have helpful output.
ast_dump(self.root_node)
self.arcs = None
self.block_stack = []
+ def collect_arcs(self):
+ self.arcs = set()
+ self.add_arcs_for_code_objects(self.root_node)
+ return self.arcs
+
def blocks(self):
"""Yield the blocks in nearest-to-farthest order."""
return reversed(self.block_stack)
@@ -361,16 +367,19 @@ class AstArcAnalyzer(object):
def line_default(self, node):
return node.lineno
- def collect_arcs(self):
- self.arcs = set()
- self.add_arcs_for_code_objects(self.root_node)
- return self.arcs
-
def add_arcs(self, node):
- """add the arcs for `node`.
+ """Add the arcs for `node`.
Return a set of line numbers, exits from this node to the next.
"""
+ # Yield-froms and awaits can appear anywhere.
+ # TODO: this is probably over-doing it, and too expensive. Can we
+ # instrument the ast walking to see how many nodes we are revisiting?
+ if isinstance(node, ast.stmt):
+ for name, value in ast.iter_fields(node):
+ if isinstance(value, ast.expr) and self.contains_return_expression(value):
+ self.process_return_exits([self.line_for_node(node)])
+ break
node_name = node.__class__.__name__
handler = getattr(self, "handle_" + node_name, self.handle_default)
return handler(node)
@@ -404,6 +413,7 @@ class AstArcAnalyzer(object):
# TODO: multi-line listcomps
# TODO: nested function definitions
# TODO: multiple `except` clauses
+ # TODO: return->finally
def process_break_exits(self, exits):
for block in self.blocks():
@@ -443,6 +453,7 @@ class AstArcAnalyzer(object):
def process_return_exits(self, exits):
for block in self.blocks():
+ # TODO: need a check here for TryBlock
if isinstance(block, FunctionBlock):
# TODO: what if there is no enclosing function?
for exit in exits:
@@ -587,6 +598,7 @@ class AstArcAnalyzer(object):
def handle_default(self, node):
node_name = node.__class__.__name__
if node_name not in ["Assign", "Assert", "AugAssign", "Expr", "Import", "Pass", "Print"]:
+ # TODO: put 1/0 here to find unhandled nodes.
print("*** Unhandled: {0}".format(node))
return set([self.line_for_node(node)])
@@ -628,6 +640,14 @@ class AstArcAnalyzer(object):
self.arcs.add((start, -start))
# TODO: test multi-line lambdas
+ def contains_return_expression(self, node):
+ """Is there a yield-from or await in `node` someplace?"""
+ for child in ast.walk(node):
+ if child.__class__.__name__ in ["YieldFrom", "Await"]:
+ return True
+
+ return False
+
## Opcodes that guide the ByteParser.
@@ -1045,7 +1065,13 @@ class Chunk(object):
)
-SKIP_FIELDS = ["ctx"]
+SKIP_DUMP_FIELDS = ["ctx"]
+
+def is_simple_value(value):
+ return (
+ value in [None, [], (), {}, set()] or
+ isinstance(value, (string_class, int, float))
+ )
def ast_dump(node, depth=0):
indent = " " * depth
@@ -1055,30 +1081,36 @@ def ast_dump(node, depth=0):
lineno = getattr(node, "lineno", None)
if lineno is not None:
- linemark = " @ {0}".format(lineno)
+ linemark = " @ {0}".format(node.lineno)
else:
linemark = ""
- print("{0}<{1}{2}".format(indent, node.__class__.__name__, linemark))
-
- indent += " "
- for field_name, value in ast.iter_fields(node):
- if field_name in SKIP_FIELDS:
- continue
- prefix = "{0}{1}:".format(indent, field_name)
- if value is None:
- print("{0} None".format(prefix))
- elif isinstance(value, (string_class, int, float)):
- print("{0} {1!r}".format(prefix, value))
- elif isinstance(value, list):
- if value == []:
- print("{0} []".format(prefix))
- else:
+ head = "{0}<{1}{2}".format(indent, node.__class__.__name__, linemark)
+
+ named_fields = [
+ (name, value)
+ for name, value in ast.iter_fields(node)
+ if name not in SKIP_DUMP_FIELDS
+ ]
+ if not named_fields:
+ print("{0}>".format(head))
+ elif len(named_fields) == 1 and is_simple_value(named_fields[0][1]):
+ field_name, value = named_fields[0]
+ print("{0} {1}: {2!r}>".format(head, field_name, value))
+ else:
+ print(head)
+ print("{0}# mro: {1}".format(indent, ", ".join(c.__name__ for c in node.__class__.__mro__[1:])))
+ next_indent = indent + " "
+ for field_name, value in named_fields:
+ prefix = "{0}{1}:".format(next_indent, field_name)
+ if is_simple_value(value):
+ print("{0} {1!r}".format(prefix, value))
+ elif isinstance(value, list):
print("{0} [".format(prefix))
for n in value:
ast_dump(n, depth + 8)
- print("{0}]".format(indent))
- else:
- print(prefix)
- ast_dump(value, depth + 8)
+ print("{0}]".format(next_indent))
+ else:
+ print(prefix)
+ ast_dump(value, depth + 8)
- print("{0}>".format(" " * depth))
+ print("{0}>".format(indent))
diff --git a/coverage/test_helpers.py b/coverage/test_helpers.py
index 50cc329..092daa0 100644
--- a/coverage/test_helpers.py
+++ b/coverage/test_helpers.py
@@ -162,20 +162,20 @@ class StdStreamCapturingMixin(TestCase):
# nose keeps stdout from littering the screen, so we can safely Tee it,
# but it doesn't capture stderr, so we don't want to Tee stderr to the
# real stderr, since it will interfere with our nice field of dots.
- self.old_stdout = sys.stdout
+ old_stdout = sys.stdout
self.captured_stdout = StringIO()
sys.stdout = Tee(sys.stdout, self.captured_stdout)
- self.old_stderr = sys.stderr
+ old_stderr = sys.stderr
self.captured_stderr = StringIO()
sys.stderr = self.captured_stderr
- self.addCleanup(self.cleanup_std_streams)
+ self.addCleanup(self.cleanup_std_streams, old_stdout, old_stderr)
- def cleanup_std_streams(self):
+ def cleanup_std_streams(self, old_stdout, old_stderr):
"""Restore stdout and stderr."""
- sys.stdout = self.old_stdout
- sys.stderr = self.old_stderr
+ sys.stdout = old_stdout
+ sys.stderr = old_stderr
def stdout(self):
"""Return the data written to stdout during the test."""
diff --git a/tests/test_arcs.py b/tests/test_arcs.py
index 303b10e..6ba663b 100644
--- a/tests/test_arcs.py
+++ b/tests/test_arcs.py
@@ -867,23 +867,25 @@ class AsyncTest(CoverageTest):
self.check_coverage("""\
import asyncio
- async def compute(x, y):
+ async def compute(x, y): # 3
print("Compute %s + %s ..." % (x, y))
await asyncio.sleep(0.001)
- return x + y
+ return x + y # 6
- async def print_sum(x, y):
- result = await compute(x, y)
+ async def print_sum(x, y): # 8
+ result = (0 +
+ await compute(x, y) # A
+ )
print("%s + %s = %s" % (x, y, result))
- loop = asyncio.get_event_loop()
+ loop = asyncio.get_event_loop() # E
loop.run_until_complete(print_sum(1, 2))
- loop.close()
+ loop.close() # G
""",
arcz=
- ".1 13 38 8C CD DE E. "
- ".4 45 56 6-3 "
- ".9 9A A-8",
+ ".1 13 38 8E EF FG G. "
+ ".4 45 56 5-3 6-3 "
+ ".9 9-8 9C C-8",
)
self.assertEqual(self.stdout(), "Compute 1 + 2 ...\n1 + 2 = 3\n")