summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/visitors.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/visitors.py')
-rw-r--r--lib/sqlalchemy/sql/visitors.py46
1 files changed, 34 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index 217e2d2ab..b550f8f28 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -21,7 +21,6 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import ClassVar
-from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -31,6 +30,7 @@ from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -42,6 +42,10 @@ from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
+if TYPE_CHECKING:
+ from .annotation import _AnnotationDict
+ from .elements import ColumnElement
+
if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_util import prefix_anon_map as prefix_anon_map
from ._py_util import cache_anon_map as anon_map
@@ -590,13 +594,23 @@ _dispatch_lookup = HasTraversalDispatch._dispatch_lookup
_generate_traversal_dispatch()
+SelfExternallyTraversible = TypeVar(
+ "SelfExternallyTraversible", bound="ExternallyTraversible"
+)
+
+
class ExternallyTraversible(HasTraverseInternals, Visitable):
__slots__ = ()
- _annotations: Collection[Any] = ()
+ _annotations: Mapping[Any, Any] = util.EMPTY_DICT
if typing.TYPE_CHECKING:
+ def _annotate(
+ self: SelfExternallyTraversible, values: _AnnotationDict
+ ) -> SelfExternallyTraversible:
+ ...
+
def get_children(
self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any
) -> Iterable[ExternallyTraversible]:
@@ -624,6 +638,7 @@ class ExternallyTraversible(HasTraverseInternals, Visitable):
_ET = TypeVar("_ET", bound=ExternallyTraversible)
+_CE = TypeVar("_CE", bound="ColumnElement[Any]")
_TraverseCallableType = Callable[[_ET], None]
@@ -633,10 +648,8 @@ class _CloneCallableType(Protocol):
...
-class _TraverseTransformCallableType(Protocol):
- def __call__(
- self, element: ExternallyTraversible, **kw: Any
- ) -> Optional[ExternallyTraversible]:
+class _TraverseTransformCallableType(Protocol[_ET]):
+ def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]:
...
@@ -1074,16 +1087,25 @@ def cloned_traverse(
def replacement_traverse(
obj: Literal[None],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> None:
...
@overload
def replacement_traverse(
+ obj: _CE,
+ opts: Mapping[str, Any],
+ replace: _TraverseTransformCallableType[Any],
+) -> _CE:
+ ...
+
+
+@overload
+def replacement_traverse(
obj: ExternallyTraversible,
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> ExternallyTraversible:
...
@@ -1091,7 +1113,7 @@ def replacement_traverse(
def replacement_traverse(
obj: Optional[ExternallyTraversible],
opts: Mapping[str, Any],
- replace: _TraverseTransformCallableType,
+ replace: _TraverseTransformCallableType[Any],
) -> Optional[ExternallyTraversible]:
"""Clone the given expression structure, allowing element
replacement by a given replacement function.
@@ -1134,7 +1156,7 @@ def replacement_traverse(
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
- return newelem
+ return newelem # type: ignore
else:
# base "already seen" on id(), not hash, so that we don't
# replace an Annotated element with its non-annotated one, and
@@ -1145,11 +1167,11 @@ def replacement_traverse(
newelem = kw["replace"](elem)
if newelem is not None:
cloned[id_elem] = newelem
- return newelem
+ return newelem # type: ignore
cloned[id_elem] = newelem = elem._clone(**kw)
newelem._copy_internals(clone=clone, **kw)
- return cloned[id_elem]
+ return cloned[id_elem] # type: ignore
if obj is not None:
obj = clone(