summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/path_registry.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/path_registry.py')
-rw-r--r--lib/sqlalchemy/orm/path_registry.py126
1 files changed, 78 insertions, 48 deletions
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
index 361cea975..36c14a672 100644
--- a/lib/sqlalchemy/orm/path_registry.py
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
from ..sql.cache_key import _CacheKeyTraversalType
from ..sql.elements import BindParameter
from ..sql.visitors import anon_map
+ from ..util.typing import _LiteralStar
from ..util.typing import TypeGuard
def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]:
@@ -80,7 +81,7 @@ def _unreduce_path(path: _SerializedPath) -> PathRegistry:
return PathRegistry.deserialize(path)
-_WILDCARD_TOKEN = "*"
+_WILDCARD_TOKEN: _LiteralStar = "*"
_DEFAULT_TOKEN = "_sa_default"
@@ -115,6 +116,7 @@ class PathRegistry(HasCacheKey):
is_token = False
is_root = False
has_entity = False
+ is_property = False
is_entity = False
path: _PathRepresentation
@@ -175,7 +177,40 @@ class PathRegistry(HasCacheKey):
def __hash__(self) -> int:
return id(self)
- def __getitem__(self, key: Any) -> PathRegistry:
+ @overload
+ def __getitem__(self, entity: str) -> TokenRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: int) -> _PathElementType:
+ ...
+
+ @overload
+ def __getitem__(self, entity: slice) -> _PathRepresentation:
+ ...
+
+ @overload
+ def __getitem__(
+ self, entity: _InternalEntityType[Any]
+ ) -> AbstractEntityRegistry:
+ ...
+
+ @overload
+ def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
+ ...
+
+ def __getitem__(
+ self,
+ entity: Union[
+ str, int, slice, _InternalEntityType[Any], MapperProperty[Any]
+ ],
+ ) -> Union[
+ TokenRegistry,
+ _PathElementType,
+ _PathRepresentation,
+ PropRegistry,
+ AbstractEntityRegistry,
+ ]:
raise NotImplementedError()
# TODO: what are we using this for?
@@ -343,18 +378,8 @@ class RootRegistry(CreatesToken):
is_root = True
is_unnatural = False
- @overload
- def __getitem__(self, entity: str) -> TokenRegistry:
- ...
-
- @overload
- def __getitem__(
- self, entity: _InternalEntityType[Any]
- ) -> AbstractEntityRegistry:
- ...
-
- def __getitem__(
- self, entity: Union[str, _InternalEntityType[Any]]
+ def _getitem(
+ self, entity: Any
) -> Union[TokenRegistry, AbstractEntityRegistry]:
if entity in PathToken._intern:
if TYPE_CHECKING:
@@ -368,6 +393,9 @@ class RootRegistry(CreatesToken):
f"invalid argument for RootRegistry.__getitem__: {entity}"
)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
PathRegistry.root = RootRegistry()
@@ -441,12 +469,15 @@ class TokenRegistry(PathRegistry):
else:
yield self
- def __getitem__(self, entity: Any) -> Any:
+ def _getitem(self, entity: Any) -> Any:
try:
return self.path[entity]
except TypeError as err:
raise IndexError(f"{entity}") from err
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class PropRegistry(PathRegistry):
__slots__ = (
@@ -463,6 +494,7 @@ class PropRegistry(PathRegistry):
"is_unnatural",
)
inherit_cache = True
+ is_property = True
prop: MapperProperty[Any]
mapper: Optional[Mapper[Any]]
@@ -557,21 +589,7 @@ class PropRegistry(PathRegistry):
assert self.entity is not None
return self[self.entity]
- @overload
- def __getitem__(self, entity: slice) -> _PathRepresentation:
- ...
-
- @overload
- def __getitem__(self, entity: int) -> _PathElementType:
- ...
-
- @overload
- def __getitem__(
- self, entity: _InternalEntityType[Any]
- ) -> AbstractEntityRegistry:
- ...
-
- def __getitem__(
+ def _getitem(
self, entity: Union[int, slice, _InternalEntityType[Any]]
) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]:
if isinstance(entity, (int, slice)):
@@ -579,6 +597,9 @@ class PropRegistry(PathRegistry):
else:
return SlotsEntityRegistry(self, entity)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class AbstractEntityRegistry(CreatesToken):
__slots__ = (
@@ -643,6 +664,10 @@ class AbstractEntityRegistry(CreatesToken):
self.natural_path = self.path
@property
+ def root_entity(self) -> _InternalEntityType[Any]:
+ return cast("_InternalEntityType[Any]", self.path[0])
+
+ @property
def entity_path(self) -> PathRegistry:
return self
@@ -653,23 +678,7 @@ class AbstractEntityRegistry(CreatesToken):
def __bool__(self) -> bool:
return True
- @overload
- def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry:
- ...
-
- @overload
- def __getitem__(self, entity: str) -> TokenRegistry:
- ...
-
- @overload
- def __getitem__(self, entity: int) -> _PathElementType:
- ...
-
- @overload
- def __getitem__(self, entity: slice) -> _PathRepresentation:
- ...
-
- def __getitem__(
+ def _getitem(
self, entity: Any
) -> Union[_PathElementType, _PathRepresentation, PathRegistry]:
if isinstance(entity, (int, slice)):
@@ -679,6 +688,9 @@ class AbstractEntityRegistry(CreatesToken):
else:
return PropRegistry(self, entity)
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
class SlotsEntityRegistry(AbstractEntityRegistry):
# for aliased class, return lightweight, no-cycles created
@@ -715,10 +727,28 @@ class CachingEntityRegistry(AbstractEntityRegistry):
def pop(self, key: Any, default: Any) -> Any:
return self._cache.pop(key, default)
- def __getitem__(self, entity: Any) -> Any:
+ def _getitem(self, entity: Any) -> Any:
if isinstance(entity, (int, slice)):
return self.path[entity]
elif isinstance(entity, PathToken):
return TokenRegistry(self, entity)
else:
return self._cache[entity]
+
+ if not TYPE_CHECKING:
+ __getitem__ = _getitem
+
+
+if TYPE_CHECKING:
+
+ def path_is_entity(
+ path: PathRegistry,
+ ) -> TypeGuard[AbstractEntityRegistry]:
+ ...
+
+ def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]:
+ ...
+
+else:
+ path_is_entity = operator.attrgetter("is_entity")
+ path_is_property = operator.attrgetter("is_property")