summaryrefslogtreecommitdiff
path: root/astroid/context.py
blob: cccc81c0778564fb4a13149d38897538e46fceba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

"""Various context related utilities, including inference and call contexts."""

from __future__ import annotations

import contextlib
import pprint
from collections.abc import Iterator
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple

from astroid.typing import InferenceResult, SuccessfulInferenceResult

if TYPE_CHECKING:
    from astroid import constraint, nodes
    from astroid.nodes.node_classes import Keyword, NodeNG

_InferenceCache = Dict[
    Tuple["NodeNG", Optional[str], Optional[str], Optional[str]], Sequence["NodeNG"]
]

_INFERENCE_CACHE: _InferenceCache = {}


def _invalidate_cache() -> None:
    _INFERENCE_CACHE.clear()


class InferenceContext:
    """Provide context for inference.

    Store already inferred nodes to save time
    Account for already visited nodes to stop infinite recursion
    """

    __slots__ = (
        "path",
        "lookupname",
        "callcontext",
        "boundnode",
        "extra_context",
        "constraints",
        "_nodes_inferred",
    )

    max_inferred = 100

    def __init__(
        self,
        path: set[tuple[nodes.NodeNG, str | None]] | None = None,
        nodes_inferred: list[int] | None = None,
    ) -> None:
        if nodes_inferred is None:
            self._nodes_inferred = [0]
        else:
            self._nodes_inferred = nodes_inferred

        self.path = path or set()
        """Path of visited nodes and their lookupname.

        Currently this key is ``(node, context.lookupname)``
        """
        self.lookupname: str | None = None
        """The original name of the node.

        e.g.
        foo = 1
        The inference of 'foo' is nodes.Const(1) but the lookup name is 'foo'
        """
        self.callcontext: CallContext | None = None
        """The call arguments and keywords for the given context."""
        self.boundnode: SuccessfulInferenceResult | None = None
        """The bound node of the given context.

        e.g. the bound node of object.__new__(cls) is the object node
        """
        self.extra_context: dict[SuccessfulInferenceResult, InferenceContext] = {}
        """Context that needs to be passed down through call stacks for call arguments."""

        self.constraints: dict[str, dict[nodes.If, set[constraint.Constraint]]] = {}
        """The constraints on nodes."""

    @property
    def nodes_inferred(self) -> int:
        """
        Number of nodes inferred in this context and all its clones/descendents.

        Wrap inner value in a mutable cell to allow for mutating a class
        variable in the presence of __slots__
        """
        return self._nodes_inferred[0]

    @nodes_inferred.setter
    def nodes_inferred(self, value: int) -> None:
        self._nodes_inferred[0] = value

    @property
    def inferred(self) -> _InferenceCache:
        """
        Inferred node contexts to their mapped results.

        Currently the key is ``(node, lookupname, callcontext, boundnode)``
        and the value is tuple of the inferred results
        """
        return _INFERENCE_CACHE

    def push(self, node: nodes.NodeNG) -> bool:
        """Push node into inference path.

        Allows one to see if the given node has already
        been looked at for this inference context
        """
        name = self.lookupname
        if (node, name) in self.path:
            return True

        self.path.add((node, name))
        return False

    def clone(self) -> InferenceContext:
        """Clone inference path.

        For example, each side of a binary operation (BinOp)
        starts with the same context but diverge as each side is inferred
        so the InferenceContext will need be cloned
        """
        # XXX copy lookupname/callcontext ?
        clone = InferenceContext(self.path.copy(), nodes_inferred=self._nodes_inferred)
        clone.callcontext = self.callcontext
        clone.boundnode = self.boundnode
        clone.extra_context = self.extra_context
        clone.constraints = self.constraints.copy()
        return clone

    @contextlib.contextmanager
    def restore_path(self) -> Iterator[None]:
        path = set(self.path)
        yield
        self.path = path

    def is_empty(self) -> bool:
        return (
            not self.path
            and not self.nodes_inferred
            and not self.callcontext
            and not self.boundnode
            and not self.lookupname
            and not self.callcontext
            and not self.extra_context
            and not self.constraints
        )

    def __str__(self) -> str:
        state = (
            f"{field}={pprint.pformat(getattr(self, field), width=80 - len(field))}"
            for field in self.__slots__
        )
        return "{}({})".format(type(self).__name__, ",\n    ".join(state))


class CallContext:
    """Holds information for a call site."""

    __slots__ = ("args", "keywords", "callee")

    def __init__(
        self,
        args: list[NodeNG],
        keywords: list[Keyword] | None = None,
        callee: InferenceResult | None = None,
    ):
        self.args = args  # Call positional arguments
        if keywords:
            arg_value_pairs = [(arg.arg, arg.value) for arg in keywords]
        else:
            arg_value_pairs = []
        self.keywords = arg_value_pairs  # Call keyword arguments
        self.callee = callee  # Function being called


def copy_context(context: InferenceContext | None) -> InferenceContext:
    """Clone a context if given, or return a fresh context."""
    if context is not None:
        return context.clone()

    return InferenceContext()


def bind_context_to_node(
    context: InferenceContext | None, node: SuccessfulInferenceResult
) -> InferenceContext:
    """Give a context a boundnode
    to retrieve the correct function name or attribute value
    with from further inference.

    Do not use an existing context since the boundnode could then
    be incorrectly propagated higher up in the call stack.
    """
    context = copy_context(context)
    context.boundnode = node
    return context