diff options
author | Ned Batchelder <ned@nedbatchelder.com> | 2016-01-02 10:18:04 -0500 |
---|---|---|
committer | Ned Batchelder <ned@nedbatchelder.com> | 2016-01-02 10:18:04 -0500 |
commit | baf18bed45cbd943f379f9ca4e7747fb607552c8 (patch) | |
tree | 8dd7986ce95861ebde3a2a95837a7ad20c01a96e | |
parent | 82dae969e9318e35bccfc08c0e652cbb931403c6 (diff) | |
download | python-coveragepy-baf18bed45cbd943f379f9ca4e7747fb607552c8.tar.gz |
Handle yield-from and await. All tests pass
-rw-r--r-- | coverage/parser.py | 88 | ||||
-rw-r--r-- | coverage/test_helpers.py | 12 | ||||
-rw-r--r-- | tests/test_arcs.py | 20 |
3 files changed, 77 insertions, 43 deletions
diff --git a/coverage/parser.py b/coverage/parser.py index 2396fb8..0462802 100644 --- a/coverage/parser.py +++ b/coverage/parser.py @@ -327,11 +327,17 @@ class AstArcAnalyzer(object): def __init__(self, text): self.root_node = ast.parse(text) if int(os.environ.get("COVERAGE_ASTDUMP", 0)): + # Dump the AST so that failing tests have helpful output. ast_dump(self.root_node) self.arcs = None self.block_stack = [] + def collect_arcs(self): + self.arcs = set() + self.add_arcs_for_code_objects(self.root_node) + return self.arcs + def blocks(self): """Yield the blocks in nearest-to-farthest order.""" return reversed(self.block_stack) @@ -361,16 +367,19 @@ class AstArcAnalyzer(object): def line_default(self, node): return node.lineno - def collect_arcs(self): - self.arcs = set() - self.add_arcs_for_code_objects(self.root_node) - return self.arcs - def add_arcs(self, node): - """add the arcs for `node`. + """Add the arcs for `node`. Return a set of line numbers, exits from this node to the next. """ + # Yield-froms and awaits can appear anywhere. + # TODO: this is probably over-doing it, and too expensive. Can we + # instrument the ast walking to see how many nodes we are revisiting? + if isinstance(node, ast.stmt): + for name, value in ast.iter_fields(node): + if isinstance(value, ast.expr) and self.contains_return_expression(value): + self.process_return_exits([self.line_for_node(node)]) + break node_name = node.__class__.__name__ handler = getattr(self, "handle_" + node_name, self.handle_default) return handler(node) @@ -404,6 +413,7 @@ class AstArcAnalyzer(object): # TODO: multi-line listcomps # TODO: nested function definitions # TODO: multiple `except` clauses + # TODO: return->finally def process_break_exits(self, exits): for block in self.blocks(): @@ -443,6 +453,7 @@ class AstArcAnalyzer(object): def process_return_exits(self, exits): for block in self.blocks(): + # TODO: need a check here for TryBlock if isinstance(block, FunctionBlock): # TODO: what if there is no enclosing function? for exit in exits: @@ -587,6 +598,7 @@ class AstArcAnalyzer(object): def handle_default(self, node): node_name = node.__class__.__name__ if node_name not in ["Assign", "Assert", "AugAssign", "Expr", "Import", "Pass", "Print"]: + # TODO: put 1/0 here to find unhandled nodes. print("*** Unhandled: {0}".format(node)) return set([self.line_for_node(node)]) @@ -628,6 +640,14 @@ class AstArcAnalyzer(object): self.arcs.add((start, -start)) # TODO: test multi-line lambdas + def contains_return_expression(self, node): + """Is there a yield-from or await in `node` someplace?""" + for child in ast.walk(node): + if child.__class__.__name__ in ["YieldFrom", "Await"]: + return True + + return False + ## Opcodes that guide the ByteParser. @@ -1045,7 +1065,13 @@ class Chunk(object): ) -SKIP_FIELDS = ["ctx"] +SKIP_DUMP_FIELDS = ["ctx"] + +def is_simple_value(value): + return ( + value in [None, [], (), {}, set()] or + isinstance(value, (string_class, int, float)) + ) def ast_dump(node, depth=0): indent = " " * depth @@ -1055,30 +1081,36 @@ def ast_dump(node, depth=0): lineno = getattr(node, "lineno", None) if lineno is not None: - linemark = " @ {0}".format(lineno) + linemark = " @ {0}".format(node.lineno) else: linemark = "" - print("{0}<{1}{2}".format(indent, node.__class__.__name__, linemark)) - - indent += " " - for field_name, value in ast.iter_fields(node): - if field_name in SKIP_FIELDS: - continue - prefix = "{0}{1}:".format(indent, field_name) - if value is None: - print("{0} None".format(prefix)) - elif isinstance(value, (string_class, int, float)): - print("{0} {1!r}".format(prefix, value)) - elif isinstance(value, list): - if value == []: - print("{0} []".format(prefix)) - else: + head = "{0}<{1}{2}".format(indent, node.__class__.__name__, linemark) + + named_fields = [ + (name, value) + for name, value in ast.iter_fields(node) + if name not in SKIP_DUMP_FIELDS + ] + if not named_fields: + print("{0}>".format(head)) + elif len(named_fields) == 1 and is_simple_value(named_fields[0][1]): + field_name, value = named_fields[0] + print("{0} {1}: {2!r}>".format(head, field_name, value)) + else: + print(head) + print("{0}# mro: {1}".format(indent, ", ".join(c.__name__ for c in node.__class__.__mro__[1:]))) + next_indent = indent + " " + for field_name, value in named_fields: + prefix = "{0}{1}:".format(next_indent, field_name) + if is_simple_value(value): + print("{0} {1!r}".format(prefix, value)) + elif isinstance(value, list): print("{0} [".format(prefix)) for n in value: ast_dump(n, depth + 8) - print("{0}]".format(indent)) - else: - print(prefix) - ast_dump(value, depth + 8) + print("{0}]".format(next_indent)) + else: + print(prefix) + ast_dump(value, depth + 8) - print("{0}>".format(" " * depth)) + print("{0}>".format(indent)) diff --git a/coverage/test_helpers.py b/coverage/test_helpers.py index 50cc329..092daa0 100644 --- a/coverage/test_helpers.py +++ b/coverage/test_helpers.py @@ -162,20 +162,20 @@ class StdStreamCapturingMixin(TestCase): # nose keeps stdout from littering the screen, so we can safely Tee it, # but it doesn't capture stderr, so we don't want to Tee stderr to the # real stderr, since it will interfere with our nice field of dots. - self.old_stdout = sys.stdout + old_stdout = sys.stdout self.captured_stdout = StringIO() sys.stdout = Tee(sys.stdout, self.captured_stdout) - self.old_stderr = sys.stderr + old_stderr = sys.stderr self.captured_stderr = StringIO() sys.stderr = self.captured_stderr - self.addCleanup(self.cleanup_std_streams) + self.addCleanup(self.cleanup_std_streams, old_stdout, old_stderr) - def cleanup_std_streams(self): + def cleanup_std_streams(self, old_stdout, old_stderr): """Restore stdout and stderr.""" - sys.stdout = self.old_stdout - sys.stderr = self.old_stderr + sys.stdout = old_stdout + sys.stderr = old_stderr def stdout(self): """Return the data written to stdout during the test.""" diff --git a/tests/test_arcs.py b/tests/test_arcs.py index 303b10e..6ba663b 100644 --- a/tests/test_arcs.py +++ b/tests/test_arcs.py @@ -867,23 +867,25 @@ class AsyncTest(CoverageTest): self.check_coverage("""\ import asyncio - async def compute(x, y): + async def compute(x, y): # 3 print("Compute %s + %s ..." % (x, y)) await asyncio.sleep(0.001) - return x + y + return x + y # 6 - async def print_sum(x, y): - result = await compute(x, y) + async def print_sum(x, y): # 8 + result = (0 + + await compute(x, y) # A + ) print("%s + %s = %s" % (x, y, result)) - loop = asyncio.get_event_loop() + loop = asyncio.get_event_loop() # E loop.run_until_complete(print_sum(1, 2)) - loop.close() + loop.close() # G """, arcz= - ".1 13 38 8C CD DE E. " - ".4 45 56 6-3 " - ".9 9A A-8", + ".1 13 38 8E EF FG G. " + ".4 45 56 5-3 6-3 " + ".9 9-8 9C C-8", ) self.assertEqual(self.stdout(), "Compute 1 + 2 ...\n1 + 2 = 3\n") |