diff options
Diffstat (limited to 'lib/sqlalchemy/orm/base.py')
-rw-r--r-- | lib/sqlalchemy/orm/base.py | 81 |
1 files changed, 47 insertions, 34 deletions
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 0ace9b1cb..63f873fd0 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -18,6 +18,7 @@ from typing import Any from typing import Callable from typing import Dict from typing import Generic +from typing import no_type_check from typing import Optional from typing import overload from typing import Type @@ -35,17 +36,20 @@ from ..sql.elements import SQLCoreOperations from ..util import FastIntFlag from ..util.langhelpers import TypingOnly from ..util.typing import Literal -from ..util.typing import Self if typing.TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _ExternalEntityType from ._typing import _InternalEntityType from .attributes import InstrumentedAttribute from .instrumentation import ClassManager + from .interfaces import PropComparator from .mapper import Mapper from .state import InstanceState from .util import AliasedClass + from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement _T = TypeVar("_T", bound=Any) @@ -191,35 +195,34 @@ EXT_CONTINUE = util.symbol("EXT_CONTINUE") EXT_STOP = util.symbol("EXT_STOP") EXT_SKIP = util.symbol("EXT_SKIP") -ONETOMANY = util.symbol( - "ONETOMANY", + +class RelationshipDirection(Enum): + ONETOMANY = 1 """Indicates the one-to-many direction for a :func:`_orm.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """, -) + """ -MANYTOONE = util.symbol( - "MANYTOONE", + MANYTOONE = 2 """Indicates the many-to-one direction for a :func:`_orm.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """, -) + """ -MANYTOMANY = util.symbol( - "MANYTOMANY", + MANYTOMANY = 3 """Indicates the many-to-many direction for a :func:`_orm.relationship`. This symbol is typically used by the internals but may be exposed within certain API features. - """, -) + """ + + +ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection) class InspectionAttrExtensionType(Enum): @@ -249,7 +252,7 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") -_F = TypeVar("_F", bound=Callable) +_F = TypeVar("_F", bound=Callable[..., Any]) _Self = TypeVar("_Self") @@ -397,29 +400,34 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: return None -def _class_to_mapper(class_or_mapper: Union[Mapper[_T], _T]) -> Mapper[_T]: +def _class_to_mapper( + class_or_mapper: Union[Mapper[_T], Type[_T]] +) -> Mapper[_T]: + # can't get mypy to see an overload for this insp = inspection.inspect(class_or_mapper, False) if insp is not None: - return insp.mapper + return insp.mapper # type: ignore else: + assert isinstance(class_or_mapper, type) raise exc.UnmappedClassError(class_or_mapper) def _mapper_or_none( - entity: Union[_T, _InternalEntityType[_T]] + entity: Union[Type[_T], _InternalEntityType[_T]] ) -> Optional[Mapper[_T]]: """Return the :class:`_orm.Mapper` for the given class or None if the class is not mapped. """ + # can't get mypy to see an overload for this insp = inspection.inspect(entity, False) if insp is not None: - return insp.mapper + return insp.mapper # type: ignore else: return None -def _is_mapped_class(entity): +def _is_mapped_class(entity: Any) -> bool: """Return True if the given object is a mapped class, :class:`_orm.Mapper`, or :class:`.AliasedClass`. """ @@ -432,20 +440,13 @@ def _is_mapped_class(entity): ) -def _orm_columns(entity): - insp = inspection.inspect(entity, False) - if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"): - return [c for c in insp.selectable.c] - else: - return [entity] - - -def _is_aliased_class(entity): +def _is_aliased_class(entity: Any) -> bool: insp = inspection.inspect(entity, False) return insp is not None and getattr(insp, "is_aliased_class", False) -def _entity_descriptor(entity, key): +@no_type_check +def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: """Return a class attribute given an entity and string name. May return :class:`.InstrumentedAttribute` or user-defined @@ -651,16 +652,26 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): if typing.TYPE_CHECKING: - def of_type(self, class_): + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: ... - def and_(self, *criteria): + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[bool]: ... - def any(self, criterion=None, **kwargs): # noqa: A001 + def any( # noqa: A001 + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: ... - def has(self, criterion=None, **kwargs): + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: ... @@ -673,7 +684,9 @@ class ORMDescriptor(Generic[_T], TypingOnly): if typing.TYPE_CHECKING: @overload - def __get__(self: Self, instance: Any, owner: Literal[None]) -> Self: + def __get__( + self, instance: Any, owner: Literal[None] + ) -> ORMDescriptor[_T]: ... @overload |