diff options
Diffstat (limited to 'lib/sqlalchemy/orm/query.py')
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 643 |
1 files changed, 411 insertions, 232 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index a60a167ac..419891708 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -23,18 +23,23 @@ from __future__ import annotations import collections.abc as collections_abc import operator from typing import Any +from typing import Callable +from typing import cast +from typing import Dict from typing import Generic from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import overload from typing import Sequence from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union -from . import exc as orm_exc +from . import attributes from . import interfaces from . import loading from . import util as orm_util @@ -44,7 +49,6 @@ from .context import _column_descriptions from .context import _determine_last_joined_entity from .context import _legacy_filter_by_entity_zero from .context import FromStatement -from .context import LABEL_STYLE_LEGACY_ORM from .context import ORMCompileState from .context import QueryContext from .interfaces import ORMColumnDescription @@ -60,6 +64,8 @@ from .. import sql from .. import util from ..engine import Result from ..engine import Row +from ..event import dispatcher +from ..event import EventTarget from ..sql import coercions from ..sql import expression from ..sql import roles @@ -71,8 +77,10 @@ from ..sql._typing import _TP from ..sql.annotation import SupportsCloneAnnotations from ..sql.base import _entity_namespace_key from ..sql.base import _generative +from ..sql.base import _NoArg from ..sql.base import Executable from ..sql.base import Generative +from ..sql.elements import BooleanClauseList from ..sql.expression import Exists from ..sql.selectable import _MemoizedSelectEntities from ..sql.selectable import _SelectFromElements @@ -81,17 +89,31 @@ from ..sql.selectable import HasHints from ..sql.selectable import HasPrefixes from ..sql.selectable import HasSuffixes from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectLabelStyle from ..util.typing import Literal +from ..util.typing import Self if TYPE_CHECKING: from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InternalEntityType + from .mapper import Mapper + from .path_registry import PathRegistry + from .session import _PKIdentityArgument from .session import Session + from .state import InstanceState + from ..engine.cursor import CursorResult + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.result import FrozenResult from ..engine.result import ScalarResult from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _JoinTargetArgument from ..sql._typing import _MAYBE_ENTITY from ..sql._typing import _no_kw from ..sql._typing import _NOT_ENTITY + from ..sql._typing import _OnClauseArgument from ..sql._typing import _PropagateAttrsType from ..sql._typing import _T0 from ..sql._typing import _T1 @@ -102,18 +124,25 @@ if TYPE_CHECKING: from ..sql._typing import _T6 from ..sql._typing import _T7 from ..sql._typing import _TypedColumnClauseArgument as _TCCA - from ..sql.roles import TypedColumnsClauseRole + from ..sql.base import CacheableOptions + from ..sql.base import ExecutableOption + from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.selectable import _JoinTargetElement from ..sql.selectable import _SetupJoinsElement from ..sql.selectable import Alias + from ..sql.selectable import CTE from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import FromClause from ..sql.selectable import ScalarSelect from ..sql.selectable import Subquery + __all__ = ["Query", "QueryContext"] _T = TypeVar("_T", bound=Any) -SelfQuery = TypeVar("SelfQuery", bound="Query") +SelfQuery = TypeVar("SelfQuery", bound="Query[Any]") @inspection._self_inspects @@ -124,6 +153,7 @@ class Query( HasPrefixes, HasSuffixes, HasHints, + EventTarget, log.Identified, Generative, Executable, @@ -150,40 +180,47 @@ class Query( """ # elements that are in Core and can be cached in the same way - _where_criteria = () - _having_criteria = () + _where_criteria: Tuple[ColumnElement[Any], ...] = () + _having_criteria: Tuple[ColumnElement[Any], ...] = () - _order_by_clauses = () - _group_by_clauses = () - _limit_clause = None - _offset_clause = None + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None - _distinct = False - _distinct_on = () + _distinct: bool = False + _distinct_on: Tuple[ColumnElement[Any], ...] = () - _for_update_arg = None - _correlate = () - _auto_correlate = True - _from_obj = () + _for_update_arg: Optional[ForUpdateArg] = None + _correlate: Tuple[FromClause, ...] = () + _auto_correlate: bool = True + _from_obj: Tuple[FromClause, ...] = () _setup_joins: Tuple[_SetupJoinsElement, ...] = () - _label_style = LABEL_STYLE_LEGACY_ORM + _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM _memoized_select_entities = () - _compile_options = ORMCompileState.default_compile_options + _compile_options: Union[ + Type[CacheableOptions], CacheableOptions + ] = ORMCompileState.default_compile_options + _with_options: Tuple[ExecutableOption, ...] load_options = QueryContext.default_load_options + { "_legacy_uniquing": True } - _params = util.EMPTY_DICT + _params: util.immutabledict[str, Any] = util.EMPTY_DICT # local Query builder state, not needed for # compilation or execution _enable_assertions = True - _statement = None + _statement: Optional[ExecutableReturnsRows] = None + + session: Session + + dispatch: dispatcher[Query[_T]] # mirrors that of ClauseElement, used to propagate the "orm" # plugin as well as the "subject" of the plugin, e.g. the mapper @@ -224,14 +261,23 @@ class Query( """ - self.session = session + # session is usually present. There's one case in subqueryloader + # where it stores a Query without a Session and also there are tests + # for the query(Entity).with_session(session) API which is likely in + # some old recipes, however these are legacy as select() can now be + # used. + self.session = session # type: ignore self._set_entities(entities) - def _set_propagate_attrs(self, values): - self._propagate_attrs = util.immutabledict(values) + def _set_propagate_attrs( + self: SelfQuery, values: Mapping[str, Any] + ) -> SelfQuery: + self._propagate_attrs = util.immutabledict(values) # type: ignore return self - def _set_entities(self, entities): + def _set_entities( + self, entities: Iterable[_ColumnsClauseArgument[Any]] + ) -> None: self._raw_columns = [ coercions.expect( roles.ColumnsClauseRole, @@ -242,15 +288,7 @@ class Query( for ent in util.to_list(entities) ] - @overload - def tuples(self: Query[Row[_TP]]) -> Query[_TP]: - ... - - @overload def tuples(self: Query[_O]) -> Query[Tuple[_O]]: - ... - - def tuples(self) -> Query[Any]: """return a tuple-typed form of this :class:`.Query`. This method invokes the :meth:`.Query.only_return_tuples` @@ -270,29 +308,27 @@ class Query( .. versionadded:: 2.0 """ - return self.only_return_tuples(True) + return self.only_return_tuples(True) # type: ignore - def _entity_from_pre_ent_zero(self): + def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]: if not self._raw_columns: return None ent = self._raw_columns[0] if "parententity" in ent._annotations: - return ent._annotations["parententity"] - elif isinstance(ent, ORMColumnsClauseRole): - return ent.entity + return ent._annotations["parententity"] # type: ignore elif "bundle" in ent._annotations: - return ent._annotations["bundle"] + return ent._annotations["bundle"] # type: ignore else: # label, other SQL expression for element in visitors.iterate(ent): if "parententity" in element._annotations: - return element._annotations["parententity"] + return element._annotations["parententity"] # type: ignore # noqa: E501 else: return None - def _only_full_mapper_zero(self, methname): + def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]: if ( len(self._raw_columns) != 1 or "parententity" not in self._raw_columns[0]._annotations @@ -303,9 +339,11 @@ class Query( "a single mapped class." % methname ) - return self._raw_columns[0]._annotations["parententity"] + return self._raw_columns[0]._annotations["parententity"] # type: ignore # noqa: E501 - def _set_select_from(self, obj, set_base_alias): + def _set_select_from( + self, obj: Iterable[_FromClauseArgument], set_base_alias: bool + ) -> None: fa = [ coercions.expect( roles.StrictFromClauseRole, @@ -320,19 +358,22 @@ class Query( self._from_obj = tuple(fa) @_generative - def _set_lazyload_from(self: SelfQuery, state) -> SelfQuery: + def _set_lazyload_from( + self: SelfQuery, state: InstanceState[Any] + ) -> SelfQuery: self.load_options += {"_lazy_loaded_from": state} return self - def _get_condition(self): - return self._no_criterion_condition( - "get", order_by=False, distinct=False - ) + def _get_condition(self) -> None: + """used by legacy BakedQuery""" + self._no_criterion_condition("get", order_by=False, distinct=False) - def _get_existing_condition(self): + def _get_existing_condition(self) -> None: self._no_criterion_assertion("get", order_by=False, distinct=False) - def _no_criterion_assertion(self, meth, order_by=True, distinct=True): + def _no_criterion_assertion( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: if not self._enable_assertions: return if ( @@ -351,7 +392,9 @@ class Query( "Query with existing criterion. " % meth ) - def _no_criterion_condition(self, meth, order_by=True, distinct=True): + def _no_criterion_condition( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: self._no_criterion_assertion(meth, order_by, distinct) self._from_obj = self._setup_joins = () @@ -362,7 +405,7 @@ class Query( self._order_by_clauses = self._group_by_clauses = () - def _no_clauseelement_condition(self, meth): + def _no_clauseelement_condition(self, meth: str) -> None: if not self._enable_assertions: return if self._order_by_clauses: @@ -372,7 +415,7 @@ class Query( ) self._no_criterion_condition(meth) - def _no_statement_condition(self, meth): + def _no_statement_condition(self, meth: str) -> None: if not self._enable_assertions: return if self._statement is not None: @@ -384,7 +427,7 @@ class Query( % meth ) - def _no_limit_offset(self, meth): + def _no_limit_offset(self, meth: str) -> None: if not self._enable_assertions: return if self._limit_clause is not None or self._offset_clause is not None: @@ -395,21 +438,21 @@ class Query( ) @property - def _has_row_limiting_clause(self): + def _has_row_limiting_clause(self) -> bool: return ( self._limit_clause is not None or self._offset_clause is not None ) def _get_options( - self, - populate_existing=None, - version_check=None, - only_load_props=None, - refresh_state=None, - identity_token=None, - ): - load_options = {} - compile_options = {} + self: SelfQuery, + populate_existing: Optional[bool] = None, + version_check: Optional[bool] = None, + only_load_props: Optional[Sequence[str]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + identity_token: Optional[Any] = None, + ) -> SelfQuery: + load_options: Dict[str, Any] = {} + compile_options: Dict[str, Any] = {} if version_check: load_options["_version_check"] = version_check @@ -430,11 +473,18 @@ class Query( return self - def _clone(self): - return self._generate() + def _clone(self: Self, **kw: Any) -> Self: + return self._generate() # type: ignore + + def _get_select_statement_only(self) -> Select[_T]: + if self._statement is not None: + raise sa_exc.InvalidRequestError( + "Can't call this method on a Query that uses from_statement()" + ) + return cast("Select[_T]", self.statement) @property - def statement(self): + def statement(self) -> Union[Select[_T], FromStatement[_T]]: """The full SELECT statement represented by this Query. The statement by default will not have disambiguating labels @@ -474,14 +524,15 @@ class Query( return stmt - def _final_statement(self, legacy_query_style=True): + def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]: """Return the 'final' SELECT statement for this :class:`.Query`. + This is used by the testing suite only and is fairly inefficient. + This is the Core-only select() that will be rendered by a complete compilation of this query, and is what .statement used to return in 1.3. - This method creates a complete compile state so is fairly expensive. """ @@ -489,9 +540,11 @@ class Query( return q._compile_state( use_legacy_query_style=legacy_query_style - ).statement + ).statement # type: ignore - def _statement_20(self, for_statement=False, use_legacy_query_style=True): + def _statement_20( + self, for_statement: bool = False, use_legacy_query_style: bool = True + ) -> Union[Select[_T], FromStatement[_T]]: # TODO: this event needs to be deprecated, as it currently applies # only to ORM query and occurs at this spot that is now more # or less an artificial spot @@ -500,7 +553,7 @@ class Query( new_query = fn(self) if new_query is not None and new_query is not self: self = new_query - if not fn._bake_ok: + if not fn._bake_ok: # type: ignore self._compile_options += {"_bake_ok": False} compile_options = self._compile_options @@ -509,6 +562,8 @@ class Query( "_use_legacy_query_style": use_legacy_query_style, } + stmt: Union[Select[_T], FromStatement[_T]] + if self._statement is not None: stmt = FromStatement(self._raw_columns, self._statement) stmt.__dict__.update( @@ -541,10 +596,10 @@ class Query( def subquery( self, - name=None, - with_labels=False, - reduce_columns=False, - ): + name: Optional[str] = None, + with_labels: bool = False, + reduce_columns: bool = False, + ) -> Subquery: """Return the full SELECT statement represented by this :class:`_query.Query`, embedded within an :class:`_expression.Alias`. @@ -571,13 +626,21 @@ class Query( if with_labels: q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) - q = q.statement + stmt = q._get_select_statement_only() + + if TYPE_CHECKING: + assert isinstance(stmt, Select) if reduce_columns: - q = q.reduce_columns() - return q.alias(name=name) + stmt = stmt.reduce_columns() + return stmt.subquery(name=name) - def cte(self, name=None, recursive=False, nesting=False): + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: r"""Return the full SELECT statement represented by this :class:`_query.Query` represented as a common table expression (CTE). @@ -632,11 +695,13 @@ class Query( :meth:`_expression.HasCTE.cte` """ - return self.enable_eagerloads(False).statement.cte( - name=name, recursive=recursive, nesting=nesting + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .cte(name=name, recursive=recursive, nesting=nesting) ) - def label(self, name): + def label(self, name: Optional[str]) -> Label[Any]: """Return the full SELECT statement represented by this :class:`_query.Query`, converted to a scalar subquery with a label of the given name. @@ -645,7 +710,11 @@ class Query( """ - return self.enable_eagerloads(False).statement.label(name) + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .label(name) + ) @overload def as_scalar( @@ -704,10 +773,14 @@ class Query( """ - return self.enable_eagerloads(False).statement.scalar_subquery() + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .scalar_subquery() + ) @property - def selectable(self): + def selectable(self) -> Union[Select[_T], FromStatement[_T]]: """Return the :class:`_expression.Select` object emitted by this :class:`_query.Query`. @@ -718,7 +791,7 @@ class Query( """ return self.__clause_element__() - def __clause_element__(self): + def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: return ( self._with_compile_options( _enable_eagerloads=False, _render_for_subquery=True @@ -759,7 +832,7 @@ class Query( return self @property - def is_single_entity(self): + def is_single_entity(self) -> bool: """Indicates if this :class:`_query.Query` returns tuples or single entities. @@ -785,7 +858,7 @@ class Query( ) @_generative - def enable_eagerloads(self: SelfQuery, value) -> SelfQuery: + def enable_eagerloads(self: SelfQuery, value: bool) -> SelfQuery: """Control whether or not eager joins and subqueries are rendered. @@ -804,7 +877,7 @@ class Query( return self @_generative - def _with_compile_options(self: SelfQuery, **opt) -> SelfQuery: + def _with_compile_options(self: SelfQuery, **opt: Any) -> SelfQuery: self._compile_options += opt return self @@ -813,13 +886,15 @@ class Query( alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) " "instead.", ) - def with_labels(self): - return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + def with_labels(self: SelfQuery) -> SelfQuery: + return self.set_label_style( + SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL + ) apply_labels = with_labels @property - def get_label_style(self): + def get_label_style(self) -> SelectLabelStyle: """ Retrieve the current label style. @@ -828,7 +903,7 @@ class Query( """ return self._label_style - def set_label_style(self, style): + def set_label_style(self: SelfQuery, style: SelectLabelStyle) -> SelfQuery: """Apply column labels to the return value of Query.statement. Indicates that this Query's `statement` accessor should return @@ -864,7 +939,7 @@ class Query( return self @_generative - def enable_assertions(self: SelfQuery, value) -> SelfQuery: + def enable_assertions(self: SelfQuery, value: bool) -> SelfQuery: """Control whether assertions are generated. When set to False, the returned Query will @@ -887,7 +962,7 @@ class Query( return self @property - def whereclause(self): + def whereclause(self) -> Optional[ColumnElement[bool]]: """A readonly attribute which returns the current WHERE criterion for this Query. @@ -895,12 +970,12 @@ class Query( criterion has been established. """ - return sql.elements.BooleanClauseList._construct_for_whereclause( + return BooleanClauseList._construct_for_whereclause( self._where_criteria ) @_generative - def _with_current_path(self: SelfQuery, path) -> SelfQuery: + def _with_current_path(self: SelfQuery, path: PathRegistry) -> SelfQuery: """indicate that this query applies to objects loaded within a certain path. @@ -913,7 +988,7 @@ class Query( return self @_generative - def yield_per(self: SelfQuery, count) -> SelfQuery: + def yield_per(self: SelfQuery, count: int) -> SelfQuery: r"""Yield only ``count`` rows at a time. The purpose of this method is when fetching very large result sets @@ -938,7 +1013,7 @@ class Query( ":meth:`_orm.Query.get`", alternative="The method is now available as :meth:`_orm.Session.get`", ) - def get(self, ident): + def get(self, ident: _PKIdentityArgument) -> Optional[Any]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -1022,7 +1097,12 @@ class Query( # it return self._get_impl(ident, loading.load_on_pk_identity) - def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None): + def _get_impl( + self, + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., Any], + identity_token: Optional[Any] = None, + ) -> Optional[Any]: mapper = self._only_full_mapper_zero("get") return self.session._get_impl( mapper, @@ -1036,7 +1116,7 @@ class Query( ) @property - def lazy_loaded_from(self): + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: """An :class:`.InstanceState` that is using this :class:`_query.Query` for a lazy load operation. @@ -1050,14 +1130,17 @@ class Query( :attr:`.ORMExecuteState.lazy_loaded_from` """ - return self.load_options._lazy_loaded_from + return self.load_options._lazy_loaded_from # type: ignore @property - def _current_path(self): - return self._compile_options._current_path + def _current_path(self) -> PathRegistry: + return self._compile_options._current_path # type: ignore @_generative - def correlate(self: SelfQuery, *fromclauses) -> SelfQuery: + def correlate( + self: SelfQuery, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> SelfQuery: """Return a :class:`.Query` construct which will correlate the given FROM clauses to that of an enclosing :class:`.Query` or :func:`~.expression.select`. @@ -1082,13 +1165,13 @@ class Query( if fromclauses and fromclauses[0] in {None, False}: self._correlate = () else: - self._correlate = set(self._correlate).union( + self._correlate = self._correlate + tuple( coercions.expect(roles.FromClauseRole, f) for f in fromclauses ) return self @_generative - def autoflush(self: SelfQuery, setting) -> SelfQuery: + def autoflush(self: SelfQuery, setting: bool) -> SelfQuery: """Return a Query with a specific 'autoflush' setting. As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method @@ -1116,7 +1199,7 @@ class Query( return self @_generative - def _with_invoke_all_eagers(self: SelfQuery, value) -> SelfQuery: + def _with_invoke_all_eagers(self: SelfQuery, value: bool) -> SelfQuery: """Set the 'invoke all eagers' flag which causes joined- and subquery loaders to traverse into already-loaded related objects and collections. @@ -1132,7 +1215,14 @@ class Query( alternative="Use the :func:`_orm.with_parent` standalone construct.", ) @util.preload_module("sqlalchemy.orm.relationships") - def with_parent(self, instance, property=None, from_entity=None): # noqa + def with_parent( + self: SelfQuery, + instance: object, + property: Optional[ # noqa: A002 + attributes.QueryableAttribute[Any] + ] = None, + from_entity: Optional[_ExternalEntityType[Any]] = None, + ) -> SelfQuery: """Add filtering criterion that relates the given instance to a child object or collection, using its attribute state as well as an established :func:`_orm.relationship()` @@ -1150,7 +1240,7 @@ class Query( An instance which has some :func:`_orm.relationship`. :param property: - String property name, or class-bound attribute, which indicates + Class bound attribute which indicates what relationship from the instance should be used to reconcile the parent/child relationship. @@ -1172,21 +1262,27 @@ class Query( for prop in mapper.iterate_properties: if ( isinstance(prop, relationships.Relationship) - and prop.mapper is entity_zero.mapper + and prop.mapper is entity_zero.mapper # type: ignore ): - property = prop # noqa + property = prop # type: ignore # noqa: A001 break else: raise sa_exc.InvalidRequestError( "Could not locate a property which relates instances " "of class '%s' to instances of class '%s'" % ( - entity_zero.mapper.class_.__name__, + entity_zero.mapper.class_.__name__, # type: ignore instance.__class__.__name__, ) ) - return self.filter(with_parent(instance, property, entity_zero.entity)) + return self.filter( + with_parent( + instance, + property, # type: ignore + entity_zero.entity, # type: ignore + ) + ) @_generative def add_entity( @@ -1211,7 +1307,7 @@ class Query( return self @_generative - def with_session(self: SelfQuery, session) -> SelfQuery: + def with_session(self: SelfQuery, session: Session) -> SelfQuery: """Return a :class:`_query.Query` that will use the given :class:`.Session`. @@ -1237,7 +1333,9 @@ class Query( self.session = session return self - def _legacy_from_self(self, *entities): + def _legacy_from_self( + self: SelfQuery, *entities: _ColumnsClauseArgument[Any] + ) -> SelfQuery: # used for query.count() as well as for the same # function in BakedQuery, as well as some old tests in test_baked.py. @@ -1255,13 +1353,13 @@ class Query( return q @_generative - def _set_enable_single_crit(self: SelfQuery, val) -> SelfQuery: + def _set_enable_single_crit(self: SelfQuery, val: bool) -> SelfQuery: self._compile_options += {"_enable_single_crit": val} return self @_generative def _from_selectable( - self: SelfQuery, fromclause, set_entity_from=True + self: SelfQuery, fromclause: FromClause, set_entity_from: bool = True ) -> SelfQuery: for attr in ( "_where_criteria", @@ -1292,7 +1390,7 @@ class Query( "is deprecated and will be removed in a " "future release. Please use :meth:`_query.Query.with_entities`", ) - def values(self, *columns): + def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]: """Return an iterator yielding result tuples corresponding to the given list of columns @@ -1304,7 +1402,7 @@ class Query( q._set_entities(columns) if not q.load_options._yield_per: q.load_options += {"_yield_per": 10} - return iter(q) + return iter(q) # type: ignore _values = values @@ -1315,25 +1413,24 @@ class Query( "future release. Please use :meth:`_query.Query.with_entities` " "in combination with :meth:`_query.Query.scalar`", ) - def value(self, column): + def value(self, column: _ColumnExpressionArgument[Any]) -> Any: """Return a scalar result corresponding to the given column expression. """ try: - return next(self.values(column))[0] + return next(self.values(column))[0] # type: ignore except StopIteration: return None @overload - def with_entities( - self, _entity: _EntityType[_O], **kwargs: Any - ) -> Query[_O]: + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def with_entities( - self, _colexpr: TypedColumnsClauseRole[_T] + self, + _colexpr: roles.TypedColumnsClauseRole[_T], ) -> RowReturningQuery[Tuple[_T]]: ... @@ -1418,14 +1515,14 @@ class Query( @overload def with_entities( - self: SelfQuery, *entities: _ColumnsClauseArgument[Any] - ) -> SelfQuery: + self, *entities: _ColumnsClauseArgument[Any] + ) -> Query[Any]: ... @_generative def with_entities( - self: SelfQuery, *entities: _ColumnsClauseArgument[Any], **__kw: Any - ) -> SelfQuery: + self, *entities: _ColumnsClauseArgument[Any], **__kw: Any + ) -> Query[Any]: r"""Return a new :class:`_query.Query` replacing the SELECT list with the given entities. @@ -1451,12 +1548,18 @@ class Query( """ if __kw: raise _no_kw() - _MemoizedSelectEntities._generate_for_statement(self) + + # Query has all the same fields as Select for this operation + # this could in theory be based on a protocol but not sure if it's + # worth it + _MemoizedSelectEntities._generate_for_statement(self) # type: ignore self._set_entities(entities) return self @_generative - def add_columns(self, *column: _ColumnExpressionArgument) -> Query[Any]: + def add_columns( + self, *column: _ColumnExpressionArgument[Any] + ) -> Query[Any]: """Add one or more column expressions to the list of result columns to be returned.""" @@ -1479,7 +1582,7 @@ class Query( "is deprecated and will be removed in a " "future release. Please use :meth:`_query.Query.add_columns`", ) - def add_column(self, column) -> Query[Any]: + def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]: """Add a column expression to the list of result columns to be returned. @@ -1487,7 +1590,7 @@ class Query( return self.add_columns(column) @_generative - def options(self: SelfQuery, *args) -> SelfQuery: + def options(self: SelfQuery, *args: ExecutableOption) -> SelfQuery: """Return a new :class:`_query.Query` object, applying the given list of mapper options. @@ -1505,18 +1608,21 @@ class Query( opts = tuple(util.flatten_iterator(args)) if self._compile_options._current_path: + # opting for lower method overhead for the checks for opt in opts: - if opt._is_legacy_option: - opt.process_query_conditionally(self) + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query_conditionally(self) # type: ignore else: for opt in opts: - if opt._is_legacy_option: - opt.process_query(self) + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query(self) # type: ignore self._with_options += opts return self - def with_transformation(self, fn): + def with_transformation( + self, fn: Callable[[Query[Any]], Query[Any]] + ) -> Query[Any]: """Return a new :class:`_query.Query` object transformed by the given function. @@ -1535,7 +1641,7 @@ class Query( """ return fn(self) - def get_execution_options(self): + def get_execution_options(self) -> _ImmutableExecuteOptions: """Get the non-SQL options which will take effect during execution. .. versionadded:: 1.3 @@ -1547,7 +1653,7 @@ class Query( return self._execution_options @_generative - def execution_options(self: SelfQuery, **kwargs) -> SelfQuery: + def execution_options(self: SelfQuery, **kwargs: Any) -> SelfQuery: """Set non-SQL options which take effect during execution. Options allowed here include all of those accepted by @@ -1596,11 +1702,17 @@ class Query( @_generative def with_for_update( self: SelfQuery, - read=False, - nowait=False, - of=None, - skip_locked=False, - key_share=False, + *, + nowait: bool = False, + read: bool = False, + of: Optional[ + Union[ + _ColumnExpressionArgument[Any], + Sequence[_ColumnExpressionArgument[Any]], + ] + ] = None, + skip_locked: bool = False, + key_share: bool = False, ) -> SelfQuery: """return a new :class:`_query.Query` with the specified options for the @@ -1659,7 +1771,9 @@ class Query( return self @_generative - def params(self: SelfQuery, *args, **kwargs) -> SelfQuery: + def params( + self: SelfQuery, __params: Optional[Dict[str, Any]] = None, **kw: Any + ) -> SelfQuery: r"""Add values for bind parameters which may have been specified in filter(). @@ -1669,17 +1783,14 @@ class Query( contain unicode keys in which case \**kwargs cannot be used. """ - if len(args) == 1: - kwargs.update(args[0]) - elif len(args) > 0: - raise sa_exc.ArgumentError( - "params() takes zero or one positional argument, " - "which is a dictionary." - ) - self._params = self._params.union(kwargs) + if __params: + kw.update(__params) + self._params = self._params.union(kw) return self - def where(self: SelfQuery, *criterion) -> SelfQuery: + def where( + self: SelfQuery, *criterion: _ColumnExpressionArgument[bool] + ) -> SelfQuery: """A synonym for :meth:`.Query.filter`. .. versionadded:: 1.4 @@ -1716,16 +1827,18 @@ class Query( :meth:`_query.Query.filter_by` - filter on keyword expressions. """ - for criterion in list(criterion): - criterion = coercions.expect( - roles.WhereHavingRole, criterion, apply_propagate_attrs=self + for crit in list(criterion): + crit = coercions.expect( + roles.WhereHavingRole, crit, apply_propagate_attrs=self ) - self._where_criteria += (criterion,) + self._where_criteria += (crit,) return self @util.memoized_property - def _last_joined_entity(self): + def _last_joined_entity( + self, + ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: if self._setup_joins: return _determine_last_joined_entity( self._setup_joins, @@ -1733,7 +1846,7 @@ class Query( else: return None - def _filter_by_zero(self): + def _filter_by_zero(self) -> Any: """for the filter_by() method, return the target entity for which we will attempt to derive an expression from based on string name. @@ -1800,13 +1913,6 @@ class Query( """ from_entity = self._filter_by_zero() - if from_entity is None: - raise sa_exc.InvalidRequestError( - "Can't use filter_by when the first entity '%s' of a query " - "is not a mapped class. Please use the filter method instead, " - "or change the order of the entities in the query" - % self._query_entity_zero() - ) clauses = [ _entity_namespace_key(from_entity, key) == value @@ -1815,9 +1921,12 @@ class Query( return self.filter(*clauses) @_generative - @_assertions(_no_statement_condition, _no_limit_offset) def order_by( - self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + self: SelfQuery, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionArgument[Any], ) -> SelfQuery: """Apply one or more ORDER BY criteria to the query and return the newly resulting :class:`_query.Query`. @@ -1844,20 +1953,27 @@ class Query( """ - if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False): + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("order_by") + + if not clauses and (__first is None or __first is False): self._order_by_clauses = () - else: + elif __first is not _NoArg.NO_ARG: criterion = tuple( coercions.expect(roles.OrderByRole, clause) - for clause in clauses + for clause in (__first,) + clauses ) self._order_by_clauses += criterion + return self @_generative - @_assertions(_no_statement_condition, _no_limit_offset) def group_by( - self: SelfQuery, *clauses: _ColumnExpressionArgument[Any] + self: SelfQuery, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], _ColumnExpressionArgument[Any] + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionArgument[Any], ) -> SelfQuery: """Apply one or more GROUP BY criterion to the query and return the newly resulting :class:`_query.Query`. @@ -1878,12 +1994,15 @@ class Query( """ - if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False): + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("group_by") + + if not clauses and (__first is None or __first is False): self._group_by_clauses = () - else: + elif __first is not _NoArg.NO_ARG: criterion = tuple( coercions.expect(roles.GroupByRole, clause) - for clause in clauses + for clause in (__first,) + clauses ) self._group_by_clauses += criterion return self @@ -1916,8 +2035,9 @@ class Query( self._having_criteria += (having_criteria,) return self - def _set_op(self, expr_fn, *q): - return self._from_selectable(expr_fn(*([self] + list(q))).subquery()) + def _set_op(self: SelfQuery, expr_fn: Any, *q: Query[Any]) -> SelfQuery: + list_of_queries = (self,) + q + return self._from_selectable(expr_fn(*(list_of_queries)).subquery()) def union(self: SelfQuery, *q: Query[Any]) -> SelfQuery: """Produce a UNION of this Query against one or more queries. @@ -2006,7 +2126,12 @@ class Query( @_generative @_assertions(_no_statement_condition, _no_limit_offset) def join( - self: SelfQuery, target, onclause=None, *, isouter=False, full=False + self: SelfQuery, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, ) -> SelfQuery: r"""Create a SQL JOIN against this :class:`_query.Query` object's criterion @@ -2193,20 +2318,23 @@ class Query( """ - target = coercions.expect( + join_target = coercions.expect( roles.JoinTargetRole, target, apply_propagate_attrs=self, legacy=True, ) if onclause is not None: - onclause = coercions.expect( + onclause_element = coercions.expect( roles.OnClauseRole, onclause, legacy=True ) + else: + onclause_element = None + self._setup_joins += ( ( - target, - onclause, + join_target, + onclause_element, None, { "isouter": isouter, @@ -2218,7 +2346,13 @@ class Query( self.__dict__.pop("_last_joined_entity", None) return self - def outerjoin(self, target, onclause=None, *, full=False): + def outerjoin( + self: SelfQuery, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> SelfQuery: """Create a left outer join against this ``Query`` object's criterion and apply generatively, returning the newly resulting ``Query``. @@ -2295,7 +2429,7 @@ class Query( self._set_select_from(from_obj, False) return self - def __getitem__(self, item): + def __getitem__(self, item: Any) -> Any: return orm_util._getitem( self, item, @@ -2303,7 +2437,11 @@ class Query( @_generative @_assertions(_no_statement_condition) - def slice(self: SelfQuery, start, stop) -> SelfQuery: + def slice( + self: SelfQuery, + start: int, + stop: int, + ) -> SelfQuery: """Computes the "slice" of the :class:`_query.Query` represented by the given indices and returns the resulting :class:`_query.Query`. @@ -2341,7 +2479,9 @@ class Query( @_generative @_assertions(_no_statement_condition) - def limit(self: SelfQuery, limit) -> SelfQuery: + def limit( + self: SelfQuery, limit: Union[int, _ColumnExpressionArgument[int]] + ) -> SelfQuery: """Apply a ``LIMIT`` to the query and return the newly resulting ``Query``. @@ -2351,7 +2491,9 @@ class Query( @_generative @_assertions(_no_statement_condition) - def offset(self: SelfQuery, offset) -> SelfQuery: + def offset( + self: SelfQuery, offset: Union[int, _ColumnExpressionArgument[int]] + ) -> SelfQuery: """Apply an ``OFFSET`` to the query and return the newly resulting ``Query``. @@ -2361,7 +2503,9 @@ class Query( @_generative @_assertions(_no_statement_condition) - def distinct(self: SelfQuery, *expr) -> SelfQuery: + def distinct( + self: SelfQuery, *expr: _ColumnExpressionArgument[Any] + ) -> SelfQuery: r"""Apply a ``DISTINCT`` to the query and return the newly resulting ``Query``. @@ -2415,7 +2559,7 @@ class Query( :ref:`faq_query_deduplicating` """ - return self._iter().all() + return self._iter().all() # type: ignore @_generative @_assertions(_no_clauseelement_condition) @@ -2462,9 +2606,9 @@ class Query( """ # replicates limit(1) behavior if self._statement is not None: - return self._iter().first() + return self._iter().first() # type: ignore else: - return self.limit(1)._iter().first() + return self.limit(1)._iter().first() # type: ignore def one_or_none(self) -> Optional[_T]: """Return at most one result or raise an exception. @@ -2490,7 +2634,7 @@ class Query( :meth:`_query.Query.one` """ - return self._iter().one_or_none() + return self._iter().one_or_none() # type: ignore def one(self) -> _T: """Return exactly one result or raise an exception. @@ -2537,18 +2681,18 @@ class Query( if not isinstance(ret, collections_abc.Sequence): return ret return ret[0] - except orm_exc.NoResultFound: + except sa_exc.NoResultFound: return None def __iter__(self) -> Iterable[_T]: - return self._iter().__iter__() + return self._iter().__iter__() # type: ignore def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: # new style execution. params = self._params statement = self._statement_20() - result = self.session.execute( + result: Union[ScalarResult[_T], Result[_T]] = self.session.execute( statement, params, execution_options={"_sa_orm_load_options": self.load_options}, @@ -2556,7 +2700,7 @@ class Query( # legacy: automatically set scalars, unique if result._attributes.get("is_single_entity", False): - result = result.scalars() + result = cast("Result[_T]", result).scalars() if ( result._attributes.get("filtered", False) @@ -2580,7 +2724,7 @@ class Query( return str(statement.compile(bind)) - def _get_bind_args(self, statement, fn, **kw): + def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any: return fn(clause=statement, **kw) @property @@ -2634,7 +2778,11 @@ class Query( return _column_descriptions(self, legacy=True) - def instances(self, result_proxy: Result, context=None) -> Any: + def instances( + self, + result_proxy: CursorResult[Any], + context: Optional[QueryContext] = None, + ) -> Any: """Return an ORM result given a :class:`_engine.CursorResult` and :class:`.QueryContext`. @@ -2661,7 +2809,7 @@ class Query( # legacy: automatically set scalars, unique if result._attributes.get("is_single_entity", False): - result = result.scalars() + result = result.scalars() # type: ignore if result._attributes.get("filtered", False): result = result.unique() @@ -2675,7 +2823,13 @@ class Query( ":func:`_orm.merge_frozen_result` function.", enable_warnings=False, # warnings occur via loading.merge_result ) - def merge_result(self, iterator, load=True): + def merge_result( + self, + iterator: Union[ + FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object] + ], + load: bool = True, + ) -> Union[FrozenResult[Any], Iterable[Any]]: """Merge a result into this :class:`_query.Query` object's Session. Given an iterator returned by a :class:`_query.Query` @@ -2743,7 +2897,8 @@ class Query( self.enable_eagerloads(False) .add_columns(sql.literal_column("1")) .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) - .statement.with_only_columns(1) + ._get_select_statement_only() + .with_only_columns(1) ) ezero = self._entity_from_pre_ent_zero() @@ -2752,7 +2907,7 @@ class Query( return sql.exists(inner) - def count(self): + def count(self) -> int: r"""Return a count of rows this the SQL formed by this :class:`Query` would return. @@ -2806,9 +2961,11 @@ class Query( """ col = sql.func.count(sql.literal_column("*")) - return self._legacy_from_self(col).enable_eagerloads(False).scalar() + return ( # type: ignore + self._legacy_from_self(col).enable_eagerloads(False).scalar() + ) - def delete(self, synchronize_session="evaluate"): + def delete(self, synchronize_session: str = "evaluate") -> int: r"""Perform a DELETE with an arbitrary WHERE clause. Deletes rows matched by this query from the database. @@ -2850,20 +3007,28 @@ class Query( self = bulk_del.query - delete_ = sql.delete(*self._raw_columns) + delete_ = sql.delete(*self._raw_columns) # type: ignore delete_._where_criteria = self._where_criteria - result = self.session.execute( - delete_, - self._params, - execution_options={"synchronize_session": synchronize_session}, + result: CursorResult[Any] = cast( + "CursorResult[Any]", + self.session.execute( + delete_, + self._params, + execution_options={"synchronize_session": synchronize_session}, + ), ) - bulk_del.result = result + bulk_del.result = result # type: ignore self.session.dispatch.after_bulk_delete(bulk_del) result.close() return result.rowcount - def update(self, values, synchronize_session="evaluate", update_args=None): + def update( + self, + values: Dict[_DMLColumnArgument, Any], + synchronize_session: str = "evaluate", + update_args: Optional[Dict[Any, Any]] = None, + ) -> int: r"""Perform an UPDATE with an arbitrary WHERE clause. Updates rows matched by this query in the database. @@ -2926,28 +3091,33 @@ class Query( bulk_ud.query = new_query self = bulk_ud.query - upd = sql.update(*self._raw_columns) + upd = sql.update(*self._raw_columns) # type: ignore ppo = update_args.pop("preserve_parameter_order", False) if ppo: - upd = upd.ordered_values(*values) + upd = upd.ordered_values(*values) # type: ignore else: upd = upd.values(values) if update_args: upd = upd.with_dialect_options(**update_args) upd._where_criteria = self._where_criteria - result = self.session.execute( - upd, - self._params, - execution_options={"synchronize_session": synchronize_session}, + result: CursorResult[Any] = cast( + "CursorResult[Any]", + self.session.execute( + upd, + self._params, + execution_options={"synchronize_session": synchronize_session}, + ), ) - bulk_ud.result = result + bulk_ud.result = result # type: ignore self.session.dispatch.after_bulk_update(bulk_ud) result.close() return result.rowcount - def _compile_state(self, for_statement=False, **kw): + def _compile_state( + self, for_statement: bool = False, **kw: Any + ) -> ORMCompileState: """Create an out-of-compiler ORMCompileState object. The ORMCompileState object is normally created directly as a result @@ -2971,13 +3141,14 @@ class Query( # ORMSelectCompileState. We could also base this on # query._statement is not None as we have the ORM Query here # however this is the more general path. - compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( - stmt, "orm" + compile_state_cls = cast( + ORMCompileState, + ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), ) return compile_state_cls.create_for_statement(stmt, None) - def _compile_context(self, for_statement=False): + def _compile_context(self, for_statement: bool = False) -> QueryContext: compile_state = self._compile_state(for_statement=for_statement) context = QueryContext( compile_state, @@ -3006,7 +3177,7 @@ class AliasOption(interfaces.LoaderOption): """ - def process_compile_state(self, compile_state: ORMCompileState): + def process_compile_state(self, compile_state: ORMCompileState) -> None: pass @@ -3017,12 +3188,12 @@ class BulkUD: """ - def __init__(self, query): + def __init__(self, query: Query[Any]): self.query = query.enable_eagerloads(False) self._validate_query_state() self.mapper = self.query._entity_from_pre_ent_zero() - def _validate_query_state(self): + def _validate_query_state(self) -> None: for attr, methname, notset, op in ( ("_limit_clause", "limit()", None, operator.is_), ("_offset_clause", "offset()", None, operator.is_), @@ -3049,14 +3220,19 @@ class BulkUD: ) @property - def session(self): + def session(self) -> Session: return self.query.session class BulkUpdate(BulkUD): """BulkUD which handles UPDATEs.""" - def __init__(self, query, values, update_kwargs): + def __init__( + self, + query: Query[Any], + values: Dict[_DMLColumnArgument, Any], + update_kwargs: Optional[Dict[Any, Any]], + ): super(BulkUpdate, self).__init__(query) self.values = values self.update_kwargs = update_kwargs @@ -3067,4 +3243,7 @@ class BulkDelete(BulkUD): class RowReturningQuery(Query[Row[_TP]]): - pass + if TYPE_CHECKING: + + def tuples(self) -> Query[_TP]: # type: ignore + ... |