summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog5
-rw-r--r--astroid/protocols.py5
-rw-r--r--astroid/tests/unittest_inference.py30
3 files changed, 39 insertions, 1 deletions
diff --git a/ChangeLog b/ChangeLog
index 33e09594..edaea218 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -56,6 +56,11 @@ Change log for the astroid package (used to be astng)
Close PyCQA/pylint#1843
+ * Fix ``contextlib.contextmanager`` inference for nested
+ context managers
+
+ Close #1699
+
2017-12-15 -- 1.6.0
diff --git a/astroid/protocols.py b/astroid/protocols.py
index 30303dec..bd1d594f 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -443,7 +443,10 @@ def _infer_context_manager(self, mgr, context):
# Get the first yield point. If it has multiple yields,
# then a RuntimeError will be raised.
# TODO(cpopa): Handle flows.
- yield_point = next(func.nodes_of_class(nodes.Yield), None)
+ possible_yield_points = func.nodes_of_class(nodes.Yield)
+ # Ignore yields in nested functions
+ yield_point = next((node for node in possible_yield_points
+ if node.scope() == func), None)
if yield_point:
if not yield_point.value:
# TODO(cpopa): an empty yield. Should be wrapped to Const.
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 98a1b284..6508b3fe 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -2218,6 +2218,36 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
self.assertRaises(InferenceError, next, module['other_decorators'].infer())
self.assertRaises(InferenceError, next, module['no_yield'].infer())
+ def test_nested_contextmanager(self):
+ """Make sure contextmanager works with nested functions
+
+ Previously contextmanager would retrieve
+ the first yield instead of the yield in the
+ proper scope
+
+ Fixes https://github.com/PyCQA/pylint/issues/1746
+ """
+ code = """
+ from contextlib import contextmanager
+
+ @contextmanager
+ def outer():
+ @contextmanager
+ def inner():
+ yield 2
+ yield inner
+
+ with outer() as ctx:
+ ctx #@
+ with ctx() as val:
+ val #@
+ """
+ context_node, value_node = extract_node(code)
+ value = next(value_node.infer())
+ context = next(context_node.infer())
+ assert isinstance(context, nodes.FunctionDef)
+ assert isinstance(value, nodes.Const)
+
def test_unary_op_leaks_stop_iteration(self):
node = extract_node('+[] #@')
self.assertEqual(util.Uninferable, next(node.infer()))