diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-05 21:47:43 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-05 22:13:11 -0400 |
commit | c7b489b25802f7a25ef78d0731411295c611cc1c (patch) | |
tree | f5e3b66ab8eb8bb7398c0195fa2b2f1de8ab91c4 /lib/sqlalchemy | |
parent | 71a3ccbdef0d88e9231b7de9c51e4ed60b3b7181 (diff) | |
download | sqlalchemy-c7b489b25802f7a25ef78d0731411295c611cc1c.tar.gz |
Implement relationship AND criteria; global loader criteria
Added the ability to add arbitrary criteria to the ON clause generated
by a relationship attribute in a query, which applies to methods such
as :meth:`_query.Query.join` as well as loader options like
:func:`_orm.joinedload`. Additionally, a "global" version of the option
allows limiting criteria to be applied to particular entities in
a query globally.
Documentation is minimal at this point, new examples will
be coming in a subsequent commit.
Some adjustments to execution options in how they are represented
in the ORMExecuteState as well as well as a few ORM tests that
forgot to get merged in a preceding commit.
Fixes: #4472
Change-Id: I2b8fc57092dedf35ebd16f6343ad0f0d7d332beb
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/ext/horizontal_shard.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 3 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 28 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/context.py | 107 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 14 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 48 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 101 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 40 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 12 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 184 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 |
12 files changed, 478 insertions, 94 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 53545826b..fe9bbaf02 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -207,7 +207,6 @@ class ShardedSession(Session): def execute_and_instances(orm_context): - if orm_context.is_select: load_options = active_options = orm_context.load_options update_options = None @@ -237,8 +236,8 @@ def execute_and_instances(orm_context): if active_options._refresh_identity_token is not None: shard_id = active_options._refresh_identity_token - elif "_sa_shard_id" in orm_context.merged_execution_options: - shard_id = orm_context.merged_execution_options["_sa_shard_id"] + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: shard_id = orm_context.bind_arguments["shard_id"] else: diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 32ec60322..458103838 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -48,6 +48,7 @@ from .strategy_options import Load # noqa from .util import aliased # noqa from .util import Bundle # noqa from .util import join # noqa +from .util import LoaderCriteriaOption # noqa from .util import object_mapper # noqa from .util import outerjoin # noqa from .util import polymorphic_union # noqa @@ -101,6 +102,8 @@ def create_session(bind=None, **kwargs): return Session(bind=bind, **kwargs) +with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm") + relationship = public_factory(RelationshipProperty, ".orm.relationship") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6dd95a5a9..2e1b9dc75 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -50,6 +50,7 @@ from .. import inspection from .. import util from ..sql import base as sql_base from ..sql import roles +from ..sql import traversals from ..sql import visitors @@ -58,6 +59,7 @@ class QueryableAttribute( interfaces._MappedAttribute, interfaces.InspectionAttr, interfaces.PropComparator, + traversals.HasCopyInternals, roles.JoinTargetRole, roles.OnClauseRole, sql_base.Immutable, @@ -91,6 +93,7 @@ class QueryableAttribute( impl=None, comparator=None, of_type=None, + extra_criteria=(), ): self.class_ = class_ self.key = key @@ -98,6 +101,7 @@ class QueryableAttribute( self.impl = impl self.comparator = comparator self._of_type = of_type + self._extra_criteria = extra_criteria manager = manager_of_class(class_) # manager is None in the case of AliasedClass @@ -114,6 +118,7 @@ class QueryableAttribute( ("key", visitors.ExtendedInternalTraversal.dp_string), ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ] def __reduce__(self): @@ -240,6 +245,29 @@ class QueryableAttribute( impl=self.impl, comparator=self.comparator.of_type(entity), of_type=inspection.inspect(entity), + extra_criteria=self._extra_criteria, + ) + + def and_(self, *other): + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator.and_(*other), + of_type=self._of_type, + extra_criteria=self._extra_criteria + other, + ) + + def _clone(self, **kw): + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator, + of_type=self._of_type, + extra_criteria=self._extra_criteria, ) def label(self, name): diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 96725e55b..a35b2f9fd 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -11,9 +11,9 @@ from .base import _is_aliased_class from .interfaces import ORMColumnsClauseRole from .path_registry import PathRegistry from .util import _entity_corresponds_to +from .util import _ORMJoin from .util import aliased from .util import Bundle -from .util import join as orm_join from .util import ORMAdapter from .. import exc as sa_exc from .. import future @@ -78,7 +78,6 @@ class QueryContext(object): _yield_per = None _refresh_state = None _lazy_loaded_from = None - _params = _EMPTY_DICT def __init__( self, @@ -308,6 +307,9 @@ class ORMFromStatementCompileState(ORMCompileState): multi_row_eager_loaders = False compound_eager_adapter = None + extra_criteria_entities = _EMPTY_DICT + eager_joins = _EMPTY_DICT + @classmethod def create_for_statement(cls, statement_container, compiler, **kw): @@ -338,6 +340,7 @@ class ORMFromStatementCompileState(ORMCompileState): if toplevel and statement_container._with_options: self.attributes = {"_unbound_load_dedupes": set()} + self.global_attributes = compiler._global_attributes for opt in statement_container._with_options: if opt._is_compile_state: @@ -345,6 +348,7 @@ class ORMFromStatementCompileState(ORMCompileState): else: self.attributes = {} + self.global_attributes = compiler._global_attributes if statement_container._with_context_options: for fn, key in statement_container._with_context_options: @@ -352,8 +356,6 @@ class ORMFromStatementCompileState(ORMCompileState): self.primary_columns = [] self.secondary_columns = [] - self.eager_joins = {} - self.single_inh_entities = {} self.create_eager_joins = [] self._fallback_from_clauses = [] @@ -423,11 +425,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def create_for_statement(cls, statement, compiler, **kw): """compiler hook, we arrive here from compiler.visit_select() only.""" + self = cls.__new__(cls) + if compiler is not None: toplevel = not compiler.stack compiler._rewrites_selected_columns = True + self.global_attributes = compiler._global_attributes else: toplevel = True + self.global_attributes = {} select_statement = statement @@ -437,8 +443,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): statement._compile_options ) - self = cls.__new__(cls) - if select_statement._execution_options: # execution options should not impact the compilation of a # query, and at the moment subqueryloader is putting some things @@ -516,7 +520,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.primary_columns = [] self.secondary_columns = [] self.eager_joins = {} - self.single_inh_entities = {} + self.extra_criteria_entities = {} self.create_eager_joins = [] self._fallback_from_clauses = [] @@ -634,7 +638,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState): if self.compile_options._enable_single_crit: - self._adjust_for_single_inheritance() + self._adjust_for_extra_criteria() if not self.primary_columns: if self.compile_options._only_load_props: @@ -1408,6 +1412,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState): left, right, onclause, prop, create_aliases, aliased_generation ) + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + if replace_from_obj_index is not None: # splice into an existing element in the # self._from_obj list @@ -1416,12 +1425,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.from_clauses = ( self.from_clauses[:replace_from_obj_index] + [ - orm_join( + _ORMJoin( left_clause, right, onclause, isouter=outerjoin, full=full, + _extra_criteria=extra_criteria, ) ] + self.from_clauses[replace_from_obj_index + 1 :] @@ -1440,8 +1450,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState): left_clause = left self.from_clauses = self.from_clauses + [ - orm_join( - left_clause, r_info, onclause, isouter=outerjoin, full=full + _ORMJoin( + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, ) ] @@ -1848,8 +1863,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState): or kwargs.get("group_by", False) ) - def _adjust_for_single_inheritance(self): - """Apply single-table-inheritance filtering. + def _get_extra_criteria(self, ext_info): + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in self.global_attributes: + return tuple( + ae._resolve_where_criteria(ext_info) + for ae in self.global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if ae.include_aliases or ae.entity is ext_info + ) + else: + return () + + def _adjust_for_extra_criteria(self): + """Apply extra criteria filtering. For all distinct single-table-inheritance mappers represented in the columns clause of this query, as well as the "select from entity", @@ -1857,38 +1887,50 @@ class ORMSelectCompileState(ORMCompileState, SelectState): clause of the given QueryContext such that only the appropriate subtypes are selected from the total results. + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + associated with the global context. + """ for fromclause in self.from_clauses: ext_info = fromclause._annotations.get("parententity", None) if ( ext_info - and ext_info.mapper._single_table_criterion is not None - and ext_info not in self.single_inh_entities + and ( + ext_info.mapper._single_table_criterion is not None + or ("additional_entity_criteria", ext_info.mapper) + in self.global_attributes + ) + and ext_info not in self.extra_criteria_entities ): - self.single_inh_entities[ext_info] = ( + self.extra_criteria_entities[ext_info] = ( ext_info, ext_info._adapter if ext_info.is_aliased_class else None, ) - search = set(self.single_inh_entities.values()) + search = set(self.extra_criteria_entities.values()) for (ext_info, adapter) in search: if ext_info in self._join_entities: continue + single_crit = ext_info.mapper._single_table_criterion + + additional_entity_criteria = self._get_extra_criteria(ext_info) + if single_crit is not None: + additional_entity_criteria += (single_crit,) + + current_adapter = self._get_current_adapter() + for crit in additional_entity_criteria: if adapter: - single_crit = adapter.traverse(single_crit) + crit = adapter.traverse(crit) - current_adapter = self._get_current_adapter() if current_adapter: - single_crit = sql_util._deep_annotate( - single_crit, {"_orm_adapt": True} - ) - single_crit = current_adapter(single_crit, False) - self._where_criteria += (single_crit,) + crit = sql_util._deep_annotate(crit, {"_orm_adapt": True}) + crit = current_adapter(crit, False) + self._where_criteria += (crit,) def _column_descriptions(query_or_select_stmt, compile_state=None): @@ -2205,9 +2247,13 @@ class _MapperEntity(_QueryEntity): adapter = self._get_entity_clauses(compile_state) single_table_crit = self.mapper._single_table_criterion - if single_table_crit is not None: + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): ext_info = self.entity_zero - compile_state.single_inh_entities[ext_info] = ( + compile_state.extra_criteria_entities[ext_info] = ( ext_info, ext_info._adapter if ext_info.is_aliased_class else None, ) @@ -2528,8 +2574,13 @@ class _ORMColumnEntity(_ColumnEntity): ezero = self.entity_zero single_table_crit = self.mapper._single_table_criterion - if single_table_crit is not None: - compile_state.single_inh_entities[ezero] = ( + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): + + compile_state.extra_criteria_entities[ezero] = ( ezero, ezero._adapter if ezero.is_aliased_class else None, ) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 4cf820ae3..068c85073 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -480,6 +480,32 @@ class PropComparator(operators.ColumnOperators): return self.operate(PropComparator.of_type_op, class_) + def and_(self, *criteria): + """Add additional criteria to the ON clause that's represented by this + relationship attribute. + + E.g.:: + + + stmt = select(User).join( + User.addresses.and_(Address.email_address != 'foo') + ) + + stmt = select(User).options( + joinedload(User.addresses.and_(Address.email_address != 'foo')) + ) + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`loader_option_criteria` + + :func:`.with_loader_criteria` + + """ + return self.operate(operators.and_, *criteria) + def any(self, criterion=None, **kwargs): r"""Return true if this collection contains any member that meets the given criterion. diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d60c03bdc..68ca0365b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1997,6 +1997,20 @@ class Query( filter(a1.email_address == 'ed@foo.com').\ filter(a2.email_address == 'ed@bar.com') + **Augmenting Built-in ON Clauses** + + As a substitute for providing a full custom ON condition for an + existing relationship, the :meth:`_orm.PropComparator.and_` function + may be applied to a relationship attribute to augment additional + criteria into the ON clause; the additional criteria will be combined + with the default criteria using AND:: + + q = session.query(User).join( + User.addresses.and_(Address.email_address != 'foo@bar.com') + ) + + .. versionadded:: 1.4 + **Joining to Tables and Subqueries** diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index cb490b7d7..794b9422c 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1115,9 +1115,15 @@ class RelationshipProperty(StrategizedProperty): """ _of_type = None + _extra_criteria = () def __init__( - self, prop, parentmapper, adapt_to_entity=None, of_type=None + self, + prop, + parentmapper, + adapt_to_entity=None, + of_type=None, + extra_criteria=(), ): """Construction of :class:`.RelationshipProperty.Comparator` is internal to the ORM's attribute mechanics. @@ -1128,6 +1134,7 @@ class RelationshipProperty(StrategizedProperty): self._adapt_to_entity = adapt_to_entity if of_type: self._of_type = of_type + self._extra_criteria = extra_criteria def adapt_to_entity(self, adapt_to_entity): return self.__class__( @@ -1191,6 +1198,7 @@ class RelationshipProperty(StrategizedProperty): source_polymorphic=True, of_type_entity=of_type_entity, alias_secondary=True, + extra_criteria=self._extra_criteria, ) if sj is not None: return pj & sj @@ -1202,12 +1210,30 @@ class RelationshipProperty(StrategizedProperty): See :meth:`.PropComparator.of_type` for an example. + """ return RelationshipProperty.Comparator( self.property, self._parententity, adapt_to_entity=self._adapt_to_entity, of_type=cls, + extra_criteria=self._extra_criteria, + ) + + def and_(self, *other): + """Add AND criteria. + + See :meth:`.PropComparator.and_` for an example. + + .. versionadded:: 1.4 + + """ + return RelationshipProperty.Comparator( + self.property, + self._parententity, + adapt_to_entity=self._adapt_to_entity, + of_type=self._of_type, + extra_criteria=self._extra_criteria + other, ) def in_(self, other): @@ -2439,6 +2465,7 @@ class RelationshipProperty(StrategizedProperty): dest_selectable=None, of_type_entity=None, alias_secondary=False, + extra_criteria=(), ): aliased = False @@ -2489,7 +2516,11 @@ class RelationshipProperty(StrategizedProperty): target_adapter, dest_selectable, ) = self._join_condition.join_targets( - source_selectable, dest_selectable, aliased, single_crit + source_selectable, + dest_selectable, + aliased, + single_crit, + extra_criteria, ) if source_selectable is None: source_selectable = self.parent.local_table @@ -3427,7 +3458,12 @@ class JoinCondition(object): ) def join_targets( - self, source_selectable, dest_selectable, aliased, single_crit=None + self, + source_selectable, + dest_selectable, + aliased, + single_crit=None, + extra_criteria=(), ): """Given a source and destination selectable, create a join between them. @@ -3463,6 +3499,12 @@ class JoinCondition(object): else: primaryjoin = primaryjoin & single_crit + if extra_criteria: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & sql.and_(*extra_criteria) + else: + primaryjoin = primaryjoin & sql.and_(*extra_criteria) + if aliased: if secondary is not None: secondary = secondary._anonymous_fromclause(flat=True) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 339c57bdc..e9d4ac2c6 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -102,23 +102,26 @@ CLOSED = util.symbol("CLOSED") class ORMExecuteState(util.MemoizedSlots): - """Stateful object used for the :meth:`.SessionEvents.do_orm_execute` + """Represents a call to the :meth:`_orm.Session.execute` method, as passed + to the :meth:`.SessionEvents.do_orm_execute` event hook. .. versionadded:: 1.4 + """ __slots__ = ( "session", "statement", "parameters", - "_execution_options", - "_merged_execution_options", + "execution_options", + "local_execution_options", "bind_arguments", "_compile_state_cls", "_starting_event_idx", "_events_todo", "_future", + "_update_execution_options", ) def __init__( @@ -135,7 +138,10 @@ class ORMExecuteState(util.MemoizedSlots): self.session = session self.statement = statement self.parameters = parameters - self._execution_options = execution_options + self.local_execution_options = execution_options + self.execution_options = statement._execution_options.union( + execution_options + ) self.bind_arguments = bind_arguments self._compile_state_cls = compile_state_cls self._events_todo = list(events_todo) @@ -182,9 +188,8 @@ class ORMExecuteState(util.MemoizedSlots): .. seealso:: - :ref:`examples_caching` - includes example use of the - :meth:`.SessionEvents.do_orm_execute` hook as well as the - :meth:`.ORMExecuteState.invoke_query` method. + :ref:`do_orm_execute_re_executing` - background and examples on the + appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`. """ @@ -203,11 +208,9 @@ class ORMExecuteState(util.MemoizedSlots): else: _params = self.parameters + _execution_options = self.local_execution_options if execution_options: - _execution_options = dict(self._execution_options) - _execution_options.update(execution_options) - else: - _execution_options = self._execution_options + _execution_options = _execution_options.union(execution_options) return self.session.execute( statement, @@ -255,42 +258,9 @@ class ORMExecuteState(util.MemoizedSlots): def _is_crud(self): return isinstance(self.statement, (dml.Update, dml.Delete)) - @property - def execution_options(self): - """Placeholder for execution options. - - Raises an informative message, as there are local options - vs. merged options that can be viewed, via the - :attr:`.ORMExecuteState.local_execution_options` and - :attr:`.ORMExecuteState.merged_execution_options` methods. - - - """ - raise AttributeError( - "Please use .local_execution_options or " - ".merged_execution_options" - ) - - @property - def local_execution_options(self): - """Dictionary view of the execution options passed to the - :meth:`.Session.execute` method. This does not include options - that may be associated with the statement being invoked. - - """ - return util.immutabledict(self._execution_options) - - @property - def merged_execution_options(self): - """Dictionary view of all execution options merged together; - this includes those of the statement as well as those passed to - :meth:`.Session.execute`, with the local options taking precedence. - - """ - return self._merged_execution_options - - def _memoized_attr__merged_execution_options(self): - return self.statement._execution_options.union(self._execution_options) + def update_execution_options(self, **opts): + # TODO: no coverage + self.local_execution_options = self.local_execution_options.union(opts) def _orm_compile_options(self): opts = self.statement._compile_options @@ -329,6 +299,20 @@ class ORMExecuteState(util.MemoizedSlots): return None @property + def is_relationship_load(self): + """Return True if this load is loading objects on behalf of a + relationship. + + This means, the loader in effect is either a LazyLoader, + SelectInLoader, SubqueryLoader, or similar, and the entire + SELECT statement being emitted is on behalf of a relationship + load. + + """ + path = self.loader_strategy_path + return path is not None and not path.is_root + + @property def load_options(self): """Return the load_options that will be used for this execution.""" @@ -337,7 +321,7 @@ class ORMExecuteState(util.MemoizedSlots): "This ORM execution is not against a SELECT statement " "so there are no load options." ) - return self._execution_options.get( + return self.execution_options.get( "_sa_orm_load_options", context.QueryContext.default_load_options ) @@ -351,7 +335,7 @@ class ORMExecuteState(util.MemoizedSlots): "This ORM execution is not against an UPDATE or DELETE " "statement so there are no update options." ) - return self._execution_options.get( + return self.execution_options.get( "_sa_orm_update_options", persistence.BulkUDCompileState.default_update_options, ) @@ -1003,8 +987,6 @@ class Session(_SessionClassMethods): :ref:`migration_20_toplevel` - :ref:`migration_20_result_rows` - :param info: optional dictionary of arbitrary data to be associated with this :class:`.Session`. Is available via the :attr:`.Session.info` attribute. Note the dictionary is copied at @@ -1282,7 +1264,7 @@ class Session(_SessionClassMethods): the operation will release the current SAVEPOINT but not commit the outermost database transaction. - If :term:`2.x-style` use is in effect via the + If :term:`2.0-style` use is in effect via the :paramref:`_orm.Session.future` flag, the outermost database transaction is committed unconditionally, automatically releasing any SAVEPOINTs in effect. @@ -1416,7 +1398,7 @@ class Session(_SessionClassMethods): self, statement, params=None, - execution_options=util.immutabledict(), + execution_options=util.EMPTY_DICT, bind_arguments=None, future=False, _parent_execute_state=None, @@ -1576,6 +1558,8 @@ class Session(_SessionClassMethods): else: compile_state_cls = None + execution_options = util.coerce_to_immutabledict(execution_options) + if compile_state_cls is not None: ( statement, @@ -1591,8 +1575,11 @@ class Session(_SessionClassMethods): else: bind_arguments.setdefault("clause", statement) if future: - execution_options = util.immutabledict().merge_with( - execution_options, {"future_result": True} + # not sure if immutabledict is working w/ this syntax + # execution_options = + # execution_options.union(future_result=True) + execution_options = execution_options.union( + {"future_result": True} ) if _parent_execute_state: @@ -1619,6 +1606,10 @@ class Session(_SessionClassMethods): if result: return result + # TODO: coverage for this pattern + statement = orm_exec_state.statement + execution_options = orm_exec_state.local_execution_options + bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind, close_with_result=True) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 44f303fee..53166bd91 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1975,6 +1975,7 @@ class JoinedLoader(AbstractRelationshipLoader): clauses, innerjoin, chained_from_outerjoin, + loadopt._extra_criteria if loadopt else (), ) ) @@ -1993,6 +1994,7 @@ class JoinedLoader(AbstractRelationshipLoader): clauses, innerjoin, chained_from_outerjoin, + extra_criteria, ): if parentmapper is None: localparent = query_entity.mapper @@ -2081,6 +2083,17 @@ class JoinedLoader(AbstractRelationshipLoader): or query_entity.entity_zero.represents_outer_join ) + extra_join_criteria = extra_criteria + additional_entity_criteria = compile_state.global_attributes.get( + ("additional_entity_criteria", self.mapper), () + ) + if additional_entity_criteria: + extra_join_criteria += tuple( + ae._resolve_where_criteria(self.mapper) + for ae in additional_entity_criteria + if ae.propagate_to_loaders + ) + if attach_on_outside: # this is the "classic" eager join case. eagerjoin = orm_util._ORMJoin( @@ -2092,11 +2105,12 @@ class JoinedLoader(AbstractRelationshipLoader): or (chained_from_outerjoin and isinstance(towrap, sql.Join)), _left_memo=self.parent, _right_memo=self.mapper, + _extra_criteria=extra_join_criteria, ) else: # all other cases are innerjoin=='nested' approach eagerjoin = self._splice_nested_inner_join( - path, towrap, clauses, onclause, + path, towrap, clauses, onclause, extra_join_criteria ) compile_state.eager_joins[query_entity_key] = eagerjoin @@ -2128,7 +2142,7 @@ class JoinedLoader(AbstractRelationshipLoader): ) def _splice_nested_inner_join( - self, path, join_obj, clauses, onclause, splicing=False + self, path, join_obj, clauses, onclause, extra_criteria, splicing=False ): if splicing is False: @@ -2137,7 +2151,12 @@ class JoinedLoader(AbstractRelationshipLoader): assert isinstance(join_obj, orm_util._ORMJoin) elif isinstance(join_obj, sql.selectable.FromGrouping): return self._splice_nested_inner_join( - path, join_obj.element, clauses, onclause, splicing, + path, + join_obj.element, + clauses, + onclause, + extra_criteria, + splicing, ) elif not isinstance(join_obj, orm_util._ORMJoin): if path[-2] is splicing: @@ -2148,18 +2167,29 @@ class JoinedLoader(AbstractRelationshipLoader): isouter=False, _left_memo=splicing, _right_memo=path[-1].mapper, + _extra_criteria=extra_criteria, ) else: # only here if splicing == True return None target_join = self._splice_nested_inner_join( - path, join_obj.right, clauses, onclause, join_obj._right_memo, + path, + join_obj.right, + clauses, + onclause, + extra_criteria, + join_obj._right_memo, ) if target_join is None: right_splice = False target_join = self._splice_nested_inner_join( - path, join_obj.left, clauses, onclause, join_obj._left_memo, + path, + join_obj.left, + clauses, + onclause, + extra_criteria, + join_obj._left_memo, ) if target_join is None: # should only return None when recursively called, diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index b405153b9..b3913ec5b 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -78,6 +78,7 @@ class Load(Generative, LoaderOption): ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ( "_context_cache_key", visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples, @@ -101,6 +102,7 @@ class Load(Generative, LoaderOption): load.context = {} load.local_opts = {} load._of_type = None + load._extra_criteria = () return load @property @@ -124,6 +126,7 @@ class Load(Generative, LoaderOption): strategy = None propagate_to_loaders = False _of_type = None + _extra_criteria = () def process_compile_state(self, compile_state): if not compile_state.compile_options._enable_eagerloads: @@ -248,6 +251,9 @@ class Load(Generative, LoaderOption): else: return None + if attr._extra_criteria: + self._extra_criteria = attr._extra_criteria + if getattr(attr, "_of_type", None): ac = attr._of_type ext_info = of_type_info = inspect(ac) @@ -356,6 +362,7 @@ class Load(Generative, LoaderOption): cloned = self._clone_for_bind_strategy(attr, strategy, "relationship") self.path = cloned.path self._of_type = cloned._of_type + self._extra_criteria = cloned._extra_criteria cloned.is_class_strategy = self.is_class_strategy = False self.propagate_to_loaders = cloned.propagate_to_loaders @@ -413,6 +420,7 @@ class Load(Generative, LoaderOption): if existing: if merge_opts: existing.local_opts.update(self.local_opts) + existing._extra_criteria += self._extra_criteria else: path.set(context, "loader", self) else: @@ -420,6 +428,7 @@ class Load(Generative, LoaderOption): path.set(context, "loader", self) if existing and existing.is_opts_only: self.local_opts.update(existing.local_opts) + existing._extra_criteria += self._extra_criteria def _set_path_strategy(self): if not self.is_class_strategy and self.path.has_entity: @@ -507,11 +516,13 @@ class _UnboundLoad(Load): self.path = () self._to_bind = [] self.local_opts = {} + self._extra_criteria = () _cache_key_traversal = [ ("path", visitors.ExtendedInternalTraversal.dp_multi_list), ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict), ] @@ -576,6 +587,7 @@ class _UnboundLoad(Load): if attr: path = path + (attr,) self.path = path + self._extra_criteria = getattr(attr, "_extra_criteria", ()) return path diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 71ee29597..82fad0815 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -23,6 +23,7 @@ from .base import object_state # noqa from .base import state_attribute_str # noqa from .base import state_class_str # noqa from .base import state_str # noqa +from .interfaces import LoaderOption from .interfaces import MapperProperty # noqa from .interfaces import ORMColumnsClauseRole from .interfaces import ORMEntityColumnsClauseRole @@ -38,6 +39,7 @@ from ..engine.result import result_tuple from ..sql import base as sql_base from ..sql import coercions from ..sql import expression +from ..sql import lambdas from ..sql import roles from ..sql import util as sql_util from ..sql import visitors @@ -854,6 +856,184 @@ class AliasedInsp( return "aliased(%s)" % (self._target.__name__,) +class LoaderCriteriaOption(LoaderOption): + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + :class:`_orm.LoaderCriteriaOption` is invoked using the + :func:`_orm.with_loader_criteria` function; see that function for + details. + + .. versionadded:: 1.4 + + """ + + _traverse_internals = [ + ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("where_criteria", visitors.InternalTraversal.dp_clauseelement), + ("include_aliases", visitors.InternalTraversal.dp_boolean), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), + ] + + def __init__( + self, + entity_or_base, + where_criteria, + loader_only=False, + include_aliases=False, + propagate_to_loaders=True, + ): + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + .. versionadded:: 1.4 + + The :func:`_orm.with_loader_criteria` option is intended to add + limiting criteria to a particular kind of entity in a query, + **globally**, meaning it will apply to the entity as it appears + in the SELECT query as well as within any subqueries, join + conditions, and relationship loads, including both eager and lazy + loaders, without the need for it to be specified in any particular + part of the query. The rendering logic uses the same system used by + single table inheritance to ensure a certain discriminator is applied + to a table. + + E.g., using :term:`2.0-style` queries, we can limit the way the + ``User.addresses`` collection is loaded, regardless of the kind + of loading used:: + + from sqlalchemy.orm import with_loader_criteria + + stmt = select(User).options( + selectinload(User.addresses), + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + Above, the "selectinload" for ``User.addresses`` will apply the + given filtering criteria to the WHERE clause. + + Another example, where the filtering will be applied to the + ON clause of the join, in this example using :term:`1.x style` + queries:: + + q = session.query(User).outerjoin(User.addresses).options( + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + The primary purpose of :func:`_orm.with_loader_criteria` is to use + it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler + to ensure that all occurrences of a particular entity are filtered + in a certain way, such as filtering for access control roles. It + also can be used to apply criteria to relationship loads. In the + example below, we can apply a certain set of rules to all queries + emitted by a particular :class:`_orm.Session`:: + + session = Session(bind=engine) + + @event.listens_for("do_orm_execute", session) + def _add_filtering_criteria(execute_state): + execute_state.statement = execute_state.statement.options( + with_loader_criteria( + SecurityRole, + lambda cls: cls.role.in_(['some_role']), + include_aliases=True + ) + ) + + The given class will expand to include all mapped subclass and + need not itself be a mapped class. + + + :param entity_or_base: a mapped class, or a class that is a super + class of a particular set of mapped classes, to which the rule + will apply. + + :param where_criteria: a Core SQL expression that applies limiting + criteria. This may also be a "lambda:" or Python function that + accepts a target class as an argument, when the given class is + a base with many different mapped subclasses. + + :param include_aliases: if True, apply the rule to :func:`_orm.aliased` + constructs as well. + + :param propagate_to_loaders: defaults to True, apply to relationship + loaders such as lazy loaders. + + + .. seealso:: + + :ref:`examples_session_orm_events` - includes examples of using + :func:`_orm.with_loader_criteria`. + + :ref:`do_orm_execute_global_criteria` - basic example on how to + combine :func:`_orm.with_loader_criteria` with the + :meth:`_orm.SessionEvents.do_orm_execute` event. + + """ + entity = inspection.inspect(entity_or_base, False) + if entity is None: + self.root_entity = entity_or_base + self.entity = None + else: + self.root_entity = None + self.entity = entity + + if callable(where_criteria): + self.deferred_where_criteria = True + self.where_criteria = lambdas.DeferredLambdaElement( + where_criteria, + roles.WhereHavingRole, + lambda_args=( + self.root_entity + if self.root_entity is not None + else self.entity.entity, + ), + ) + else: + self.deferred_where_criteria = False + self.where_criteria = coercions.expect( + roles.WhereHavingRole, where_criteria + ) + + self.include_aliases = include_aliases + self.propagate_to_loaders = propagate_to_loaders + + def _all_mappers(self): + if self.entity: + for ent in self.entity.mapper.self_and_descendants: + yield ent + else: + stack = list(self.root_entity.__subclasses__()) + while stack: + subclass = stack.pop(0) + ent = inspection.inspect(subclass) + if ent: + for mp in ent.mapper.self_and_descendants: + yield mp + else: + stack.extend(subclass.__subclasses__()) + + def _resolve_where_criteria(self, ext_info): + if self.deferred_where_criteria: + return self.where_criteria._resolve_with_args(ext_info.entity) + else: + return self.where_criteria + + def process_compile_state(self, compile_state): + """Apply a modification to a given :class:`.CompileState`.""" + + # if options to limit the criteria to immediate query only, + # use compile_state.attributes instead + + for mp in self._all_mappers(): + load_criteria = compile_state.global_attributes.setdefault( + ("additional_entity_criteria", mp), [] + ) + + load_criteria.append(self) + + inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) inspection._inspects(AliasedInsp)(lambda target: target) @@ -1270,6 +1450,7 @@ class _ORMJoin(expression.Join): full=False, _left_memo=None, _right_memo=None, + _extra_criteria=(), ): left_info = inspection.inspect(left) @@ -1291,6 +1472,7 @@ class _ORMJoin(expression.Join): if isinstance(onclause, attributes.QueryableAttribute): on_selectable = onclause.comparator._source_selectable() prop = onclause.property + _extra_criteria += onclause._extra_criteria elif isinstance(onclause, MapperProperty): # used internally by joined eager loader...possibly not ideal prop = onclause @@ -1319,6 +1501,7 @@ class _ORMJoin(expression.Join): source_polymorphic=True, of_type_entity=right_info, alias_secondary=True, + extra_criteria=_extra_criteria, ) if sj is not None: @@ -1331,6 +1514,7 @@ class _ORMJoin(expression.Join): onclause = sj else: onclause = pj + self._target_adapter = target_adapter expression.Join.__init__(self, left, right, onclause, isouter, full) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ac4055bdf..b8984316c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -792,6 +792,10 @@ class SQLCompiler(Compiled): def prefetch(self): return list(self.insert_prefetch + self.update_prefetch) + @util.memoized_property + def _global_attributes(self): + return {} + @util.memoized_instancemethod def _init_cte_state(self): """Initialize collections related to CTEs only if |