diff options
author | David Liu <david@cs.toronto.edu> | 2023-01-06 15:18:30 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-06 21:18:30 +0100 |
commit | 21880ddb44225c16c064ad0b2412e64f2f399709 (patch) | |
tree | 3bb25ca12360aae45d06df3eb72998745663a033 /astroid/constraint.py | |
parent | f476ebca6cfbfb30d1c723ca4144e23645d5164d (diff) | |
download | astroid-git-21880ddb44225c16c064ad0b2412e64f2f399709.tar.gz |
Support "is None" constraints from if statements during inference (#1189)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com>
Co-authored-by: Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
Diffstat (limited to 'astroid/constraint.py')
-rw-r--r-- | astroid/constraint.py | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/astroid/constraint.py b/astroid/constraint.py new file mode 100644 index 00000000..deed9ac5 --- /dev/null +++ b/astroid/constraint.py @@ -0,0 +1,137 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt + +"""Classes representing different types of constraints on inference values.""" +from __future__ import annotations + +import sys +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import Union + +from astroid import bases, nodes, util +from astroid.typing import InferenceResult + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] + + +class Constraint(ABC): + """Represents a single constraint on a variable.""" + + def __init__(self, node: nodes.NodeNG, negate: bool) -> None: + self.node = node + """The node that this constraint applies to.""" + self.negate = negate + """True if this constraint is negated. E.g., "is not" instead of "is".""" + + @classmethod + @abstractmethod + def match( + cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: + """Return a new constraint for node matched from expr, if expr matches + the constraint pattern. + + If negate is True, negate the constraint. + """ + + @abstractmethod + def satisfied_by(self, inferred: InferenceResult) -> bool: + """Return True if this constraint is satisfied by the given inferred value.""" + + +class NoneConstraint(Constraint): + """Represents an "is None" or "is not None" constraint.""" + + CONST_NONE: nodes.Const = nodes.Const(None) + + @classmethod + def match( + cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: + """Return a new constraint for node matched from expr, if expr matches + the constraint pattern. + + Negate the constraint based on the value of negate. + """ + if isinstance(expr, nodes.Compare) and len(expr.ops) == 1: + left = expr.left + op, right = expr.ops[0] + if op in {"is", "is not"} and ( + _matches(left, node) and _matches(right, cls.CONST_NONE) + ): + negate = (op == "is" and negate) or (op == "is not" and not negate) + return cls(node=node, negate=negate) + + return None + + def satisfied_by(self, inferred: InferenceResult) -> bool: + """Return True if this constraint is satisfied by the given inferred value.""" + # Assume true if uninferable + if inferred is util.Uninferable: + return True + + # Return the XOR of self.negate and matches(inferred, self.CONST_NONE) + return self.negate ^ _matches(inferred, self.CONST_NONE) + + +def get_constraints( + expr: _NameNodes, frame: nodes.LocalsDictNodeNG +) -> dict[nodes.If, set[Constraint]]: + """Returns the constraints for the given expression. + + The returned dictionary maps the node where the constraint was generated to the + corresponding constraint(s). + + Constraints are computed statically by analysing the code surrounding expr. + Currently this only supports constraints generated from if conditions. + """ + current_node: nodes.NodeNG | None = expr + constraints_mapping: dict[nodes.If, set[Constraint]] = {} + while current_node is not None and current_node is not frame: + parent = current_node.parent + if isinstance(parent, nodes.If): + branch, _ = parent.locate_child(current_node) + constraints: set[Constraint] | None = None + if branch == "body": + constraints = set(_match_constraint(expr, parent.test)) + elif branch == "orelse": + constraints = set(_match_constraint(expr, parent.test, invert=True)) + + if constraints: + constraints_mapping[parent] = constraints + current_node = parent + + return constraints_mapping + + +ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,)) +"""All supported constraint types.""" + + +def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool: + """Returns True if the two nodes match.""" + if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): + return node1.name == node2.name + if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute): + return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr) + if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const): + return node1.value == node2.value + + return False + + +def _match_constraint( + node: _NameNodes, expr: nodes.NodeNG, invert: bool = False +) -> Iterator[Constraint]: + """Yields all constraint patterns for node that match.""" + for constraint_cls in ALL_CONSTRAINT_CLASSES: + constraint = constraint_cls.match(node, expr, invert) + if constraint: + yield constraint |