summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Walls <jacobtylerwalls@gmail.com>2023-05-07 20:54:18 -0400
committerGitHub <noreply@github.com>2023-05-07 20:54:18 -0400
commit900c5467b80d2b5d531990d3da1d1666e9edb0f0 (patch)
treebe98c1a80bd8d7102ff7ddc60a94d10c2e7e6c52
parent0740a0dd5e9cb48bb1a400aded498e4db1fcfca9 (diff)
downloadastroid-git-900c5467b80d2b5d531990d3da1d1666e9edb0f0.tar.gz
Improve typing of inference functions (#2166)
Co-authored-by: Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>
-rw-r--r--astroid/inference.py1
-rw-r--r--astroid/inference_tip.py51
-rw-r--r--astroid/nodes/node_ng.py21
-rw-r--r--astroid/transforms.py11
-rw-r--r--astroid/typing.py42
5 files changed, 84 insertions, 42 deletions
diff --git a/astroid/inference.py b/astroid/inference.py
index 4dadc116..1729d81d 100644
--- a/astroid/inference.py
+++ b/astroid/inference.py
@@ -254,7 +254,6 @@ def infer_name(
return bases._infer_stmts(stmts, context, frame)
-# pylint: disable=no-value-for-parameter
# The order of the decorators here is important
# See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5
nodes.Name._infer = decorators.raise_if_nothing_inferred(
diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py
index 92cb6b4f..44a7fcf1 100644
--- a/astroid/inference_tip.py
+++ b/astroid/inference_tip.py
@@ -6,26 +6,25 @@
from __future__ import annotations
-import sys
-from collections.abc import Callable, Iterator
+from collections.abc import Generator
+from typing import Any, TypeVar
from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
-from astroid.typing import InferenceResult, InferFn
-
-if sys.version_info >= (3, 11):
- from typing import ParamSpec
-else:
- from typing_extensions import ParamSpec
-
-_P = ParamSpec("_P")
+from astroid.typing import (
+ InferenceResult,
+ InferFn,
+ TransformFn,
+)
_cache: dict[
- tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
+ tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult]
] = {}
-_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set()
+_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set()
+
+_NodesT = TypeVar("_NodesT", bound=NodeNG)
def clear_inference_tip_cache() -> None:
@@ -33,21 +32,22 @@ def clear_inference_tip_cache() -> None:
_cache.clear()
-def _inference_tip_cached(
- func: Callable[_P, Iterator[InferenceResult]],
-) -> Callable[_P, Iterator[InferenceResult]]:
+def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]:
"""Cache decorator used for inference tips."""
- def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
- node = args[0]
- context = args[1]
+ def inner(
+ node: _NodesT,
+ context: InferenceContext | None = None,
+ **kwargs: Any,
+ ) -> Generator[InferenceResult, None, None]:
partial_cache_key = (func, node)
if partial_cache_key in _CURRENTLY_INFERRING:
# If through recursion we end up trying to infer the same
# func + node we raise here.
raise UseInferenceDefault
try:
- return _cache[func, node, context]
+ yield from _cache[func, node, context]
+ return
except KeyError:
# Recursion guard with a partial cache key.
# Using the full key causes a recursion error on PyPy.
@@ -55,16 +55,18 @@ def _inference_tip_cached(
# with slightly different contexts while still passing the simple
# test cases included with this commit.
_CURRENTLY_INFERRING.add(partial_cache_key)
- result = _cache[func, node, context] = list(func(*args, **kwargs))
+ result = _cache[func, node, context] = list(func(node, context, **kwargs))
# Remove recursion guard.
_CURRENTLY_INFERRING.remove(partial_cache_key)
- return iter(result)
+ yield from result
return inner
-def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
+def inference_tip(
+ infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False
+) -> TransformFn[_NodesT]:
"""Given an instance specific inference function, return a function to be
given to AstroidManager().register_transform to set this inference function.
@@ -86,7 +88,9 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) ->
excess overwrites.
"""
- def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
+ def transform(
+ node: _NodesT, infer_function: InferFn[_NodesT] = infer_function
+ ) -> _NodesT:
if (
raise_on_overwrite
and node._explicit_inference is not None
@@ -100,7 +104,6 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) ->
node=node,
)
)
- # pylint: disable=no-value-for-parameter
node._explicit_inference = _inference_tip_cached(infer_function)
return node
diff --git a/astroid/nodes/node_ng.py b/astroid/nodes/node_ng.py
index de5dec77..31c842ee 100644
--- a/astroid/nodes/node_ng.py
+++ b/astroid/nodes/node_ng.py
@@ -5,6 +5,7 @@
from __future__ import annotations
import pprint
+import sys
import warnings
from collections.abc import Generator, Iterator
from functools import cached_property
@@ -37,6 +38,12 @@ from astroid.nodes.const import OP_PRECEDENCE
from astroid.nodes.utils import Position
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
+
+
if TYPE_CHECKING:
from astroid import nodes
@@ -80,7 +87,7 @@ class NodeNG:
_other_other_fields: ClassVar[tuple[str, ...]] = ()
"""Attributes that contain AST-dependent fields."""
# instance specific inference function infer(node, context)
- _explicit_inference: InferFn | None = None
+ _explicit_inference: InferFn[Self] | None = None
def __init__(
self,
@@ -137,9 +144,17 @@ class NodeNG:
# explicit_inference is not bound, give it self explicitly
try:
if context is None:
- yield from self._explicit_inference(self, context, **kwargs)
+ yield from self._explicit_inference(
+ self, # type: ignore[arg-type]
+ context,
+ **kwargs,
+ )
return
- for result in self._explicit_inference(self, context, **kwargs):
+ for result in self._explicit_inference(
+ self, # type: ignore[arg-type]
+ context,
+ **kwargs,
+ ):
context.nodes_inferred += 1
yield result
return
diff --git a/astroid/transforms.py b/astroid/transforms.py
index f6c72794..29332223 100644
--- a/astroid/transforms.py
+++ b/astroid/transforms.py
@@ -9,7 +9,7 @@ from collections.abc import Callable
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload
from astroid.context import _invalidate_cache
-from astroid.typing import SuccessfulInferenceResult
+from astroid.typing import SuccessfulInferenceResult, TransformFn
if TYPE_CHECKING:
from astroid import nodes
@@ -17,9 +17,6 @@ if TYPE_CHECKING:
_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
- _Transform = Callable[
- [_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult]
- ]
_Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]]
_Vistables = Union[
@@ -52,7 +49,7 @@ class TransformVisitor:
type[SuccessfulInferenceResult],
list[
tuple[
- _Transform[SuccessfulInferenceResult],
+ TransformFn[SuccessfulInferenceResult],
_Predicate[SuccessfulInferenceResult],
]
],
@@ -123,7 +120,7 @@ class TransformVisitor:
def register_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
- transform: _Transform[_SuccessfulInferenceResultT],
+ transform: TransformFn[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Register `transform(node)` function to be applied on the given node.
@@ -139,7 +136,7 @@ class TransformVisitor:
def unregister_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
- transform: _Transform[_SuccessfulInferenceResultT],
+ transform: TransformFn[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Unregister the given transform."""
diff --git a/astroid/typing.py b/astroid/typing.py
index f42832e4..0ae30fcc 100644
--- a/astroid/typing.py
+++ b/astroid/typing.py
@@ -4,7 +4,17 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generator,
+ Generic,
+ Protocol,
+ TypedDict,
+ TypeVar,
+ Union,
+)
if TYPE_CHECKING:
from astroid import bases, exceptions, nodes, transforms, util
@@ -12,9 +22,6 @@ if TYPE_CHECKING:
from astroid.interpreter._import import spec
-_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG")
-
-
class InferenceErrorInfo(TypedDict):
"""Store additional Inference error information
raised with StopIteration exception.
@@ -24,9 +31,6 @@ class InferenceErrorInfo(TypedDict):
context: InferenceContext | None
-InferFn = Callable[..., Any]
-
-
class AstroidManagerBrain(TypedDict):
"""Dictionary to store relevant information for a AstroidManager class."""
@@ -46,6 +50,11 @@ SuccessfulInferenceResult = Union["nodes.NodeNG", "bases.Proxy"]
_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
+_SuccessfulInferenceResultT_contra = TypeVar(
+ "_SuccessfulInferenceResultT_contra",
+ bound=SuccessfulInferenceResult,
+ contravariant=True,
+)
ConstFactoryResult = Union[
"nodes.List",
@@ -67,3 +76,22 @@ InferBinaryOp = Callable[
],
Generator[InferenceResult, None, None],
]
+
+
+class InferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]):
+ def __call__(
+ self,
+ node: _SuccessfulInferenceResultT_contra,
+ context: InferenceContext | None = None,
+ **kwargs: Any,
+ ) -> Generator[InferenceResult, None, None]:
+ ... # pragma: no cover
+
+
+class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]):
+ def __call__(
+ self,
+ node: _SuccessfulInferenceResultT,
+ infer_function: InferFn[_SuccessfulInferenceResultT] = ...,
+ ) -> _SuccessfulInferenceResultT | None:
+ ... # pragma: no cover