summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ChangeLog4
-rw-r--r--astroid/protocols.py2
-rw-r--r--astroid/tests/unittest_inference.py33
3 files changed, 37 insertions, 2 deletions
diff --git a/ChangeLog b/ChangeLog
index 79ebc06e..1a4e4eff 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -6,6 +6,10 @@ What's New in astroid 2.3.0?
============================
Release Date: TBA
+* Drop a superfluous and wrong callcontext when inferring the result of a context manager
+
+ Close PyCQA/pylint#2859
+
* ``igetattr`` raises ``InferenceError`` on re-inference of the same object
This prevents ``StopIteration`` from leaking when we encounter the same
diff --git a/astroid/protocols.py b/astroid/protocols.py
index db0f5d9d..bf497ff1 100644
--- a/astroid/protocols.py
+++ b/astroid/protocols.py
@@ -494,8 +494,6 @@ def _infer_context_manager(self, mgr, context):
raise exceptions.InferenceError(node=inferred)
if not isinstance(enter, bases.BoundMethod):
raise exceptions.InferenceError(node=enter)
- if not context.callcontext:
- context.callcontext = contextmod.CallContext(args=[inferred])
yield from enter.infer_call_result(self, context)
else:
raise exceptions.InferenceError(node=mgr)
diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py
index 0dd23d87..b388655c 100644
--- a/astroid/tests/unittest_inference.py
+++ b/astroid/tests/unittest_inference.py
@@ -5213,5 +5213,38 @@ def test_prevent_recursion_error_in_igetattr_and_context_manager_inference():
assert next(node.infer()) is util.Uninferable
+def test_infer_context_manager_with_unknown_args():
+ code = """
+ class client_log(object):
+ def __init__(self, client):
+ self.client = client
+ def __enter__(self):
+ return self.client
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ with client_log(None) as c:
+ c #@
+ """
+ node = extract_node(code)
+ assert next(node.infer()) is util.Uninferable
+
+ # But if we know the argument, then it is easy
+ code = """
+ class client_log(object):
+ def __init__(self, client=24):
+ self.client = client
+ def __enter__(self):
+ return self.client
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ with client_log(None) as c:
+ c #@
+ """
+ node = extract_node(code)
+ assert isinstance(next(node.infer()), nodes.Const)
+
+
if __name__ == "__main__":
unittest.main()