summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py211
1 files changed, 200 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 74035ec0a..263d56101 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -27,6 +27,7 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -85,12 +86,16 @@ from ..util.typing import Literal
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InstanceDict
+ from ._typing import _O
+ from .context import FromStatement
from .interfaces import ORMOption
from .interfaces import UserDefinedOption
from .mapper import Mapper
from .path_registry import PathRegistry
+ from .query import RowReturningQuery
from ..engine import Result
from ..engine import Row
from ..engine import RowMapping
@@ -104,10 +109,23 @@ if typing.TYPE_CHECKING:
from ..event import _InstanceLevelDispatch
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
from ..sql.elements import ClauseElement
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.schema import Table
- from ..sql.selectable import TableClause
+ from ..sql.selectable import Select
+ from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
__all__ = [
"Session",
@@ -189,7 +207,7 @@ class _SessionClassMethods:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Union[Row, RowMapping]] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
"""Return an identity key.
@@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots):
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
bind_arguments: Optional[_BindArguments] = None,
- ) -> Result:
+ ) -> Result[Any]:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
already proceeded.
@@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
_scalar_result: bool = ...,
- ) -> Result:
+ ) -> Result[Any]:
...
def _execute_internal(
@@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget):
)
for idx, fn in enumerate(events_todo):
orm_exec_state._starting_event_idx = idx
- fn_result: Optional[Result] = fn(orm_exec_state)
+ fn_result: Optional[Result[Any]] = fn(orm_exec_state)
if fn_result:
if _scalar_result:
return fn_result.scalar()
@@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget):
if _scalar_result and not compile_state_cls:
if TYPE_CHECKING:
params = cast(_CoreSingleExecuteParams, params)
- return conn.scalar(statement, params or {}, execution_options)
+ return conn.scalar(
+ statement, params or {}, execution_options=execution_options
+ )
- result: Result = conn.execute(
- statement, params or {}, execution_options
+ result: Result[Any] = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if compile_state_cls:
@@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget):
else:
return result
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
def execute(
self,
statement: Executable,
@@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a SQL expression construct.
Returns a :class:`_engine.Result` object representing
@@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget):
_add_event=_add_event,
)
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
@@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget):
**kw,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
def scalars(
self,
statement: Executable,
@@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget):
f'{", ".join(context)} or this Session.'
)
+ @overload
+ def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+ ...
+
+ @overload
+ def query(
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.query
+
+ @overload
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+ ) -> Query[Any]:
+ ...
+
def query(
- self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> Query[Any]:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget):
with_for_update = ForUpdateArg._from_argument(with_for_update)
- stmt = sql.select(object_mapper(instance))
+ stmt: Select[Any] = sql.select(object_mapper(instance))
if (
loading.load_on_ident(
self,