summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-04-19 21:06:41 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-04-27 14:46:36 -0400
commitad11c482e2233f44e8747d4d5a2b17a995fff1fa (patch)
tree57f8ddd30928951519fd6ac0f418e9cbf8e65610 /lib/sqlalchemy/orm/session.py
parent033d1a16e7a220555d7611a5b8cacb1bd83822ae (diff)
downloadsqlalchemy-ad11c482e2233f44e8747d4d5a2b17a995fff1fa.tar.gz
pep484 ORM / SQL result support
after some experimentation it seems mypy is more amenable to the generic types being fully integrated rather than having separate spin-off types. so key structures like Result, Row, Select become generic. For DML Insert, Update, Delete, these are spun into type-specific subclasses ReturningInsert, ReturningUpdate, ReturningDelete, which is fine since the "row-ness" of these constructs doesn't happen until returning() is called in any case. a Tuple based model is then integrated so that these objects can carry along information about their return types. Overloads at the .execute() level carry through the Tuple from the invoked object to the result. To suit the issue of AliasedClass generating attributes that are dynamic, experimented with a custom subclass AsAliased, but then just settled on having aliased() lie to the type checker and return `Type[_O]`, essentially. will need some type-related accessors for with_polymorphic() also. Additionally, identified an issue in Update when used "mysql style" against a join(), it basically doesn't work if asked to UPDATE two tables on the same column name. added an error message to the specific condition where it happens with a very non-specific error message that we hit a thing we can't do right now, suggest multi-table update as a possible cause. Change-Id: I5eff7eefe1d6166ee74160b2785c5e6a81fa8b95
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py211
1 files changed, 200 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 74035ec0a..263d56101 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -27,6 +27,7 @@ from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
import weakref
@@ -85,12 +86,16 @@ from ..util.typing import Literal
from ..util.typing import Protocol
if typing.TYPE_CHECKING:
+ from ._typing import _EntityType
from ._typing import _IdentityKeyType
from ._typing import _InstanceDict
+ from ._typing import _O
+ from .context import FromStatement
from .interfaces import ORMOption
from .interfaces import UserDefinedOption
from .mapper import Mapper
from .path_registry import PathRegistry
+ from .query import RowReturningQuery
from ..engine import Result
from ..engine import Row
from ..engine import RowMapping
@@ -104,10 +109,23 @@ if typing.TYPE_CHECKING:
from ..event import _InstanceLevelDispatch
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _InfoType
+ from ..sql._typing import _T0
+ from ..sql._typing import _T1
+ from ..sql._typing import _T2
+ from ..sql._typing import _T3
+ from ..sql._typing import _T4
+ from ..sql._typing import _T5
+ from ..sql._typing import _T6
+ from ..sql._typing import _T7
+ from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import Executable
from ..sql.elements import ClauseElement
+ from ..sql.roles import TypedColumnsClauseRole
from ..sql.schema import Table
- from ..sql.selectable import TableClause
+ from ..sql.selectable import Select
+ from ..sql.selectable import TypedReturnsRows
+
+_T = TypeVar("_T", bound=Any)
__all__ = [
"Session",
@@ -189,7 +207,7 @@ class _SessionClassMethods:
ident: Union[Any, Tuple[Any, ...]] = None,
*,
instance: Optional[Any] = None,
- row: Optional[Union[Row, RowMapping]] = None,
+ row: Optional[Union[Row[Any], RowMapping]] = None,
identity_token: Optional[Any] = None,
) -> _IdentityKeyType[Any]:
"""Return an identity key.
@@ -295,7 +313,7 @@ class ORMExecuteState(util.MemoizedSlots):
params: Optional[_CoreAnyExecuteParams] = None,
execution_options: Optional[_ExecuteOptionsParameter] = None,
bind_arguments: Optional[_BindArguments] = None,
- ) -> Result:
+ ) -> Result[Any]:
"""Execute the statement represented by this
:class:`.ORMExecuteState`, without re-invoking events that have
already proceeded.
@@ -1718,7 +1736,7 @@ class Session(_SessionClassMethods, EventTarget):
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
_scalar_result: bool = ...,
- ) -> Result:
+ ) -> Result[Any]:
...
def _execute_internal(
@@ -1789,7 +1807,7 @@ class Session(_SessionClassMethods, EventTarget):
)
for idx, fn in enumerate(events_todo):
orm_exec_state._starting_event_idx = idx
- fn_result: Optional[Result] = fn(orm_exec_state)
+ fn_result: Optional[Result[Any]] = fn(orm_exec_state)
if fn_result:
if _scalar_result:
return fn_result.scalar()
@@ -1806,10 +1824,12 @@ class Session(_SessionClassMethods, EventTarget):
if _scalar_result and not compile_state_cls:
if TYPE_CHECKING:
params = cast(_CoreSingleExecuteParams, params)
- return conn.scalar(statement, params or {}, execution_options)
+ return conn.scalar(
+ statement, params or {}, execution_options=execution_options
+ )
- result: Result = conn.execute(
- statement, params or {}, execution_options
+ result: Result[Any] = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if compile_state_cls:
@@ -1827,6 +1847,32 @@ class Session(_SessionClassMethods, EventTarget):
else:
return result
+ @overload
+ def execute(
+ self,
+ statement: TypedReturnsRows[_T],
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[_T]:
+ ...
+
+ @overload
+ def execute(
+ self,
+ statement: Executable,
+ params: Optional[_CoreAnyExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ _parent_execute_state: Optional[Any] = None,
+ _add_event: Optional[Any] = None,
+ ) -> Result[Any]:
+ ...
+
def execute(
self,
statement: Executable,
@@ -1836,7 +1882,7 @@ class Session(_SessionClassMethods, EventTarget):
bind_arguments: Optional[_BindArguments] = None,
_parent_execute_state: Optional[Any] = None,
_add_event: Optional[Any] = None,
- ) -> Result:
+ ) -> Result[Any]:
r"""Execute a SQL expression construct.
Returns a :class:`_engine.Result` object representing
@@ -1897,6 +1943,30 @@ class Session(_SessionClassMethods, EventTarget):
_add_event=_add_event,
)
+ @overload
+ def scalar(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Optional[_T]:
+ ...
+
+ @overload
+ def scalar(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> Any:
+ ...
+
def scalar(
self,
statement: Executable,
@@ -1923,6 +1993,30 @@ class Session(_SessionClassMethods, EventTarget):
**kw,
)
+ @overload
+ def scalars(
+ self,
+ statement: TypedReturnsRows[Tuple[_T]],
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[_T]:
+ ...
+
+ @overload
+ def scalars(
+ self,
+ statement: Executable,
+ params: Optional[_CoreSingleExecuteParams] = None,
+ *,
+ execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
+ bind_arguments: Optional[_BindArguments] = None,
+ **kw: Any,
+ ) -> ScalarResult[Any]:
+ ...
+
def scalars(
self,
statement: Executable,
@@ -2284,8 +2378,103 @@ class Session(_SessionClassMethods, EventTarget):
f'{", ".join(context)} or this Session.'
)
+ @overload
+ def query(self, _entity: _EntityType[_O]) -> Query[_O]:
+ ...
+
+ @overload
+ def query(
+ self, _colexpr: TypedColumnsClauseRole[_T]
+ ) -> RowReturningQuery[Tuple[_T]]:
+ ...
+
+ # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8
+
+ # code within this block is **programmatically,
+ # statically generated** by tools/generate_tuple_map_overloads.py
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+ ) -> RowReturningQuery[Tuple[_T0, _T1]]:
+ ...
+
+ @overload
+ def query(
+ self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]:
+ ...
+
+ @overload
+ def query(
+ self,
+ __ent0: _TCCA[_T0],
+ __ent1: _TCCA[_T1],
+ __ent2: _TCCA[_T2],
+ __ent3: _TCCA[_T3],
+ __ent4: _TCCA[_T4],
+ __ent5: _TCCA[_T5],
+ __ent6: _TCCA[_T6],
+ __ent7: _TCCA[_T7],
+ ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]:
+ ...
+
+ # END OVERLOADED FUNCTIONS self.query
+
+ @overload
+ def query(
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
+ ) -> Query[Any]:
+ ...
+
def query(
- self, *entities: _ColumnsClauseArgument, **kwargs: Any
+ self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any
) -> Query[Any]:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
@@ -2486,7 +2675,7 @@ class Session(_SessionClassMethods, EventTarget):
with_for_update = ForUpdateArg._from_argument(with_for_update)
- stmt = sql.select(object_mapper(instance))
+ stmt: Select[Any] = sql.select(object_mapper(instance))
if (
loading.load_on_ident(
self,