summaryrefslogtreecommitdiff
path: root/astroid/inference_tip.py
blob: 9eda5b4fc795dd5310fc466c7314dceee0d4b05e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

"""Transform utilities (filters and decorator)."""

from __future__ import annotations

from collections import OrderedDict
from collections.abc import Generator
from typing import Any, TypeVar

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import (
    InferenceResult,
    InferFn,
    TransformFn,
)

_cache: OrderedDict[
    tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult]
] = OrderedDict()

_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set()

_NodesT = TypeVar("_NodesT", bound=NodeNG)


def clear_inference_tip_cache() -> None:
    """Clear the inference tips cache."""
    _cache.clear()


def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]:
    """Cache decorator used for inference tips."""

    def inner(
        node: _NodesT,
        context: InferenceContext | None = None,
        **kwargs: Any,
    ) -> Generator[InferenceResult, None, None]:
        partial_cache_key = (func, node)
        if partial_cache_key in _CURRENTLY_INFERRING:
            # If through recursion we end up trying to infer the same
            # func + node we raise here.
            _CURRENTLY_INFERRING.remove(partial_cache_key)
            raise UseInferenceDefault
        if context is not None and context.is_empty():
            # Fresh, empty contexts will defeat the cache.
            context = None
        try:
            yield from _cache[func, node, context]
            return
        except KeyError:
            # Recursion guard with a partial cache key.
            # Using the full key causes a recursion error on PyPy.
            # It's a pragmatic compromise to avoid so much recursive inference
            # with slightly different contexts while still passing the simple
            # test cases included with this commit.
            _CURRENTLY_INFERRING.add(partial_cache_key)
            try:
                # May raise UseInferenceDefault
                result = _cache[func, node, context] = list(
                    func(node, context, **kwargs)
                )
            finally:
                # Remove recursion guard.
                try:
                    _CURRENTLY_INFERRING.remove(partial_cache_key)
                except KeyError:
                    pass  # Recursion may beat us to the punch.

                if len(_cache) > 64:
                    _cache.popitem(last=False)

        # https://github.com/pylint-dev/pylint/issues/8686
        yield from result  # pylint: disable=used-before-assignment

    return inner


def inference_tip(
    infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False
) -> TransformFn[_NodesT]:
    """Given an instance specific inference function, return a function to be
    given to AstroidManager().register_transform to set this inference function.

    :param bool raise_on_overwrite: Raise an `InferenceOverwriteError`
        if the inference tip will overwrite another. Used for debugging

    Typical usage

    .. sourcecode:: python

       AstroidManager().register_transform(Call, inference_tip(infer_named_tuple),
                                  predicate)

    .. Note::

        Using an inference tip will override
        any previously set inference tip for the given
        node. Use a predicate in the transform to prevent
        excess overwrites.
    """

    def transform(
        node: _NodesT, infer_function: InferFn[_NodesT] = infer_function
    ) -> _NodesT:
        if (
            raise_on_overwrite
            and node._explicit_inference is not None
            and node._explicit_inference is not infer_function
        ):
            raise InferenceOverwriteError(
                "Inference already set to {existing_inference}. "
                "Trying to overwrite with {new_inference} for {node}".format(
                    existing_inference=infer_function,
                    new_inference=node._explicit_inference,
                    node=node,
                )
            )
        node._explicit_inference = _inference_tip_cached(infer_function)
        return node

    return transform