diff options
23 files changed, 1731 insertions, 100 deletions
diff --git a/doc/build/changelog/unreleased_14/4472.rst b/doc/build/changelog/unreleased_14/4472.rst new file mode 100644 index 000000000..6de5058c1 --- /dev/null +++ b/doc/build/changelog/unreleased_14/4472.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: feature, orm + :tickets: 4472 + + Added the ability to add arbitrary criteria to the ON clause generated + by a relationship attribute in a query, which applies to methods such + as :meth:`_query.Query.join` as well as loader options like + :func:`_orm.joinedload`. Additionally, a "global" version of the option + allows limiting criteria to be applied to particular entities in + a query globally. + + .. seealso:: + + :ref:`loader_option_criteria` + + :func:`_orm.with_loader_criteria` + + .. TODO: add links to new examples section and session-related + documentation involving do_orm_execute event when merged
\ No newline at end of file diff --git a/doc/build/orm/loading_relationships.rst b/doc/build/orm/loading_relationships.rst index 50d3cc51a..8909d9a6e 100644 --- a/doc/build/orm/loading_relationships.rst +++ b/doc/build/orm/loading_relationships.rst @@ -112,13 +112,10 @@ the string name of an attribute against a parent, or for greater specificity can accommodate a class-bound attribute directly:: # set children to load lazily - session.query(Parent).options(lazyload('children')).all() - - # same, using class-bound attribute session.query(Parent).options(lazyload(Parent.children)).all() # set children to load eagerly with a join - session.query(Parent).options(joinedload('children')).all() + session.query(Parent).options(joinedload(Parent.children)).all() The loader options can also be "chained" using **method chaining** to specify how loading should occur further levels deep:: @@ -141,6 +138,48 @@ collections loaded. When the ``children`` collection on a particular objects, but additionally apply eager loading to the ``subelements`` collection on each member of ``children``. +The above examples, using :class:`_orm.Query`, are now referred to as +:term:`1.x style` queries. The options system is available as well for +:term:`2.0 style` queries using the :meth:`_sql.Select.options` method:: + + stmt = select(Parent).options( + lazyload(Parent.children). + subqueryload(Child.subelements)) + + result = session.execute(stmt) + +Under the hood, :class:`_orm.Query` is ultimately using the above +:class:`_sql.select` based mechanism. + + +.. _loader_option_criteria: + +Adding Criteria to loader options +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The relationship attributes used to indicate loader options include the +ability to add additional filtering criteria to the ON clause of the join +that's created, or to the WHERE criteria involved, depending on the loader +strategy. This can be achieved using the :meth:`.PropComparator.and_` +method which will pass through an option such that loaded results are limited +to the given filter criteria:: + + session.query(A).options(lazyload(A.bs.and_(B.id > 5))) + +When using limiting criteria, if a particular collection is already loaded +it won't be refreshed; to ensure the new criteria takes place, apply +the :meth:`_orm.Query.populate_existing` option:: + + session.query(A).options(lazyload(A.bs.and_(B.id > 5))).populate_existing() + +In order to add filtering criteria to all occurrences of an entity throughout +a query, regardless of loader strategy or where it occurs in the loading +process, see the :func:`_orm.with_loader_criteria` function. + +.. versionadded:: 1.4 + +Specifying Sub-Options with Load.options() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Using method chaining, the loader style of each link in the path is explicitly stated. To navigate along a path without changing the existing loader style of a particular attribute, the :func:`.defaultload` method/function may be used:: @@ -1263,6 +1302,7 @@ Relationship Loader API .. autofunction:: lazyload .. autoclass:: Load + :members: .. autofunction:: noload diff --git a/doc/build/orm/query.rst b/doc/build/orm/query.rst index 3fddd6c34..ed45a65e7 100644 --- a/doc/build/orm/query.rst +++ b/doc/build/orm/query.rst @@ -44,6 +44,8 @@ ORM-Specific Query Constructs .. autoclass:: sqlalchemy.orm.strategy_options.Load :members: +.. autofunction:: sqlalchemy.orm.with_loader_criteria + .. autofunction:: join .. autofunction:: outerjoin diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 53545826b..fe9bbaf02 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -207,7 +207,6 @@ class ShardedSession(Session): def execute_and_instances(orm_context): - if orm_context.is_select: load_options = active_options = orm_context.load_options update_options = None @@ -237,8 +236,8 @@ def execute_and_instances(orm_context): if active_options._refresh_identity_token is not None: shard_id = active_options._refresh_identity_token - elif "_sa_shard_id" in orm_context.merged_execution_options: - shard_id = orm_context.merged_execution_options["_sa_shard_id"] + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: shard_id = orm_context.bind_arguments["shard_id"] else: diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 32ec60322..458103838 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -48,6 +48,7 @@ from .strategy_options import Load # noqa from .util import aliased # noqa from .util import Bundle # noqa from .util import join # noqa +from .util import LoaderCriteriaOption # noqa from .util import object_mapper # noqa from .util import outerjoin # noqa from .util import polymorphic_union # noqa @@ -101,6 +102,8 @@ def create_session(bind=None, **kwargs): return Session(bind=bind, **kwargs) +with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm") + relationship = public_factory(RelationshipProperty, ".orm.relationship") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6dd95a5a9..2e1b9dc75 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -50,6 +50,7 @@ from .. import inspection from .. import util from ..sql import base as sql_base from ..sql import roles +from ..sql import traversals from ..sql import visitors @@ -58,6 +59,7 @@ class QueryableAttribute( interfaces._MappedAttribute, interfaces.InspectionAttr, interfaces.PropComparator, + traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, sql_base.Immutable, @@ -91,6 +93,7 @@ class QueryableAttribute( impl=None, comparator=None, of_type=None, + extra_criteria=(), ): self.class_ = class_ self.key = key @@ -98,6 +101,7 @@ class QueryableAttribute( self.impl = impl self.comparator = comparator self._of_type = of_type + self._extra_criteria = extra_criteria manager = manager_of_class(class_) # manager is None in the case of AliasedClass @@ -114,6 +118,7 @@ class QueryableAttribute( ("key", visitors.ExtendedInternalTraversal.dp_string), ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ] def __reduce__(self): @@ -240,6 +245,29 @@ class QueryableAttribute( impl=self.impl, comparator=self.comparator.of_type(entity), of_type=inspection.inspect(entity), + extra_criteria=self._extra_criteria, + ) + + def and_(self, *other): + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator.and_(*other), + of_type=self._of_type, + extra_criteria=self._extra_criteria + other, + ) + + def _clone(self, **kw): + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator, + of_type=self._of_type, + extra_criteria=self._extra_criteria, ) def label(self, name): diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 96725e55b..a35b2f9fd 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -11,9 +11,9 @@ from .base import _is_aliased_class from .interfaces import ORMColumnsClauseRole from .path_registry import PathRegistry from .util import _entity_corresponds_to +from .util import _ORMJoin from .util import aliased from .util import Bundle -from .util import join as orm_join from .util import ORMAdapter from .. import exc as sa_exc from .. import future @@ -78,7 +78,6 @@ class QueryContext(object): _yield_per = None _refresh_state = None _lazy_loaded_from = None - _params = _EMPTY_DICT def __init__( self, @@ -308,6 +307,9 @@ class ORMFromStatementCompileState(ORMCompileState): multi_row_eager_loaders = False compound_eager_adapter = None + extra_criteria_entities = _EMPTY_DICT + eager_joins = _EMPTY_DICT + @classmethod def create_for_statement(cls, statement_container, compiler, **kw): @@ -338,6 +340,7 @@ class ORMFromStatementCompileState(ORMCompileState): if toplevel and statement_container._with_options: self.attributes = {"_unbound_load_dedupes": set()} + self.global_attributes = compiler._global_attributes for opt in statement_container._with_options: if opt._is_compile_state: @@ -345,6 +348,7 @@ class ORMFromStatementCompileState(ORMCompileState): else: self.attributes = {} + self.global_attributes = compiler._global_attributes if statement_container._with_context_options: for fn, key in statement_container._with_context_options: @@ -352,8 +356,6 @@ class ORMFromStatementCompileState(ORMCompileState): self.primary_columns = [] self.secondary_columns = [] - self.eager_joins = {} - self.single_inh_entities = {} self.create_eager_joins = [] self._fallback_from_clauses = [] @@ -423,11 +425,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def create_for_statement(cls, statement, compiler, **kw): """compiler hook, we arrive here from compiler.visit_select() only.""" + self = cls.__new__(cls) + if compiler is not None: toplevel = not compiler.stack compiler._rewrites_selected_columns = True + self.global_attributes = compiler._global_attributes else: toplevel = True + self.global_attributes = {} select_statement = statement @@ -437,8 +443,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement._compile_options ) - self = cls.__new__(cls) - if select_statement._execution_options: # execution options should not impact the compilation of a # query, and at the moment subqueryloader is putting some things @@ -516,7 +520,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.primary_columns = [] self.secondary_columns = [] self.eager_joins = {} - self.single_inh_entities = {} + self.extra_criteria_entities = {} self.create_eager_joins = [] self._fallback_from_clauses = [] @@ -634,7 +638,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): if self.compile_options._enable_single_crit: - self._adjust_for_single_inheritance() + self._adjust_for_extra_criteria() if not self.primary_columns: if self.compile_options._only_load_props: @@ -1408,6 +1412,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState): left, right, onclause, prop, create_aliases, aliased_generation ) + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + if replace_from_obj_index is not None: # splice into an existing element in the # self._from_obj list @@ -1416,12 +1425,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.from_clauses = ( self.from_clauses[:replace_from_obj_index] + [ - orm_join( + _ORMJoin( left_clause, right, onclause, isouter=outerjoin, full=full, + _extra_criteria=extra_criteria, ) ] + self.from_clauses[replace_from_obj_index + 1 :] @@ -1440,8 +1450,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState): left_clause = left self.from_clauses = self.from_clauses + [ - orm_join( - left_clause, r_info, onclause, isouter=outerjoin, full=full + _ORMJoin( + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, ) ] @@ -1848,8 +1863,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState): or kwargs.get("group_by", False) ) - def _adjust_for_single_inheritance(self): - """Apply single-table-inheritance filtering. + def _get_extra_criteria(self, ext_info): + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in self.global_attributes: + return tuple( + ae._resolve_where_criteria(ext_info) + for ae in self.global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if ae.include_aliases or ae.entity is ext_info + ) + else: + return () + + def _adjust_for_extra_criteria(self): + """Apply extra criteria filtering. For all distinct single-table-inheritance mappers represented in the columns clause of this query, as well as the "select from entity", @@ -1857,38 +1887,50 @@ class ORMSelectCompileState(ORMCompileState, SelectState): clause of the given QueryContext such that only the appropriate subtypes are selected from the total results. + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + associated with the global context. + """ for fromclause in self.from_clauses: ext_info = fromclause._annotations.get("parententity", None) if ( ext_info - and ext_info.mapper._single_table_criterion is not None - and ext_info not in self.single_inh_entities + and ( + ext_info.mapper._single_table_criterion is not None + or ("additional_entity_criteria", ext_info.mapper) + in self.global_attributes + ) + and ext_info not in self.extra_criteria_entities ): - self.single_inh_entities[ext_info] = ( + self.extra_criteria_entities[ext_info] = ( ext_info, ext_info._adapter if ext_info.is_aliased_class else None, ) - search = set(self.single_inh_entities.values()) + search = set(self.extra_criteria_entities.values()) for (ext_info, adapter) in search: if ext_info in self._join_entities: continue + single_crit = ext_info.mapper._single_table_criterion + + additional_entity_criteria = self._get_extra_criteria(ext_info) + if single_crit is not None: + additional_entity_criteria += (single_crit,) + + current_adapter = self._get_current_adapter() + for crit in additional_entity_criteria: if adapter: - single_crit = adapter.traverse(single_crit) + crit = adapter.traverse(crit) - current_adapter = self._get_current_adapter() if current_adapter: - single_crit = sql_util._deep_annotate( - single_crit, {"_orm_adapt": True} - ) - single_crit = current_adapter(single_crit, False) - self._where_criteria += (single_crit,) + crit = sql_util._deep_annotate(crit, {"_orm_adapt": True}) + crit = current_adapter(crit, False) + self._where_criteria += (crit,) def _column_descriptions(query_or_select_stmt, compile_state=None): @@ -2205,9 +2247,13 @@ class _MapperEntity(_QueryEntity): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion - if single_table_crit is not None: + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): ext_info = self.entity_zero - compile_state.single_inh_entities[ext_info] = ( + compile_state.extra_criteria_entities[ext_info] = ( ext_info, ext_info._adapter if ext_info.is_aliased_class else None, ) @@ -2528,8 +2574,13 @@ class _ORMColumnEntity(_ColumnEntity): ezero = self.entity_zero single_table_crit = self.mapper._single_table_criterion - if single_table_crit is not None: - compile_state.single_inh_entities[ezero] = ( + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): + + compile_state.extra_criteria_entities[ezero] = ( ezero, ezero._adapter if ezero.is_aliased_class else None, ) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 4cf820ae3..068c85073 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -480,6 +480,32 @@ class PropComparator(operators.ColumnOperators): return self.operate(PropComparator.of_type_op, class_) + def and_(self, *criteria): + """Add additional criteria to the ON clause that's represented by this + relationship attribute. + + E.g.:: + + + stmt = select(User).join( + User.addresses.and_(Address.email_address != 'foo') + ) + + stmt = select(User).options( + joinedload(User.addresses.and_(Address.email_address != 'foo')) + ) + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`loader_option_criteria` + + :func:`.with_loader_criteria` + + """ + return self.operate(operators.and_, *criteria) + def any(self, criterion=None, **kwargs): r"""Return true if this collection contains any member that meets the given criterion. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d60c03bdc..68ca0365b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1997,6 +1997,20 @@ class Query( filter(a1.email_address == 'ed@foo.com').\ filter(a2.email_address == 'ed@bar.com') + **Augmenting Built-in ON Clauses** + + As a substitute for providing a full custom ON condition for an + existing relationship, the :meth:`_orm.PropComparator.and_` function + may be applied to a relationship attribute to augment additional + criteria into the ON clause; the additional criteria will be combined + with the default criteria using AND:: + + q = session.query(User).join( + User.addresses.and_(Address.email_address != 'foo@bar.com') + ) + + .. versionadded:: 1.4 + **Joining to Tables and Subqueries** diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index cb490b7d7..794b9422c 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1115,9 +1115,15 @@ class RelationshipProperty(StrategizedProperty): """ _of_type = None + _extra_criteria = () def __init__( - self, prop, parentmapper, adapt_to_entity=None, of_type=None + self, + prop, + parentmapper, + adapt_to_entity=None, + of_type=None, + extra_criteria=(), ): """Construction of :class:`.RelationshipProperty.Comparator` is internal to the ORM's attribute mechanics. @@ -1128,6 +1134,7 @@ class RelationshipProperty(StrategizedProperty): self._adapt_to_entity = adapt_to_entity if of_type: self._of_type = of_type + self._extra_criteria = extra_criteria def adapt_to_entity(self, adapt_to_entity): return self.__class__( @@ -1191,6 +1198,7 @@ class RelationshipProperty(StrategizedProperty): source_polymorphic=True, of_type_entity=of_type_entity, alias_secondary=True, + extra_criteria=self._extra_criteria, ) if sj is not None: return pj & sj @@ -1202,12 +1210,30 @@ class RelationshipProperty(StrategizedProperty): See :meth:`.PropComparator.of_type` for an example. + """ return RelationshipProperty.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, of_type=cls, + extra_criteria=self._extra_criteria, + ) + + def and_(self, *other): + """Add AND criteria. + + See :meth:`.PropComparator.and_` for an example. + + .. versionadded:: 1.4 + + """ + return RelationshipProperty.Comparator( + self.property, + self._parententity, + adapt_to_entity=self._adapt_to_entity, + of_type=self._of_type, + extra_criteria=self._extra_criteria + other, ) def in_(self, other): @@ -2439,6 +2465,7 @@ class RelationshipProperty(StrategizedProperty): dest_selectable=None, of_type_entity=None, alias_secondary=False, + extra_criteria=(), ): aliased = False @@ -2489,7 +2516,11 @@ class RelationshipProperty(StrategizedProperty): target_adapter, dest_selectable, ) = self._join_condition.join_targets( - source_selectable, dest_selectable, aliased, single_crit + source_selectable, + dest_selectable, + aliased, + single_crit, + extra_criteria, ) if source_selectable is None: source_selectable = self.parent.local_table @@ -3427,7 +3458,12 @@ class JoinCondition(object): ) def join_targets( - self, source_selectable, dest_selectable, aliased, single_crit=None + self, + source_selectable, + dest_selectable, + aliased, + single_crit=None, + extra_criteria=(), ): """Given a source and destination selectable, create a join between them. @@ -3463,6 +3499,12 @@ class JoinCondition(object): else: primaryjoin = primaryjoin & single_crit + if extra_criteria: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & sql.and_(*extra_criteria) + else: + primaryjoin = primaryjoin & sql.and_(*extra_criteria) + if aliased: if secondary is not None: secondary = secondary._anonymous_fromclause(flat=True) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 339c57bdc..e9d4ac2c6 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -102,23 +102,26 @@ CLOSED = util.symbol("CLOSED") class ORMExecuteState(util.MemoizedSlots): - """Stateful object used for the :meth:`.SessionEvents.do_orm_execute` + """Represents a call to the :meth:`_orm.Session.execute` method, as passed + to the :meth:`.SessionEvents.do_orm_execute` event hook. .. versionadded:: 1.4 + """ __slots__ = ( "session", "statement", "parameters", - "_execution_options", - "_merged_execution_options", + "execution_options", + "local_execution_options", "bind_arguments", "_compile_state_cls", "_starting_event_idx", "_events_todo", "_future", + "_update_execution_options", ) def __init__( @@ -135,7 +138,10 @@ class ORMExecuteState(util.MemoizedSlots): self.session = session self.statement = statement self.parameters = parameters - self._execution_options = execution_options + self.local_execution_options = execution_options + self.execution_options = statement._execution_options.union( + execution_options + ) self.bind_arguments = bind_arguments self._compile_state_cls = compile_state_cls self._events_todo = list(events_todo) @@ -182,9 +188,8 @@ class ORMExecuteState(util.MemoizedSlots): .. seealso:: - :ref:`examples_caching` - includes example use of the - :meth:`.SessionEvents.do_orm_execute` hook as well as the - :meth:`.ORMExecuteState.invoke_query` method. + :ref:`do_orm_execute_re_executing` - background and examples on the + appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`. """ @@ -203,11 +208,9 @@ class ORMExecuteState(util.MemoizedSlots): else: _params = self.parameters + _execution_options = self.local_execution_options if execution_options: - _execution_options = dict(self._execution_options) - _execution_options.update(execution_options) - else: - _execution_options = self._execution_options + _execution_options = _execution_options.union(execution_options) return self.session.execute( statement, @@ -255,42 +258,9 @@ class ORMExecuteState(util.MemoizedSlots): def _is_crud(self): return isinstance(self.statement, (dml.Update, dml.Delete)) - @property - def execution_options(self): - """Placeholder for execution options. - - Raises an informative message, as there are local options - vs. merged options that can be viewed, via the - :attr:`.ORMExecuteState.local_execution_options` and - :attr:`.ORMExecuteState.merged_execution_options` methods. - - - """ - raise AttributeError( - "Please use .local_execution_options or " - ".merged_execution_options" - ) - - @property - def local_execution_options(self): - """Dictionary view of the execution options passed to the - :meth:`.Session.execute` method. This does not include options - that may be associated with the statement being invoked. - - """ - return util.immutabledict(self._execution_options) - - @property - def merged_execution_options(self): - """Dictionary view of all execution options merged together; - this includes those of the statement as well as those passed to - :meth:`.Session.execute`, with the local options taking precedence. - - """ - return self._merged_execution_options - - def _memoized_attr__merged_execution_options(self): - return self.statement._execution_options.union(self._execution_options) + def update_execution_options(self, **opts): + # TODO: no coverage + self.local_execution_options = self.local_execution_options.union(opts) def _orm_compile_options(self): opts = self.statement._compile_options @@ -329,6 +299,20 @@ class ORMExecuteState(util.MemoizedSlots): return None @property + def is_relationship_load(self): + """Return True if this load is loading objects on behalf of a + relationship. + + This means, the loader in effect is either a LazyLoader, + SelectInLoader, SubqueryLoader, or similar, and the entire + SELECT statement being emitted is on behalf of a relationship + load. + + """ + path = self.loader_strategy_path + return path is not None and not path.is_root + + @property def load_options(self): """Return the load_options that will be used for this execution.""" @@ -337,7 +321,7 @@ class ORMExecuteState(util.MemoizedSlots): "This ORM execution is not against a SELECT statement " "so there are no load options." ) - return self._execution_options.get( + return self.execution_options.get( "_sa_orm_load_options", context.QueryContext.default_load_options ) @@ -351,7 +335,7 @@ class ORMExecuteState(util.MemoizedSlots): "This ORM execution is not against an UPDATE or DELETE " "statement so there are no update options." ) - return self._execution_options.get( + return self.execution_options.get( "_sa_orm_update_options", persistence.BulkUDCompileState.default_update_options, ) @@ -1003,8 +987,6 @@ class Session(_SessionClassMethods): :ref:`migration_20_toplevel` - :ref:`migration_20_result_rows` - :param info: optional dictionary of arbitrary data to be associated with this :class:`.Session`. Is available via the :attr:`.Session.info` attribute. Note the dictionary is copied at @@ -1282,7 +1264,7 @@ class Session(_SessionClassMethods): the operation will release the current SAVEPOINT but not commit the outermost database transaction. - If :term:`2.x-style` use is in effect via the + If :term:`2.0-style` use is in effect via the :paramref:`_orm.Session.future` flag, the outermost database transaction is committed unconditionally, automatically releasing any SAVEPOINTs in effect. @@ -1416,7 +1398,7 @@ class Session(_SessionClassMethods): self, statement, params=None, - execution_options=util.immutabledict(), + execution_options=util.EMPTY_DICT, bind_arguments=None, future=False, _parent_execute_state=None, @@ -1576,6 +1558,8 @@ class Session(_SessionClassMethods): else: compile_state_cls = None + execution_options = util.coerce_to_immutabledict(execution_options) + if compile_state_cls is not None: ( statement, @@ -1591,8 +1575,11 @@ class Session(_SessionClassMethods): else: bind_arguments.setdefault("clause", statement) if future: - execution_options = util.immutabledict().merge_with( - execution_options, {"future_result": True} + # not sure if immutabledict is working w/ this syntax + # execution_options = + # execution_options.union(future_result=True) + execution_options = execution_options.union( + {"future_result": True} ) if _parent_execute_state: @@ -1619,6 +1606,10 @@ class Session(_SessionClassMethods): if result: return result + # TODO: coverage for this pattern + statement = orm_exec_state.statement + execution_options = orm_exec_state.local_execution_options + bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind, close_with_result=True) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 44f303fee..53166bd91 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1975,6 +1975,7 @@ class JoinedLoader(AbstractRelationshipLoader): clauses, innerjoin, chained_from_outerjoin, + loadopt._extra_criteria if loadopt else (), ) ) @@ -1993,6 +1994,7 @@ class JoinedLoader(AbstractRelationshipLoader): clauses, innerjoin, chained_from_outerjoin, + extra_criteria, ): if parentmapper is None: localparent = query_entity.mapper @@ -2081,6 +2083,17 @@ class JoinedLoader(AbstractRelationshipLoader): or query_entity.entity_zero.represents_outer_join ) + extra_join_criteria = extra_criteria + additional_entity_criteria = compile_state.global_attributes.get( + ("additional_entity_criteria", self.mapper), () + ) + if additional_entity_criteria: + extra_join_criteria += tuple( + ae._resolve_where_criteria(self.mapper) + for ae in additional_entity_criteria + if ae.propagate_to_loaders + ) + if attach_on_outside: # this is the "classic" eager join case. eagerjoin = orm_util._ORMJoin( @@ -2092,11 +2105,12 @@ class JoinedLoader(AbstractRelationshipLoader): or (chained_from_outerjoin and isinstance(towrap, sql.Join)), _left_memo=self.parent, _right_memo=self.mapper, + _extra_criteria=extra_join_criteria, ) else: # all other cases are innerjoin=='nested' approach eagerjoin = self._splice_nested_inner_join( - path, towrap, clauses, onclause, + path, towrap, clauses, onclause, extra_join_criteria ) compile_state.eager_joins[query_entity_key] = eagerjoin @@ -2128,7 +2142,7 @@ class JoinedLoader(AbstractRelationshipLoader): ) def _splice_nested_inner_join( - self, path, join_obj, clauses, onclause, splicing=False + self, path, join_obj, clauses, onclause, extra_criteria, splicing=False ): if splicing is False: @@ -2137,7 +2151,12 @@ class JoinedLoader(AbstractRelationshipLoader): assert isinstance(join_obj, orm_util._ORMJoin) elif isinstance(join_obj, sql.selectable.FromGrouping): return self._splice_nested_inner_join( - path, join_obj.element, clauses, onclause, splicing, + path, + join_obj.element, + clauses, + onclause, + extra_criteria, + splicing, ) elif not isinstance(join_obj, orm_util._ORMJoin): if path[-2] is splicing: @@ -2148,18 +2167,29 @@ class JoinedLoader(AbstractRelationshipLoader): isouter=False, _left_memo=splicing, _right_memo=path[-1].mapper, + _extra_criteria=extra_criteria, ) else: # only here if splicing == True return None target_join = self._splice_nested_inner_join( - path, join_obj.right, clauses, onclause, join_obj._right_memo, + path, + join_obj.right, + clauses, + onclause, + extra_criteria, + join_obj._right_memo, ) if target_join is None: right_splice = False target_join = self._splice_nested_inner_join( - path, join_obj.left, clauses, onclause, join_obj._left_memo, + path, + join_obj.left, + clauses, + onclause, + extra_criteria, + join_obj._left_memo, ) if target_join is None: # should only return None when recursively called, diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index b405153b9..b3913ec5b 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -78,6 +78,7 @@ class Load(Generative, LoaderOption): ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ( "_context_cache_key", visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples, @@ -101,6 +102,7 @@ class Load(Generative, LoaderOption): load.context = {} load.local_opts = {} load._of_type = None + load._extra_criteria = () return load @property @@ -124,6 +126,7 @@ class Load(Generative, LoaderOption): strategy = None propagate_to_loaders = False _of_type = None + _extra_criteria = () def process_compile_state(self, compile_state): if not compile_state.compile_options._enable_eagerloads: @@ -248,6 +251,9 @@ class Load(Generative, LoaderOption): else: return None + if attr._extra_criteria: + self._extra_criteria = attr._extra_criteria + if getattr(attr, "_of_type", None): ac = attr._of_type ext_info = of_type_info = inspect(ac) @@ -356,6 +362,7 @@ class Load(Generative, LoaderOption): cloned = self._clone_for_bind_strategy(attr, strategy, "relationship") self.path = cloned.path self._of_type = cloned._of_type + self._extra_criteria = cloned._extra_criteria cloned.is_class_strategy = self.is_class_strategy = False self.propagate_to_loaders = cloned.propagate_to_loaders @@ -413,6 +420,7 @@ class Load(Generative, LoaderOption): if existing: if merge_opts: existing.local_opts.update(self.local_opts) + existing._extra_criteria += self._extra_criteria else: path.set(context, "loader", self) else: @@ -420,6 +428,7 @@ class Load(Generative, LoaderOption): path.set(context, "loader", self) if existing and existing.is_opts_only: self.local_opts.update(existing.local_opts) + existing._extra_criteria += self._extra_criteria def _set_path_strategy(self): if not self.is_class_strategy and self.path.has_entity: @@ -507,11 +516,13 @@ class _UnboundLoad(Load): self.path = () self._to_bind = [] self.local_opts = {} + self._extra_criteria = () _cache_key_traversal = [ ("path", visitors.ExtendedInternalTraversal.dp_multi_list), ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), ] @@ -576,6 +587,7 @@ class _UnboundLoad(Load): if attr: path = path + (attr,) self.path = path + self._extra_criteria = getattr(attr, "_extra_criteria", ()) return path diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 71ee29597..82fad0815 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -23,6 +23,7 @@ from .base import object_state # noqa from .base import state_attribute_str # noqa from .base import state_class_str # noqa from .base import state_str # noqa +from .interfaces import LoaderOption from .interfaces import MapperProperty # noqa from .interfaces import ORMColumnsClauseRole from .interfaces import ORMEntityColumnsClauseRole @@ -38,6 +39,7 @@ from ..engine.result import result_tuple from ..sql import base as sql_base from ..sql import coercions from ..sql import expression +from ..sql import lambdas from ..sql import roles from ..sql import util as sql_util from ..sql import visitors @@ -854,6 +856,184 @@ class AliasedInsp( return "aliased(%s)" % (self._target.__name__,) +class LoaderCriteriaOption(LoaderOption): + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + :class:`_orm.LoaderCriteriaOption` is invoked using the + :func:`_orm.with_loader_criteria` function; see that function for + details. + + .. versionadded:: 1.4 + + """ + + _traverse_internals = [ + ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("where_criteria", visitors.InternalTraversal.dp_clauseelement), + ("include_aliases", visitors.InternalTraversal.dp_boolean), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), + ] + + def __init__( + self, + entity_or_base, + where_criteria, + loader_only=False, + include_aliases=False, + propagate_to_loaders=True, + ): + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + .. versionadded:: 1.4 + + The :func:`_orm.with_loader_criteria` option is intended to add + limiting criteria to a particular kind of entity in a query, + **globally**, meaning it will apply to the entity as it appears + in the SELECT query as well as within any subqueries, join + conditions, and relationship loads, including both eager and lazy + loaders, without the need for it to be specified in any particular + part of the query. The rendering logic uses the same system used by + single table inheritance to ensure a certain discriminator is applied + to a table. + + E.g., using :term:`2.0-style` queries, we can limit the way the + ``User.addresses`` collection is loaded, regardless of the kind + of loading used:: + + from sqlalchemy.orm import with_loader_criteria + + stmt = select(User).options( + selectinload(User.addresses), + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + Above, the "selectinload" for ``User.addresses`` will apply the + given filtering criteria to the WHERE clause. + + Another example, where the filtering will be applied to the + ON clause of the join, in this example using :term:`1.x style` + queries:: + + q = session.query(User).outerjoin(User.addresses).options( + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + The primary purpose of :func:`_orm.with_loader_criteria` is to use + it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler + to ensure that all occurrences of a particular entity are filtered + in a certain way, such as filtering for access control roles. It + also can be used to apply criteria to relationship loads. In the + example below, we can apply a certain set of rules to all queries + emitted by a particular :class:`_orm.Session`:: + + session = Session(bind=engine) + + @event.listens_for("do_orm_execute", session) + def _add_filtering_criteria(execute_state): + execute_state.statement = execute_state.statement.options( + with_loader_criteria( + SecurityRole, + lambda cls: cls.role.in_(['some_role']), + include_aliases=True + ) + ) + + The given class will expand to include all mapped subclass and + need not itself be a mapped class. + + + :param entity_or_base: a mapped class, or a class that is a super + class of a particular set of mapped classes, to which the rule + will apply. + + :param where_criteria: a Core SQL expression that applies limiting + criteria. This may also be a "lambda:" or Python function that + accepts a target class as an argument, when the given class is + a base with many different mapped subclasses. + + :param include_aliases: if True, apply the rule to :func:`_orm.aliased` + constructs as well. + + :param propagate_to_loaders: defaults to True, apply to relationship + loaders such as lazy loaders. + + + .. seealso:: + + :ref:`examples_session_orm_events` - includes examples of using + :func:`_orm.with_loader_criteria`. + + :ref:`do_orm_execute_global_criteria` - basic example on how to + combine :func:`_orm.with_loader_criteria` with the + :meth:`_orm.SessionEvents.do_orm_execute` event. + + """ + entity = inspection.inspect(entity_or_base, False) + if entity is None: + self.root_entity = entity_or_base + self.entity = None + else: + self.root_entity = None + self.entity = entity + + if callable(where_criteria): + self.deferred_where_criteria = True + self.where_criteria = lambdas.DeferredLambdaElement( + where_criteria, + roles.WhereHavingRole, + lambda_args=( + self.root_entity + if self.root_entity is not None + else self.entity.entity, + ), + ) + else: + self.deferred_where_criteria = False + self.where_criteria = coercions.expect( + roles.WhereHavingRole, where_criteria + ) + + self.include_aliases = include_aliases + self.propagate_to_loaders = propagate_to_loaders + + def _all_mappers(self): + if self.entity: + for ent in self.entity.mapper.self_and_descendants: + yield ent + else: + stack = list(self.root_entity.__subclasses__()) + while stack: + subclass = stack.pop(0) + ent = inspection.inspect(subclass) + if ent: + for mp in ent.mapper.self_and_descendants: + yield mp + else: + stack.extend(subclass.__subclasses__()) + + def _resolve_where_criteria(self, ext_info): + if self.deferred_where_criteria: + return self.where_criteria._resolve_with_args(ext_info.entity) + else: + return self.where_criteria + + def process_compile_state(self, compile_state): + """Apply a modification to a given :class:`.CompileState`.""" + + # if options to limit the criteria to immediate query only, + # use compile_state.attributes instead + + for mp in self._all_mappers(): + load_criteria = compile_state.global_attributes.setdefault( + ("additional_entity_criteria", mp), [] + ) + + load_criteria.append(self) + + inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) inspection._inspects(AliasedInsp)(lambda target: target) @@ -1270,6 +1450,7 @@ class _ORMJoin(expression.Join): full=False, _left_memo=None, _right_memo=None, + _extra_criteria=(), ): left_info = inspection.inspect(left) @@ -1291,6 +1472,7 @@ class _ORMJoin(expression.Join): if isinstance(onclause, attributes.QueryableAttribute): on_selectable = onclause.comparator._source_selectable() prop = onclause.property + _extra_criteria += onclause._extra_criteria elif isinstance(onclause, MapperProperty): # used internally by joined eager loader...possibly not ideal prop = onclause @@ -1319,6 +1501,7 @@ class _ORMJoin(expression.Join): source_polymorphic=True, of_type_entity=right_info, alias_secondary=True, + extra_criteria=_extra_criteria, ) if sj is not None: @@ -1331,6 +1514,7 @@ class _ORMJoin(expression.Join): onclause = sj else: onclause = pj + self._target_adapter = target_adapter expression.Join.__init__(self, left, right, onclause, isouter, full) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ac4055bdf..b8984316c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -792,6 +792,10 @@ class SQLCompiler(Compiled): def prefetch(self): return list(self.insert_prefetch + self.update_prefetch) + @util.memoized_property + def _global_attributes(self): + return {} + @util.memoized_instancemethod def _init_cte_state(self): """Initialize collections related to CTEs only if diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index 6279dcf55..c8e83bbd7 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -1017,8 +1017,8 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest): if ckey: break else: - if "_cache_key" in orm_context.merged_execution_options: - ckey = orm_context.merged_execution_options["_cache_key"] + if "_cache_key" in orm_context.execution_options: + ckey = orm_context.execution_options["_cache_key"] if ckey is not None: return get_value( diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index e33e95cc0..86e0bd360 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -1302,6 +1302,28 @@ class _PolymorphicTestBase(object): [e1, e3], ) + def test_join_and_thru_polymorphic_nonaliased_one(self): + sess = create_session() + eq_( + sess.query(Company) + .join(Company.employees) + .join(Person.paperwork.and_(Paperwork.description.like("%#2%"))) + .all(), + [c1], + ) + + def test_join_and_thru_polymorphic_aliased_one(self): + sess = create_session() + ea = aliased(Person) + pa = aliased(Paperwork) + eq_( + sess.query(Company) + .join(ea, Company.employees) + .join(pa, ea.paperwork.and_(pa.description.like("%#2%"))) + .all(), + [c1], + ) + def test_join_through_polymorphic_nonaliased_one(self): sess = create_session() eq_( diff --git a/test/orm/test_bundle.py b/test/orm/test_bundle.py index f4af84094..9d1d0b61b 100644 --- a/test/orm/test_bundle.py +++ b/test/orm/test_bundle.py @@ -3,6 +3,7 @@ from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String +from sqlalchemy import testing from sqlalchemy.orm import aliased from sqlalchemy.orm import Bundle from sqlalchemy.orm import mapper @@ -186,6 +187,35 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): ], ) + def test_multi_bundle_future(self): + Data = self.classes.Data + Other = self.classes.Other + + d1 = aliased(Data) + + b1 = Bundle("b1", d1.d1, d1.d2) + b2 = Bundle("b2", Data.d1, Other.o1) + + sess = Session(testing.db, future=True) + + stmt = ( + select(b1, b2) + .join(Data.others) + .join(d1, d1.id == Data.id) + .filter(b1.c.d1 == "d3d1") + ) + + eq_( + sess.execute(stmt).all(), + [ + (("d3d1", "d3d2"), ("d3d1", "d3o0")), + (("d3d1", "d3d2"), ("d3d1", "d3o1")), + (("d3d1", "d3d2"), ("d3d1", "d3o2")), + (("d3d1", "d3d2"), ("d3d1", "d3o3")), + (("d3d1", "d3d2"), ("d3d1", "d3o4")), + ], + ) + def test_single_entity(self): Data = self.classes.Data sess = Session() @@ -197,6 +227,18 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL): [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")], ) + def test_single_entity_future(self): + Data = self.classes.Data + sess = Session(testing.db, future=True) + + b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True) + + stmt = select(b1).filter(b1.c.d1.between("d3d1", "d5d1")) + eq_( + sess.execute(stmt).scalars().all(), + [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")], + ) + def test_single_entity_flag_but_multi_entities(self): Data = self.classes.Data sess = Session() diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py index 02b1b9fbf..45a60a5cb 100644 --- a/test/orm/test_cache_key.py +++ b/test/orm/test_cache_key.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload +from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql.base import CacheableOptions from sqlalchemy.sql.visitors import InternalTraversal @@ -65,6 +66,62 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): compare_values=True, ) + def test_loader_criteria(self): + User, Address = self.classes("User", "Address") + + from sqlalchemy import Column, Integer, String + + class Foo(object): + id = Column(Integer) + name = Column(String) + + self._run_cache_key_fixture( + lambda: ( + with_loader_criteria(User, User.name != "somename"), + with_loader_criteria(User, User.id != 5), + with_loader_criteria(User, lambda cls: cls.id == 10), + with_loader_criteria(Address, Address.id != 5), + with_loader_criteria(Foo, lambda cls: cls.id == 10), + ), + compare_values=True, + ) + + def test_loader_criteria_bound_param_thing(self): + from sqlalchemy import Column, Integer + + class Foo(object): + id = Column(Integer) + + def go(param): + return with_loader_criteria(Foo, lambda cls: cls.id == param) + + g1 = go(10) + g2 = go(20) + + ck1 = g1._generate_cache_key() + ck2 = g2._generate_cache_key() + + eq_(ck1.key, ck2.key) + eq_(ck1.bindparams[0].key, ck2.bindparams[0].key) + eq_(ck1.bindparams[0].value, 10) + eq_(ck2.bindparams[0].value, 20) + + def test_instrumented_attributes(self): + User, Address, Keyword, Order, Item = self.classes( + "User", "Address", "Keyword", "Order", "Item" + ) + + self._run_cache_key_fixture( + lambda: ( + User.addresses, + User.addresses.of_type(aliased(Address)), + User.orders, + User.orders.and_(Order.id != 5), + User.orders.and_(Order.description != "somename"), + ), + compare_values=True, + ) + def test_unbound_options(self): User, Address, Keyword, Order, Item = self.classes( "User", "Address", "Keyword", "Order", "Item" @@ -75,6 +132,10 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): joinedload(User.addresses), joinedload(User.addresses.of_type(aliased(Address))), joinedload("addresses"), + joinedload(User.orders), + joinedload(User.orders.and_(Order.id != 5)), + joinedload(User.orders.and_(Order.id == 5)), + joinedload(User.orders.and_(Order.description != "somename")), joinedload(User.orders).selectinload("items"), joinedload(User.orders).selectinload(Order.items), defer(User.id), @@ -110,6 +171,10 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): User.addresses.of_type(aliased(Address)) ), Load(User).joinedload(User.orders), + Load(User).joinedload(User.orders.and_(Order.id != 5)), + Load(User).joinedload( + User.orders.and_(Order.description != "somename") + ), Load(User).defer(User.id), Load(User).subqueryload("addresses"), Load(Address).defer("id"), @@ -169,6 +234,9 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest): select(User).join(Address, User.addresses), select(User).join(a1, User.addresses), select(User).join(User.addresses.of_type(a1)), + select(User).join( + User.addresses.and_(Address.email_address == "foo") + ), select(User) .join(Address, User.addresses) .join_from(User, Order), diff --git a/test/orm/test_events.py b/test/orm/test_events.py index b68e0d2e6..df48cfe63 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -2,6 +2,8 @@ import sqlalchemy as sa from sqlalchemy import event from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import literal_column +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.declarative import declarative_base @@ -47,6 +49,170 @@ class _RemoveListeners(object): super(_RemoveListeners, self).teardown() +class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): + run_setup_mappers = "once" + run_inserts = "once" + run_deletes = None + + @classmethod + def setup_mappers(cls): + cls._setup_stock_mapping() + + def _caching_session_fixture(self): + + cache = {} + + maker = sessionmaker(testing.db, future=True) + + def get_value(cache_key, cache, createfunc): + if cache_key in cache: + return cache[cache_key]() + else: + cache[cache_key] = retval = createfunc().freeze() + return retval() + + @event.listens_for(maker, "do_orm_execute", retval=True) + def do_orm_execute(orm_context): + ckey = None + for opt in orm_context.user_defined_options: + ckey = opt.get_cache_key(orm_context) + if ckey: + break + else: + if "cache_key" in orm_context.execution_options: + ckey = orm_context.execution_options["cache_key"] + + if ckey is not None: + return get_value(ckey, cache, orm_context.invoke_statement,) + + return maker() + + def test_cache_option(self): + User, Address = self.classes("User", "Address") + + with self.sql_execution_asserter(testing.db) as asserter: + + with self._caching_session_fixture() as session: + stmt = ( + select(User) + .where(User.id == 7) + .execution_options(cache_key="user7") + ) + + result = session.execute(stmt) + + eq_( + result.scalars().all(), + [User(id=7, addresses=[Address(id=1)])], + ) + + result = session.execute(stmt) + + eq_( + result.scalars().all(), + [User(id=7, addresses=[Address(id=1)])], + ) + + asserter.assert_( + CompiledSQL( + "SELECT users.id, users.name FROM users " + "WHERE users.id = :id_1", + [{"id_1": 7}], + ), + CompiledSQL( + "SELECT addresses.id AS addresses_id, addresses.user_id AS " + "addresses_user_id, " + "addresses.email_address AS addresses_email_address " + "FROM addresses WHERE :param_1 = addresses.user_id " + "ORDER BY addresses.id", + [{"param_1": 7}], + ), + ) + + def test_chained_events_one(self): + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + ctx.update_execution_options(one=True) + + @event.listens_for(sess, "do_orm_execute") + def two(ctx): + ctx.update_execution_options(two=True) + + @event.listens_for(sess, "do_orm_execute") + def three(ctx): + ctx.update_execution_options(three=True) + + @event.listens_for(sess, "do_orm_execute") + def four(ctx): + ctx.update_execution_options(four=True) + + result = sess.execute(select(literal_column("1"))) + + eq_( + result.context.execution_options, + { + "four": True, + "future_result": True, + "one": True, + "three": True, + "two": True, + }, + ) + + def test_chained_events_two(self): + + sess = Session(testing.db, future=True) + + def added(ctx): + ctx.update_execution_options(added_evt=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + ctx.update_execution_options(one=True) + + @event.listens_for(sess, "do_orm_execute", retval=True) + def two(ctx): + ctx.update_execution_options(two=True) + return ctx.invoke_statement( + statement=ctx.statement.execution_options(statement_two=True) + ) + + @event.listens_for(sess, "do_orm_execute") + def three(ctx): + ctx.update_execution_options(three=True) + + @event.listens_for(sess, "do_orm_execute") + def four(ctx): + ctx.update_execution_options(four=True) + return ctx.invoke_statement( + statement=ctx.statement.execution_options(statement_four=True) + ) + + @event.listens_for(sess, "do_orm_execute") + def five(ctx): + ctx.update_execution_options(five=True) + + result = sess.execute(select(literal_column("1")), _add_event=added) + + eq_( + result.context.execution_options, + { + "statement_two": True, + "statement_four": True, + "future_result": True, + "one": True, + "two": True, + "three": True, + "four": True, + "five": True, + "added_evt": True, + }, + ) + + class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): run_inserts = None diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 208db9d85..b5a6e3b29 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1391,6 +1391,7 @@ class PickleTest(PathTest, QueryTest): "propagate_to_loaders": True, "_of_type": None, "_to_bind": to_bind, + "_extra_criteria": (), }, ) diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py new file mode 100644 index 000000000..c4bcf0404 --- /dev/null +++ b/test/orm/test_relationship_criteria.py @@ -0,0 +1,867 @@ +import datetime +import random + +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import event +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import orm +from sqlalchemy import select +from sqlalchemy import sql +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy.orm import aliased +from sqlalchemy.orm import joinedload +from sqlalchemy.orm import mapper +from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload +from sqlalchemy.orm import Session +from sqlalchemy.orm import with_loader_criteria +from sqlalchemy.testing import eq_ +from sqlalchemy.testing.assertsql import CompiledSQL +from test.orm import _fixtures + + +class _Fixtures(_fixtures.FixtureTest): + @testing.fixture + def user_address_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) + return User, Address + + @testing.fixture + def order_item_fixture(self): + Order, Item = self.classes("Order", "Item") + orders, items, order_items = self.tables( + "orders", "items", "order_items" + ) + + mapper( + Order, + orders, + properties={ + # m2m + "items": relationship( + Item, secondary=order_items, order_by=items.c.id + ), + }, + ) + mapper(Item, items) + + return Order, Item + + @testing.fixture + def mixin_fixture(self): + users = self.tables.users + + class HasFoob(object): + name = Column(String) + + class UserWFoob(HasFoob, self.Comparable): + pass + + mapper( + UserWFoob, users, + ) + return HasFoob, UserWFoob + + +class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): + """ + combinations: + + + with_loader_criteria + # for these we have mapper_criteria + + select(mapper) # select_mapper + select(mapper.col, mapper.col) # select_mapper_col + select(func.count()).select_from(mapper) # select_from_mapper + select(a).join(mapper, a.target) # select_join_mapper + select(a).options(joinedload(a.target)) # select_joinedload_mapper + + + # for these we have aliased_criteria, inclaliased_criteria + + select(aliased) # select_aliased + select(aliased.col, aliased.col) # select_aliased_col + select(func.count()).select_from(aliased) # select_from_aliased + select(a).join(aliased, a.target) # select_join_aliased + select(a).options(joinedload(a.target.of_type(aliased)) + # select_joinedload_aliased + + """ + + __dialect__ = "default" + + def test_select_mapper_mapper_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + stmt = select(User).options( + with_loader_criteria(User, User.name != "name") + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name " + "FROM users WHERE users.name != :name_1", + ) + + def test_select_from_mapper_mapper_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + stmt = ( + select(sql.func.count()) + .select_from(User) + .options(with_loader_criteria(User, User.name != "name")) + ) + + self.assert_compile( + stmt, + "SELECT count(*) AS count_1 FROM users " + "WHERE users.name != :name_1", + ) + + def test_select_mapper_columns_mapper_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + stmt = select(User.id, User.name).options( + with_loader_criteria(User, User.name != "name") + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name " + "FROM users WHERE users.name != :name_1", + ) + + def test_select_join_mapper_mapper_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + stmt = ( + select(User) + .join(User.addresses) + .options( + with_loader_criteria(Address, Address.email_address != "name") + ) + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users " + "JOIN addresses ON users.id = addresses.user_id " + "AND addresses.email_address != :email_address_1", + ) + + def test_select_joinm2m_mapper_mapper_criteria(self, order_item_fixture): + Order, Item = order_item_fixture + + stmt = ( + select(Order) + .join(Order.items) + .options( + with_loader_criteria(Item, Item.description != "description") + ) + ) + + self.assert_compile( + stmt, + "SELECT orders.id, orders.user_id, orders.address_id, " + "orders.description, orders.isopen FROM orders " + "JOIN order_items AS order_items_1 " + "ON orders.id = order_items_1.order_id " + "JOIN items ON items.id = order_items_1.item_id " + "AND items.description != :description_1", + ) + + def test_select_joinedload_mapper_mapper_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + stmt = select(User).options( + joinedload(User.addresses), + with_loader_criteria(Address, Address.email_address != "name"), + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address " + "FROM users LEFT OUTER JOIN addresses AS addresses_1 " + "ON users.id = addresses_1.user_id " + "AND addresses_1.email_address != :email_address_1 " + "ORDER BY addresses_1.id", + ) + + def test_select_selectinload_mapper_mapper_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + stmt = select(User).options( + selectinload(User.addresses), + with_loader_criteria(Address, Address.email_address != "name"), + ) + + s = Session(testing.db, future=True) + + with self.sql_execution_asserter() as asserter: + + s.execute(stmt).all() + + asserter.assert_( + CompiledSQL("SELECT users.id, users.name FROM users", [],), + CompiledSQL( + "SELECT addresses.user_id AS addresses_user_id, addresses.id " + "AS addresses_id, addresses.email_address " + "AS addresses_email_address FROM addresses " + "WHERE addresses.user_id IN ([POSTCOMPILE_primary_keys]) " + "AND addresses.email_address != :email_address_1 " + "ORDER BY addresses.id", + [{"primary_keys": [7, 8, 9, 10], "email_address_1": "name"}], + ), + ) + + def test_select_lazyload_mapper_mapper_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + stmt = ( + select(User) + .options( + with_loader_criteria(Address, Address.email_address != "name"), + ) + .order_by(User.id) + ) + + s = Session(testing.db, future=True) + + with self.sql_execution_asserter() as asserter: + for u in s.execute(stmt).scalars(): + u.addresses + + asserter.assert_( + CompiledSQL( + "SELECT users.id, users.name FROM users ORDER BY users.id", [], + ), + CompiledSQL( + "SELECT addresses.id AS addresses_id, " + "addresses.user_id AS addresses_user_id, " + "addresses.email_address AS addresses_email_address " + "FROM addresses WHERE :param_1 = addresses.user_id " + "AND addresses.email_address != :email_address_1 " + "ORDER BY addresses.id", + [{"param_1": 7, "email_address_1": "name"}], + ), + CompiledSQL( + "SELECT addresses.id AS addresses_id, " + "addresses.user_id AS addresses_user_id, " + "addresses.email_address AS addresses_email_address " + "FROM addresses WHERE :param_1 = addresses.user_id " + "AND addresses.email_address != :email_address_1 " + "ORDER BY addresses.id", + [{"param_1": 8, "email_address_1": "name"}], + ), + CompiledSQL( + "SELECT addresses.id AS addresses_id, " + "addresses.user_id AS addresses_user_id, " + "addresses.email_address AS addresses_email_address " + "FROM addresses WHERE :param_1 = addresses.user_id " + "AND addresses.email_address != :email_address_1 " + "ORDER BY addresses.id", + [{"param_1": 9, "email_address_1": "name"}], + ), + CompiledSQL( + "SELECT addresses.id AS addresses_id, " + "addresses.user_id AS addresses_user_id, " + "addresses.email_address AS addresses_email_address " + "FROM addresses WHERE :param_1 = addresses.user_id " + "AND addresses.email_address != :email_address_1 " + "ORDER BY addresses.id", + [{"param_1": 10, "email_address_1": "name"}], + ), + ) + + def test_select_aliased_inclaliased_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + u1 = aliased(User) + stmt = select(u1).options( + with_loader_criteria( + User, User.name != "name", include_aliases=True + ) + ) + + self.assert_compile( + stmt, + "SELECT users_1.id, users_1.name " + "FROM users AS users_1 WHERE users_1.name != :name_1", + ) + + def test_select_from_aliased_inclaliased_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + u1 = aliased(User) + stmt = ( + select(sql.func.count()) + .select_from(u1) + .options( + with_loader_criteria( + User, User.name != "name", include_aliases=True + ) + ) + ) + + self.assert_compile( + stmt, + "SELECT count(*) AS count_1 FROM users AS users_1 " + "WHERE users_1.name != :name_1", + ) + + def test_select_aliased_columns_inclaliased_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + u1 = aliased(User) + stmt = select(u1.id, u1.name).options( + with_loader_criteria( + User, User.name != "name", include_aliases=True + ) + ) + + self.assert_compile( + stmt, + "SELECT users_1.id, users_1.name " + "FROM users AS users_1 WHERE users_1.name != :name_1", + ) + + def test_select_join_aliased_inclaliased_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + a1 = aliased(Address) + stmt = ( + select(User) + .join(User.addresses.of_type(a1)) + .options( + with_loader_criteria( + Address, + Address.email_address != "name", + include_aliases=True, + ) + ) + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users " + "JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id " + "AND addresses_1.email_address != :email_address_1", + ) + + def test_select_joinm2m_aliased_inclaliased_criteria( + self, order_item_fixture + ): + Order, Item = order_item_fixture + + i1 = aliased(Item) + + stmt = ( + select(Order) + .join(Order.items.of_type(i1)) + .options( + with_loader_criteria( + Item, + Item.description != "description", + include_aliases=True, + ) + ) + ) + + self.assert_compile( + stmt, + "SELECT orders.id, orders.user_id, orders.address_id, " + "orders.description, orders.isopen FROM orders " + "JOIN order_items AS order_items_1 " + "ON orders.id = order_items_1.order_id " + "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " + "AND items_1.description != :description_1", + ) + + def test_select_aliased_aliased_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + u1 = aliased(User) + stmt = select(u1).options(with_loader_criteria(u1, u1.name != "name")) + + self.assert_compile( + stmt, + "SELECT users_1.id, users_1.name " + "FROM users AS users_1 WHERE users_1.name != :name_1", + ) + + def test_select_aliased_columns_aliased_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + u1 = aliased(User) + stmt = select(u1.id, u1.name).options( + with_loader_criteria(u1, u1.name != "name") + ) + + self.assert_compile( + stmt, + "SELECT users_1.id, users_1.name " + "FROM users AS users_1 WHERE users_1.name != :name_1", + ) + + def test_joinedload_global_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + s = Session(testing.db, future=True) + + stmt = select(User).options( + joinedload(User.addresses), + with_loader_criteria(Address, Address.email_address != "email"), + ) + + with self.sql_execution_asserter() as asserter: + + s.execute(stmt) + + asserter.assert_( + CompiledSQL( + "SELECT users.id, users.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address FROM " + "users LEFT OUTER JOIN addresses AS addresses_1 " + "ON users.id = addresses_1.user_id " + "AND addresses_1.email_address != :email_address_1 " + "ORDER BY addresses_1.id", + [{"email_address_1": "email"}], + ), + ) + + def test_query_count_global_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + s = Session(testing.db) + + q = s.query(User).options(with_loader_criteria(User, User.id != 8)) + + with self.sql_execution_asserter() as asserter: + q.count() + + asserter.assert_( + CompiledSQL( + "SELECT count(*) AS count_1 FROM (SELECT " + "users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id != :id_1) AS anon_1", + [{"id_1": 8}], + ), + ) + + def test_query_count_after_the_fact_global_criteria( + self, user_address_fixture + ): + User, Address = user_address_fixture + + s = Session(testing.db) + + # this essentially tests that the query.from_self() which takes + # place in count() is one that can still be affected by + # the loader criteria, meaning it has to be an ORM query + + q = s.query(User) + + @event.listens_for(s, "do_orm_execute") + def add_criteria(orm_context): + orm_context.statement = orm_context.statement.options( + with_loader_criteria(User, User.id != 8) + ) + + with self.sql_execution_asserter() as asserter: + q.count() + + asserter.assert_( + CompiledSQL( + "SELECT count(*) AS count_1 FROM (SELECT " + "users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id != :id_1) AS anon_1", + [{"id_1": 8}], + ), + ) + + def test_select_count_subquery_global_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + stmt = select(User).subquery() + + stmt = ( + select(sql.func.count()) + .select_from(stmt) + .options(with_loader_criteria(User, User.id != 8)) + ) + + self.assert_compile( + stmt, + "SELECT count(*) AS count_1 FROM (SELECT users.id AS id, " + "users.name AS name FROM users WHERE users.id != :id_1) AS anon_1", + ) + + def test_query_outerjoin_global_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + s = Session(testing.db) + + q = ( + s.query(User, Address) + .outerjoin(User.addresses) + .options( + with_loader_criteria( + Address, ~Address.email_address.like("ed@%"), + ) + ) + .order_by(User.id) + ) + + self.assert_compile( + q, + "SELECT users.id AS users_id, users.name AS users_name, " + "addresses.id AS addresses_id, " + "addresses.user_id AS addresses_user_id, " + "addresses.email_address AS addresses_email_address " + "FROM users LEFT OUTER JOIN addresses " + "ON users.id = addresses.user_id AND " + "addresses.email_address NOT LIKE :email_address_1 " + "ORDER BY users.id", + ) + eq_( + q.all(), + [ + (User(id=7), Address(id=1)), + (User(id=8), None), # three addresses not here + (User(id=9), Address(id=5)), + (User(id=10), None), + ], + ) + + def test_caching_and_binds_lambda(self, mixin_fixture): + HasFoob, UserWFoob = mixin_fixture + + statement = select(UserWFoob).filter(UserWFoob.id < 10) + + def go(value): + return statement.options( + with_loader_criteria( + HasFoob, + lambda cls: cls.name == value, + include_aliases=True, + ) + ) + + s = Session(testing.db, future=True) + + for i in range(10): + name = random.choice(["ed", "fred", "jack"]) + stmt = go(name) + + eq_(s.execute(stmt).scalars().all(), [UserWFoob(name=name)]) + + +class TemporalFixtureTest(testing.fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + class HasTemporal(object): + """Mixin that identifies a class as having a timestamp column""" + + timestamp = Column( + DateTime, default=datetime.datetime.utcnow, nullable=False + ) + + cls.HasTemporal = HasTemporal + + def temporal_range(range_lower, range_upper): + return with_loader_criteria( + HasTemporal, + lambda cls: cls.timestamp.between(range_lower, range_upper), + include_aliases=True, + ) + + cls.temporal_range = staticmethod(temporal_range) + + class Parent(HasTemporal, cls.DeclarativeBasic): + __tablename__ = "parent" + id = Column(Integer, primary_key=True) + children = relationship("Child", order_by="Child.id") + + class Child(HasTemporal, cls.DeclarativeBasic): + __tablename__ = "child" + id = Column(Integer, primary_key=True) + parent_id = Column( + Integer, ForeignKey("parent.id"), nullable=False + ) + + @classmethod + def insert_data(cls, connection): + Parent, Child = cls.classes("Parent", "Child") + + sess = Session(connection) + c1, c2, c3, c4, c5 = [ + Child(timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00)), + Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)), + Child(timestamp=datetime.datetime(2009, 10, 20, 12, 00, 00)), + Child(timestamp=datetime.datetime(2009, 10, 12, 12, 00, 00)), + Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)), + ] + + p1 = Parent( + timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00), + children=[c1, c2, c3], + ) + p2 = Parent( + timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00), + children=[c4, c5], + ) + + sess.add_all([p1, p2]) + sess.commit() + + @testing.combinations((True,), (False,), argnames="use_caching") + @testing.combinations( + (None,), + (orm.lazyload,), + (orm.joinedload,), + (orm.subqueryload,), + (orm.selectinload,), + argnames="loader_strategy", + ) + def test_same_relatinship_load_different_range( + self, use_caching, loader_strategy + ): + """This is the first test that exercises lazy loading, which uses + a lambda select, which then needs to transform the select to have + different bound parameters if it's not cached (or generate a working + list of parameters if it is), which then calls into a + with_loader_crieria that itself has another lambda inside of it, + which means we have to traverse and replace that lambda's expression, + but we can't evaluate it until compile time, so the inner lambda + holds onto the "transform" function so it can run it as needed. + this makes use of a new feature in visitors that exports a + "run this traversal later" function. + + All of these individual features, cloning lambdaelements, + running replacement traversals later, are very new and need a lot + of tests, most likely in test/sql/test_lambdas.py. + + the test is from the "temporal_range" example which is the whole + use case this feature is designed for and it is a whopper. + + + """ + Parent, Child = self.classes("Parent", "Child") + temporal_range = self.temporal_range + + if use_caching: + Parent.children.property.bake_queries = True + eng = testing.db + else: + Parent.children.property.bake_queries = False + eng = testing.db.execution_options(compiled_cache=None) + + sess = Session(eng, future=True) + + if loader_strategy: + loader_options = (loader_strategy(Parent.children),) + else: + loader_options = () + + p1 = sess.execute( + select(Parent).filter( + Parent.timestamp == datetime.datetime(2009, 10, 15, 12, 00, 00) + ) + ).scalar() + c1, c2 = p1.children[0:2] + c2_id = c2.id + + p2 = sess.execute( + select(Parent).filter( + Parent.timestamp == datetime.datetime(2009, 10, 17, 12, 00, 00) + ) + ).scalar() + c5 = p2.children[1] + + parents = ( + sess.execute( + select(Parent) + .execution_options(populate_existing=True) + .options( + temporal_range( + datetime.datetime(2009, 10, 16, 12, 00, 00), + datetime.datetime(2009, 10, 18, 12, 00, 00), + ), + *loader_options + ) + ) + .scalars() + .all() + ) + + assert parents[0] == p2 + assert parents[0].children == [c5] + + parents = ( + sess.execute( + select(Parent) + .execution_options(populate_existing=True) + .join(Parent.children) + .filter(Child.id == c2_id) + .options( + temporal_range( + datetime.datetime(2009, 10, 15, 11, 00, 00), + datetime.datetime(2009, 10, 18, 12, 00, 00), + ), + *loader_options + ) + ) + .scalars() + .all() + ) + + assert parents[0] == p1 + assert parents[0].children == [c1, c2] + + +class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def user_address_fixture(self): + users, Address, addresses, User = ( + self.tables.users, + self.classes.Address, + self.tables.addresses, + self.classes.User, + ) + + mapper( + User, + users, + properties={ + "addresses": relationship( + mapper(Address, addresses), order_by=Address.id + ) + }, + ) + return User, Address + + def test_joinedload_local_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + s = Session(testing.db, future=True) + + stmt = select(User).options( + joinedload(User.addresses.and_(Address.email_address != "email")), + ) + + with self.sql_execution_asserter() as asserter: + + s.execute(stmt) + + asserter.assert_( + CompiledSQL( + "SELECT users.id, users.name, addresses_1.id AS id_1, " + "addresses_1.user_id, addresses_1.email_address FROM " + "users LEFT OUTER JOIN addresses AS addresses_1 " + "ON users.id = addresses_1.user_id " + "AND addresses_1.email_address != :email_address_1 " + "ORDER BY addresses_1.id", + [{"email_address_1": "email"}], + ), + ) + + def test_query_join_local_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + s = Session(testing.db) + + q = s.query(User).join( + User.addresses.and_(Address.email_address != "email") + ) + + self.assert_compile( + q, + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users JOIN addresses ON users.id = addresses.user_id " + "AND addresses.email_address != :email_address_1", + ) + + def test_select_join_local_criteria(self, user_address_fixture): + User, Address = user_address_fixture + + stmt = select(User).join( + User.addresses.and_(Address.email_address != "email") + ) + + self.assert_compile( + stmt, + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id " + "AND addresses.email_address != :email_address_1", + ) + + def test_select_joinm2m_local_criteria(self, order_item_fixture): + Order, Item = order_item_fixture + + stmt = select(Order).join( + Order.items.and_(Item.description != "description") + ) + + self.assert_compile( + stmt, + "SELECT orders.id, orders.user_id, orders.address_id, " + "orders.description, orders.isopen " + "FROM orders JOIN order_items AS order_items_1 " + "ON orders.id = order_items_1.order_id " + "JOIN items ON items.id = order_items_1.item_id " + "AND items.description != :description_1", + ) + + def test_select_joinm2m_aliased_local_criteria(self, order_item_fixture): + Order, Item = order_item_fixture + + i1 = aliased(Item) + stmt = select(Order).join( + Order.items.of_type(i1).and_(i1.description != "description") + ) + + self.assert_compile( + stmt, + "SELECT orders.id, orders.user_id, orders.address_id, " + "orders.description, orders.isopen " + "FROM orders JOIN order_items AS order_items_1 " + "ON orders.id = order_items_1.order_id " + "JOIN items AS items_1 ON items_1.id = order_items_1.item_id " + "AND items_1.description != :description_1", + ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index b573accbd..7aad2cab8 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1512,3 +1512,23 @@ class CompareClausesTest(fixtures.TestBase): is_true(x_p_a.compare(x_p)) is_true(x_p.compare(x_p_a)) is_false(x_p_a.compare(x_a)) + + +class ExecutableFlagsTest(fixtures.TestBase): + @testing.combinations( + (select(column("a")),), + (table("q", column("a")).insert(),), + (table("q", column("a")).update(),), + (table("q", column("a")).delete(),), + (lambda_stmt(lambda: select(column("a"))),), + ) + def test_is_select(self, case): + if isinstance(case, LambdaElement): + resolved_case = case._resolved + else: + resolved_case = case + + if isinstance(resolved_case, Select): + is_true(case.is_select) + else: + is_false(case.is_select) |