diff options
-rw-r--r-- | astroid/protocols.py | 2 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 29 |
2 files changed, 31 insertions, 0 deletions
diff --git a/astroid/protocols.py b/astroid/protocols.py index 8bd3813..21f0c29 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -414,6 +414,8 @@ def _infer_context_manager(self, mgr, context): return if not isinstance(enter, bases.BoundMethod): return + if not context.callcontext: + context.callcontext = contextmod.CallContext(args=[inferred]) for result in enter.infer_call_result(self, context): yield result diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 6756838..3dbcd5c 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -3059,6 +3059,35 @@ class InferenceTest(resources.SysPathSetup, unittest.TestCase): inferred = next(node.infer()) self.assertRaises(InferenceError, next, inferred.infer_call_result(node)) + def test_context_call_for_context_managers(self): + ast_nodes = test_utils.extract_node(''' + class A: + def __enter__(self): + return self + class B: + __enter__ = lambda self: self + class C: + @property + def a(self): return A() + def __enter__(self): + return self.a + with A() as a: + a #@ + with B() as b: + b #@ + with C() as c: + c #@ + ''') + first_a = next(ast_nodes[0].infer()) + self.assertIsInstance(first_a, Instance) + self.assertEqual(first_a.name, 'A') + second_b = next(ast_nodes[1].infer()) + self.assertIsInstance(second_b, Instance) + self.assertEqual(second_b.name, 'B') + third_c = next(ast_nodes[2].infer()) + self.assertIsInstance(third_c, Instance) + self.assertEqual(third_c.name, 'A') + class GetattrTest(unittest.TestCase): |