diff options
author | Daniƫl van Noord <13665637+DanielNoord@users.noreply.github.com> | 2023-03-22 13:47:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-22 13:47:04 +0100 |
commit | 598e4c3fc51173562fcbdda9c8413dd4e5f92b06 (patch) | |
tree | c9ac8f02c460c518a5af5e8aaf3ae8792b11e0cb /astroid/transforms.py | |
parent | 7ed0804279c4334093d410fc25831dfa86ab5e8c (diff) | |
download | astroid-git-598e4c3fc51173562fcbdda9c8413dd4e5f92b06.tar.gz |
Add typing to ``TransformVisitor`` (#2062)
Diffstat (limited to 'astroid/transforms.py')
-rw-r--r-- | astroid/transforms.py | 115 |
1 files changed, 90 insertions, 25 deletions
diff --git a/astroid/transforms.py b/astroid/transforms.py index be37879f..3751ffb7 100644 --- a/astroid/transforms.py +++ b/astroid/transforms.py @@ -4,13 +4,34 @@ from __future__ import annotations -import collections -from typing import TYPE_CHECKING +from collections import defaultdict +from collections.abc import Callable +from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload from astroid.context import _invalidate_cache +from astroid.typing import SuccessfulInferenceResult if TYPE_CHECKING: - from astroid import NodeNG + from astroid import nodes + + _SuccessfulInferenceResultT = TypeVar( + "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult + ) + _Transform = Callable[ + [_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult] + ] + _Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]] + +_Vistables = Union[ + "nodes.NodeNG", List["nodes.NodeNG"], Tuple["nodes.NodeNG", ...], str, None +] +_VisitReturns = Union[ + SuccessfulInferenceResult, + List[SuccessfulInferenceResult], + Tuple[SuccessfulInferenceResult, ...], + str, + None, +] class TransformVisitor: @@ -24,17 +45,26 @@ class TransformVisitor: Based on its usage in AstroidManager.brain, it should not be reinstantiated. """ - def __init__(self): - self.transforms = collections.defaultdict(list) - - def _transform(self, node: NodeNG) -> NodeNG: + def __init__(self) -> None: + # The typing here is incorrect, but it's the best we can do + # Refer to register_transform and unregister_transform for the correct types + self.transforms: defaultdict[ + type[SuccessfulInferenceResult], + list[ + tuple[ + _Transform[SuccessfulInferenceResult], + _Predicate[SuccessfulInferenceResult], + ] + ], + ] = defaultdict(list) + + def _transform(self, node: SuccessfulInferenceResult) -> SuccessfulInferenceResult: """Call matching transforms for the given node if any and return the transformed node. """ cls = node.__class__ - transforms = self.transforms[cls] - for transform_func, predicate in transforms: + for transform_func, predicate in self.transforms[cls]: if predicate is None or predicate(node): ret = transform_func(node) # if the transformation function returns something, it's @@ -47,16 +77,40 @@ class TransformVisitor: break return node - def _visit(self, node): - if hasattr(node, "_astroid_fields"): - for name in node._astroid_fields: - value = getattr(node, name) - visited = self._visit_generic(value) - if visited != value: - setattr(node, name, visited) + def _visit(self, node: nodes.NodeNG) -> SuccessfulInferenceResult: + for name in node._astroid_fields: + value = getattr(node, name) + value = cast(_Vistables, value) + visited = self._visit_generic(value) + if visited != value: + setattr(node, name, visited) return self._transform(node) - def _visit_generic(self, node): + @overload + def _visit_generic(self, node: None) -> None: + ... + + @overload + def _visit_generic(self, node: str) -> str: + ... + + @overload + def _visit_generic( + self, node: list[nodes.NodeNG] + ) -> list[SuccessfulInferenceResult]: + ... + + @overload + def _visit_generic( + self, node: tuple[nodes.NodeNG, ...] + ) -> tuple[SuccessfulInferenceResult, ...]: + ... + + @overload + def _visit_generic(self, node: nodes.NodeNG) -> SuccessfulInferenceResult: + ... + + def _visit_generic(self, node: _Vistables) -> _VisitReturns: if isinstance(node, list): return [self._visit_generic(child) for child in node] if isinstance(node, tuple): @@ -66,21 +120,32 @@ class TransformVisitor: return self._visit(node) - def register_transform(self, node_class, transform, predicate=None) -> None: - """Register `transform(node)` function to be applied on the given - astroid's `node_class` if `predicate` is None or returns true + def register_transform( + self, + node_class: type[_SuccessfulInferenceResultT], + transform: _Transform[_SuccessfulInferenceResultT], + predicate: _Predicate[_SuccessfulInferenceResultT] | None = None, + ) -> None: + """Register `transform(node)` function to be applied on the given node. + + The transform will only be applied if `predicate` is None or returns true when called with the node as argument. The transform function may return a value which is then used to substitute the original node in the tree. """ - self.transforms[node_class].append((transform, predicate)) - - def unregister_transform(self, node_class, transform, predicate=None) -> None: + self.transforms[node_class].append((transform, predicate)) # type: ignore[index, arg-type] + + def unregister_transform( + self, + node_class: type[_SuccessfulInferenceResultT], + transform: _Transform[_SuccessfulInferenceResultT], + predicate: _Predicate[_SuccessfulInferenceResultT] | None = None, + ) -> None: """Unregister the given transform.""" - self.transforms[node_class].remove((transform, predicate)) + self.transforms[node_class].remove((transform, predicate)) # type: ignore[index, arg-type] - def visit(self, module): + def visit(self, module: nodes.Module) -> SuccessfulInferenceResult: """Walk the given astroid *tree* and transform each encountered node. Only the nodes which have transforms registered will actually |