diff options
author | David Liu <david@cs.toronto.edu> | 2023-01-06 15:18:30 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-06 21:18:30 +0100 |
commit | 21880ddb44225c16c064ad0b2412e64f2f399709 (patch) | |
tree | 3bb25ca12360aae45d06df3eb72998745663a033 /astroid/bases.py | |
parent | f476ebca6cfbfb30d1c723ca4144e23645d5164d (diff) | |
download | astroid-git-21880ddb44225c16c064ad0b2412e64f2f399709.tar.gz |
Support "is None" constraints from if statements during inference (#1189)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
Co-authored-by: Pierre Sassoulas <pierre.sassoulas@gmail.com>
Co-authored-by: Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
Diffstat (limited to 'astroid/bases.py')
-rw-r--r-- | astroid/bases.py | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/astroid/bases.py b/astroid/bases.py index 2c047855..bf99ddce 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -11,7 +11,7 @@ import collections import collections.abc import sys from collections.abc import Sequence -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from astroid import decorators, nodes from astroid.const import PY310_PLUS @@ -35,6 +35,9 @@ if sys.version_info >= (3, 8): else: from typing_extensions import Literal +if TYPE_CHECKING: + from astroid.constraint import Constraint + objectmodel = lazy_import("interpreter.objectmodel") helpers = lazy_import("helpers") manager = lazy_import("manager") @@ -146,11 +149,14 @@ def _infer_stmts( ) -> collections.abc.Generator[InferenceResult, None, None]: """Return an iterator on statements inferred by each statement in *stmts*.""" inferred = False + constraint_failed = False if context is not None: name = context.lookupname context = context.clone() + constraints = context.constraints.get(name, {}) else: name = None + constraints = {} context = InferenceContext() for stmt in stmts: @@ -161,16 +167,26 @@ def _infer_stmts( # 'context' is always InferenceContext and Instances get '_infer_name' from ClassDef context.lookupname = stmt._infer_name(frame, name) # type: ignore[union-attr] try: + stmt_constraints: set[Constraint] = set() + for constraint_stmt, potential_constraints in constraints.items(): + if not constraint_stmt.parent_of(stmt): + stmt_constraints.update(potential_constraints) # Mypy doesn't recognize that 'stmt' can't be Uninferable for inf in stmt.infer(context=context): # type: ignore[union-attr] - yield inf - inferred = True + if all(constraint.satisfied_by(inf) for constraint in stmt_constraints): + yield inf + inferred = True + else: + constraint_failed = True except NameInferenceError: continue except InferenceError: yield Uninferable inferred = True - if not inferred: + + if not inferred and constraint_failed: + yield Uninferable + elif not inferred: raise InferenceError( "Inference failed for all members of {stmts!r}.", stmts=stmts, |