summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/util.py')
-rw-r--r--lib/sqlalchemy/orm/util.py184
1 files changed, 184 insertions, 0 deletions
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 71ee29597..82fad0815 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -23,6 +23,7 @@ from .base import object_state # noqa
from .base import state_attribute_str # noqa
from .base import state_class_str # noqa
from .base import state_str # noqa
+from .interfaces import LoaderOption
from .interfaces import MapperProperty # noqa
from .interfaces import ORMColumnsClauseRole
from .interfaces import ORMEntityColumnsClauseRole
@@ -38,6 +39,7 @@ from ..engine.result import result_tuple
from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
+from ..sql import lambdas
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
@@ -854,6 +856,184 @@ class AliasedInsp(
return "aliased(%s)" % (self._target.__name__,)
+class LoaderCriteriaOption(LoaderOption):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ :class:`_orm.LoaderCriteriaOption` is invoked using the
+ :func:`_orm.with_loader_criteria` function; see that function for
+ details.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _traverse_internals = [
+ ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("where_criteria", visitors.InternalTraversal.dp_clauseelement),
+ ("include_aliases", visitors.InternalTraversal.dp_boolean),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(
+ self,
+ entity_or_base,
+ where_criteria,
+ loader_only=False,
+ include_aliases=False,
+ propagate_to_loaders=True,
+ ):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ .. versionadded:: 1.4
+
+ The :func:`_orm.with_loader_criteria` option is intended to add
+ limiting criteria to a particular kind of entity in a query,
+ **globally**, meaning it will apply to the entity as it appears
+ in the SELECT query as well as within any subqueries, join
+ conditions, and relationship loads, including both eager and lazy
+ loaders, without the need for it to be specified in any particular
+ part of the query. The rendering logic uses the same system used by
+ single table inheritance to ensure a certain discriminator is applied
+ to a table.
+
+ E.g., using :term:`2.0-style` queries, we can limit the way the
+ ``User.addresses`` collection is loaded, regardless of the kind
+ of loading used::
+
+ from sqlalchemy.orm import with_loader_criteria
+
+ stmt = select(User).options(
+ selectinload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ Above, the "selectinload" for ``User.addresses`` will apply the
+ given filtering criteria to the WHERE clause.
+
+ Another example, where the filtering will be applied to the
+ ON clause of the join, in this example using :term:`1.x style`
+ queries::
+
+ q = session.query(User).outerjoin(User.addresses).options(
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ The primary purpose of :func:`_orm.with_loader_criteria` is to use
+ it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler
+ to ensure that all occurrences of a particular entity are filtered
+ in a certain way, such as filtering for access control roles. It
+ also can be used to apply criteria to relationship loads. In the
+ example below, we can apply a certain set of rules to all queries
+ emitted by a particular :class:`_orm.Session`::
+
+ session = Session(bind=engine)
+
+ @event.listens_for("do_orm_execute", session)
+ def _add_filtering_criteria(execute_state):
+ execute_state.statement = execute_state.statement.options(
+ with_loader_criteria(
+ SecurityRole,
+ lambda cls: cls.role.in_(['some_role']),
+ include_aliases=True
+ )
+ )
+
+ The given class will expand to include all mapped subclass and
+ need not itself be a mapped class.
+
+
+ :param entity_or_base: a mapped class, or a class that is a super
+ class of a particular set of mapped classes, to which the rule
+ will apply.
+
+ :param where_criteria: a Core SQL expression that applies limiting
+ criteria. This may also be a "lambda:" or Python function that
+ accepts a target class as an argument, when the given class is
+ a base with many different mapped subclasses.
+
+ :param include_aliases: if True, apply the rule to :func:`_orm.aliased`
+ constructs as well.
+
+ :param propagate_to_loaders: defaults to True, apply to relationship
+ loaders such as lazy loaders.
+
+
+ .. seealso::
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :func:`_orm.with_loader_criteria`.
+
+ :ref:`do_orm_execute_global_criteria` - basic example on how to
+ combine :func:`_orm.with_loader_criteria` with the
+ :meth:`_orm.SessionEvents.do_orm_execute` event.
+
+ """
+ entity = inspection.inspect(entity_or_base, False)
+ if entity is None:
+ self.root_entity = entity_or_base
+ self.entity = None
+ else:
+ self.root_entity = None
+ self.entity = entity
+
+ if callable(where_criteria):
+ self.deferred_where_criteria = True
+ self.where_criteria = lambdas.DeferredLambdaElement(
+ where_criteria,
+ roles.WhereHavingRole,
+ lambda_args=(
+ self.root_entity
+ if self.root_entity is not None
+ else self.entity.entity,
+ ),
+ )
+ else:
+ self.deferred_where_criteria = False
+ self.where_criteria = coercions.expect(
+ roles.WhereHavingRole, where_criteria
+ )
+
+ self.include_aliases = include_aliases
+ self.propagate_to_loaders = propagate_to_loaders
+
+ def _all_mappers(self):
+ if self.entity:
+ for ent in self.entity.mapper.self_and_descendants:
+ yield ent
+ else:
+ stack = list(self.root_entity.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ ent = inspection.inspect(subclass)
+ if ent:
+ for mp in ent.mapper.self_and_descendants:
+ yield mp
+ else:
+ stack.extend(subclass.__subclasses__())
+
+ def _resolve_where_criteria(self, ext_info):
+ if self.deferred_where_criteria:
+ return self.where_criteria._resolve_with_args(ext_info.entity)
+ else:
+ return self.where_criteria
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ # if options to limit the criteria to immediate query only,
+ # use compile_state.attributes instead
+
+ for mp in self._all_mappers():
+ load_criteria = compile_state.global_attributes.setdefault(
+ ("additional_entity_criteria", mp), []
+ )
+
+ load_criteria.append(self)
+
+
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
inspection._inspects(AliasedInsp)(lambda target: target)
@@ -1270,6 +1450,7 @@ class _ORMJoin(expression.Join):
full=False,
_left_memo=None,
_right_memo=None,
+ _extra_criteria=(),
):
left_info = inspection.inspect(left)
@@ -1291,6 +1472,7 @@ class _ORMJoin(expression.Join):
if isinstance(onclause, attributes.QueryableAttribute):
on_selectable = onclause.comparator._source_selectable()
prop = onclause.property
+ _extra_criteria += onclause._extra_criteria
elif isinstance(onclause, MapperProperty):
# used internally by joined eager loader...possibly not ideal
prop = onclause
@@ -1319,6 +1501,7 @@ class _ORMJoin(expression.Join):
source_polymorphic=True,
of_type_entity=right_info,
alias_secondary=True,
+ extra_criteria=_extra_criteria,
)
if sj is not None:
@@ -1331,6 +1514,7 @@ class _ORMJoin(expression.Join):
onclause = sj
else:
onclause = pj
+
self._target_adapter = target_adapter
expression.Join.__init__(self, left, right, onclause, isouter, full)