summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudiu Popa <pcmanticore@gmail.com>2015-10-22 10:13:56 +0100
committerClaudiu Popa <pcmanticore@gmail.com>2015-10-22 10:13:56 +0100
commitb65cde98f3abb85f80882ea39de40196f69ab2ec (patch)
tree1b7c21613d528f3b25e6646221f6fc0b0e608794
parentb64661734ed34ac100d229e800d59e1cb4ebf619 (diff)
downloadastroid-b65cde98f3abb85f80882ea39de40196f69ab2ec.tar.gz
Create a context call when __enter__ is called for solving what a context manager returns.
-rw-r--r--astroid/protocols.py2
-rw-r--r--astroid/tests/unittest_inference.py29
2 files changed, 31 insertions, 0 deletions
diff --git a/astroid/protocols.py b/astroid/protocols.py
index 8bd3813..21f0c29 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -414,6 +414,8 @@ def _infer_context_manager(self, mgr, context):
return
if not isinstance(enter, bases.BoundMethod):
return
+ if not context.callcontext:
+ context.callcontext = contextmod.CallContext(args=[inferred])
for result in enter.infer_call_result(self, context):
yield result
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 6756838..3dbcd5c 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -3059,6 +3059,35 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase):
inferred = next(node.infer())
self.assertRaises(InferenceError, next, inferred.infer_call_result(node))
+ def test_context_call_for_context_managers(self):
+ ast_nodes = test_utils.extract_node('''
+ class A:
+ def __enter__(self):
+ return self
+ class B:
+ __enter__ = lambda self: self
+ class C:
+ @property
+ def a(self): return A()
+ def __enter__(self):
+ return self.a
+ with A() as a:
+ a #@
+ with B() as b:
+ b #@
+ with C() as c:
+ c #@
+ ''')
+ first_a = next(ast_nodes[0].infer())
+ self.assertIsInstance(first_a, Instance)
+ self.assertEqual(first_a.name, 'A')
+ second_b = next(ast_nodes[1].infer())
+ self.assertIsInstance(second_b, Instance)
+ self.assertEqual(second_b.name, 'B')
+ third_c = next(ast_nodes[2].infer())
+ self.assertIsInstance(third_c, Instance)
+ self.assertEqual(third_c.name, 'A')
+
class GetattrTest(unittest.TestCase):