diff options
-rw-r--r-- | ChangeLog | 5 | ||||
-rw-r--r-- | astroid/protocols.py | 5 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 30 |
3 files changed, 39 insertions, 1 deletions
@@ -56,6 +56,11 @@ Change log for the astroid package (used to be astng) Close PyCQA/pylint#1843 + * Fix ``contextlib.contextmanager`` inference for nested + context managers + + Close #1699 + 2017-12-15 -- 1.6.0 diff --git a/astroid/protocols.py b/astroid/protocols.py index 30303dec..bd1d594f 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -443,7 +443,10 @@ def _infer_context_manager(self, mgr, context): # Get the first yield point. If it has multiple yields, # then a RuntimeError will be raised. # TODO(cpopa): Handle flows. - yield_point = next(func.nodes_of_class(nodes.Yield), None) + possible_yield_points = func.nodes_of_class(nodes.Yield) + # Ignore yields in nested functions + yield_point = next((node for node in possible_yield_points + if node.scope() == func), None) if yield_point: if not yield_point.value: # TODO(cpopa): an empty yield. Should be wrapped to Const. diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 98a1b284..6508b3fe 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -2218,6 +2218,36 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): self.assertRaises(InferenceError, next, module['other_decorators'].infer()) self.assertRaises(InferenceError, next, module['no_yield'].infer()) + def test_nested_contextmanager(self): + """Make sure contextmanager works with nested functions + + Previously contextmanager would retrieve + the first yield instead of the yield in the + proper scope + + Fixes https://github.com/PyCQA/pylint/issues/1746 + """ + code = """ + from contextlib import contextmanager + + @contextmanager + def outer(): + @contextmanager + def inner(): + yield 2 + yield inner + + with outer() as ctx: + ctx #@ + with ctx() as val: + val #@ + """ + context_node, value_node = extract_node(code) + value = next(value_node.infer()) + context = next(context_node.infer()) + assert isinstance(context, nodes.FunctionDef) + assert isinstance(value, nodes.Const) + def test_unary_op_leaks_stop_iteration(self): node = extract_node('+[] #@') self.assertEqual(util.Uninferable, next(node.infer())) |