diff options
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 217e2d2ab..b550f8f28 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -21,7 +21,6 @@ from typing import Any from typing import Callable from typing import cast from typing import ClassVar -from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -31,6 +30,7 @@ from typing import Optional from typing import overload from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -42,6 +42,10 @@ from ..util.typing import Literal from ..util.typing import Protocol from ..util.typing import Self +if TYPE_CHECKING: + from .annotation import _AnnotationDict + from .elements import ColumnElement + if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import prefix_anon_map as prefix_anon_map from ._py_util import cache_anon_map as anon_map @@ -590,13 +594,23 @@ _dispatch_lookup = HasTraversalDispatch._dispatch_lookup _generate_traversal_dispatch() +SelfExternallyTraversible = TypeVar( + "SelfExternallyTraversible", bound="ExternallyTraversible" +) + + class ExternallyTraversible(HasTraverseInternals, Visitable): __slots__ = () - _annotations: Collection[Any] = () + _annotations: Mapping[Any, Any] = util.EMPTY_DICT if typing.TYPE_CHECKING: + def _annotate( + self: SelfExternallyTraversible, values: _AnnotationDict + ) -> SelfExternallyTraversible: + ... + def get_children( self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any ) -> Iterable[ExternallyTraversible]: @@ -624,6 +638,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): _ET = TypeVar("_ET", bound=ExternallyTraversible) +_CE = TypeVar("_CE", bound="ColumnElement[Any]") _TraverseCallableType = Callable[[_ET], None] @@ -633,10 +648,8 @@ class _CloneCallableType(Protocol): ... -class _TraverseTransformCallableType(Protocol): - def __call__( - self, element: ExternallyTraversible, **kw: Any - ) -> Optional[ExternallyTraversible]: +class _TraverseTransformCallableType(Protocol[_ET]): + def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: ... @@ -1074,16 +1087,25 @@ def cloned_traverse( def replacement_traverse( obj: Literal[None], opts: Mapping[str, Any], - replace: _TraverseTransformCallableType, + replace: _TraverseTransformCallableType[Any], ) -> None: ... @overload def replacement_traverse( + obj: _CE, + opts: Mapping[str, Any], + replace: _TraverseTransformCallableType[Any], +) -> _CE: + ... + + +@overload +def replacement_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], - replace: _TraverseTransformCallableType, + replace: _TraverseTransformCallableType[Any], ) -> ExternallyTraversible: ... @@ -1091,7 +1113,7 @@ def replacement_traverse( def replacement_traverse( obj: Optional[ExternallyTraversible], opts: Mapping[str, Any], - replace: _TraverseTransformCallableType, + replace: _TraverseTransformCallableType[Any], ) -> Optional[ExternallyTraversible]: """Clone the given expression structure, allowing element replacement by a given replacement function. @@ -1134,7 +1156,7 @@ def replacement_traverse( newelem = replace(elem) if newelem is not None: stop_on.add(id(newelem)) - return newelem + return newelem # type: ignore else: # base "already seen" on id(), not hash, so that we don't # replace an Annotated element with its non-annotated one, and @@ -1145,11 +1167,11 @@ def replacement_traverse( newelem = kw["replace"](elem) if newelem is not None: cloned[id_elem] = newelem - return newelem + return newelem # type: ignore cloned[id_elem] = newelem = elem._clone(**kw) newelem._copy_internals(clone=clone, **kw) - return cloned[id_elem] + return cloned[id_elem] # type: ignore if obj is not None: obj = clone( |