summaryrefslogtreecommitdiff
path: root/astroid/transforms.py
diff options
context:
space:
mode:
authorDaniƫl van Noord <13665637+DanielNoord@users.noreply.github.com>2023-03-22 13:47:04 +0100
committerGitHub <noreply@github.com>2023-03-22 13:47:04 +0100
commit598e4c3fc51173562fcbdda9c8413dd4e5f92b06 (patch)
treec9ac8f02c460c518a5af5e8aaf3ae8792b11e0cb /astroid/transforms.py
parent7ed0804279c4334093d410fc25831dfa86ab5e8c (diff)
downloadastroid-git-598e4c3fc51173562fcbdda9c8413dd4e5f92b06.tar.gz
Add typing to ``TransformVisitor`` (#2062)
Diffstat (limited to 'astroid/transforms.py')
-rw-r--r--astroid/transforms.py115
1 files changed, 90 insertions, 25 deletions
diff --git a/astroid/transforms.py b/astroid/transforms.py
index be37879f..3751ffb7 100644
--- a/astroid/transforms.py
+++ b/astroid/transforms.py
@@ -4,13 +4,34 @@
from __future__ import annotations
-import collections
-from typing import TYPE_CHECKING
+from collections import defaultdict
+from collections.abc import Callable
+from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload
from astroid.context import _invalidate_cache
+from astroid.typing import SuccessfulInferenceResult
if TYPE_CHECKING:
- from astroid import NodeNG
+ from astroid import nodes
+
+ _SuccessfulInferenceResultT = TypeVar(
+ "_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
+ )
+ _Transform = Callable[
+ [_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult]
+ ]
+ _Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]]
+
+_Vistables = Union[
+ "nodes.NodeNG", List["nodes.NodeNG"], Tuple["nodes.NodeNG", ...], str, None
+]
+_VisitReturns = Union[
+ SuccessfulInferenceResult,
+ List[SuccessfulInferenceResult],
+ Tuple[SuccessfulInferenceResult, ...],
+ str,
+ None,
+]
class TransformVisitor:
@@ -24,17 +45,26 @@ class TransformVisitor:
Based on its usage in AstroidManager.brain, it should not be reinstantiated.
"""
- def __init__(self):
- self.transforms = collections.defaultdict(list)
-
- def _transform(self, node: NodeNG) -> NodeNG:
+ def __init__(self) -> None:
+ # The typing here is incorrect, but it's the best we can do
+ # Refer to register_transform and unregister_transform for the correct types
+ self.transforms: defaultdict[
+ type[SuccessfulInferenceResult],
+ list[
+ tuple[
+ _Transform[SuccessfulInferenceResult],
+ _Predicate[SuccessfulInferenceResult],
+ ]
+ ],
+ ] = defaultdict(list)
+
+ def _transform(self, node: SuccessfulInferenceResult) -> SuccessfulInferenceResult:
"""Call matching transforms for the given node if any and return the
transformed node.
"""
cls = node.__class__
- transforms = self.transforms[cls]
- for transform_func, predicate in transforms:
+ for transform_func, predicate in self.transforms[cls]:
if predicate is None or predicate(node):
ret = transform_func(node)
# if the transformation function returns something, it's
@@ -47,16 +77,40 @@ class TransformVisitor:
break
return node
- def _visit(self, node):
- if hasattr(node, "_astroid_fields"):
- for name in node._astroid_fields:
- value = getattr(node, name)
- visited = self._visit_generic(value)
- if visited != value:
- setattr(node, name, visited)
+ def _visit(self, node: nodes.NodeNG) -> SuccessfulInferenceResult:
+ for name in node._astroid_fields:
+ value = getattr(node, name)
+ value = cast(_Vistables, value)
+ visited = self._visit_generic(value)
+ if visited != value:
+ setattr(node, name, visited)
return self._transform(node)
- def _visit_generic(self, node):
+ @overload
+ def _visit_generic(self, node: None) -> None:
+ ...
+
+ @overload
+ def _visit_generic(self, node: str) -> str:
+ ...
+
+ @overload
+ def _visit_generic(
+ self, node: list[nodes.NodeNG]
+ ) -> list[SuccessfulInferenceResult]:
+ ...
+
+ @overload
+ def _visit_generic(
+ self, node: tuple[nodes.NodeNG, ...]
+ ) -> tuple[SuccessfulInferenceResult, ...]:
+ ...
+
+ @overload
+ def _visit_generic(self, node: nodes.NodeNG) -> SuccessfulInferenceResult:
+ ...
+
+ def _visit_generic(self, node: _Vistables) -> _VisitReturns:
if isinstance(node, list):
return [self._visit_generic(child) for child in node]
if isinstance(node, tuple):
@@ -66,21 +120,32 @@ class TransformVisitor:
return self._visit(node)
- def register_transform(self, node_class, transform, predicate=None) -> None:
- """Register `transform(node)` function to be applied on the given
- astroid's `node_class` if `predicate` is None or returns true
+ def register_transform(
+ self,
+ node_class: type[_SuccessfulInferenceResultT],
+ transform: _Transform[_SuccessfulInferenceResultT],
+ predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
+ ) -> None:
+ """Register `transform(node)` function to be applied on the given node.
+
+ The transform will only be applied if `predicate` is None or returns true
when called with the node as argument.
The transform function may return a value which is then used to
substitute the original node in the tree.
"""
- self.transforms[node_class].append((transform, predicate))
-
- def unregister_transform(self, node_class, transform, predicate=None) -> None:
+ self.transforms[node_class].append((transform, predicate)) # type: ignore[index, arg-type]
+
+ def unregister_transform(
+ self,
+ node_class: type[_SuccessfulInferenceResultT],
+ transform: _Transform[_SuccessfulInferenceResultT],
+ predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
+ ) -> None:
"""Unregister the given transform."""
- self.transforms[node_class].remove((transform, predicate))
+ self.transforms[node_class].remove((transform, predicate)) # type: ignore[index, arg-type]
- def visit(self, module):
+ def visit(self, module: nodes.Module) -> SuccessfulInferenceResult:
"""Walk the given astroid *tree* and transform each encountered node.
Only the nodes which have transforms registered will actually