summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-08-05 21:47:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-08-05 22:13:11 -0400
commitc7b489b25802f7a25ef78d0731411295c611cc1c (patch)
treef5e3b66ab8eb8bb7398c0195fa2b2f1de8ab91c4 /lib/sqlalchemy
parent71a3ccbdef0d88e9231b7de9c51e4ed60b3b7181 (diff)
downloadsqlalchemy-c7b489b25802f7a25ef78d0731411295c611cc1c.tar.gz
Implement relationship AND criteria; global loader criteria
Added the ability to add arbitrary criteria to the ON clause generated by a relationship attribute in a query, which applies to methods such as :meth:`_query.Query.join` as well as loader options like :func:`_orm.joinedload`. Additionally, a "global" version of the option allows limiting criteria to be applied to particular entities in a query globally. Documentation is minimal at this point, new examples will be coming in a subsequent commit. Some adjustments to execution options in how they are represented in the ORMExecuteState as well as well as a few ORM tests that forgot to get merged in a preceding commit. Fixes: #4472 Change-Id: I2b8fc57092dedf35ebd16f6343ad0f0d7d332beb
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py5
-rw-r--r--lib/sqlalchemy/orm/__init__.py3
-rw-r--r--lib/sqlalchemy/orm/attributes.py28
-rw-r--r--lib/sqlalchemy/orm/context.py107
-rw-r--r--lib/sqlalchemy/orm/interfaces.py26
-rw-r--r--lib/sqlalchemy/orm/query.py14
-rw-r--r--lib/sqlalchemy/orm/relationships.py48
-rw-r--r--lib/sqlalchemy/orm/session.py101
-rw-r--r--lib/sqlalchemy/orm/strategies.py40
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py12
-rw-r--r--lib/sqlalchemy/orm/util.py184
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
12 files changed, 478 insertions, 94 deletions
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 53545826b..fe9bbaf02 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -207,7 +207,6 @@ class ShardedSession(Session):
def execute_and_instances(orm_context):
-
if orm_context.is_select:
load_options = active_options = orm_context.load_options
update_options = None
@@ -237,8 +236,8 @@ def execute_and_instances(orm_context):
if active_options._refresh_identity_token is not None:
shard_id = active_options._refresh_identity_token
- elif "_sa_shard_id" in orm_context.merged_execution_options:
- shard_id = orm_context.merged_execution_options["_sa_shard_id"]
+ elif "_sa_shard_id" in orm_context.execution_options:
+ shard_id = orm_context.execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
shard_id = orm_context.bind_arguments["shard_id"]
else:
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 32ec60322..458103838 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -48,6 +48,7 @@ from .strategy_options import Load # noqa
from .util import aliased # noqa
from .util import Bundle # noqa
from .util import join # noqa
+from .util import LoaderCriteriaOption # noqa
from .util import object_mapper # noqa
from .util import outerjoin # noqa
from .util import polymorphic_union # noqa
@@ -101,6 +102,8 @@ def create_session(bind=None, **kwargs):
return Session(bind=bind, **kwargs)
+with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm")
+
relationship = public_factory(RelationshipProperty, ".orm.relationship")
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 6dd95a5a9..2e1b9dc75 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -50,6 +50,7 @@ from .. import inspection
from .. import util
from ..sql import base as sql_base
from ..sql import roles
+from ..sql import traversals
from ..sql import visitors
@@ -58,6 +59,7 @@ class QueryableAttribute(
interfaces._MappedAttribute,
interfaces.InspectionAttr,
interfaces.PropComparator,
+ traversals.HasCopyInternals,
roles.JoinTargetRole,
roles.OnClauseRole,
sql_base.Immutable,
@@ -91,6 +93,7 @@ class QueryableAttribute(
impl=None,
comparator=None,
of_type=None,
+ extra_criteria=(),
):
self.class_ = class_
self.key = key
@@ -98,6 +101,7 @@ class QueryableAttribute(
self.impl = impl
self.comparator = comparator
self._of_type = of_type
+ self._extra_criteria = extra_criteria
manager = manager_of_class(class_)
# manager is None in the case of AliasedClass
@@ -114,6 +118,7 @@ class QueryableAttribute(
("key", visitors.ExtendedInternalTraversal.dp_string),
("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
]
def __reduce__(self):
@@ -240,6 +245,29 @@ class QueryableAttribute(
impl=self.impl,
comparator=self.comparator.of_type(entity),
of_type=inspection.inspect(entity),
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.and_(*other),
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
+ )
+
+ def _clone(self, **kw):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria,
)
def label(self, name):
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 96725e55b..a35b2f9fd 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -11,9 +11,9 @@ from .base import _is_aliased_class
from .interfaces import ORMColumnsClauseRole
from .path_registry import PathRegistry
from .util import _entity_corresponds_to
+from .util import _ORMJoin
from .util import aliased
from .util import Bundle
-from .util import join as orm_join
from .util import ORMAdapter
from .. import exc as sa_exc
from .. import future
@@ -78,7 +78,6 @@ class QueryContext(object):
_yield_per = None
_refresh_state = None
_lazy_loaded_from = None
- _params = _EMPTY_DICT
def __init__(
self,
@@ -308,6 +307,9 @@ class ORMFromStatementCompileState(ORMCompileState):
multi_row_eager_loaders = False
compound_eager_adapter = None
+ extra_criteria_entities = _EMPTY_DICT
+ eager_joins = _EMPTY_DICT
+
@classmethod
def create_for_statement(cls, statement_container, compiler, **kw):
@@ -338,6 +340,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if toplevel and statement_container._with_options:
self.attributes = {"_unbound_load_dedupes": set()}
+ self.global_attributes = compiler._global_attributes
for opt in statement_container._with_options:
if opt._is_compile_state:
@@ -345,6 +348,7 @@ class ORMFromStatementCompileState(ORMCompileState):
else:
self.attributes = {}
+ self.global_attributes = compiler._global_attributes
if statement_container._with_context_options:
for fn, key in statement_container._with_context_options:
@@ -352,8 +356,6 @@ class ORMFromStatementCompileState(ORMCompileState):
self.primary_columns = []
self.secondary_columns = []
- self.eager_joins = {}
- self.single_inh_entities = {}
self.create_eager_joins = []
self._fallback_from_clauses = []
@@ -423,11 +425,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def create_for_statement(cls, statement, compiler, **kw):
"""compiler hook, we arrive here from compiler.visit_select() only."""
+ self = cls.__new__(cls)
+
if compiler is not None:
toplevel = not compiler.stack
compiler._rewrites_selected_columns = True
+ self.global_attributes = compiler._global_attributes
else:
toplevel = True
+ self.global_attributes = {}
select_statement = statement
@@ -437,8 +443,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement._compile_options
)
- self = cls.__new__(cls)
-
if select_statement._execution_options:
# execution options should not impact the compilation of a
# query, and at the moment subqueryloader is putting some things
@@ -516,7 +520,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.primary_columns = []
self.secondary_columns = []
self.eager_joins = {}
- self.single_inh_entities = {}
+ self.extra_criteria_entities = {}
self.create_eager_joins = []
self._fallback_from_clauses = []
@@ -634,7 +638,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if self.compile_options._enable_single_crit:
- self._adjust_for_single_inheritance()
+ self._adjust_for_extra_criteria()
if not self.primary_columns:
if self.compile_options._only_load_props:
@@ -1408,6 +1412,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left, right, onclause, prop, create_aliases, aliased_generation
)
+ if not r_info.is_selectable:
+ extra_criteria = self._get_extra_criteria(r_info)
+ else:
+ extra_criteria = ()
+
if replace_from_obj_index is not None:
# splice into an existing element in the
# self._from_obj list
@@ -1416,12 +1425,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.from_clauses = (
self.from_clauses[:replace_from_obj_index]
+ [
- orm_join(
+ _ORMJoin(
left_clause,
right,
onclause,
isouter=outerjoin,
full=full,
+ _extra_criteria=extra_criteria,
)
]
+ self.from_clauses[replace_from_obj_index + 1 :]
@@ -1440,8 +1450,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left_clause = left
self.from_clauses = self.from_clauses + [
- orm_join(
- left_clause, r_info, onclause, isouter=outerjoin, full=full
+ _ORMJoin(
+ left_clause,
+ r_info,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
)
]
@@ -1848,8 +1863,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
or kwargs.get("group_by", False)
)
- def _adjust_for_single_inheritance(self):
- """Apply single-table-inheritance filtering.
+ def _get_extra_criteria(self, ext_info):
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in self.global_attributes:
+ return tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in self.global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+ else:
+ return ()
+
+ def _adjust_for_extra_criteria(self):
+ """Apply extra criteria filtering.
For all distinct single-table-inheritance mappers represented in
the columns clause of this query, as well as the "select from entity",
@@ -1857,38 +1887,50 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
clause of the given QueryContext such that only the appropriate
subtypes are selected from the total results.
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ associated with the global context.
+
"""
for fromclause in self.from_clauses:
ext_info = fromclause._annotations.get("parententity", None)
if (
ext_info
- and ext_info.mapper._single_table_criterion is not None
- and ext_info not in self.single_inh_entities
+ and (
+ ext_info.mapper._single_table_criterion is not None
+ or ("additional_entity_criteria", ext_info.mapper)
+ in self.global_attributes
+ )
+ and ext_info not in self.extra_criteria_entities
):
- self.single_inh_entities[ext_info] = (
+ self.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
)
- search = set(self.single_inh_entities.values())
+ search = set(self.extra_criteria_entities.values())
for (ext_info, adapter) in search:
if ext_info in self._join_entities:
continue
+
single_crit = ext_info.mapper._single_table_criterion
+
+ additional_entity_criteria = self._get_extra_criteria(ext_info)
+
if single_crit is not None:
+ additional_entity_criteria += (single_crit,)
+
+ current_adapter = self._get_current_adapter()
+ for crit in additional_entity_criteria:
if adapter:
- single_crit = adapter.traverse(single_crit)
+ crit = adapter.traverse(crit)
- current_adapter = self._get_current_adapter()
if current_adapter:
- single_crit = sql_util._deep_annotate(
- single_crit, {"_orm_adapt": True}
- )
- single_crit = current_adapter(single_crit, False)
- self._where_criteria += (single_crit,)
+ crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+ crit = current_adapter(crit, False)
+ self._where_criteria += (crit,)
def _column_descriptions(query_or_select_stmt, compile_state=None):
@@ -2205,9 +2247,13 @@ class _MapperEntity(_QueryEntity):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
- if single_table_crit is not None:
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
ext_info = self.entity_zero
- compile_state.single_inh_entities[ext_info] = (
+ compile_state.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
)
@@ -2528,8 +2574,13 @@ class _ORMColumnEntity(_ColumnEntity):
ezero = self.entity_zero
single_table_crit = self.mapper._single_table_criterion
- if single_table_crit is not None:
- compile_state.single_inh_entities[ezero] = (
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+
+ compile_state.extra_criteria_entities[ezero] = (
ezero,
ezero._adapter if ezero.is_aliased_class else None,
)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 4cf820ae3..068c85073 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -480,6 +480,32 @@ class PropComparator(operators.ColumnOperators):
return self.operate(PropComparator.of_type_op, class_)
+ def and_(self, *criteria):
+ """Add additional criteria to the ON clause that's represented by this
+ relationship attribute.
+
+ E.g.::
+
+
+ stmt = select(User).join(
+ User.addresses.and_(Address.email_address != 'foo')
+ )
+
+ stmt = select(User).options(
+ joinedload(User.addresses.and_(Address.email_address != 'foo'))
+ )
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`loader_option_criteria`
+
+ :func:`.with_loader_criteria`
+
+ """
+ return self.operate(operators.and_, *criteria)
+
def any(self, criterion=None, **kwargs):
r"""Return true if this collection contains any member that meets the
given criterion.
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d60c03bdc..68ca0365b 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1997,6 +1997,20 @@ class Query(
filter(a1.email_address == 'ed@foo.com').\
filter(a2.email_address == 'ed@bar.com')
+ **Augmenting Built-in ON Clauses**
+
+ As a substitute for providing a full custom ON condition for an
+ existing relationship, the :meth:`_orm.PropComparator.and_` function
+ may be applied to a relationship attribute to augment additional
+ criteria into the ON clause; the additional criteria will be combined
+ with the default criteria using AND::
+
+ q = session.query(User).join(
+ User.addresses.and_(Address.email_address != 'foo@bar.com')
+ )
+
+ .. versionadded:: 1.4
+
**Joining to Tables and Subqueries**
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index cb490b7d7..794b9422c 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -1115,9 +1115,15 @@ class RelationshipProperty(StrategizedProperty):
"""
_of_type = None
+ _extra_criteria = ()
def __init__(
- self, prop, parentmapper, adapt_to_entity=None, of_type=None
+ self,
+ prop,
+ parentmapper,
+ adapt_to_entity=None,
+ of_type=None,
+ extra_criteria=(),
):
"""Construction of :class:`.RelationshipProperty.Comparator`
is internal to the ORM's attribute mechanics.
@@ -1128,6 +1134,7 @@ class RelationshipProperty(StrategizedProperty):
self._adapt_to_entity = adapt_to_entity
if of_type:
self._of_type = of_type
+ self._extra_criteria = extra_criteria
def adapt_to_entity(self, adapt_to_entity):
return self.__class__(
@@ -1191,6 +1198,7 @@ class RelationshipProperty(StrategizedProperty):
source_polymorphic=True,
of_type_entity=of_type_entity,
alias_secondary=True,
+ extra_criteria=self._extra_criteria,
)
if sj is not None:
return pj & sj
@@ -1202,12 +1210,30 @@ class RelationshipProperty(StrategizedProperty):
See :meth:`.PropComparator.of_type` for an example.
+
"""
return RelationshipProperty.Comparator(
self.property,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
of_type=cls,
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ """Add AND criteria.
+
+ See :meth:`.PropComparator.and_` for an example.
+
+ .. versionadded:: 1.4
+
+ """
+ return RelationshipProperty.Comparator(
+ self.property,
+ self._parententity,
+ adapt_to_entity=self._adapt_to_entity,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
)
def in_(self, other):
@@ -2439,6 +2465,7 @@ class RelationshipProperty(StrategizedProperty):
dest_selectable=None,
of_type_entity=None,
alias_secondary=False,
+ extra_criteria=(),
):
aliased = False
@@ -2489,7 +2516,11 @@ class RelationshipProperty(StrategizedProperty):
target_adapter,
dest_selectable,
) = self._join_condition.join_targets(
- source_selectable, dest_selectable, aliased, single_crit
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit,
+ extra_criteria,
)
if source_selectable is None:
source_selectable = self.parent.local_table
@@ -3427,7 +3458,12 @@ class JoinCondition(object):
)
def join_targets(
- self, source_selectable, dest_selectable, aliased, single_crit=None
+ self,
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit=None,
+ extra_criteria=(),
):
"""Given a source and destination selectable, create a
join between them.
@@ -3463,6 +3499,12 @@ class JoinCondition(object):
else:
primaryjoin = primaryjoin & single_crit
+ if extra_criteria:
+ if secondaryjoin is not None:
+ secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
+ else:
+ primaryjoin = primaryjoin & sql.and_(*extra_criteria)
+
if aliased:
if secondary is not None:
secondary = secondary._anonymous_fromclause(flat=True)
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 339c57bdc..e9d4ac2c6 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -102,23 +102,26 @@ CLOSED = util.symbol("CLOSED")
class ORMExecuteState(util.MemoizedSlots):
- """Stateful object used for the :meth:`.SessionEvents.do_orm_execute`
+ """Represents a call to the :meth:`_orm.Session.execute` method, as passed
+ to the :meth:`.SessionEvents.do_orm_execute` event hook.
.. versionadded:: 1.4
+
"""
__slots__ = (
"session",
"statement",
"parameters",
- "_execution_options",
- "_merged_execution_options",
+ "execution_options",
+ "local_execution_options",
"bind_arguments",
"_compile_state_cls",
"_starting_event_idx",
"_events_todo",
"_future",
+ "_update_execution_options",
)
def __init__(
@@ -135,7 +138,10 @@ class ORMExecuteState(util.MemoizedSlots):
self.session = session
self.statement = statement
self.parameters = parameters
- self._execution_options = execution_options
+ self.local_execution_options = execution_options
+ self.execution_options = statement._execution_options.union(
+ execution_options
+ )
self.bind_arguments = bind_arguments
self._compile_state_cls = compile_state_cls
self._events_todo = list(events_todo)
@@ -182,9 +188,8 @@ class ORMExecuteState(util.MemoizedSlots):
.. seealso::
- :ref:`examples_caching` - includes example use of the
- :meth:`.SessionEvents.do_orm_execute` hook as well as the
- :meth:`.ORMExecuteState.invoke_query` method.
+ :ref:`do_orm_execute_re_executing` - background and examples on the
+ appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`.
"""
@@ -203,11 +208,9 @@ class ORMExecuteState(util.MemoizedSlots):
else:
_params = self.parameters
+ _execution_options = self.local_execution_options
if execution_options:
- _execution_options = dict(self._execution_options)
- _execution_options.update(execution_options)
- else:
- _execution_options = self._execution_options
+ _execution_options = _execution_options.union(execution_options)
return self.session.execute(
statement,
@@ -255,42 +258,9 @@ class ORMExecuteState(util.MemoizedSlots):
def _is_crud(self):
return isinstance(self.statement, (dml.Update, dml.Delete))
- @property
- def execution_options(self):
- """Placeholder for execution options.
-
- Raises an informative message, as there are local options
- vs. merged options that can be viewed, via the
- :attr:`.ORMExecuteState.local_execution_options` and
- :attr:`.ORMExecuteState.merged_execution_options` methods.
-
-
- """
- raise AttributeError(
- "Please use .local_execution_options or "
- ".merged_execution_options"
- )
-
- @property
- def local_execution_options(self):
- """Dictionary view of the execution options passed to the
- :meth:`.Session.execute` method. This does not include options
- that may be associated with the statement being invoked.
-
- """
- return util.immutabledict(self._execution_options)
-
- @property
- def merged_execution_options(self):
- """Dictionary view of all execution options merged together;
- this includes those of the statement as well as those passed to
- :meth:`.Session.execute`, with the local options taking precedence.
-
- """
- return self._merged_execution_options
-
- def _memoized_attr__merged_execution_options(self):
- return self.statement._execution_options.union(self._execution_options)
+ def update_execution_options(self, **opts):
+ # TODO: no coverage
+ self.local_execution_options = self.local_execution_options.union(opts)
def _orm_compile_options(self):
opts = self.statement._compile_options
@@ -329,6 +299,20 @@ class ORMExecuteState(util.MemoizedSlots):
return None
@property
+ def is_relationship_load(self):
+ """Return True if this load is loading objects on behalf of a
+ relationship.
+
+ This means, the loader in effect is either a LazyLoader,
+ SelectInLoader, SubqueryLoader, or similar, and the entire
+ SELECT statement being emitted is on behalf of a relationship
+ load.
+
+ """
+ path = self.loader_strategy_path
+ return path is not None and not path.is_root
+
+ @property
def load_options(self):
"""Return the load_options that will be used for this execution."""
@@ -337,7 +321,7 @@ class ORMExecuteState(util.MemoizedSlots):
"This ORM execution is not against a SELECT statement "
"so there are no load options."
)
- return self._execution_options.get(
+ return self.execution_options.get(
"_sa_orm_load_options", context.QueryContext.default_load_options
)
@@ -351,7 +335,7 @@ class ORMExecuteState(util.MemoizedSlots):
"This ORM execution is not against an UPDATE or DELETE "
"statement so there are no update options."
)
- return self._execution_options.get(
+ return self.execution_options.get(
"_sa_orm_update_options",
persistence.BulkUDCompileState.default_update_options,
)
@@ -1003,8 +987,6 @@ class Session(_SessionClassMethods):
:ref:`migration_20_toplevel`
- :ref:`migration_20_result_rows`
-
:param info: optional dictionary of arbitrary data to be associated
with this :class:`.Session`. Is available via the
:attr:`.Session.info` attribute. Note the dictionary is copied at
@@ -1282,7 +1264,7 @@ class Session(_SessionClassMethods):
the operation will release the current SAVEPOINT but not commit
the outermost database transaction.
- If :term:`2.x-style` use is in effect via the
+ If :term:`2.0-style` use is in effect via the
:paramref:`_orm.Session.future` flag, the outermost database
transaction is committed unconditionally, automatically releasing any
SAVEPOINTs in effect.
@@ -1416,7 +1398,7 @@ class Session(_SessionClassMethods):
self,
statement,
params=None,
- execution_options=util.immutabledict(),
+ execution_options=util.EMPTY_DICT,
bind_arguments=None,
future=False,
_parent_execute_state=None,
@@ -1576,6 +1558,8 @@ class Session(_SessionClassMethods):
else:
compile_state_cls = None
+ execution_options = util.coerce_to_immutabledict(execution_options)
+
if compile_state_cls is not None:
(
statement,
@@ -1591,8 +1575,11 @@ class Session(_SessionClassMethods):
else:
bind_arguments.setdefault("clause", statement)
if future:
- execution_options = util.immutabledict().merge_with(
- execution_options, {"future_result": True}
+ # not sure if immutabledict is working w/ this syntax
+ # execution_options =
+ # execution_options.union(future_result=True)
+ execution_options = execution_options.union(
+ {"future_result": True}
)
if _parent_execute_state:
@@ -1619,6 +1606,10 @@ class Session(_SessionClassMethods):
if result:
return result
+ # TODO: coverage for this pattern
+ statement = orm_exec_state.statement
+ execution_options = orm_exec_state.local_execution_options
+
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind, close_with_result=True)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 44f303fee..53166bd91 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -1975,6 +1975,7 @@ class JoinedLoader(AbstractRelationshipLoader):
clauses,
innerjoin,
chained_from_outerjoin,
+ loadopt._extra_criteria if loadopt else (),
)
)
@@ -1993,6 +1994,7 @@ class JoinedLoader(AbstractRelationshipLoader):
clauses,
innerjoin,
chained_from_outerjoin,
+ extra_criteria,
):
if parentmapper is None:
localparent = query_entity.mapper
@@ -2081,6 +2083,17 @@ class JoinedLoader(AbstractRelationshipLoader):
or query_entity.entity_zero.represents_outer_join
)
+ extra_join_criteria = extra_criteria
+ additional_entity_criteria = compile_state.global_attributes.get(
+ ("additional_entity_criteria", self.mapper), ()
+ )
+ if additional_entity_criteria:
+ extra_join_criteria += tuple(
+ ae._resolve_where_criteria(self.mapper)
+ for ae in additional_entity_criteria
+ if ae.propagate_to_loaders
+ )
+
if attach_on_outside:
# this is the "classic" eager join case.
eagerjoin = orm_util._ORMJoin(
@@ -2092,11 +2105,12 @@ class JoinedLoader(AbstractRelationshipLoader):
or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
_left_memo=self.parent,
_right_memo=self.mapper,
+ _extra_criteria=extra_join_criteria,
)
else:
# all other cases are innerjoin=='nested' approach
eagerjoin = self._splice_nested_inner_join(
- path, towrap, clauses, onclause,
+ path, towrap, clauses, onclause, extra_join_criteria
)
compile_state.eager_joins[query_entity_key] = eagerjoin
@@ -2128,7 +2142,7 @@ class JoinedLoader(AbstractRelationshipLoader):
)
def _splice_nested_inner_join(
- self, path, join_obj, clauses, onclause, splicing=False
+ self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
):
if splicing is False:
@@ -2137,7 +2151,12 @@ class JoinedLoader(AbstractRelationshipLoader):
assert isinstance(join_obj, orm_util._ORMJoin)
elif isinstance(join_obj, sql.selectable.FromGrouping):
return self._splice_nested_inner_join(
- path, join_obj.element, clauses, onclause, splicing,
+ path,
+ join_obj.element,
+ clauses,
+ onclause,
+ extra_criteria,
+ splicing,
)
elif not isinstance(join_obj, orm_util._ORMJoin):
if path[-2] is splicing:
@@ -2148,18 +2167,29 @@ class JoinedLoader(AbstractRelationshipLoader):
isouter=False,
_left_memo=splicing,
_right_memo=path[-1].mapper,
+ _extra_criteria=extra_criteria,
)
else:
# only here if splicing == True
return None
target_join = self._splice_nested_inner_join(
- path, join_obj.right, clauses, onclause, join_obj._right_memo,
+ path,
+ join_obj.right,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._right_memo,
)
if target_join is None:
right_splice = False
target_join = self._splice_nested_inner_join(
- path, join_obj.left, clauses, onclause, join_obj._left_memo,
+ path,
+ join_obj.left,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._left_memo,
)
if target_join is None:
# should only return None when recursively called,
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index b405153b9..b3913ec5b 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -78,6 +78,7 @@ class Load(Generative, LoaderOption):
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
(
"_context_cache_key",
visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
@@ -101,6 +102,7 @@ class Load(Generative, LoaderOption):
load.context = {}
load.local_opts = {}
load._of_type = None
+ load._extra_criteria = ()
return load
@property
@@ -124,6 +126,7 @@ class Load(Generative, LoaderOption):
strategy = None
propagate_to_loaders = False
_of_type = None
+ _extra_criteria = ()
def process_compile_state(self, compile_state):
if not compile_state.compile_options._enable_eagerloads:
@@ -248,6 +251,9 @@ class Load(Generative, LoaderOption):
else:
return None
+ if attr._extra_criteria:
+ self._extra_criteria = attr._extra_criteria
+
if getattr(attr, "_of_type", None):
ac = attr._of_type
ext_info = of_type_info = inspect(ac)
@@ -356,6 +362,7 @@ class Load(Generative, LoaderOption):
cloned = self._clone_for_bind_strategy(attr, strategy, "relationship")
self.path = cloned.path
self._of_type = cloned._of_type
+ self._extra_criteria = cloned._extra_criteria
cloned.is_class_strategy = self.is_class_strategy = False
self.propagate_to_loaders = cloned.propagate_to_loaders
@@ -413,6 +420,7 @@ class Load(Generative, LoaderOption):
if existing:
if merge_opts:
existing.local_opts.update(self.local_opts)
+ existing._extra_criteria += self._extra_criteria
else:
path.set(context, "loader", self)
else:
@@ -420,6 +428,7 @@ class Load(Generative, LoaderOption):
path.set(context, "loader", self)
if existing and existing.is_opts_only:
self.local_opts.update(existing.local_opts)
+ existing._extra_criteria += self._extra_criteria
def _set_path_strategy(self):
if not self.is_class_strategy and self.path.has_entity:
@@ -507,11 +516,13 @@ class _UnboundLoad(Load):
self.path = ()
self._to_bind = []
self.local_opts = {}
+ self._extra_criteria = ()
_cache_key_traversal = [
("path", visitors.ExtendedInternalTraversal.dp_multi_list),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
]
@@ -576,6 +587,7 @@ class _UnboundLoad(Load):
if attr:
path = path + (attr,)
self.path = path
+ self._extra_criteria = getattr(attr, "_extra_criteria", ())
return path
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 71ee29597..82fad0815 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -23,6 +23,7 @@ from .base import object_state # noqa
from .base import state_attribute_str # noqa
from .base import state_class_str # noqa
from .base import state_str # noqa
+from .interfaces import LoaderOption
from .interfaces import MapperProperty # noqa
from .interfaces import ORMColumnsClauseRole
from .interfaces import ORMEntityColumnsClauseRole
@@ -38,6 +39,7 @@ from ..engine.result import result_tuple
from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
+from ..sql import lambdas
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
@@ -854,6 +856,184 @@ class AliasedInsp(
return "aliased(%s)" % (self._target.__name__,)
+class LoaderCriteriaOption(LoaderOption):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ :class:`_orm.LoaderCriteriaOption` is invoked using the
+ :func:`_orm.with_loader_criteria` function; see that function for
+ details.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _traverse_internals = [
+ ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("where_criteria", visitors.InternalTraversal.dp_clauseelement),
+ ("include_aliases", visitors.InternalTraversal.dp_boolean),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(
+ self,
+ entity_or_base,
+ where_criteria,
+ loader_only=False,
+ include_aliases=False,
+ propagate_to_loaders=True,
+ ):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ .. versionadded:: 1.4
+
+ The :func:`_orm.with_loader_criteria` option is intended to add
+ limiting criteria to a particular kind of entity in a query,
+ **globally**, meaning it will apply to the entity as it appears
+ in the SELECT query as well as within any subqueries, join
+ conditions, and relationship loads, including both eager and lazy
+ loaders, without the need for it to be specified in any particular
+ part of the query. The rendering logic uses the same system used by
+ single table inheritance to ensure a certain discriminator is applied
+ to a table.
+
+ E.g., using :term:`2.0-style` queries, we can limit the way the
+ ``User.addresses`` collection is loaded, regardless of the kind
+ of loading used::
+
+ from sqlalchemy.orm import with_loader_criteria
+
+ stmt = select(User).options(
+ selectinload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ Above, the "selectinload" for ``User.addresses`` will apply the
+ given filtering criteria to the WHERE clause.
+
+ Another example, where the filtering will be applied to the
+ ON clause of the join, in this example using :term:`1.x style`
+ queries::
+
+ q = session.query(User).outerjoin(User.addresses).options(
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ The primary purpose of :func:`_orm.with_loader_criteria` is to use
+ it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler
+ to ensure that all occurrences of a particular entity are filtered
+ in a certain way, such as filtering for access control roles. It
+ also can be used to apply criteria to relationship loads. In the
+ example below, we can apply a certain set of rules to all queries
+ emitted by a particular :class:`_orm.Session`::
+
+ session = Session(bind=engine)
+
+ @event.listens_for("do_orm_execute", session)
+ def _add_filtering_criteria(execute_state):
+ execute_state.statement = execute_state.statement.options(
+ with_loader_criteria(
+ SecurityRole,
+ lambda cls: cls.role.in_(['some_role']),
+ include_aliases=True
+ )
+ )
+
+ The given class will expand to include all mapped subclass and
+ need not itself be a mapped class.
+
+
+ :param entity_or_base: a mapped class, or a class that is a super
+ class of a particular set of mapped classes, to which the rule
+ will apply.
+
+ :param where_criteria: a Core SQL expression that applies limiting
+ criteria. This may also be a "lambda:" or Python function that
+ accepts a target class as an argument, when the given class is
+ a base with many different mapped subclasses.
+
+ :param include_aliases: if True, apply the rule to :func:`_orm.aliased`
+ constructs as well.
+
+ :param propagate_to_loaders: defaults to True, apply to relationship
+ loaders such as lazy loaders.
+
+
+ .. seealso::
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :func:`_orm.with_loader_criteria`.
+
+ :ref:`do_orm_execute_global_criteria` - basic example on how to
+ combine :func:`_orm.with_loader_criteria` with the
+ :meth:`_orm.SessionEvents.do_orm_execute` event.
+
+ """
+ entity = inspection.inspect(entity_or_base, False)
+ if entity is None:
+ self.root_entity = entity_or_base
+ self.entity = None
+ else:
+ self.root_entity = None
+ self.entity = entity
+
+ if callable(where_criteria):
+ self.deferred_where_criteria = True
+ self.where_criteria = lambdas.DeferredLambdaElement(
+ where_criteria,
+ roles.WhereHavingRole,
+ lambda_args=(
+ self.root_entity
+ if self.root_entity is not None
+ else self.entity.entity,
+ ),
+ )
+ else:
+ self.deferred_where_criteria = False
+ self.where_criteria = coercions.expect(
+ roles.WhereHavingRole, where_criteria
+ )
+
+ self.include_aliases = include_aliases
+ self.propagate_to_loaders = propagate_to_loaders
+
+ def _all_mappers(self):
+ if self.entity:
+ for ent in self.entity.mapper.self_and_descendants:
+ yield ent
+ else:
+ stack = list(self.root_entity.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ ent = inspection.inspect(subclass)
+ if ent:
+ for mp in ent.mapper.self_and_descendants:
+ yield mp
+ else:
+ stack.extend(subclass.__subclasses__())
+
+ def _resolve_where_criteria(self, ext_info):
+ if self.deferred_where_criteria:
+ return self.where_criteria._resolve_with_args(ext_info.entity)
+ else:
+ return self.where_criteria
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ # if options to limit the criteria to immediate query only,
+ # use compile_state.attributes instead
+
+ for mp in self._all_mappers():
+ load_criteria = compile_state.global_attributes.setdefault(
+ ("additional_entity_criteria", mp), []
+ )
+
+ load_criteria.append(self)
+
+
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
inspection._inspects(AliasedInsp)(lambda target: target)
@@ -1270,6 +1450,7 @@ class _ORMJoin(expression.Join):
full=False,
_left_memo=None,
_right_memo=None,
+ _extra_criteria=(),
):
left_info = inspection.inspect(left)
@@ -1291,6 +1472,7 @@ class _ORMJoin(expression.Join):
if isinstance(onclause, attributes.QueryableAttribute):
on_selectable = onclause.comparator._source_selectable()
prop = onclause.property
+ _extra_criteria += onclause._extra_criteria
elif isinstance(onclause, MapperProperty):
# used internally by joined eager loader...possibly not ideal
prop = onclause
@@ -1319,6 +1501,7 @@ class _ORMJoin(expression.Join):
source_polymorphic=True,
of_type_entity=right_info,
alias_secondary=True,
+ extra_criteria=_extra_criteria,
)
if sj is not None:
@@ -1331,6 +1514,7 @@ class _ORMJoin(expression.Join):
onclause = sj
else:
onclause = pj
+
self._target_adapter = target_adapter
expression.Join.__init__(self, left, right, onclause, isouter, full)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ac4055bdf..b8984316c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -792,6 +792,10 @@ class SQLCompiler(Compiled):
def prefetch(self):
return list(self.insert_prefetch + self.update_prefetch)
+ @util.memoized_property
+ def _global_attributes(self):
+ return {}
+
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if