summaryrefslogtreecommitdiff
path: root/astroid/constraint.py
blob: 6e23b592f1d0e66f54b3b524df320fc70d9ccfb0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

"""Classes representing different types of constraints on inference values."""
from __future__ import annotations

import sys
from abc import ABC, abstractmethod
from collections.abc import Iterator
from typing import Union

from astroid import bases, nodes, util
from astroid.typing import InferenceResult

if sys.version_info >= (3, 11):
    from typing import Self
else:
    from typing_extensions import Self

_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]


class Constraint(ABC):
    """Represents a single constraint on a variable."""

    def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
        self.node = node
        """The node that this constraint applies to."""
        self.negate = negate
        """True if this constraint is negated. E.g., "is not" instead of "is"."""

    @classmethod
    @abstractmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node matched from expr, if expr matches
        the constraint pattern.

        If negate is True, negate the constraint.
        """

    @abstractmethod
    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True if this constraint is satisfied by the given inferred value."""


class NoneConstraint(Constraint):
    """Represents an "is None" or "is not None" constraint."""

    CONST_NONE: nodes.Const = nodes.Const(None)

    @classmethod
    def match(
        cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
    ) -> Self | None:
        """Return a new constraint for node matched from expr, if expr matches
        the constraint pattern.

        Negate the constraint based on the value of negate.
        """
        if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
            left = expr.left
            op, right = expr.ops[0]
            if op in {"is", "is not"} and (
                _matches(left, node) and _matches(right, cls.CONST_NONE)
            ):
                negate = (op == "is" and negate) or (op == "is not" and not negate)
                return cls(node=node, negate=negate)

        return None

    def satisfied_by(self, inferred: InferenceResult) -> bool:
        """Return True if this constraint is satisfied by the given inferred value."""
        # Assume true if uninferable
        if isinstance(inferred, util.UninferableBase):
            return True

        # Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
        return self.negate ^ _matches(inferred, self.CONST_NONE)


def get_constraints(
    expr: _NameNodes, frame: nodes.LocalsDictNodeNG
) -> dict[nodes.If, set[Constraint]]:
    """Returns the constraints for the given expression.

    The returned dictionary maps the node where the constraint was generated to the
    corresponding constraint(s).

    Constraints are computed statically by analysing the code surrounding expr.
    Currently this only supports constraints generated from if conditions.
    """
    current_node: nodes.NodeNG | None = expr
    constraints_mapping: dict[nodes.If, set[Constraint]] = {}
    while current_node is not None and current_node is not frame:
        parent = current_node.parent
        if isinstance(parent, nodes.If):
            branch, _ = parent.locate_child(current_node)
            constraints: set[Constraint] | None = None
            if branch == "body":
                constraints = set(_match_constraint(expr, parent.test))
            elif branch == "orelse":
                constraints = set(_match_constraint(expr, parent.test, invert=True))

            if constraints:
                constraints_mapping[parent] = constraints
        current_node = parent

    return constraints_mapping


ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
"""All supported constraint types."""


def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
    """Returns True if the two nodes match."""
    if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
        return node1.name == node2.name
    if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
        return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
    if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
        return node1.value == node2.value

    return False


def _match_constraint(
    node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
) -> Iterator[Constraint]:
    """Yields all constraint patterns for node that match."""
    for constraint_cls in ALL_CONSTRAINT_CLASSES:
        constraint = constraint_cls.match(node, expr, invert)
        if constraint:
            yield constraint