summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/base.py')
-rw-r--r--lib/sqlalchemy/orm/base.py81
1 files changed, 47 insertions, 34 deletions
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
index 0ace9b1cb..63f873fd0 100644
--- a/lib/sqlalchemy/orm/base.py
+++ b/lib/sqlalchemy/orm/base.py
@@ -18,6 +18,7 @@ from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
+from typing import no_type_check
from typing import Optional
from typing import overload
from typing import Type
@@ -35,17 +36,20 @@ from ..sql.elements import SQLCoreOperations
from ..util import FastIntFlag
from ..util.langhelpers import TypingOnly
from ..util.typing import Literal
-from ..util.typing import Self
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _ExternalEntityType
from ._typing import _InternalEntityType
from .attributes import InstrumentedAttribute
from .instrumentation import ClassManager
+ from .interfaces import PropComparator
from .mapper import Mapper
from .state import InstanceState
from .util import AliasedClass
+ from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _InfoType
+ from ..sql.elements import ColumnElement
_T = TypeVar("_T", bound=Any)
@@ -191,35 +195,34 @@ EXT_CONTINUE = util.symbol("EXT_CONTINUE")
EXT_STOP = util.symbol("EXT_STOP")
EXT_SKIP = util.symbol("EXT_SKIP")
-ONETOMANY = util.symbol(
- "ONETOMANY",
+
+class RelationshipDirection(Enum):
+ ONETOMANY = 1
"""Indicates the one-to-many direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
-MANYTOONE = util.symbol(
- "MANYTOONE",
+ MANYTOONE = 2
"""Indicates the many-to-one direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
-MANYTOMANY = util.symbol(
- "MANYTOMANY",
+ MANYTOMANY = 3
"""Indicates the many-to-many direction for a :func:`_orm.relationship`.
This symbol is typically used by the internals but may be exposed within
certain API features.
- """,
-)
+ """
+
+
+ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection)
class InspectionAttrExtensionType(Enum):
@@ -249,7 +252,7 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
-_F = TypeVar("_F", bound=Callable)
+_F = TypeVar("_F", bound=Callable[..., Any])
_Self = TypeVar("_Self")
@@ -397,29 +400,34 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]:
return None
-def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]:
+def _class_to_mapper(
+ class_or_mapper: Union[Mapper[_T], Type[_T]]
+) -> Mapper[_T]:
+ # can't get mypy to see an overload for this
insp = inspection.inspect(class_or_mapper, False)
if insp is not None:
- return insp.mapper
+ return insp.mapper # type: ignore
else:
+ assert isinstance(class_or_mapper, type)
raise exc.UnmappedClassError(class_or_mapper)
def _mapper_or_none(
- entity: Union[_T, _InternalEntityType[_T]]
+ entity: Union[Type[_T], _InternalEntityType[_T]]
) -> Optional[Mapper[_T]]:
"""Return the :class:`_orm.Mapper` for the given class or None if the
class is not mapped.
"""
+ # can't get mypy to see an overload for this
insp = inspection.inspect(entity, False)
if insp is not None:
- return insp.mapper
+ return insp.mapper # type: ignore
else:
return None
-def _is_mapped_class(entity):
+def _is_mapped_class(entity: Any) -> bool:
"""Return True if the given object is a mapped class,
:class:`_orm.Mapper`, or :class:`.AliasedClass`.
"""
@@ -432,20 +440,13 @@ def _is_mapped_class(entity):
)
-def _orm_columns(entity):
- insp = inspection.inspect(entity, False)
- if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
- return [c for c in insp.selectable.c]
- else:
- return [entity]
-
-
-def _is_aliased_class(entity):
+def _is_aliased_class(entity: Any) -> bool:
insp = inspection.inspect(entity, False)
return insp is not None and getattr(insp, "is_aliased_class", False)
-def _entity_descriptor(entity, key):
+@no_type_check
+def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any:
"""Return a class attribute given an entity and string name.
May return :class:`.InstrumentedAttribute` or user-defined
@@ -651,16 +652,26 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
if typing.TYPE_CHECKING:
- def of_type(self, class_):
+ def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
...
- def and_(self, *criteria):
+ def and_(
+ self, *criteria: _ColumnExpressionArgument[bool]
+ ) -> PropComparator[bool]:
...
- def any(self, criterion=None, **kwargs): # noqa: A001
+ def any( # noqa: A001
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
...
- def has(self, criterion=None, **kwargs):
+ def has(
+ self,
+ criterion: Optional[_ColumnExpressionArgument[bool]] = None,
+ **kwargs: Any,
+ ) -> ColumnElement[bool]:
...
@@ -673,7 +684,9 @@ class ORMDescriptor(Generic[_T], TypingOnly):
if typing.TYPE_CHECKING:
@overload
- def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self:
+ def __get__(
+ self, instance: Any, owner: Literal[None]
+ ) -> ORMDescriptor[_T]:
...
@overload