diff options
-rw-r--r-- | ChangeLog | 4 | ||||
-rw-r--r-- | astroid/protocols.py | 2 | ||||
-rw-r--r-- | astroid/tests/unittest_inference.py | 33 |
3 files changed, 37 insertions, 2 deletions
@@ -6,6 +6,10 @@ What's New in astroid 2.3.0? ============================ Release Date: TBA +* Drop a superfluous and wrong callcontext when inferring the result of a context manager + + Close PyCQA/pylint#2859 + * ``igetattr`` raises ``InferenceError`` on re-inference of the same object This prevents ``StopIteration`` from leaking when we encounter the same diff --git a/astroid/protocols.py b/astroid/protocols.py index db0f5d9d..bf497ff1 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -494,8 +494,6 @@ def _infer_context_manager(self, mgr, context): raise exceptions.InferenceError(node=inferred) if not isinstance(enter, bases.BoundMethod): raise exceptions.InferenceError(node=enter) - if not context.callcontext: - context.callcontext = contextmod.CallContext(args=[inferred]) yield from enter.infer_call_result(self, context) else: raise exceptions.InferenceError(node=mgr) diff --git a/astroid/tests/unittest_inference.py b/astroid/tests/unittest_inference.py index 0dd23d87..b388655c 100644 --- a/astroid/tests/unittest_inference.py +++ b/astroid/tests/unittest_inference.py @@ -5213,5 +5213,38 @@ def test_prevent_recursion_error_in_igetattr_and_context_manager_inference(): assert next(node.infer()) is util.Uninferable +def test_infer_context_manager_with_unknown_args(): + code = """ + class client_log(object): + def __init__(self, client): + self.client = client + def __enter__(self): + return self.client + def __exit__(self, exc_type, exc_value, traceback): + pass + + with client_log(None) as c: + c #@ + """ + node = extract_node(code) + assert next(node.infer()) is util.Uninferable + + # But if we know the argument, then it is easy + code = """ + class client_log(object): + def __init__(self, client=24): + self.client = client + def __enter__(self): + return self.client + def __exit__(self, exc_type, exc_value, traceback): + pass + + with client_log(None) as c: + c #@ + """ + node = extract_node(code) + assert isinstance(next(node.infer()), nodes.Const) + + if __name__ == "__main__": unittest.main() |