diff options
Diffstat (limited to 'lib/sqlalchemy/sql/annotation.py')
-rw-r--r-- | lib/sqlalchemy/sql/annotation.py | 305 |
1 files changed, 239 insertions, 66 deletions
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index b76393ad6..7afc2de97 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -13,22 +13,77 @@ associations. from __future__ import annotations +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar + from . import operators -from .base import HasCacheKey -from .traversals import anon_map +from .cache_key import HasCacheKey +from .visitors import anon_map +from .visitors import ExternallyTraversible from .visitors import InternalTraversal from .. import util +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .visitors import _TraverseInternalsType + from ..util.typing import Self + +_AnnotationDict = Mapping[str, Any] + +EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT + -EMPTY_ANNOTATIONS = util.immutabledict() +SelfSupportsAnnotations = TypeVar( + "SelfSupportsAnnotations", bound="SupportsAnnotations" +) -class SupportsAnnotations: +class SupportsAnnotations(ExternallyTraversible): __slots__ = () - _annotations = EMPTY_ANNOTATIONS + _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS + proxy_set: Set[SupportsAnnotations] + _is_immutable: bool + + def _annotate(self, values: _AnnotationDict) -> SupportsAnnotations: + raise NotImplementedError() + + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + raise NotImplementedError() @util.memoized_property - def _annotations_cache_key(self): + def _annotations_cache_key(self) -> Tuple[Any, ...]: anon_map_ = anon_map() return ( "_annotations", @@ -47,14 +102,22 @@ class SupportsAnnotations: ) +SelfSupportsCloneAnnotations = TypeVar( + "SelfSupportsCloneAnnotations", bound="SupportsCloneAnnotations" +) + + class SupportsCloneAnnotations(SupportsAnnotations): - __slots__ = () + if not typing.TYPE_CHECKING: + __slots__ = () - _clone_annotations_traverse_internals = [ + _clone_annotations_traverse_internals: _TraverseInternalsType = [ ("_annotations", InternalTraversal.dp_annotations_key) ] - def _annotate(self, values): + def _annotate( + self: SelfSupportsCloneAnnotations, values: _AnnotationDict + ) -> SelfSupportsCloneAnnotations: """return a copy of this ClauseElement with annotations updated by the given dictionary. @@ -65,7 +128,9 @@ class SupportsCloneAnnotations(SupportsAnnotations): new.__dict__.pop("_generate_cache_key", None) return new - def _with_annotations(self, values): + def _with_annotations( + self: SelfSupportsCloneAnnotations, values: _AnnotationDict + ) -> SelfSupportsCloneAnnotations: """return a copy of this ClauseElement with annotations replaced by the given dictionary. @@ -76,7 +141,27 @@ class SupportsCloneAnnotations(SupportsAnnotations): new.__dict__.pop("_generate_cache_key", None) return new - def _deannotate(self, values=None, clone=False): + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: """return a copy of this :class:`_expression.ClauseElement` with annotations removed. @@ -96,24 +181,52 @@ class SupportsCloneAnnotations(SupportsAnnotations): return self +SelfSupportsWrappingAnnotations = TypeVar( + "SelfSupportsWrappingAnnotations", bound="SupportsWrappingAnnotations" +) + + class SupportsWrappingAnnotations(SupportsAnnotations): __slots__ = () - def _annotate(self, values): + _constructor: Callable[..., SupportsWrappingAnnotations] + entity_namespace: Mapping[str, Any] + + def _annotate(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations updated by the given dictionary. """ - return Annotated(self, values) + return Annotated._as_annotated_instance(self, values) - def _with_annotations(self, values): + def _with_annotations(self, values: _AnnotationDict) -> Annotated: """return a copy of this ClauseElement with annotations replaced by the given dictionary. """ - return Annotated(self, values) - - def _deannotate(self, values=None, clone=False): + return Annotated._as_annotated_instance(self, values) + + @overload + def _deannotate( + self: SelfSupportsAnnotations, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfSupportsAnnotations: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: """return a copy of this :class:`_expression.ClauseElement` with annotations removed. @@ -129,8 +242,11 @@ class SupportsWrappingAnnotations(SupportsAnnotations): return self -class Annotated: - """clones a SupportsAnnotated and applies an 'annotations' dictionary. +SelfAnnotated = TypeVar("SelfAnnotated", bound="Annotated") + + +class Annotated(SupportsAnnotations): + """clones a SupportsAnnotations and applies an 'annotations' dictionary. Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place @@ -151,21 +267,26 @@ class Annotated: _is_column_operators = False - def __new__(cls, *args): - if not args: - # clone constructor - return object.__new__(cls) - else: - element, values = args - # pull appropriate subclass from registry of annotated - # classes - try: - cls = annotated_classes[element.__class__] - except KeyError: - cls = _new_annotation_type(element.__class__, cls) - return object.__new__(cls) - - def __init__(self, element, values): + @classmethod + def _as_annotated_instance( + cls, element: SupportsWrappingAnnotations, values: _AnnotationDict + ) -> Annotated: + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return cls(element, values) + + _annotations: util.immutabledict[str, Any] + __element: SupportsWrappingAnnotations + _hash: int + + def __new__(cls: Type[SelfAnnotated], *args: Any) -> SelfAnnotated: + return object.__new__(cls) + + def __init__( + self, element: SupportsWrappingAnnotations, values: _AnnotationDict + ): self.__dict__ = element.__dict__.copy() self.__dict__.pop("_annotations_cache_key", None) self.__dict__.pop("_generate_cache_key", None) @@ -173,11 +294,15 @@ class Annotated: self._annotations = util.immutabledict(values) self._hash = hash(element) - def _annotate(self, values): + def _annotate( + self: SelfAnnotated, values: _AnnotationDict + ) -> SelfAnnotated: _values = self._annotations.union(values) return self._with_annotations(_values) - def _with_annotations(self, values): + def _with_annotations( + self: SelfAnnotated, values: util.immutabledict[str, Any] + ) -> SelfAnnotated: clone = self.__class__.__new__(self.__class__) clone.__dict__ = self.__dict__.copy() clone.__dict__.pop("_annotations_cache_key", None) @@ -185,7 +310,27 @@ class Annotated: clone._annotations = values return clone - def _deannotate(self, values=None, clone=True): + @overload + def _deannotate( + self: SelfAnnotated, + values: Literal[None] = ..., + clone: bool = ..., + ) -> SelfAnnotated: + ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> Annotated: + ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = True, + ) -> SupportsAnnotations: if values is None: return self.__element else: @@ -199,14 +344,18 @@ class Annotated: ) ) - def _compiler_dispatch(self, visitor, **kw): - return self.__element.__class__._compiler_dispatch(self, visitor, **kw) + if not typing.TYPE_CHECKING: + # manually proxy some methods that need extra attention + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any: + return self.__element.__class__._compiler_dispatch( + self, visitor, **kw + ) - @property - def _constructor(self): - return self.__element._constructor + @property + def _constructor(self): + return self.__element._constructor - def _clone(self, **kw): + def _clone(self: SelfAnnotated, **kw: Any) -> SelfAnnotated: clone = self.__element._clone(**kw) if clone is self.__element: # detect immutable, don't change anything @@ -217,22 +366,25 @@ class Annotated: clone.__dict__.update(self.__dict__) return self.__class__(clone, self._annotations) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]: return self.__class__, (self.__element, self._annotations) - def __hash__(self): + def __hash__(self) -> int: return self._hash - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if self._is_column_operators: return self.__element.__class__.__eq__(self, other) else: return hash(other) == hash(self) @property - def entity_namespace(self): + def entity_namespace(self) -> Mapping[str, Any]: if "entity_namespace" in self._annotations: - return self._annotations["entity_namespace"].entity_namespace + return cast( + SupportsWrappingAnnotations, + self._annotations["entity_namespace"], + ).entity_namespace else: return self.__element.entity_namespace @@ -242,12 +394,19 @@ class Annotated: # so that the resulting objects are pickleable; additionally, other # decisions can be made up front about the type of object being annotated # just once per class rather than per-instance. -annotated_classes = {} +annotated_classes: Dict[ + Type[SupportsWrappingAnnotations], Type[Annotated] +] = {} + +_SA = TypeVar("_SA", bound="SupportsAnnotations") def _deep_annotate( - element, annotations, exclude=None, detect_subquery_cols=False -): + element: _SA, + annotations: _AnnotationDict, + exclude: Optional[Sequence[SupportsAnnotations]] = None, + detect_subquery_cols: bool = False, +) -> _SA: """Deep copy the given ClauseElement, annotating each element with the given annotations dictionary. @@ -258,9 +417,9 @@ def _deep_annotate( # annotated objects hack the __hash__() method so if we want to # uniquely process them we have to use id() - cloned_ids = {} + cloned_ids: Dict[int, SupportsAnnotations] = {} - def clone(elem, **kw): + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: kw["detect_subquery_cols"] = detect_subquery_cols id_ = id(elem) @@ -285,17 +444,20 @@ def _deep_annotate( return newelem if element is not None: - element = clone(element) - clone = None # remove gc cycles + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles return element -def _deep_deannotate(element, values=None): +def _deep_deannotate( + element: _SA, values: Optional[Sequence[str]] = None +) -> _SA: """Deep copy the given element, removing annotations.""" - cloned = {} + cloned: Dict[Any, SupportsAnnotations] = {} - def clone(elem, **kw): + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + key: Any if values: key = id(elem) else: @@ -310,12 +472,14 @@ def _deep_deannotate(element, values=None): return cloned[key] if element is not None: - element = clone(element) - clone = None # remove gc cycles + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles return element -def _shallow_annotate(element, annotations): +def _shallow_annotate( + element: SupportsAnnotations, annotations: _AnnotationDict +) -> SupportsAnnotations: """Annotate the given ClauseElement and copy its internals so that internal objects refer to the new annotated object. @@ -328,7 +492,13 @@ def _shallow_annotate(element, annotations): return element -def _new_annotation_type(cls, base_cls): +def _new_annotation_type( + cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated] +) -> Type[Annotated]: + """Generates a new class that subclasses Annotated and proxies a given + element type. + + """ if issubclass(cls, Annotated): return cls elif cls in annotated_classes: @@ -342,8 +512,9 @@ def _new_annotation_type(cls, base_cls): base_cls = annotated_classes[super_] break - annotated_classes[cls] = anno_cls = type( - "Annotated%s" % cls.__name__, (base_cls, cls), {} + annotated_classes[cls] = anno_cls = cast( + Type[Annotated], + type("Annotated%s" % cls.__name__, (base_cls, cls), {}), ) globals()["Annotated%s" % cls.__name__] = anno_cls @@ -359,13 +530,15 @@ def _new_annotation_type(cls, base_cls): # some classes include this even if they have traverse_internals # e.g. BindParameter, add it if present. if cls.__dict__.get("inherit_cache", False): - anno_cls.inherit_cache = True + anno_cls.inherit_cache = True # type: ignore anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) return anno_cls -def _prepare_annotations(target_hierarchy, base_cls): +def _prepare_annotations( + target_hierarchy: Type[SupportsAnnotations], base_cls: Type[Annotated] +) -> None: for cls in util.walk_subclasses(target_hierarchy): _new_annotation_type(cls, base_cls) |