diff options
Diffstat (limited to 'astroid')
-rw-r--r-- | astroid/inference.py | 1 | ||||
-rw-r--r-- | astroid/inference_tip.py | 51 | ||||
-rw-r--r-- | astroid/nodes/node_ng.py | 21 | ||||
-rw-r--r-- | astroid/transforms.py | 11 | ||||
-rw-r--r-- | astroid/typing.py | 42 |
5 files changed, 84 insertions, 42 deletions
diff --git a/astroid/inference.py b/astroid/inference.py index 4dadc116..1729d81d 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -254,7 +254,6 @@ def infer_name( return bases._infer_stmts(stmts, context, frame) -# pylint: disable=no-value-for-parameter # The order of the decorators here is important # See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5 nodes.Name._infer = decorators.raise_if_nothing_inferred( diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 92cb6b4f..44a7fcf1 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -6,26 +6,25 @@ from __future__ import annotations -import sys -from collections.abc import Callable, Iterator +from collections.abc import Generator +from typing import Any, TypeVar from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG -from astroid.typing import InferenceResult, InferFn - -if sys.version_info >= (3, 11): - from typing import ParamSpec -else: - from typing_extensions import ParamSpec - -_P = ParamSpec("_P") +from astroid.typing import ( + InferenceResult, + InferFn, + TransformFn, +) _cache: dict[ - tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] + tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult] ] = {} -_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set() +_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set() + +_NodesT = TypeVar("_NodesT", bound=NodeNG) def clear_inference_tip_cache() -> None: @@ -33,21 +32,22 @@ def clear_inference_tip_cache() -> None: _cache.clear() -def _inference_tip_cached( - func: Callable[_P, Iterator[InferenceResult]], -) -> Callable[_P, Iterator[InferenceResult]]: +def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]: """Cache decorator used for inference tips.""" - def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]: - node = args[0] - context = args[1] + def inner( + node: _NodesT, + context: InferenceContext | None = None, + **kwargs: Any, + ) -> Generator[InferenceResult, None, None]: partial_cache_key = (func, node) if partial_cache_key in _CURRENTLY_INFERRING: # If through recursion we end up trying to infer the same # func + node we raise here. raise UseInferenceDefault try: - return _cache[func, node, context] + yield from _cache[func, node, context] + return except KeyError: # Recursion guard with a partial cache key. # Using the full key causes a recursion error on PyPy. @@ -55,16 +55,18 @@ def _inference_tip_cached( # with slightly different contexts while still passing the simple # test cases included with this commit. _CURRENTLY_INFERRING.add(partial_cache_key) - result = _cache[func, node, context] = list(func(*args, **kwargs)) + result = _cache[func, node, context] = list(func(node, context, **kwargs)) # Remove recursion guard. _CURRENTLY_INFERRING.remove(partial_cache_key) - return iter(result) + yield from result return inner -def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn: +def inference_tip( + infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False +) -> TransformFn[_NodesT]: """Given an instance specific inference function, return a function to be given to AstroidManager().register_transform to set this inference function. @@ -86,7 +88,9 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> excess overwrites. """ - def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG: + def transform( + node: _NodesT, infer_function: InferFn[_NodesT] = infer_function + ) -> _NodesT: if ( raise_on_overwrite and node._explicit_inference is not None @@ -100,7 +104,6 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> node=node, ) ) - # pylint: disable=no-value-for-parameter node._explicit_inference = _inference_tip_cached(infer_function) return node diff --git a/astroid/nodes/node_ng.py b/astroid/nodes/node_ng.py index de5dec77..31c842ee 100644 --- a/astroid/nodes/node_ng.py +++ b/astroid/nodes/node_ng.py @@ -5,6 +5,7 @@ from __future__ import annotations import pprint +import sys import warnings from collections.abc import Generator, Iterator from functools import cached_property @@ -37,6 +38,12 @@ from astroid.nodes.const import OP_PRECEDENCE from astroid.nodes.utils import Position from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + if TYPE_CHECKING: from astroid import nodes @@ -80,7 +87,7 @@ class NodeNG: _other_other_fields: ClassVar[tuple[str, ...]] = () """Attributes that contain AST-dependent fields.""" # instance specific inference function infer(node, context) - _explicit_inference: InferFn | None = None + _explicit_inference: InferFn[Self] | None = None def __init__( self, @@ -137,9 +144,17 @@ class NodeNG: # explicit_inference is not bound, give it self explicitly try: if context is None: - yield from self._explicit_inference(self, context, **kwargs) + yield from self._explicit_inference( + self, # type: ignore[arg-type] + context, + **kwargs, + ) return - for result in self._explicit_inference(self, context, **kwargs): + for result in self._explicit_inference( + self, # type: ignore[arg-type] + context, + **kwargs, + ): context.nodes_inferred += 1 yield result return diff --git a/astroid/transforms.py b/astroid/transforms.py index f6c72794..29332223 100644 --- a/astroid/transforms.py +++ b/astroid/transforms.py @@ -9,7 +9,7 @@ from collections.abc import Callable from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload from astroid.context import _invalidate_cache -from astroid.typing import SuccessfulInferenceResult +from astroid.typing import SuccessfulInferenceResult, TransformFn if TYPE_CHECKING: from astroid import nodes @@ -17,9 +17,6 @@ if TYPE_CHECKING: _SuccessfulInferenceResultT = TypeVar( "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult ) - _Transform = Callable[ - [_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult] - ] _Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]] _Vistables = Union[ @@ -52,7 +49,7 @@ class TransformVisitor: type[SuccessfulInferenceResult], list[ tuple[ - _Transform[SuccessfulInferenceResult], + TransformFn[SuccessfulInferenceResult], _Predicate[SuccessfulInferenceResult], ] ], @@ -123,7 +120,7 @@ class TransformVisitor: def register_transform( self, node_class: type[_SuccessfulInferenceResultT], - transform: _Transform[_SuccessfulInferenceResultT], + transform: TransformFn[_SuccessfulInferenceResultT], predicate: _Predicate[_SuccessfulInferenceResultT] | None = None, ) -> None: """Register `transform(node)` function to be applied on the given node. @@ -139,7 +136,7 @@ class TransformVisitor: def unregister_transform( self, node_class: type[_SuccessfulInferenceResultT], - transform: _Transform[_SuccessfulInferenceResultT], + transform: TransformFn[_SuccessfulInferenceResultT], predicate: _Predicate[_SuccessfulInferenceResultT] | None = None, ) -> None: """Unregister the given transform.""" diff --git a/astroid/typing.py b/astroid/typing.py index f42832e4..0ae30fcc 100644 --- a/astroid/typing.py +++ b/astroid/typing.py @@ -4,7 +4,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Generic, + Protocol, + TypedDict, + TypeVar, + Union, +) if TYPE_CHECKING: from astroid import bases, exceptions, nodes, transforms, util @@ -12,9 +22,6 @@ if TYPE_CHECKING: from astroid.interpreter._import import spec -_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG") - - class InferenceErrorInfo(TypedDict): """Store additional Inference error information raised with StopIteration exception. @@ -24,9 +31,6 @@ class InferenceErrorInfo(TypedDict): context: InferenceContext | None -InferFn = Callable[..., Any] - - class AstroidManagerBrain(TypedDict): """Dictionary to store relevant information for a AstroidManager class.""" @@ -46,6 +50,11 @@ SuccessfulInferenceResult = Union["nodes.NodeNG", "bases.Proxy"] _SuccessfulInferenceResultT = TypeVar( "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult ) +_SuccessfulInferenceResultT_contra = TypeVar( + "_SuccessfulInferenceResultT_contra", + bound=SuccessfulInferenceResult, + contravariant=True, +) ConstFactoryResult = Union[ "nodes.List", @@ -67,3 +76,22 @@ InferBinaryOp = Callable[ ], Generator[InferenceResult, None, None], ] + + +class InferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]): + def __call__( + self, + node: _SuccessfulInferenceResultT_contra, + context: InferenceContext | None = None, + **kwargs: Any, + ) -> Generator[InferenceResult, None, None]: + ... # pragma: no cover + + +class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]): + def __call__( + self, + node: _SuccessfulInferenceResultT, + infer_function: InferFn[_SuccessfulInferenceResultT] = ..., + ) -> _SuccessfulInferenceResultT | None: + ... # pragma: no cover |