diff options
Diffstat (limited to 'astroid')
-rw-r--r-- | astroid/inference_tip.py | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 5b855c9e..92cb6b4f 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -9,6 +9,7 @@ from __future__ import annotations import sys from collections.abc import Callable, Iterator +from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG from astroid.typing import InferenceResult, InferFn @@ -20,7 +21,11 @@ else: _P = ParamSpec("_P") -_cache: dict[tuple[InferFn, NodeNG], list[InferenceResult] | None] = {} +_cache: dict[ + tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] +] = {} + +_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set() def clear_inference_tip_cache() -> None: @@ -35,16 +40,25 @@ def _inference_tip_cached( def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]: node = args[0] - try: - result = _cache[func, node] + context = args[1] + partial_cache_key = (func, node) + if partial_cache_key in _CURRENTLY_INFERRING: # If through recursion we end up trying to infer the same # func + node we raise here. - if result is None: - raise UseInferenceDefault() + raise UseInferenceDefault + try: + return _cache[func, node, context] except KeyError: - _cache[func, node] = None - result = _cache[func, node] = list(func(*args, **kwargs)) - assert result + # Recursion guard with a partial cache key. + # Using the full key causes a recursion error on PyPy. + # It's a pragmatic compromise to avoid so much recursive inference + # with slightly different contexts while still passing the simple + # test cases included with this commit. + _CURRENTLY_INFERRING.add(partial_cache_key) + result = _cache[func, node, context] = list(func(*args, **kwargs)) + # Remove recursion guard. + _CURRENTLY_INFERRING.remove(partial_cache_key) + return iter(result) return inner |