diff options
Diffstat (limited to 'lib/sqlalchemy/orm/path_registry.py')
-rw-r--r-- | lib/sqlalchemy/orm/path_registry.py | 126 |
1 files changed, 78 insertions, 48 deletions
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index 361cea975..36c14a672 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from ..sql.cache_key import _CacheKeyTraversalType from ..sql.elements import BindParameter from ..sql.visitors import anon_map + from ..util.typing import _LiteralStar from ..util.typing import TypeGuard def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: @@ -80,7 +81,7 @@ def _unreduce_path(path: _SerializedPath) -> PathRegistry: return PathRegistry.deserialize(path) -_WILDCARD_TOKEN = "*" +_WILDCARD_TOKEN: _LiteralStar = "*" _DEFAULT_TOKEN = "_sa_default" @@ -115,6 +116,7 @@ class PathRegistry(HasCacheKey): is_token = False is_root = False has_entity = False + is_property = False is_entity = False path: _PathRepresentation @@ -175,7 +177,40 @@ class PathRegistry(HasCacheKey): def __hash__(self) -> int: return id(self) - def __getitem__(self, key: Any) -> PathRegistry: + @overload + def __getitem__(self, entity: str) -> TokenRegistry: + ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: + ... + + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + ... + + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... + + def __getitem__( + self, + entity: Union[ + str, int, slice, _InternalEntityType[Any], MapperProperty[Any] + ], + ) -> Union[ + TokenRegistry, + _PathElementType, + _PathRepresentation, + PropRegistry, + AbstractEntityRegistry, + ]: raise NotImplementedError() # TODO: what are we using this for? @@ -343,18 +378,8 @@ class RootRegistry(CreatesToken): is_root = True is_unnatural = False - @overload - def __getitem__(self, entity: str) -> TokenRegistry: - ... - - @overload - def __getitem__( - self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: - ... - - def __getitem__( - self, entity: Union[str, _InternalEntityType[Any]] + def _getitem( + self, entity: Any ) -> Union[TokenRegistry, AbstractEntityRegistry]: if entity in PathToken._intern: if TYPE_CHECKING: @@ -368,6 +393,9 @@ class RootRegistry(CreatesToken): f"invalid argument for RootRegistry.__getitem__: {entity}" ) + if not TYPE_CHECKING: + __getitem__ = _getitem + PathRegistry.root = RootRegistry() @@ -441,12 +469,15 @@ class TokenRegistry(PathRegistry): else: yield self - def __getitem__(self, entity: Any) -> Any: + def _getitem(self, entity: Any) -> Any: try: return self.path[entity] except TypeError as err: raise IndexError(f"{entity}") from err + if not TYPE_CHECKING: + __getitem__ = _getitem + class PropRegistry(PathRegistry): __slots__ = ( @@ -463,6 +494,7 @@ class PropRegistry(PathRegistry): "is_unnatural", ) inherit_cache = True + is_property = True prop: MapperProperty[Any] mapper: Optional[Mapper[Any]] @@ -557,21 +589,7 @@ class PropRegistry(PathRegistry): assert self.entity is not None return self[self.entity] - @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: - ... - - @overload - def __getitem__(self, entity: int) -> _PathElementType: - ... - - @overload - def __getitem__( - self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: - ... - - def __getitem__( + def _getitem( self, entity: Union[int, slice, _InternalEntityType[Any]] ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: if isinstance(entity, (int, slice)): @@ -579,6 +597,9 @@ class PropRegistry(PathRegistry): else: return SlotsEntityRegistry(self, entity) + if not TYPE_CHECKING: + __getitem__ = _getitem + class AbstractEntityRegistry(CreatesToken): __slots__ = ( @@ -643,6 +664,10 @@ class AbstractEntityRegistry(CreatesToken): self.natural_path = self.path @property + def root_entity(self) -> _InternalEntityType[Any]: + return cast("_InternalEntityType[Any]", self.path[0]) + + @property def entity_path(self) -> PathRegistry: return self @@ -653,23 +678,7 @@ class AbstractEntityRegistry(CreatesToken): def __bool__(self) -> bool: return True - @overload - def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: - ... - - @overload - def __getitem__(self, entity: str) -> TokenRegistry: - ... - - @overload - def __getitem__(self, entity: int) -> _PathElementType: - ... - - @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: - ... - - def __getitem__( + def _getitem( self, entity: Any ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: if isinstance(entity, (int, slice)): @@ -679,6 +688,9 @@ class AbstractEntityRegistry(CreatesToken): else: return PropRegistry(self, entity) + if not TYPE_CHECKING: + __getitem__ = _getitem + class SlotsEntityRegistry(AbstractEntityRegistry): # for aliased class, return lightweight, no-cycles created @@ -715,10 +727,28 @@ class CachingEntityRegistry(AbstractEntityRegistry): def pop(self, key: Any, default: Any) -> Any: return self._cache.pop(key, default) - def __getitem__(self, entity: Any) -> Any: + def _getitem(self, entity: Any) -> Any: if isinstance(entity, (int, slice)): return self.path[entity] elif isinstance(entity, PathToken): return TokenRegistry(self, entity) else: return self._cache[entity] + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +if TYPE_CHECKING: + + def path_is_entity( + path: PathRegistry, + ) -> TypeGuard[AbstractEntityRegistry]: + ... + + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: + ... + +else: + path_is_entity = operator.attrgetter("is_entity") + path_is_property = operator.attrgetter("is_property") |