diff options
Diffstat (limited to 'lib/sqlalchemy/orm/scoping.py')
-rw-r--r-- | lib/sqlalchemy/orm/scoping.py | 219 |
1 files changed, 198 insertions, 21 deletions
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 93d18b8d7..9220c44c7 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -13,6 +13,7 @@ from typing import Dict from typing import Iterable from typing import Iterator from typing import Optional +from typing import overload from typing import Sequence from typing import Tuple from typing import Type @@ -20,8 +21,6 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from . import exc as orm_exc -from .base import class_mapper from .session import Session from .. import exc as sa_exc from .. import util @@ -33,11 +32,13 @@ from ..util import warn_deprecated from ..util.typing import Protocol if TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _IdentityKeyType from .identity import IdentityMap from .interfaces import ORMOption from .mapper import Mapper from .query import Query + from .query import RowReturningQuery from .session import _BindArguments from .session import _EntityBindKey from .session import _PKIdentityArgument @@ -48,19 +49,33 @@ if TYPE_CHECKING: from ..engine import Engine from ..engine import Result from ..engine import Row + from ..engine import RowMapping from ..engine.interfaces import _CoreAnyExecuteParams from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.result import ScalarResult from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole from ..sql.selectable import ForUpdateArg + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) class _QueryDescriptorType(Protocol): - def __get__(self, instance: Any, owner: Type[Any]) -> Optional[Query[Any]]: + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... @@ -236,7 +251,7 @@ class scoped_session: self.registry.clear() def query_property( - self, query_cls: Optional[Type[Query[Any]]] = None + self, query_cls: Optional[Type[Query[_T]]] = None ) -> _QueryDescriptorType: """return a class property which produces a :class:`_query.Query` object @@ -264,20 +279,13 @@ class scoped_session: """ class query: - def __get__( - s, instance: Any, owner: Type[Any] - ) -> Optional[Query[Any]]: - try: - mapper = class_mapper(owner) - assert mapper is not None - if query_cls: - # custom query class - return query_cls(mapper, session=self.registry()) - else: - # session's configured query class - return self.registry().query(mapper) - except orm_exc.UnmappedClassError: - return None + def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]: + if query_cls: + # custom query class + return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501 + else: + # session's configured query class + return self.registry().query(owner) return query() @@ -548,6 +556,32 @@ class scoped_session: return self._proxied.delete(instance) + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + def execute( self, statement: Executable, @@ -557,7 +591,7 @@ class scoped_session: bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result: + ) -> Result[Any]: r"""Execute a SQL expression construct. .. container:: class_bases @@ -1430,8 +1464,103 @@ class scoped_session: return self._proxied.merge(instance, load=load, options=options) + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload def query( - self, *entities: _ColumnsClauseArgument, **kwargs: Any + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any ) -> Query[Any]: r"""Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -1559,6 +1688,30 @@ class scoped_session: return self._proxied.rollback() + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + def scalar( self, statement: Executable, @@ -1590,6 +1743,30 @@ class scoped_session: **kw, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + def scalars( self, statement: Executable, @@ -1848,7 +2025,7 @@ class scoped_session: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Row] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: r"""Return an identity key. |