diff options
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 211 |
1 files changed, 200 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 74035ec0a..263d56101 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -27,6 +27,7 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union import weakref @@ -85,12 +86,16 @@ from ..util.typing import Literal from ..util.typing import Protocol if typing.TYPE_CHECKING: + from ._typing import _EntityType from ._typing import _IdentityKeyType from ._typing import _InstanceDict + from ._typing import _O + from .context import FromStatement from .interfaces import ORMOption from .interfaces import UserDefinedOption from .mapper import Mapper from .path_registry import PathRegistry + from .query import RowReturningQuery from ..engine import Result from ..engine import Row from ..engine import RowMapping @@ -104,10 +109,23 @@ if typing.TYPE_CHECKING: from ..event import _InstanceLevelDispatch from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _InfoType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import Executable from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole from ..sql.schema import Table - from ..sql.selectable import TableClause + from ..sql.selectable import Select + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) __all__ = [ "Session", @@ -189,7 +207,7 @@ class _SessionClassMethods: ident: Union[Any, Tuple[Any, ...]] = None, *, instance: Optional[Any] = None, - row: Optional[Union[Row, RowMapping]] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, identity_token: Optional[Any] = None, ) -> _IdentityKeyType[Any]: """Return an identity key. @@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots): params: Optional[_CoreAnyExecuteParams] = None, execution_options: Optional[_ExecuteOptionsParameter] = None, bind_arguments: Optional[_BindArguments] = None, - ) -> Result: + ) -> Result[Any]: """Execute the statement represented by this :class:`.ORMExecuteState`, without re-invoking events that have already proceeded. @@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result: + ) -> Result[Any]: ... def _execute_internal( @@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget): ) for idx, fn in enumerate(events_todo): orm_exec_state._starting_event_idx = idx - fn_result: Optional[Result] = fn(orm_exec_state) + fn_result: Optional[Result[Any]] = fn(orm_exec_state) if fn_result: if _scalar_result: return fn_result.scalar() @@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget): if _scalar_result and not compile_state_cls: if TYPE_CHECKING: params = cast(_CoreSingleExecuteParams, params) - return conn.scalar(statement, params or {}, execution_options) + return conn.scalar( + statement, params or {}, execution_options=execution_options + ) - result: Result = conn.execute( - statement, params or {}, execution_options + result: Result[Any] = conn.execute( + statement, params or {}, execution_options=execution_options ) if compile_state_cls: @@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget): else: return result + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: + ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + ... + def execute( self, statement: Executable, @@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result: + ) -> Result[Any]: r"""Execute a SQL expression construct. Returns a :class:`_engine.Result` object representing @@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget): _add_event=_add_event, ) + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: + ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + ... + def scalar( self, statement: Executable, @@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget): **kw, ) + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: + ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + ... + def scalars( self, statement: Executable, @@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget): f'{", ".join(context)} or this Session.' ) + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: + ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + ... + def query( - self, *entities: _ColumnsClauseArgument, **kwargs: Any + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any ) -> Query[Any]: """Return a new :class:`_query.Query` object corresponding to this :class:`_orm.Session`. @@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update = ForUpdateArg._from_argument(with_for_update) - stmt = sql.select(object_mapper(instance)) + stmt: Select[Any] = sql.select(object_mapper(instance)) if ( loading.load_on_ident( self, |