summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/_typing.py10
-rw-r--r--lib/sqlalchemy/sql/annotation.py18
-rw-r--r--lib/sqlalchemy/sql/base.py22
-rw-r--r--lib/sqlalchemy/sql/coercions.py11
-rw-r--r--lib/sqlalchemy/sql/elements.py18
-rw-r--r--lib/sqlalchemy/sql/selectable.py66
-rw-r--r--lib/sqlalchemy/sql/traversals.py37
-rw-r--r--lib/sqlalchemy/sql/util.py44
-rw-r--r--lib/sqlalchemy/sql/visitors.py46
9 files changed, 210 insertions, 62 deletions
diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py
index f49a6d3ec..ed1bd2832 100644
--- a/lib/sqlalchemy/sql/_typing.py
+++ b/lib/sqlalchemy/sql/_typing.py
@@ -61,6 +61,9 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
+
+
class _HasClauseElement(Protocol):
"""indicates a class that has a __clause_element__() method"""
@@ -68,6 +71,13 @@ class _HasClauseElement(Protocol):
...
+class _CoreAdapterProto(Protocol):
+ """protocol for the ClauseAdapter/ColumnAdapter.traverse() method."""
+
+ def __call__(self, obj: _CE) -> _CE:
+ ...
+
+
# match column types that are not ORM entities
_NOT_ENTITY = TypeVar(
"_NOT_ENTITY",
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index fa36c09fc..56d88bc2f 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -454,9 +454,23 @@ def _deep_annotate(
return element
+@overload
+def _deep_deannotate(
+ element: Literal[None], values: Optional[Sequence[str]] = None
+) -> Literal[None]:
+ ...
+
+
+@overload
def _deep_deannotate(
element: _SA, values: Optional[Sequence[str]] = None
) -> _SA:
+ ...
+
+
+def _deep_deannotate(
+ element: Optional[_SA], values: Optional[Sequence[str]] = None
+) -> Optional[_SA]:
"""Deep copy the given element, removing annotations."""
cloned: Dict[Any, SupportsAnnotations] = {}
@@ -482,9 +496,7 @@ def _deep_deannotate(
return element
-def _shallow_annotate(
- element: SupportsAnnotations, annotations: _AnnotationDict
-) -> SupportsAnnotations:
+def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA:
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 248b48a25..f5a9c10c0 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -750,6 +750,17 @@ class _MetaOptions(type):
o1.__dict__.update(other)
return o1
+ if TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> Any:
+ ...
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ ...
+
+ def __delattr__(self, key: str) -> None:
+ ...
+
class Options(metaclass=_MetaOptions):
"""A cacheable option dictionary with defaults."""
@@ -904,6 +915,17 @@ class Options(metaclass=_MetaOptions):
else:
return existing_options, exec_options
+ if TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> Any:
+ ...
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ ...
+
+ def __delattr__(self, key: str) -> None:
+ ...
+
class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
index eef5cf211..501188b12 100644
--- a/lib/sqlalchemy/sql/coercions.py
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -56,6 +56,7 @@ if typing.TYPE_CHECKING:
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import DQLDMLClauseElement
+ from .elements import NamedColumn
from .elements import SQLCoreOperations
from .schema import Column
from .selectable import _ColumnsClauseElement
@@ -199,6 +200,15 @@ def expect(
@overload
def expect(
+ role: Type[roles.LabeledColumnExprRole[Any]],
+ element: _ColumnExpressionArgument[_T],
+ **kw: Any,
+) -> NamedColumn[_T]:
+ ...
+
+
+@overload
+def expect(
role: Union[
Type[roles.ExpressionElementRole[Any]],
Type[roles.LimitOffsetRole],
@@ -217,6 +227,7 @@ def expect(
Type[roles.LimitOffsetRole],
Type[roles.WhereHavingRole],
Type[roles.OnClauseRole],
+ Type[roles.ColumnArgumentRole],
],
element: Any,
**kw: Any,
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 41b7f6392..61c5379d8 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -503,7 +503,7 @@ class ClauseElement(
def params(
self: SelfClauseElement,
- __optionaldict: Optional[Dict[str, Any]] = None,
+ __optionaldict: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> SelfClauseElement:
"""Return a copy with :func:`_expression.bindparam` elements
@@ -525,7 +525,7 @@ class ClauseElement(
def _replace_params(
self: SelfClauseElement,
unique: bool,
- optionaldict: Optional[Dict[str, Any]],
+ optionaldict: Optional[Mapping[str, Any]],
kwargs: Dict[str, Any],
) -> SelfClauseElement:
@@ -545,7 +545,7 @@ class ClauseElement(
{"bindparam": visit_bindparam},
)
- def compare(self, other, **kw):
+ def compare(self, other: ClauseElement, **kw: Any) -> bool:
r"""Compare this :class:`_expression.ClauseElement` to
the given :class:`_expression.ClauseElement`.
@@ -2516,7 +2516,9 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]):
return False_._singleton
@classmethod
- def _ifnone(cls, other):
+ def _ifnone(
+ cls, other: Optional[ColumnElement[Any]]
+ ) -> ColumnElement[Any]:
if other is None:
return cls._instance()
else:
@@ -4226,7 +4228,13 @@ class NamedColumn(KeyedColumnElement[_T]):
) -> Optional[str]:
return name
- def _bind_param(self, operator, obj, type_=None, expanding=False):
+ def _bind_param(
+ self,
+ operator: OperatorType,
+ obj: Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ expanding: bool = False,
+ ) -> BindParameter[_T]:
return BindParameter(
self.key,
obj,
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index d0b0f1476..fd98f17e3 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -64,6 +64,7 @@ from .base import _EntityNamespace
from .base import _expand_cloned
from .base import _from_objects
from .base import _generative
+from .base import _NoArg
from .base import _select_iterables
from .base import CacheableOptions
from .base import ColumnCollection
@@ -131,6 +132,7 @@ if TYPE_CHECKING:
from .dml import Insert
from .dml import Update
from .elements import KeyedColumnElement
+ from .elements import Label
from .elements import NamedColumn
from .elements import TextClause
from .functions import Function
@@ -212,7 +214,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
"""
raise NotImplementedError()
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`.ReturnsRows` is
'derived' from the given :class:`.FromClause`.
@@ -778,7 +780,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return TableSample._construct(self, sampling, name, seed)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
"""Return ``True`` if this :class:`_expression.FromClause` is
'derived' from the given ``FromClause``.
@@ -1128,11 +1130,14 @@ class SelectLabelStyle(Enum):
"""
+ LABEL_STYLE_LEGACY_ORM = 3
+
(
LABEL_STYLE_NONE,
LABEL_STYLE_TABLENAME_PLUS_COL,
LABEL_STYLE_DISAMBIGUATE_ONLY,
+ _,
) = list(SelectLabelStyle)
LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY
@@ -1231,7 +1236,7 @@ class Join(roles.DMLTableRole, FromClause):
id(self.right),
)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
return (
# use hash() to ensure direct comparison to annotated works
# as well
@@ -1635,7 +1640,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
"""Legacy for dialects that are referring to Alias.original."""
return self.element
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
if fromclause in self._cloned_set:
return True
return self.element.is_derived_from(fromclause)
@@ -2840,7 +2845,7 @@ class FromGrouping(GroupedElement, FromClause):
def foreign_keys(self):
return self.element.foreign_keys
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
return self.element.is_derived_from(fromclause)
def alias(
@@ -3080,11 +3085,17 @@ class ForUpdateArg(ClauseElement):
def __init__(
self,
- nowait=False,
- read=False,
- of=None,
- skip_locked=False,
- key_share=False,
+ *,
+ nowait: bool = False,
+ read: bool = False,
+ of: Optional[
+ Union[
+ _ColumnExpressionArgument[Any],
+ Sequence[_ColumnExpressionArgument[Any]],
+ ]
+ ] = None,
+ skip_locked: bool = False,
+ key_share: bool = False,
):
"""Represents arguments specified to
:meth:`_expression.Select.for_update`.
@@ -3455,7 +3466,7 @@ class SelectBase(
return ScalarSelect(self)
- def label(self, name):
+ def label(self, name: Optional[str]) -> Label[Any]:
"""Return a 'scalar' representation of this selectable, embedded as a
subquery with a label.
@@ -3667,6 +3678,7 @@ class GenerativeSelect(SelectBase, Generative):
@_generative
def with_for_update(
self: SelfGenerativeSelect,
+ *,
nowait: bool = False,
read: bool = False,
of: Optional[
@@ -4064,7 +4076,11 @@ class GenerativeSelect(SelectBase, Generative):
@_generative
def order_by(
- self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfGenerativeSelect,
+ __first: Union[
+ Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of ORDER BY
criteria applied.
@@ -4092,18 +4108,22 @@ class GenerativeSelect(SelectBase, Generative):
"""
- if len(clauses) == 1 and clauses[0] is None:
+ if not clauses and __first is None:
self._order_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
self._order_by_clauses += tuple(
coercions.expect(roles.OrderByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
return self
@_generative
def group_by(
- self: SelfGenerativeSelect, *clauses: _ColumnExpressionArgument[Any]
+ self: SelfGenerativeSelect,
+ __first: Union[
+ Literal[None, _NoArg.NO_ARG], _ColumnExpressionArgument[Any]
+ ] = _NoArg.NO_ARG,
+ *clauses: _ColumnExpressionArgument[Any],
) -> SelfGenerativeSelect:
r"""Return a new selectable with the given list of GROUP BY
criterion applied.
@@ -4128,12 +4148,12 @@ class GenerativeSelect(SelectBase, Generative):
"""
- if len(clauses) == 1 and clauses[0] is None:
+ if not clauses and __first is None:
self._group_by_clauses = ()
- else:
+ elif __first is not _NoArg.NO_ARG:
self._group_by_clauses += tuple(
coercions.expect(roles.GroupByRole, clause)
- for clause in clauses
+ for clause in (__first,) + clauses
)
return self
@@ -4257,7 +4277,7 @@ class CompoundSelect(HasCompileState, GenerativeSelect, ExecutableReturnsRows):
) -> GroupedElement:
return SelectStatementGrouping(self)
- def is_derived_from(self, fromclause: FromClause) -> bool:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
for s in self.selects:
if s.is_derived_from(fromclause):
return True
@@ -4959,7 +4979,7 @@ class Select(
_raw_columns: List[_ColumnsClauseElement]
- _distinct = False
+ _distinct: bool = False
_distinct_on: Tuple[ColumnElement[Any], ...] = ()
_correlate: Tuple[FromClause, ...] = ()
_correlate_except: Optional[Tuple[FromClause, ...]] = None
@@ -5478,8 +5498,8 @@ class Select(
return iter(self._all_selected_columns)
- def is_derived_from(self, fromclause: FromClause) -> bool:
- if self in fromclause._cloned_set:
+ def is_derived_from(self, fromclause: Optional[FromClause]) -> bool:
+ if fromclause is not None and self in fromclause._cloned_set:
return True
for f in self._iterate_from_elements():
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
index aceed99a5..94e635740 100644
--- a/lib/sqlalchemy/sql/traversals.py
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -19,6 +19,7 @@ from typing import Callable
from typing import Deque
from typing import Dict
from typing import Iterable
+from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
@@ -39,7 +40,7 @@ COMPARE_FAILED = False
COMPARE_SUCCEEDED = True
-def compare(obj1, obj2, **kw):
+def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
strategy: TraversalComparatorStrategy
if kw.get("use_proxies", False):
strategy = ColIdentityComparatorStrategy()
@@ -49,7 +50,7 @@ def compare(obj1, obj2, **kw):
return strategy.compare(obj1, obj2, **kw)
-def _preconfigure_traversals(target_hierarchy):
+def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
for cls in util.walk_subclasses(target_hierarchy):
if hasattr(cls, "_generate_cache_attrs") and hasattr(
cls, "_traverse_internals"
@@ -482,14 +483,22 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
def __init__(self):
self.stack: Deque[
- Tuple[ExternallyTraversible, ExternallyTraversible]
+ Tuple[
+ Optional[ExternallyTraversible],
+ Optional[ExternallyTraversible],
+ ]
] = deque()
self.cache = set()
def _memoized_attr_anon_map(self):
return (anon_map(), anon_map())
- def compare(self, obj1, obj2, **kw):
+ def compare(
+ self,
+ obj1: ExternallyTraversible,
+ obj2: ExternallyTraversible,
+ **kw: Any,
+ ) -> bool:
stack = self.stack
cache = self.cache
@@ -551,6 +560,10 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
elif left_attrname in attributes_compared:
continue
+ assert left_visit_sym is not None
+ assert left_attrname is not None
+ assert right_attrname is not None
+
dispatch = self.dispatch(left_visit_sym)
assert dispatch, (
f"{self.__class__} has no dispatch for "
@@ -595,6 +608,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in zip_longest(left, right, fillvalue=None):
+ if l is None:
+ if r is not None:
+ return COMPARE_FAILED
+ else:
+ continue
+ elif r is None:
+ return COMPARE_FAILED
+
if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
self.anon_map[1], []
):
@@ -604,6 +625,14 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
self, attrname, left_parent, left, right_parent, right, **kw
):
for l, r in zip_longest(left, right, fillvalue=None):
+ if l is None:
+ if r is not None:
+ return COMPARE_FAILED
+ else:
+ continue
+ elif r is None:
+ return COMPARE_FAILED
+
if (
l._gen_cache_key(self.anon_map[0], [])
if l._is_has_cache_key
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 262689128..390e23952 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -73,6 +73,7 @@ if typing.TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ._typing import _EquivalentColumnMap
from ._typing import _TypeEngineArgument
+ from .elements import BinaryExpression
from .elements import TextClause
from .selectable import _JoinTargetElement
from .selectable import _SelectIterable
@@ -86,8 +87,15 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.row import Row
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
-def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
+
+def join_condition(
+ a: FromClause,
+ b: FromClause,
+ a_subset: Optional[FromClause] = None,
+ consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
+) -> ColumnElement[bool]:
"""Create a join condition between two tables or selectables.
e.g.::
@@ -118,7 +126,9 @@ def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
)
-def find_join_source(clauses, join_to):
+def find_join_source(
+ clauses: List[FromClause], join_to: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
@@ -144,7 +154,9 @@ def find_join_source(clauses, join_to):
return idx
-def find_left_clause_that_matches_given(clauses, join_from):
+def find_left_clause_that_matches_given(
+ clauses: Sequence[FromClause], join_from: FromClause
+) -> List[int]:
"""Given a list of FROM clauses and a selectable,
return the indexes from the list of
clauses which is derived from the selectable.
@@ -243,7 +255,12 @@ def find_left_clause_to_join_from(
return idx
-def visit_binary_product(fn, expr):
+def visit_binary_product(
+ fn: Callable[
+ [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
+ ],
+ expr: ColumnElement[Any],
+) -> None:
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
@@ -278,19 +295,19 @@ def visit_binary_product(fn, expr):
a binary comparison is passed as pairs.
"""
- stack: List[ClauseElement] = []
+ stack: List[BinaryExpression[Any]] = []
- def visit(element):
+ def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
if isinstance(element, ScalarSelect):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
elif element.__visit_name__ == "binary" and operators.is_comparison(
- element.operator
+ element.operator # type: ignore
):
- stack.insert(0, element)
- for l in visit(element.left):
- for r in visit(element.right):
+ stack.insert(0, element) # type: ignore
+ for l in visit(element.left): # type: ignore
+ for r in visit(element.right): # type: ignore
fn(stack[0], l, r)
stack.pop(0)
for elem in element.get_children():
@@ -502,7 +519,7 @@ def extract_first_column_annotation(column, annotation_name):
return None
-def selectables_overlap(left, right):
+def selectables_overlap(left: FromClause, right: FromClause) -> bool:
"""Return True if left/right have some overlapping selectable"""
return bool(
@@ -701,7 +718,7 @@ class _repr_params(_repr_base):
return "[%s]" % (", ".join(trunc(value) for value in params))
-def adapt_criterion_to_null(crit, nulls):
+def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
"""given criterion containing bind params, convert selected elements
to IS NULL.
@@ -922,9 +939,6 @@ def criterion_as_pairs(
return pairs
-_CE = TypeVar("_CE", bound="ClauseElement")
-
-
class ClauseAdapter(visitors.ReplacingExternalTraversal):
"""Clones and modifies clauses based on column correspondence.
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 217e2d2ab..b550f8f28 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -21,7 +21,6 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import ClassVar
-from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -31,6 +30,7 @@ from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -42,6 +42,10 @@ from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
+if TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .elements import ColumnElement
+
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import prefix_anon_map as prefix_anon_map
from ._py_util import cache_anon_map as anon_map
@@ -590,13 +594,23 @@ _dispatch_lookup = HasTraversalDispatch._dispatch_lookup
_generate_traversal_dispatch()
+SelfExternallyTraversible = TypeVar(
+ "SelfExternallyTraversible", bound="ExternallyTraversible"
+)
+
+
class ExternallyTraversible(HasTraverseInternals, Visitable):
__slots__ = ()
- _annotations: Collection[Any] = ()
+ _annotations: Mapping[Any, Any] = util.EMPTY_DICT
if typing.TYPE_CHECKING:
+ def _annotate(
+ self: SelfExternallyTraversible, values: _AnnotationDict
+ ) -> SelfExternallyTraversible:
+ ...
+
def get_children(
self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
@@ -624,6 +638,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
_TraverseCallableType = Callable[[_ET], None]
@@ -633,10 +648,8 @@ class _CloneCallableType(Protocol):
...
-class _TraverseTransformCallableType(Protocol):
- def __call__(
- self, element: ExternallyTraversible, **kw: Any
- ) -> Optional[ExternallyTraversible]:
+class _TraverseTransformCallableType(Protocol[_ET]):
+ def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]:
...
@@ -1074,16 +1087,25 @@ def cloned_traverse(
def replacement_traverse(
obj: Literal[None],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> None:
...
@overload
def replacement_traverse(
+ obj: _CE,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType[Any],
+) -> _CE:
+ ...
+
+
+@overload
+def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> ExternallyTraversible:
...
@@ -1091,7 +1113,7 @@ def replacement_traverse(
def replacement_traverse(
obj: Optional[ExternallyTraversible],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
@@ -1134,7 +1156,7 @@ def replacement_traverse(
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
- return newelem
+ return newelem # type: ignore
else:
# base "already seen" on id(), not hash, so that we don't
# replace an Annotated element with its non-annotated one, and
@@ -1145,11 +1167,11 @@ def replacement_traverse(
newelem = kw["replace"](elem)
if newelem is not None:
cloned[id_elem] = newelem
- return newelem
+ return newelem # type: ignore
cloned[id_elem] = newelem = elem._clone(**kw)
newelem._copy_internals(clone=clone, **kw)
- return cloned[id_elem]
+ return cloned[id_elem] # type: ignore
if obj is not None:
obj = clone(