summaryrefslogtreecommitdiff
path: root/astroid/constraint.py
diff options
context:
space:
mode:
authorDavid Liu <david@cs.toronto.edu>2023-01-06 15:18:30 -0500
committerGitHub <noreply@github.com>2023-01-06 21:18:30 +0100
commit21880ddb44225c16c064ad0b2412e64f2f399709 (patch)
tree3bb25ca12360aae45d06df3eb72998745663a033 /astroid/constraint.py
parentf476ebca6cfbfb30d1c723ca4144e23645d5164d (diff)
downloadastroid-git-21880ddb44225c16c064ad0b2412e64f2f399709.tar.gz
Support "is None" constraints from if statements during inference (#1189)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com> Co-authored-by: Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
Diffstat (limited to 'astroid/constraint.py')
-rw-r--r--astroid/constraint.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/astroid/constraint.py b/astroid/constraint.py
new file mode 100644
index 00000000..deed9ac5
--- /dev/null
+++ b/astroid/constraint.py
@@ -0,0 +1,137 @@
+# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
+# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
+# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
+
+"""Classes representing different types of constraints on inference values."""
+from __future__ import annotations
+
+import sys
+from abc import ABC, abstractmethod
+from collections.abc import Iterator
+from typing import Union
+
+from astroid import bases, nodes, util
+from astroid.typing import InferenceResult
+
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
+
+_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]
+
+
+class Constraint(ABC):
+ """Represents a single constraint on a variable."""
+
+ def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
+ self.node = node
+ """The node that this constraint applies to."""
+ self.negate = negate
+ """True if this constraint is negated. E.g., "is not" instead of "is"."""
+
+ @classmethod
+ @abstractmethod
+ def match(
+ cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
+ ) -> Self | None:
+ """Return a new constraint for node matched from expr, if expr matches
+ the constraint pattern.
+
+ If negate is True, negate the constraint.
+ """
+
+ @abstractmethod
+ def satisfied_by(self, inferred: InferenceResult) -> bool:
+ """Return True if this constraint is satisfied by the given inferred value."""
+
+
+class NoneConstraint(Constraint):
+ """Represents an "is None" or "is not None" constraint."""
+
+ CONST_NONE: nodes.Const = nodes.Const(None)
+
+ @classmethod
+ def match(
+ cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
+ ) -> Self | None:
+ """Return a new constraint for node matched from expr, if expr matches
+ the constraint pattern.
+
+ Negate the constraint based on the value of negate.
+ """
+ if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
+ left = expr.left
+ op, right = expr.ops[0]
+ if op in {"is", "is not"} and (
+ _matches(left, node) and _matches(right, cls.CONST_NONE)
+ ):
+ negate = (op == "is" and negate) or (op == "is not" and not negate)
+ return cls(node=node, negate=negate)
+
+ return None
+
+ def satisfied_by(self, inferred: InferenceResult) -> bool:
+ """Return True if this constraint is satisfied by the given inferred value."""
+ # Assume true if uninferable
+ if inferred is util.Uninferable:
+ return True
+
+ # Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
+ return self.negate ^ _matches(inferred, self.CONST_NONE)
+
+
+def get_constraints(
+ expr: _NameNodes, frame: nodes.LocalsDictNodeNG
+) -> dict[nodes.If, set[Constraint]]:
+ """Returns the constraints for the given expression.
+
+ The returned dictionary maps the node where the constraint was generated to the
+ corresponding constraint(s).
+
+ Constraints are computed statically by analysing the code surrounding expr.
+ Currently this only supports constraints generated from if conditions.
+ """
+ current_node: nodes.NodeNG | None = expr
+ constraints_mapping: dict[nodes.If, set[Constraint]] = {}
+ while current_node is not None and current_node is not frame:
+ parent = current_node.parent
+ if isinstance(parent, nodes.If):
+ branch, _ = parent.locate_child(current_node)
+ constraints: set[Constraint] | None = None
+ if branch == "body":
+ constraints = set(_match_constraint(expr, parent.test))
+ elif branch == "orelse":
+ constraints = set(_match_constraint(expr, parent.test, invert=True))
+
+ if constraints:
+ constraints_mapping[parent] = constraints
+ current_node = parent
+
+ return constraints_mapping
+
+
+ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
+"""All supported constraint types."""
+
+
+def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
+ """Returns True if the two nodes match."""
+ if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
+ return node1.name == node2.name
+ if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
+ return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
+ if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
+ return node1.value == node2.value
+
+ return False
+
+
+def _match_constraint(
+ node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
+) -> Iterator[Constraint]:
+ """Yields all constraint patterns for node that match."""
+ for constraint_cls in ALL_CONSTRAINT_CLASSES:
+ constraint = constraint_cls.match(node, expr, invert)
+ if constraint:
+ yield constraint