summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_14/4472.rst19
-rw-r--r--doc/build/orm/loading_relationships.rst48
-rw-r--r--doc/build/orm/query.rst2
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py5
-rw-r--r--lib/sqlalchemy/orm/__init__.py3
-rw-r--r--lib/sqlalchemy/orm/attributes.py28
-rw-r--r--lib/sqlalchemy/orm/context.py107
-rw-r--r--lib/sqlalchemy/orm/interfaces.py26
-rw-r--r--lib/sqlalchemy/orm/query.py14
-rw-r--r--lib/sqlalchemy/orm/relationships.py48
-rw-r--r--lib/sqlalchemy/orm/session.py101
-rw-r--r--lib/sqlalchemy/orm/strategies.py40
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py12
-rw-r--r--lib/sqlalchemy/orm/util.py184
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--test/ext/test_baked.py4
-rw-r--r--test/orm/inheritance/test_polymorphic_rel.py22
-rw-r--r--test/orm/test_bundle.py42
-rw-r--r--test/orm/test_cache_key.py68
-rw-r--r--test/orm/test_events.py166
-rw-r--r--test/orm/test_options.py1
-rw-r--r--test/orm/test_relationship_criteria.py867
-rw-r--r--test/sql/test_compare.py20
23 files changed, 1731 insertions, 100 deletions
diff --git a/doc/build/changelog/unreleased_14/4472.rst b/doc/build/changelog/unreleased_14/4472.rst
new file mode 100644
index 000000000..6de5058c1
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/4472.rst
@@ -0,0 +1,19 @@
+.. change::
+ :tags: feature, orm
+ :tickets: 4472
+
+ Added the ability to add arbitrary criteria to the ON clause generated
+ by a relationship attribute in a query, which applies to methods such
+ as :meth:`_query.Query.join` as well as loader options like
+ :func:`_orm.joinedload`. Additionally, a "global" version of the option
+ allows limiting criteria to be applied to particular entities in
+ a query globally.
+
+ .. seealso::
+
+ :ref:`loader_option_criteria`
+
+ :func:`_orm.with_loader_criteria`
+
+ .. TODO: add links to new examples section and session-related
+ documentation involving do_orm_execute event when merged \ No newline at end of file
diff --git a/doc/build/orm/loading_relationships.rst b/doc/build/orm/loading_relationships.rst
index 50d3cc51a..8909d9a6e 100644
--- a/doc/build/orm/loading_relationships.rst
+++ b/doc/build/orm/loading_relationships.rst
@@ -112,13 +112,10 @@ the string name of an attribute against a parent, or for greater specificity
can accommodate a class-bound attribute directly::
# set children to load lazily
- session.query(Parent).options(lazyload('children')).all()
-
- # same, using class-bound attribute
session.query(Parent).options(lazyload(Parent.children)).all()
# set children to load eagerly with a join
- session.query(Parent).options(joinedload('children')).all()
+ session.query(Parent).options(joinedload(Parent.children)).all()
The loader options can also be "chained" using **method chaining**
to specify how loading should occur further levels deep::
@@ -141,6 +138,48 @@ collections loaded. When the ``children`` collection on a particular
objects, but additionally apply eager loading to the ``subelements``
collection on each member of ``children``.
+The above examples, using :class:`_orm.Query`, are now referred to as
+:term:`1.x style` queries. The options system is available as well for
+:term:`2.0 style` queries using the :meth:`_sql.Select.options` method::
+
+ stmt = select(Parent).options(
+ lazyload(Parent.children).
+ subqueryload(Child.subelements))
+
+ result = session.execute(stmt)
+
+Under the hood, :class:`_orm.Query` is ultimately using the above
+:class:`_sql.select` based mechanism.
+
+
+.. _loader_option_criteria:
+
+Adding Criteria to loader options
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The relationship attributes used to indicate loader options include the
+ability to add additional filtering criteria to the ON clause of the join
+that's created, or to the WHERE criteria involved, depending on the loader
+strategy. This can be achieved using the :meth:`.PropComparator.and_`
+method which will pass through an option such that loaded results are limited
+to the given filter criteria::
+
+ session.query(A).options(lazyload(A.bs.and_(B.id > 5)))
+
+When using limiting criteria, if a particular collection is already loaded
+it won't be refreshed; to ensure the new criteria takes place, apply
+the :meth:`_orm.Query.populate_existing` option::
+
+ session.query(A).options(lazyload(A.bs.and_(B.id > 5))).populate_existing()
+
+In order to add filtering criteria to all occurrences of an entity throughout
+a query, regardless of loader strategy or where it occurs in the loading
+process, see the :func:`_orm.with_loader_criteria` function.
+
+.. versionadded:: 1.4
+
+Specifying Sub-Options with Load.options()
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Using method chaining, the loader style of each link in the path is explicitly
stated. To navigate along a path without changing the existing loader style
of a particular attribute, the :func:`.defaultload` method/function may be used::
@@ -1263,6 +1302,7 @@ Relationship Loader API
.. autofunction:: lazyload
.. autoclass:: Load
+ :members:
.. autofunction:: noload
diff --git a/doc/build/orm/query.rst b/doc/build/orm/query.rst
index 3fddd6c34..ed45a65e7 100644
--- a/doc/build/orm/query.rst
+++ b/doc/build/orm/query.rst
@@ -44,6 +44,8 @@ ORM-Specific Query Constructs
.. autoclass:: sqlalchemy.orm.strategy_options.Load
:members:
+.. autofunction:: sqlalchemy.orm.with_loader_criteria
+
.. autofunction:: join
.. autofunction:: outerjoin
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
index 53545826b..fe9bbaf02 100644
--- a/lib/sqlalchemy/ext/horizontal_shard.py
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -207,7 +207,6 @@ class ShardedSession(Session):
def execute_and_instances(orm_context):
-
if orm_context.is_select:
load_options = active_options = orm_context.load_options
update_options = None
@@ -237,8 +236,8 @@ def execute_and_instances(orm_context):
if active_options._refresh_identity_token is not None:
shard_id = active_options._refresh_identity_token
- elif "_sa_shard_id" in orm_context.merged_execution_options:
- shard_id = orm_context.merged_execution_options["_sa_shard_id"]
+ elif "_sa_shard_id" in orm_context.execution_options:
+ shard_id = orm_context.execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
shard_id = orm_context.bind_arguments["shard_id"]
else:
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
index 32ec60322..458103838 100644
--- a/lib/sqlalchemy/orm/__init__.py
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -48,6 +48,7 @@ from .strategy_options import Load # noqa
from .util import aliased # noqa
from .util import Bundle # noqa
from .util import join # noqa
+from .util import LoaderCriteriaOption # noqa
from .util import object_mapper # noqa
from .util import outerjoin # noqa
from .util import polymorphic_union # noqa
@@ -101,6 +102,8 @@ def create_session(bind=None, **kwargs):
return Session(bind=bind, **kwargs)
+with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm")
+
relationship = public_factory(RelationshipProperty, ".orm.relationship")
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index 6dd95a5a9..2e1b9dc75 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -50,6 +50,7 @@ from .. import inspection
from .. import util
from ..sql import base as sql_base
from ..sql import roles
+from ..sql import traversals
from ..sql import visitors
@@ -58,6 +59,7 @@ class QueryableAttribute(
interfaces._MappedAttribute,
interfaces.InspectionAttr,
interfaces.PropComparator,
+ traversals.HasCopyInternals,
roles.JoinTargetRole,
roles.OnClauseRole,
sql_base.Immutable,
@@ -91,6 +93,7 @@ class QueryableAttribute(
impl=None,
comparator=None,
of_type=None,
+ extra_criteria=(),
):
self.class_ = class_
self.key = key
@@ -98,6 +101,7 @@ class QueryableAttribute(
self.impl = impl
self.comparator = comparator
self._of_type = of_type
+ self._extra_criteria = extra_criteria
manager = manager_of_class(class_)
# manager is None in the case of AliasedClass
@@ -114,6 +118,7 @@ class QueryableAttribute(
("key", visitors.ExtendedInternalTraversal.dp_string),
("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
]
def __reduce__(self):
@@ -240,6 +245,29 @@ class QueryableAttribute(
impl=self.impl,
comparator=self.comparator.of_type(entity),
of_type=inspection.inspect(entity),
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.and_(*other),
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
+ )
+
+ def _clone(self, **kw):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria,
)
def label(self, name):
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 96725e55b..a35b2f9fd 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -11,9 +11,9 @@ from .base import _is_aliased_class
from .interfaces import ORMColumnsClauseRole
from .path_registry import PathRegistry
from .util import _entity_corresponds_to
+from .util import _ORMJoin
from .util import aliased
from .util import Bundle
-from .util import join as orm_join
from .util import ORMAdapter
from .. import exc as sa_exc
from .. import future
@@ -78,7 +78,6 @@ class QueryContext(object):
_yield_per = None
_refresh_state = None
_lazy_loaded_from = None
- _params = _EMPTY_DICT
def __init__(
self,
@@ -308,6 +307,9 @@ class ORMFromStatementCompileState(ORMCompileState):
multi_row_eager_loaders = False
compound_eager_adapter = None
+ extra_criteria_entities = _EMPTY_DICT
+ eager_joins = _EMPTY_DICT
+
@classmethod
def create_for_statement(cls, statement_container, compiler, **kw):
@@ -338,6 +340,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if toplevel and statement_container._with_options:
self.attributes = {"_unbound_load_dedupes": set()}
+ self.global_attributes = compiler._global_attributes
for opt in statement_container._with_options:
if opt._is_compile_state:
@@ -345,6 +348,7 @@ class ORMFromStatementCompileState(ORMCompileState):
else:
self.attributes = {}
+ self.global_attributes = compiler._global_attributes
if statement_container._with_context_options:
for fn, key in statement_container._with_context_options:
@@ -352,8 +356,6 @@ class ORMFromStatementCompileState(ORMCompileState):
self.primary_columns = []
self.secondary_columns = []
- self.eager_joins = {}
- self.single_inh_entities = {}
self.create_eager_joins = []
self._fallback_from_clauses = []
@@ -423,11 +425,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def create_for_statement(cls, statement, compiler, **kw):
"""compiler hook, we arrive here from compiler.visit_select() only."""
+ self = cls.__new__(cls)
+
if compiler is not None:
toplevel = not compiler.stack
compiler._rewrites_selected_columns = True
+ self.global_attributes = compiler._global_attributes
else:
toplevel = True
+ self.global_attributes = {}
select_statement = statement
@@ -437,8 +443,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement._compile_options
)
- self = cls.__new__(cls)
-
if select_statement._execution_options:
# execution options should not impact the compilation of a
# query, and at the moment subqueryloader is putting some things
@@ -516,7 +520,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.primary_columns = []
self.secondary_columns = []
self.eager_joins = {}
- self.single_inh_entities = {}
+ self.extra_criteria_entities = {}
self.create_eager_joins = []
self._fallback_from_clauses = []
@@ -634,7 +638,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if self.compile_options._enable_single_crit:
- self._adjust_for_single_inheritance()
+ self._adjust_for_extra_criteria()
if not self.primary_columns:
if self.compile_options._only_load_props:
@@ -1408,6 +1412,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left, right, onclause, prop, create_aliases, aliased_generation
)
+ if not r_info.is_selectable:
+ extra_criteria = self._get_extra_criteria(r_info)
+ else:
+ extra_criteria = ()
+
if replace_from_obj_index is not None:
# splice into an existing element in the
# self._from_obj list
@@ -1416,12 +1425,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.from_clauses = (
self.from_clauses[:replace_from_obj_index]
+ [
- orm_join(
+ _ORMJoin(
left_clause,
right,
onclause,
isouter=outerjoin,
full=full,
+ _extra_criteria=extra_criteria,
)
]
+ self.from_clauses[replace_from_obj_index + 1 :]
@@ -1440,8 +1450,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left_clause = left
self.from_clauses = self.from_clauses + [
- orm_join(
- left_clause, r_info, onclause, isouter=outerjoin, full=full
+ _ORMJoin(
+ left_clause,
+ r_info,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
)
]
@@ -1848,8 +1863,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
or kwargs.get("group_by", False)
)
- def _adjust_for_single_inheritance(self):
- """Apply single-table-inheritance filtering.
+ def _get_extra_criteria(self, ext_info):
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in self.global_attributes:
+ return tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in self.global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+ else:
+ return ()
+
+ def _adjust_for_extra_criteria(self):
+ """Apply extra criteria filtering.
For all distinct single-table-inheritance mappers represented in
the columns clause of this query, as well as the "select from entity",
@@ -1857,38 +1887,50 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
clause of the given QueryContext such that only the appropriate
subtypes are selected from the total results.
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ associated with the global context.
+
"""
for fromclause in self.from_clauses:
ext_info = fromclause._annotations.get("parententity", None)
if (
ext_info
- and ext_info.mapper._single_table_criterion is not None
- and ext_info not in self.single_inh_entities
+ and (
+ ext_info.mapper._single_table_criterion is not None
+ or ("additional_entity_criteria", ext_info.mapper)
+ in self.global_attributes
+ )
+ and ext_info not in self.extra_criteria_entities
):
- self.single_inh_entities[ext_info] = (
+ self.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
)
- search = set(self.single_inh_entities.values())
+ search = set(self.extra_criteria_entities.values())
for (ext_info, adapter) in search:
if ext_info in self._join_entities:
continue
+
single_crit = ext_info.mapper._single_table_criterion
+
+ additional_entity_criteria = self._get_extra_criteria(ext_info)
+
if single_crit is not None:
+ additional_entity_criteria += (single_crit,)
+
+ current_adapter = self._get_current_adapter()
+ for crit in additional_entity_criteria:
if adapter:
- single_crit = adapter.traverse(single_crit)
+ crit = adapter.traverse(crit)
- current_adapter = self._get_current_adapter()
if current_adapter:
- single_crit = sql_util._deep_annotate(
- single_crit, {"_orm_adapt": True}
- )
- single_crit = current_adapter(single_crit, False)
- self._where_criteria += (single_crit,)
+ crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+ crit = current_adapter(crit, False)
+ self._where_criteria += (crit,)
def _column_descriptions(query_or_select_stmt, compile_state=None):
@@ -2205,9 +2247,13 @@ class _MapperEntity(_QueryEntity):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
- if single_table_crit is not None:
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
ext_info = self.entity_zero
- compile_state.single_inh_entities[ext_info] = (
+ compile_state.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
)
@@ -2528,8 +2574,13 @@ class _ORMColumnEntity(_ColumnEntity):
ezero = self.entity_zero
single_table_crit = self.mapper._single_table_criterion
- if single_table_crit is not None:
- compile_state.single_inh_entities[ezero] = (
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+
+ compile_state.extra_criteria_entities[ezero] = (
ezero,
ezero._adapter if ezero.is_aliased_class else None,
)
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index 4cf820ae3..068c85073 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -480,6 +480,32 @@ class PropComparator(operators.ColumnOperators):
return self.operate(PropComparator.of_type_op, class_)
+ def and_(self, *criteria):
+ """Add additional criteria to the ON clause that's represented by this
+ relationship attribute.
+
+ E.g.::
+
+
+ stmt = select(User).join(
+ User.addresses.and_(Address.email_address != 'foo')
+ )
+
+ stmt = select(User).options(
+ joinedload(User.addresses.and_(Address.email_address != 'foo'))
+ )
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`loader_option_criteria`
+
+ :func:`.with_loader_criteria`
+
+ """
+ return self.operate(operators.and_, *criteria)
+
def any(self, criterion=None, **kwargs):
r"""Return true if this collection contains any member that meets the
given criterion.
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index d60c03bdc..68ca0365b 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1997,6 +1997,20 @@ class Query(
filter(a1.email_address == 'ed@foo.com').\
filter(a2.email_address == 'ed@bar.com')
+ **Augmenting Built-in ON Clauses**
+
+ As a substitute for providing a full custom ON condition for an
+ existing relationship, the :meth:`_orm.PropComparator.and_` function
+ may be applied to a relationship attribute to augment additional
+ criteria into the ON clause; the additional criteria will be combined
+ with the default criteria using AND::
+
+ q = session.query(User).join(
+ User.addresses.and_(Address.email_address != 'foo@bar.com')
+ )
+
+ .. versionadded:: 1.4
+
**Joining to Tables and Subqueries**
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
index cb490b7d7..794b9422c 100644
--- a/lib/sqlalchemy/orm/relationships.py
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -1115,9 +1115,15 @@ class RelationshipProperty(StrategizedProperty):
"""
_of_type = None
+ _extra_criteria = ()
def __init__(
- self, prop, parentmapper, adapt_to_entity=None, of_type=None
+ self,
+ prop,
+ parentmapper,
+ adapt_to_entity=None,
+ of_type=None,
+ extra_criteria=(),
):
"""Construction of :class:`.RelationshipProperty.Comparator`
is internal to the ORM's attribute mechanics.
@@ -1128,6 +1134,7 @@ class RelationshipProperty(StrategizedProperty):
self._adapt_to_entity = adapt_to_entity
if of_type:
self._of_type = of_type
+ self._extra_criteria = extra_criteria
def adapt_to_entity(self, adapt_to_entity):
return self.__class__(
@@ -1191,6 +1198,7 @@ class RelationshipProperty(StrategizedProperty):
source_polymorphic=True,
of_type_entity=of_type_entity,
alias_secondary=True,
+ extra_criteria=self._extra_criteria,
)
if sj is not None:
return pj & sj
@@ -1202,12 +1210,30 @@ class RelationshipProperty(StrategizedProperty):
See :meth:`.PropComparator.of_type` for an example.
+
"""
return RelationshipProperty.Comparator(
self.property,
self._parententity,
adapt_to_entity=self._adapt_to_entity,
of_type=cls,
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ """Add AND criteria.
+
+ See :meth:`.PropComparator.and_` for an example.
+
+ .. versionadded:: 1.4
+
+ """
+ return RelationshipProperty.Comparator(
+ self.property,
+ self._parententity,
+ adapt_to_entity=self._adapt_to_entity,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
)
def in_(self, other):
@@ -2439,6 +2465,7 @@ class RelationshipProperty(StrategizedProperty):
dest_selectable=None,
of_type_entity=None,
alias_secondary=False,
+ extra_criteria=(),
):
aliased = False
@@ -2489,7 +2516,11 @@ class RelationshipProperty(StrategizedProperty):
target_adapter,
dest_selectable,
) = self._join_condition.join_targets(
- source_selectable, dest_selectable, aliased, single_crit
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit,
+ extra_criteria,
)
if source_selectable is None:
source_selectable = self.parent.local_table
@@ -3427,7 +3458,12 @@ class JoinCondition(object):
)
def join_targets(
- self, source_selectable, dest_selectable, aliased, single_crit=None
+ self,
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit=None,
+ extra_criteria=(),
):
"""Given a source and destination selectable, create a
join between them.
@@ -3463,6 +3499,12 @@ class JoinCondition(object):
else:
primaryjoin = primaryjoin & single_crit
+ if extra_criteria:
+ if secondaryjoin is not None:
+ secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
+ else:
+ primaryjoin = primaryjoin & sql.and_(*extra_criteria)
+
if aliased:
if secondary is not None:
secondary = secondary._anonymous_fromclause(flat=True)
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 339c57bdc..e9d4ac2c6 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -102,23 +102,26 @@ CLOSED = util.symbol("CLOSED")
class ORMExecuteState(util.MemoizedSlots):
- """Stateful object used for the :meth:`.SessionEvents.do_orm_execute`
+ """Represents a call to the :meth:`_orm.Session.execute` method, as passed
+ to the :meth:`.SessionEvents.do_orm_execute` event hook.
.. versionadded:: 1.4
+
"""
__slots__ = (
"session",
"statement",
"parameters",
- "_execution_options",
- "_merged_execution_options",
+ "execution_options",
+ "local_execution_options",
"bind_arguments",
"_compile_state_cls",
"_starting_event_idx",
"_events_todo",
"_future",
+ "_update_execution_options",
)
def __init__(
@@ -135,7 +138,10 @@ class ORMExecuteState(util.MemoizedSlots):
self.session = session
self.statement = statement
self.parameters = parameters
- self._execution_options = execution_options
+ self.local_execution_options = execution_options
+ self.execution_options = statement._execution_options.union(
+ execution_options
+ )
self.bind_arguments = bind_arguments
self._compile_state_cls = compile_state_cls
self._events_todo = list(events_todo)
@@ -182,9 +188,8 @@ class ORMExecuteState(util.MemoizedSlots):
.. seealso::
- :ref:`examples_caching` - includes example use of the
- :meth:`.SessionEvents.do_orm_execute` hook as well as the
- :meth:`.ORMExecuteState.invoke_query` method.
+ :ref:`do_orm_execute_re_executing` - background and examples on the
+ appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`.
"""
@@ -203,11 +208,9 @@ class ORMExecuteState(util.MemoizedSlots):
else:
_params = self.parameters
+ _execution_options = self.local_execution_options
if execution_options:
- _execution_options = dict(self._execution_options)
- _execution_options.update(execution_options)
- else:
- _execution_options = self._execution_options
+ _execution_options = _execution_options.union(execution_options)
return self.session.execute(
statement,
@@ -255,42 +258,9 @@ class ORMExecuteState(util.MemoizedSlots):
def _is_crud(self):
return isinstance(self.statement, (dml.Update, dml.Delete))
- @property
- def execution_options(self):
- """Placeholder for execution options.
-
- Raises an informative message, as there are local options
- vs. merged options that can be viewed, via the
- :attr:`.ORMExecuteState.local_execution_options` and
- :attr:`.ORMExecuteState.merged_execution_options` methods.
-
-
- """
- raise AttributeError(
- "Please use .local_execution_options or "
- ".merged_execution_options"
- )
-
- @property
- def local_execution_options(self):
- """Dictionary view of the execution options passed to the
- :meth:`.Session.execute` method. This does not include options
- that may be associated with the statement being invoked.
-
- """
- return util.immutabledict(self._execution_options)
-
- @property
- def merged_execution_options(self):
- """Dictionary view of all execution options merged together;
- this includes those of the statement as well as those passed to
- :meth:`.Session.execute`, with the local options taking precedence.
-
- """
- return self._merged_execution_options
-
- def _memoized_attr__merged_execution_options(self):
- return self.statement._execution_options.union(self._execution_options)
+ def update_execution_options(self, **opts):
+ # TODO: no coverage
+ self.local_execution_options = self.local_execution_options.union(opts)
def _orm_compile_options(self):
opts = self.statement._compile_options
@@ -329,6 +299,20 @@ class ORMExecuteState(util.MemoizedSlots):
return None
@property
+ def is_relationship_load(self):
+ """Return True if this load is loading objects on behalf of a
+ relationship.
+
+ This means, the loader in effect is either a LazyLoader,
+ SelectInLoader, SubqueryLoader, or similar, and the entire
+ SELECT statement being emitted is on behalf of a relationship
+ load.
+
+ """
+ path = self.loader_strategy_path
+ return path is not None and not path.is_root
+
+ @property
def load_options(self):
"""Return the load_options that will be used for this execution."""
@@ -337,7 +321,7 @@ class ORMExecuteState(util.MemoizedSlots):
"This ORM execution is not against a SELECT statement "
"so there are no load options."
)
- return self._execution_options.get(
+ return self.execution_options.get(
"_sa_orm_load_options", context.QueryContext.default_load_options
)
@@ -351,7 +335,7 @@ class ORMExecuteState(util.MemoizedSlots):
"This ORM execution is not against an UPDATE or DELETE "
"statement so there are no update options."
)
- return self._execution_options.get(
+ return self.execution_options.get(
"_sa_orm_update_options",
persistence.BulkUDCompileState.default_update_options,
)
@@ -1003,8 +987,6 @@ class Session(_SessionClassMethods):
:ref:`migration_20_toplevel`
- :ref:`migration_20_result_rows`
-
:param info: optional dictionary of arbitrary data to be associated
with this :class:`.Session`. Is available via the
:attr:`.Session.info` attribute. Note the dictionary is copied at
@@ -1282,7 +1264,7 @@ class Session(_SessionClassMethods):
the operation will release the current SAVEPOINT but not commit
the outermost database transaction.
- If :term:`2.x-style` use is in effect via the
+ If :term:`2.0-style` use is in effect via the
:paramref:`_orm.Session.future` flag, the outermost database
transaction is committed unconditionally, automatically releasing any
SAVEPOINTs in effect.
@@ -1416,7 +1398,7 @@ class Session(_SessionClassMethods):
self,
statement,
params=None,
- execution_options=util.immutabledict(),
+ execution_options=util.EMPTY_DICT,
bind_arguments=None,
future=False,
_parent_execute_state=None,
@@ -1576,6 +1558,8 @@ class Session(_SessionClassMethods):
else:
compile_state_cls = None
+ execution_options = util.coerce_to_immutabledict(execution_options)
+
if compile_state_cls is not None:
(
statement,
@@ -1591,8 +1575,11 @@ class Session(_SessionClassMethods):
else:
bind_arguments.setdefault("clause", statement)
if future:
- execution_options = util.immutabledict().merge_with(
- execution_options, {"future_result": True}
+ # not sure if immutabledict is working w/ this syntax
+ # execution_options =
+ # execution_options.union(future_result=True)
+ execution_options = execution_options.union(
+ {"future_result": True}
)
if _parent_execute_state:
@@ -1619,6 +1606,10 @@ class Session(_SessionClassMethods):
if result:
return result
+ # TODO: coverage for this pattern
+ statement = orm_exec_state.statement
+ execution_options = orm_exec_state.local_execution_options
+
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind, close_with_result=True)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
index 44f303fee..53166bd91 100644
--- a/lib/sqlalchemy/orm/strategies.py
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -1975,6 +1975,7 @@ class JoinedLoader(AbstractRelationshipLoader):
clauses,
innerjoin,
chained_from_outerjoin,
+ loadopt._extra_criteria if loadopt else (),
)
)
@@ -1993,6 +1994,7 @@ class JoinedLoader(AbstractRelationshipLoader):
clauses,
innerjoin,
chained_from_outerjoin,
+ extra_criteria,
):
if parentmapper is None:
localparent = query_entity.mapper
@@ -2081,6 +2083,17 @@ class JoinedLoader(AbstractRelationshipLoader):
or query_entity.entity_zero.represents_outer_join
)
+ extra_join_criteria = extra_criteria
+ additional_entity_criteria = compile_state.global_attributes.get(
+ ("additional_entity_criteria", self.mapper), ()
+ )
+ if additional_entity_criteria:
+ extra_join_criteria += tuple(
+ ae._resolve_where_criteria(self.mapper)
+ for ae in additional_entity_criteria
+ if ae.propagate_to_loaders
+ )
+
if attach_on_outside:
# this is the "classic" eager join case.
eagerjoin = orm_util._ORMJoin(
@@ -2092,11 +2105,12 @@ class JoinedLoader(AbstractRelationshipLoader):
or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
_left_memo=self.parent,
_right_memo=self.mapper,
+ _extra_criteria=extra_join_criteria,
)
else:
# all other cases are innerjoin=='nested' approach
eagerjoin = self._splice_nested_inner_join(
- path, towrap, clauses, onclause,
+ path, towrap, clauses, onclause, extra_join_criteria
)
compile_state.eager_joins[query_entity_key] = eagerjoin
@@ -2128,7 +2142,7 @@ class JoinedLoader(AbstractRelationshipLoader):
)
def _splice_nested_inner_join(
- self, path, join_obj, clauses, onclause, splicing=False
+ self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
):
if splicing is False:
@@ -2137,7 +2151,12 @@ class JoinedLoader(AbstractRelationshipLoader):
assert isinstance(join_obj, orm_util._ORMJoin)
elif isinstance(join_obj, sql.selectable.FromGrouping):
return self._splice_nested_inner_join(
- path, join_obj.element, clauses, onclause, splicing,
+ path,
+ join_obj.element,
+ clauses,
+ onclause,
+ extra_criteria,
+ splicing,
)
elif not isinstance(join_obj, orm_util._ORMJoin):
if path[-2] is splicing:
@@ -2148,18 +2167,29 @@ class JoinedLoader(AbstractRelationshipLoader):
isouter=False,
_left_memo=splicing,
_right_memo=path[-1].mapper,
+ _extra_criteria=extra_criteria,
)
else:
# only here if splicing == True
return None
target_join = self._splice_nested_inner_join(
- path, join_obj.right, clauses, onclause, join_obj._right_memo,
+ path,
+ join_obj.right,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._right_memo,
)
if target_join is None:
right_splice = False
target_join = self._splice_nested_inner_join(
- path, join_obj.left, clauses, onclause, join_obj._left_memo,
+ path,
+ join_obj.left,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._left_memo,
)
if target_join is None:
# should only return None when recursively called,
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index b405153b9..b3913ec5b 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -78,6 +78,7 @@ class Load(Generative, LoaderOption):
("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
(
"_context_cache_key",
visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
@@ -101,6 +102,7 @@ class Load(Generative, LoaderOption):
load.context = {}
load.local_opts = {}
load._of_type = None
+ load._extra_criteria = ()
return load
@property
@@ -124,6 +126,7 @@ class Load(Generative, LoaderOption):
strategy = None
propagate_to_loaders = False
_of_type = None
+ _extra_criteria = ()
def process_compile_state(self, compile_state):
if not compile_state.compile_options._enable_eagerloads:
@@ -248,6 +251,9 @@ class Load(Generative, LoaderOption):
else:
return None
+ if attr._extra_criteria:
+ self._extra_criteria = attr._extra_criteria
+
if getattr(attr, "_of_type", None):
ac = attr._of_type
ext_info = of_type_info = inspect(ac)
@@ -356,6 +362,7 @@ class Load(Generative, LoaderOption):
cloned = self._clone_for_bind_strategy(attr, strategy, "relationship")
self.path = cloned.path
self._of_type = cloned._of_type
+ self._extra_criteria = cloned._extra_criteria
cloned.is_class_strategy = self.is_class_strategy = False
self.propagate_to_loaders = cloned.propagate_to_loaders
@@ -413,6 +420,7 @@ class Load(Generative, LoaderOption):
if existing:
if merge_opts:
existing.local_opts.update(self.local_opts)
+ existing._extra_criteria += self._extra_criteria
else:
path.set(context, "loader", self)
else:
@@ -420,6 +428,7 @@ class Load(Generative, LoaderOption):
path.set(context, "loader", self)
if existing and existing.is_opts_only:
self.local_opts.update(existing.local_opts)
+ existing._extra_criteria += self._extra_criteria
def _set_path_strategy(self):
if not self.is_class_strategy and self.path.has_entity:
@@ -507,11 +516,13 @@ class _UnboundLoad(Load):
self.path = ()
self._to_bind = []
self.local_opts = {}
+ self._extra_criteria = ()
_cache_key_traversal = [
("path", visitors.ExtendedInternalTraversal.dp_multi_list),
("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
]
@@ -576,6 +587,7 @@ class _UnboundLoad(Load):
if attr:
path = path + (attr,)
self.path = path
+ self._extra_criteria = getattr(attr, "_extra_criteria", ())
return path
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
index 71ee29597..82fad0815 100644
--- a/lib/sqlalchemy/orm/util.py
+++ b/lib/sqlalchemy/orm/util.py
@@ -23,6 +23,7 @@ from .base import object_state # noqa
from .base import state_attribute_str # noqa
from .base import state_class_str # noqa
from .base import state_str # noqa
+from .interfaces import LoaderOption
from .interfaces import MapperProperty # noqa
from .interfaces import ORMColumnsClauseRole
from .interfaces import ORMEntityColumnsClauseRole
@@ -38,6 +39,7 @@ from ..engine.result import result_tuple
from ..sql import base as sql_base
from ..sql import coercions
from ..sql import expression
+from ..sql import lambdas
from ..sql import roles
from ..sql import util as sql_util
from ..sql import visitors
@@ -854,6 +856,184 @@ class AliasedInsp(
return "aliased(%s)" % (self._target.__name__,)
+class LoaderCriteriaOption(LoaderOption):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ :class:`_orm.LoaderCriteriaOption` is invoked using the
+ :func:`_orm.with_loader_criteria` function; see that function for
+ details.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _traverse_internals = [
+ ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("where_criteria", visitors.InternalTraversal.dp_clauseelement),
+ ("include_aliases", visitors.InternalTraversal.dp_boolean),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(
+ self,
+ entity_or_base,
+ where_criteria,
+ loader_only=False,
+ include_aliases=False,
+ propagate_to_loaders=True,
+ ):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ .. versionadded:: 1.4
+
+ The :func:`_orm.with_loader_criteria` option is intended to add
+ limiting criteria to a particular kind of entity in a query,
+ **globally**, meaning it will apply to the entity as it appears
+ in the SELECT query as well as within any subqueries, join
+ conditions, and relationship loads, including both eager and lazy
+ loaders, without the need for it to be specified in any particular
+ part of the query. The rendering logic uses the same system used by
+ single table inheritance to ensure a certain discriminator is applied
+ to a table.
+
+ E.g., using :term:`2.0-style` queries, we can limit the way the
+ ``User.addresses`` collection is loaded, regardless of the kind
+ of loading used::
+
+ from sqlalchemy.orm import with_loader_criteria
+
+ stmt = select(User).options(
+ selectinload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ Above, the "selectinload" for ``User.addresses`` will apply the
+ given filtering criteria to the WHERE clause.
+
+ Another example, where the filtering will be applied to the
+ ON clause of the join, in this example using :term:`1.x style`
+ queries::
+
+ q = session.query(User).outerjoin(User.addresses).options(
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ The primary purpose of :func:`_orm.with_loader_criteria` is to use
+ it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler
+ to ensure that all occurrences of a particular entity are filtered
+ in a certain way, such as filtering for access control roles. It
+ also can be used to apply criteria to relationship loads. In the
+ example below, we can apply a certain set of rules to all queries
+ emitted by a particular :class:`_orm.Session`::
+
+ session = Session(bind=engine)
+
+ @event.listens_for("do_orm_execute", session)
+ def _add_filtering_criteria(execute_state):
+ execute_state.statement = execute_state.statement.options(
+ with_loader_criteria(
+ SecurityRole,
+ lambda cls: cls.role.in_(['some_role']),
+ include_aliases=True
+ )
+ )
+
+ The given class will expand to include all mapped subclass and
+ need not itself be a mapped class.
+
+
+ :param entity_or_base: a mapped class, or a class that is a super
+ class of a particular set of mapped classes, to which the rule
+ will apply.
+
+ :param where_criteria: a Core SQL expression that applies limiting
+ criteria. This may also be a "lambda:" or Python function that
+ accepts a target class as an argument, when the given class is
+ a base with many different mapped subclasses.
+
+ :param include_aliases: if True, apply the rule to :func:`_orm.aliased`
+ constructs as well.
+
+ :param propagate_to_loaders: defaults to True, apply to relationship
+ loaders such as lazy loaders.
+
+
+ .. seealso::
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :func:`_orm.with_loader_criteria`.
+
+ :ref:`do_orm_execute_global_criteria` - basic example on how to
+ combine :func:`_orm.with_loader_criteria` with the
+ :meth:`_orm.SessionEvents.do_orm_execute` event.
+
+ """
+ entity = inspection.inspect(entity_or_base, False)
+ if entity is None:
+ self.root_entity = entity_or_base
+ self.entity = None
+ else:
+ self.root_entity = None
+ self.entity = entity
+
+ if callable(where_criteria):
+ self.deferred_where_criteria = True
+ self.where_criteria = lambdas.DeferredLambdaElement(
+ where_criteria,
+ roles.WhereHavingRole,
+ lambda_args=(
+ self.root_entity
+ if self.root_entity is not None
+ else self.entity.entity,
+ ),
+ )
+ else:
+ self.deferred_where_criteria = False
+ self.where_criteria = coercions.expect(
+ roles.WhereHavingRole, where_criteria
+ )
+
+ self.include_aliases = include_aliases
+ self.propagate_to_loaders = propagate_to_loaders
+
+ def _all_mappers(self):
+ if self.entity:
+ for ent in self.entity.mapper.self_and_descendants:
+ yield ent
+ else:
+ stack = list(self.root_entity.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ ent = inspection.inspect(subclass)
+ if ent:
+ for mp in ent.mapper.self_and_descendants:
+ yield mp
+ else:
+ stack.extend(subclass.__subclasses__())
+
+ def _resolve_where_criteria(self, ext_info):
+ if self.deferred_where_criteria:
+ return self.where_criteria._resolve_with_args(ext_info.entity)
+ else:
+ return self.where_criteria
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ # if options to limit the criteria to immediate query only,
+ # use compile_state.attributes instead
+
+ for mp in self._all_mappers():
+ load_criteria = compile_state.global_attributes.setdefault(
+ ("additional_entity_criteria", mp), []
+ )
+
+ load_criteria.append(self)
+
+
inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
inspection._inspects(AliasedInsp)(lambda target: target)
@@ -1270,6 +1450,7 @@ class _ORMJoin(expression.Join):
full=False,
_left_memo=None,
_right_memo=None,
+ _extra_criteria=(),
):
left_info = inspection.inspect(left)
@@ -1291,6 +1472,7 @@ class _ORMJoin(expression.Join):
if isinstance(onclause, attributes.QueryableAttribute):
on_selectable = onclause.comparator._source_selectable()
prop = onclause.property
+ _extra_criteria += onclause._extra_criteria
elif isinstance(onclause, MapperProperty):
# used internally by joined eager loader...possibly not ideal
prop = onclause
@@ -1319,6 +1501,7 @@ class _ORMJoin(expression.Join):
source_polymorphic=True,
of_type_entity=right_info,
alias_secondary=True,
+ extra_criteria=_extra_criteria,
)
if sj is not None:
@@ -1331,6 +1514,7 @@ class _ORMJoin(expression.Join):
onclause = sj
else:
onclause = pj
+
self._target_adapter = target_adapter
expression.Join.__init__(self, left, right, onclause, isouter, full)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index ac4055bdf..b8984316c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -792,6 +792,10 @@ class SQLCompiler(Compiled):
def prefetch(self):
return list(self.insert_prefetch + self.update_prefetch)
+ @util.memoized_property
+ def _global_attributes(self):
+ return {}
+
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if
diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py
index 6279dcf55..c8e83bbd7 100644
--- a/test/ext/test_baked.py
+++ b/test/ext/test_baked.py
@@ -1017,8 +1017,8 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest):
if ckey:
break
else:
- if "_cache_key" in orm_context.merged_execution_options:
- ckey = orm_context.merged_execution_options["_cache_key"]
+ if "_cache_key" in orm_context.execution_options:
+ ckey = orm_context.execution_options["_cache_key"]
if ckey is not None:
return get_value(
diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py
index e33e95cc0..86e0bd360 100644
--- a/test/orm/inheritance/test_polymorphic_rel.py
+++ b/test/orm/inheritance/test_polymorphic_rel.py
@@ -1302,6 +1302,28 @@ class _PolymorphicTestBase(object):
[e1, e3],
)
+ def test_join_and_thru_polymorphic_nonaliased_one(self):
+ sess = create_session()
+ eq_(
+ sess.query(Company)
+ .join(Company.employees)
+ .join(Person.paperwork.and_(Paperwork.description.like("%#2%")))
+ .all(),
+ [c1],
+ )
+
+ def test_join_and_thru_polymorphic_aliased_one(self):
+ sess = create_session()
+ ea = aliased(Person)
+ pa = aliased(Paperwork)
+ eq_(
+ sess.query(Company)
+ .join(ea, Company.employees)
+ .join(pa, ea.paperwork.and_(pa.description.like("%#2%")))
+ .all(),
+ [c1],
+ )
+
def test_join_through_polymorphic_nonaliased_one(self):
sess = create_session()
eq_(
diff --git a/test/orm/test_bundle.py b/test/orm/test_bundle.py
index f4af84094..9d1d0b61b 100644
--- a/test/orm/test_bundle.py
+++ b/test/orm/test_bundle.py
@@ -3,6 +3,7 @@ from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
+from sqlalchemy import testing
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Bundle
from sqlalchemy.orm import mapper
@@ -186,6 +187,35 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
],
)
+ def test_multi_bundle_future(self):
+ Data = self.classes.Data
+ Other = self.classes.Other
+
+ d1 = aliased(Data)
+
+ b1 = Bundle("b1", d1.d1, d1.d2)
+ b2 = Bundle("b2", Data.d1, Other.o1)
+
+ sess = Session(testing.db, future=True)
+
+ stmt = (
+ select(b1, b2)
+ .join(Data.others)
+ .join(d1, d1.id == Data.id)
+ .filter(b1.c.d1 == "d3d1")
+ )
+
+ eq_(
+ sess.execute(stmt).all(),
+ [
+ (("d3d1", "d3d2"), ("d3d1", "d3o0")),
+ (("d3d1", "d3d2"), ("d3d1", "d3o1")),
+ (("d3d1", "d3d2"), ("d3d1", "d3o2")),
+ (("d3d1", "d3d2"), ("d3d1", "d3o3")),
+ (("d3d1", "d3d2"), ("d3d1", "d3o4")),
+ ],
+ )
+
def test_single_entity(self):
Data = self.classes.Data
sess = Session()
@@ -197,6 +227,18 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
[("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")],
)
+ def test_single_entity_future(self):
+ Data = self.classes.Data
+ sess = Session(testing.db, future=True)
+
+ b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True)
+
+ stmt = select(b1).filter(b1.c.d1.between("d3d1", "d5d1"))
+ eq_(
+ sess.execute(stmt).scalars().all(),
+ [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")],
+ )
+
def test_single_entity_flag_but_multi_entities(self):
Data = self.classes.Data
sess = Session()
diff --git a/test/orm/test_cache_key.py b/test/orm/test_cache_key.py
index 02b1b9fbf..45a60a5cb 100644
--- a/test/orm/test_cache_key.py
+++ b/test/orm/test_cache_key.py
@@ -15,6 +15,7 @@ from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm import subqueryload
+from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.orm import with_polymorphic
from sqlalchemy.sql.base import CacheableOptions
from sqlalchemy.sql.visitors import InternalTraversal
@@ -65,6 +66,62 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
compare_values=True,
)
+ def test_loader_criteria(self):
+ User, Address = self.classes("User", "Address")
+
+ from sqlalchemy import Column, Integer, String
+
+ class Foo(object):
+ id = Column(Integer)
+ name = Column(String)
+
+ self._run_cache_key_fixture(
+ lambda: (
+ with_loader_criteria(User, User.name != "somename"),
+ with_loader_criteria(User, User.id != 5),
+ with_loader_criteria(User, lambda cls: cls.id == 10),
+ with_loader_criteria(Address, Address.id != 5),
+ with_loader_criteria(Foo, lambda cls: cls.id == 10),
+ ),
+ compare_values=True,
+ )
+
+ def test_loader_criteria_bound_param_thing(self):
+ from sqlalchemy import Column, Integer
+
+ class Foo(object):
+ id = Column(Integer)
+
+ def go(param):
+ return with_loader_criteria(Foo, lambda cls: cls.id == param)
+
+ g1 = go(10)
+ g2 = go(20)
+
+ ck1 = g1._generate_cache_key()
+ ck2 = g2._generate_cache_key()
+
+ eq_(ck1.key, ck2.key)
+ eq_(ck1.bindparams[0].key, ck2.bindparams[0].key)
+ eq_(ck1.bindparams[0].value, 10)
+ eq_(ck2.bindparams[0].value, 20)
+
+ def test_instrumented_attributes(self):
+ User, Address, Keyword, Order, Item = self.classes(
+ "User", "Address", "Keyword", "Order", "Item"
+ )
+
+ self._run_cache_key_fixture(
+ lambda: (
+ User.addresses,
+ User.addresses.of_type(aliased(Address)),
+ User.orders,
+ User.orders.and_(Order.id != 5),
+ User.orders.and_(Order.description != "somename"),
+ ),
+ compare_values=True,
+ )
+
def test_unbound_options(self):
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
@@ -75,6 +132,10 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
joinedload(User.addresses),
joinedload(User.addresses.of_type(aliased(Address))),
joinedload("addresses"),
+ joinedload(User.orders),
+ joinedload(User.orders.and_(Order.id != 5)),
+ joinedload(User.orders.and_(Order.id == 5)),
+ joinedload(User.orders.and_(Order.description != "somename")),
joinedload(User.orders).selectinload("items"),
joinedload(User.orders).selectinload(Order.items),
defer(User.id),
@@ -110,6 +171,10 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
User.addresses.of_type(aliased(Address))
),
Load(User).joinedload(User.orders),
+ Load(User).joinedload(User.orders.and_(Order.id != 5)),
+ Load(User).joinedload(
+ User.orders.and_(Order.description != "somename")
+ ),
Load(User).defer(User.id),
Load(User).subqueryload("addresses"),
Load(Address).defer("id"),
@@ -169,6 +234,9 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
select(User).join(Address, User.addresses),
select(User).join(a1, User.addresses),
select(User).join(User.addresses.of_type(a1)),
+ select(User).join(
+ User.addresses.and_(Address.email_address == "foo")
+ ),
select(User)
.join(Address, User.addresses)
.join_from(User, Order),
diff --git a/test/orm/test_events.py b/test/orm/test_events.py
index b68e0d2e6..df48cfe63 100644
--- a/test/orm/test_events.py
+++ b/test/orm/test_events.py
@@ -2,6 +2,8 @@ import sqlalchemy as sa
from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy import literal_column
+from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.ext.declarative import declarative_base
@@ -47,6 +49,170 @@ class _RemoveListeners(object):
super(_RemoveListeners, self).teardown()
+class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
+ run_setup_mappers = "once"
+ run_inserts = "once"
+ run_deletes = None
+
+ @classmethod
+ def setup_mappers(cls):
+ cls._setup_stock_mapping()
+
+ def _caching_session_fixture(self):
+
+ cache = {}
+
+ maker = sessionmaker(testing.db, future=True)
+
+ def get_value(cache_key, cache, createfunc):
+ if cache_key in cache:
+ return cache[cache_key]()
+ else:
+ cache[cache_key] = retval = createfunc().freeze()
+ return retval()
+
+ @event.listens_for(maker, "do_orm_execute", retval=True)
+ def do_orm_execute(orm_context):
+ ckey = None
+ for opt in orm_context.user_defined_options:
+ ckey = opt.get_cache_key(orm_context)
+ if ckey:
+ break
+ else:
+ if "cache_key" in orm_context.execution_options:
+ ckey = orm_context.execution_options["cache_key"]
+
+ if ckey is not None:
+ return get_value(ckey, cache, orm_context.invoke_statement,)
+
+ return maker()
+
+ def test_cache_option(self):
+ User, Address = self.classes("User", "Address")
+
+ with self.sql_execution_asserter(testing.db) as asserter:
+
+ with self._caching_session_fixture() as session:
+ stmt = (
+ select(User)
+ .where(User.id == 7)
+ .execution_options(cache_key="user7")
+ )
+
+ result = session.execute(stmt)
+
+ eq_(
+ result.scalars().all(),
+ [User(id=7, addresses=[Address(id=1)])],
+ )
+
+ result = session.execute(stmt)
+
+ eq_(
+ result.scalars().all(),
+ [User(id=7, addresses=[Address(id=1)])],
+ )
+
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT users.id, users.name FROM users "
+ "WHERE users.id = :id_1",
+ [{"id_1": 7}],
+ ),
+ CompiledSQL(
+ "SELECT addresses.id AS addresses_id, addresses.user_id AS "
+ "addresses_user_id, "
+ "addresses.email_address AS addresses_email_address "
+ "FROM addresses WHERE :param_1 = addresses.user_id "
+ "ORDER BY addresses.id",
+ [{"param_1": 7}],
+ ),
+ )
+
+ def test_chained_events_one(self):
+
+ sess = Session(testing.db, future=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ ctx.update_execution_options(one=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def two(ctx):
+ ctx.update_execution_options(two=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def three(ctx):
+ ctx.update_execution_options(three=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def four(ctx):
+ ctx.update_execution_options(four=True)
+
+ result = sess.execute(select(literal_column("1")))
+
+ eq_(
+ result.context.execution_options,
+ {
+ "four": True,
+ "future_result": True,
+ "one": True,
+ "three": True,
+ "two": True,
+ },
+ )
+
+ def test_chained_events_two(self):
+
+ sess = Session(testing.db, future=True)
+
+ def added(ctx):
+ ctx.update_execution_options(added_evt=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def one(ctx):
+ ctx.update_execution_options(one=True)
+
+ @event.listens_for(sess, "do_orm_execute", retval=True)
+ def two(ctx):
+ ctx.update_execution_options(two=True)
+ return ctx.invoke_statement(
+ statement=ctx.statement.execution_options(statement_two=True)
+ )
+
+ @event.listens_for(sess, "do_orm_execute")
+ def three(ctx):
+ ctx.update_execution_options(three=True)
+
+ @event.listens_for(sess, "do_orm_execute")
+ def four(ctx):
+ ctx.update_execution_options(four=True)
+ return ctx.invoke_statement(
+ statement=ctx.statement.execution_options(statement_four=True)
+ )
+
+ @event.listens_for(sess, "do_orm_execute")
+ def five(ctx):
+ ctx.update_execution_options(five=True)
+
+ result = sess.execute(select(literal_column("1")), _add_event=added)
+
+ eq_(
+ result.context.execution_options,
+ {
+ "statement_two": True,
+ "statement_four": True,
+ "future_result": True,
+ "one": True,
+ "two": True,
+ "three": True,
+ "four": True,
+ "five": True,
+ "added_evt": True,
+ },
+ )
+
+
class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
run_inserts = None
diff --git a/test/orm/test_options.py b/test/orm/test_options.py
index 208db9d85..b5a6e3b29 100644
--- a/test/orm/test_options.py
+++ b/test/orm/test_options.py
@@ -1391,6 +1391,7 @@ class PickleTest(PathTest, QueryTest):
"propagate_to_loaders": True,
"_of_type": None,
"_to_bind": to_bind,
+ "_extra_criteria": (),
},
)
diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py
new file mode 100644
index 000000000..c4bcf0404
--- /dev/null
+++ b/test/orm/test_relationship_criteria.py
@@ -0,0 +1,867 @@
+import datetime
+import random
+
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import event
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import orm
+from sqlalchemy import select
+from sqlalchemy import sql
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import with_loader_criteria
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing.assertsql import CompiledSQL
+from test.orm import _fixtures
+
+
+class _Fixtures(_fixtures.FixtureTest):
+ @testing.fixture
+ def user_address_fixture(self):
+ users, Address, addresses, User = (
+ self.tables.users,
+ self.classes.Address,
+ self.tables.addresses,
+ self.classes.User,
+ )
+
+ mapper(
+ User,
+ users,
+ properties={
+ "addresses": relationship(
+ mapper(Address, addresses), order_by=Address.id
+ )
+ },
+ )
+ return User, Address
+
+ @testing.fixture
+ def order_item_fixture(self):
+ Order, Item = self.classes("Order", "Item")
+ orders, items, order_items = self.tables(
+ "orders", "items", "order_items"
+ )
+
+ mapper(
+ Order,
+ orders,
+ properties={
+ # m2m
+ "items": relationship(
+ Item, secondary=order_items, order_by=items.c.id
+ ),
+ },
+ )
+ mapper(Item, items)
+
+ return Order, Item
+
+ @testing.fixture
+ def mixin_fixture(self):
+ users = self.tables.users
+
+ class HasFoob(object):
+ name = Column(String)
+
+ class UserWFoob(HasFoob, self.Comparable):
+ pass
+
+ mapper(
+ UserWFoob, users,
+ )
+ return HasFoob, UserWFoob
+
+
+class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
+ """
+ combinations:
+
+
+ with_loader_criteria
+ # for these we have mapper_criteria
+
+ select(mapper) # select_mapper
+ select(mapper.col, mapper.col) # select_mapper_col
+ select(func.count()).select_from(mapper) # select_from_mapper
+ select(a).join(mapper, a.target) # select_join_mapper
+ select(a).options(joinedload(a.target)) # select_joinedload_mapper
+
+
+ # for these we have aliased_criteria, inclaliased_criteria
+
+ select(aliased) # select_aliased
+ select(aliased.col, aliased.col) # select_aliased_col
+ select(func.count()).select_from(aliased) # select_from_aliased
+ select(a).join(aliased, a.target) # select_join_aliased
+ select(a).options(joinedload(a.target.of_type(aliased))
+ # select_joinedload_aliased
+
+ """
+
+ __dialect__ = "default"
+
+ def test_select_mapper_mapper_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ stmt = select(User).options(
+ with_loader_criteria(User, User.name != "name")
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name "
+ "FROM users WHERE users.name != :name_1",
+ )
+
+ def test_select_from_mapper_mapper_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ stmt = (
+ select(sql.func.count())
+ .select_from(User)
+ .options(with_loader_criteria(User, User.name != "name"))
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT count(*) AS count_1 FROM users "
+ "WHERE users.name != :name_1",
+ )
+
+ def test_select_mapper_columns_mapper_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ stmt = select(User.id, User.name).options(
+ with_loader_criteria(User, User.name != "name")
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name "
+ "FROM users WHERE users.name != :name_1",
+ )
+
+ def test_select_join_mapper_mapper_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ stmt = (
+ select(User)
+ .join(User.addresses)
+ .options(
+ with_loader_criteria(Address, Address.email_address != "name")
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name FROM users "
+ "JOIN addresses ON users.id = addresses.user_id "
+ "AND addresses.email_address != :email_address_1",
+ )
+
+ def test_select_joinm2m_mapper_mapper_criteria(self, order_item_fixture):
+ Order, Item = order_item_fixture
+
+ stmt = (
+ select(Order)
+ .join(Order.items)
+ .options(
+ with_loader_criteria(Item, Item.description != "description")
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT orders.id, orders.user_id, orders.address_id, "
+ "orders.description, orders.isopen FROM orders "
+ "JOIN order_items AS order_items_1 "
+ "ON orders.id = order_items_1.order_id "
+ "JOIN items ON items.id = order_items_1.item_id "
+ "AND items.description != :description_1",
+ )
+
+ def test_select_joinedload_mapper_mapper_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ stmt = select(User).options(
+ joinedload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != "name"),
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name, addresses_1.id AS id_1, "
+ "addresses_1.user_id, addresses_1.email_address "
+ "FROM users LEFT OUTER JOIN addresses AS addresses_1 "
+ "ON users.id = addresses_1.user_id "
+ "AND addresses_1.email_address != :email_address_1 "
+ "ORDER BY addresses_1.id",
+ )
+
+ def test_select_selectinload_mapper_mapper_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ stmt = select(User).options(
+ selectinload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != "name"),
+ )
+
+ s = Session(testing.db, future=True)
+
+ with self.sql_execution_asserter() as asserter:
+
+ s.execute(stmt).all()
+
+ asserter.assert_(
+ CompiledSQL("SELECT users.id, users.name FROM users", [],),
+ CompiledSQL(
+ "SELECT addresses.user_id AS addresses_user_id, addresses.id "
+ "AS addresses_id, addresses.email_address "
+ "AS addresses_email_address FROM addresses "
+ "WHERE addresses.user_id IN ([POSTCOMPILE_primary_keys]) "
+ "AND addresses.email_address != :email_address_1 "
+ "ORDER BY addresses.id",
+ [{"primary_keys": [7, 8, 9, 10], "email_address_1": "name"}],
+ ),
+ )
+
+ def test_select_lazyload_mapper_mapper_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ stmt = (
+ select(User)
+ .options(
+ with_loader_criteria(Address, Address.email_address != "name"),
+ )
+ .order_by(User.id)
+ )
+
+ s = Session(testing.db, future=True)
+
+ with self.sql_execution_asserter() as asserter:
+ for u in s.execute(stmt).scalars():
+ u.addresses
+
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT users.id, users.name FROM users ORDER BY users.id", [],
+ ),
+ CompiledSQL(
+ "SELECT addresses.id AS addresses_id, "
+ "addresses.user_id AS addresses_user_id, "
+ "addresses.email_address AS addresses_email_address "
+ "FROM addresses WHERE :param_1 = addresses.user_id "
+ "AND addresses.email_address != :email_address_1 "
+ "ORDER BY addresses.id",
+ [{"param_1": 7, "email_address_1": "name"}],
+ ),
+ CompiledSQL(
+ "SELECT addresses.id AS addresses_id, "
+ "addresses.user_id AS addresses_user_id, "
+ "addresses.email_address AS addresses_email_address "
+ "FROM addresses WHERE :param_1 = addresses.user_id "
+ "AND addresses.email_address != :email_address_1 "
+ "ORDER BY addresses.id",
+ [{"param_1": 8, "email_address_1": "name"}],
+ ),
+ CompiledSQL(
+ "SELECT addresses.id AS addresses_id, "
+ "addresses.user_id AS addresses_user_id, "
+ "addresses.email_address AS addresses_email_address "
+ "FROM addresses WHERE :param_1 = addresses.user_id "
+ "AND addresses.email_address != :email_address_1 "
+ "ORDER BY addresses.id",
+ [{"param_1": 9, "email_address_1": "name"}],
+ ),
+ CompiledSQL(
+ "SELECT addresses.id AS addresses_id, "
+ "addresses.user_id AS addresses_user_id, "
+ "addresses.email_address AS addresses_email_address "
+ "FROM addresses WHERE :param_1 = addresses.user_id "
+ "AND addresses.email_address != :email_address_1 "
+ "ORDER BY addresses.id",
+ [{"param_1": 10, "email_address_1": "name"}],
+ ),
+ )
+
+ def test_select_aliased_inclaliased_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ u1 = aliased(User)
+ stmt = select(u1).options(
+ with_loader_criteria(
+ User, User.name != "name", include_aliases=True
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users_1.id, users_1.name "
+ "FROM users AS users_1 WHERE users_1.name != :name_1",
+ )
+
+ def test_select_from_aliased_inclaliased_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ u1 = aliased(User)
+ stmt = (
+ select(sql.func.count())
+ .select_from(u1)
+ .options(
+ with_loader_criteria(
+ User, User.name != "name", include_aliases=True
+ )
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT count(*) AS count_1 FROM users AS users_1 "
+ "WHERE users_1.name != :name_1",
+ )
+
+ def test_select_aliased_columns_inclaliased_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ u1 = aliased(User)
+ stmt = select(u1.id, u1.name).options(
+ with_loader_criteria(
+ User, User.name != "name", include_aliases=True
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users_1.id, users_1.name "
+ "FROM users AS users_1 WHERE users_1.name != :name_1",
+ )
+
+ def test_select_join_aliased_inclaliased_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ a1 = aliased(Address)
+ stmt = (
+ select(User)
+ .join(User.addresses.of_type(a1))
+ .options(
+ with_loader_criteria(
+ Address,
+ Address.email_address != "name",
+ include_aliases=True,
+ )
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name FROM users "
+ "JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id "
+ "AND addresses_1.email_address != :email_address_1",
+ )
+
+ def test_select_joinm2m_aliased_inclaliased_criteria(
+ self, order_item_fixture
+ ):
+ Order, Item = order_item_fixture
+
+ i1 = aliased(Item)
+
+ stmt = (
+ select(Order)
+ .join(Order.items.of_type(i1))
+ .options(
+ with_loader_criteria(
+ Item,
+ Item.description != "description",
+ include_aliases=True,
+ )
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT orders.id, orders.user_id, orders.address_id, "
+ "orders.description, orders.isopen FROM orders "
+ "JOIN order_items AS order_items_1 "
+ "ON orders.id = order_items_1.order_id "
+ "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
+ "AND items_1.description != :description_1",
+ )
+
+ def test_select_aliased_aliased_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ u1 = aliased(User)
+ stmt = select(u1).options(with_loader_criteria(u1, u1.name != "name"))
+
+ self.assert_compile(
+ stmt,
+ "SELECT users_1.id, users_1.name "
+ "FROM users AS users_1 WHERE users_1.name != :name_1",
+ )
+
+ def test_select_aliased_columns_aliased_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ u1 = aliased(User)
+ stmt = select(u1.id, u1.name).options(
+ with_loader_criteria(u1, u1.name != "name")
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users_1.id, users_1.name "
+ "FROM users AS users_1 WHERE users_1.name != :name_1",
+ )
+
+ def test_joinedload_global_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ s = Session(testing.db, future=True)
+
+ stmt = select(User).options(
+ joinedload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != "email"),
+ )
+
+ with self.sql_execution_asserter() as asserter:
+
+ s.execute(stmt)
+
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT users.id, users.name, addresses_1.id AS id_1, "
+ "addresses_1.user_id, addresses_1.email_address FROM "
+ "users LEFT OUTER JOIN addresses AS addresses_1 "
+ "ON users.id = addresses_1.user_id "
+ "AND addresses_1.email_address != :email_address_1 "
+ "ORDER BY addresses_1.id",
+ [{"email_address_1": "email"}],
+ ),
+ )
+
+ def test_query_count_global_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ s = Session(testing.db)
+
+ q = s.query(User).options(with_loader_criteria(User, User.id != 8))
+
+ with self.sql_execution_asserter() as asserter:
+ q.count()
+
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT count(*) AS count_1 FROM (SELECT "
+ "users.id AS users_id, users.name AS users_name "
+ "FROM users WHERE users.id != :id_1) AS anon_1",
+ [{"id_1": 8}],
+ ),
+ )
+
+ def test_query_count_after_the_fact_global_criteria(
+ self, user_address_fixture
+ ):
+ User, Address = user_address_fixture
+
+ s = Session(testing.db)
+
+ # this essentially tests that the query.from_self() which takes
+ # place in count() is one that can still be affected by
+ # the loader criteria, meaning it has to be an ORM query
+
+ q = s.query(User)
+
+ @event.listens_for(s, "do_orm_execute")
+ def add_criteria(orm_context):
+ orm_context.statement = orm_context.statement.options(
+ with_loader_criteria(User, User.id != 8)
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ q.count()
+
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT count(*) AS count_1 FROM (SELECT "
+ "users.id AS users_id, users.name AS users_name "
+ "FROM users WHERE users.id != :id_1) AS anon_1",
+ [{"id_1": 8}],
+ ),
+ )
+
+ def test_select_count_subquery_global_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ stmt = select(User).subquery()
+
+ stmt = (
+ select(sql.func.count())
+ .select_from(stmt)
+ .options(with_loader_criteria(User, User.id != 8))
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT count(*) AS count_1 FROM (SELECT users.id AS id, "
+ "users.name AS name FROM users WHERE users.id != :id_1) AS anon_1",
+ )
+
+ def test_query_outerjoin_global_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ s = Session(testing.db)
+
+ q = (
+ s.query(User, Address)
+ .outerjoin(User.addresses)
+ .options(
+ with_loader_criteria(
+ Address, ~Address.email_address.like("ed@%"),
+ )
+ )
+ .order_by(User.id)
+ )
+
+ self.assert_compile(
+ q,
+ "SELECT users.id AS users_id, users.name AS users_name, "
+ "addresses.id AS addresses_id, "
+ "addresses.user_id AS addresses_user_id, "
+ "addresses.email_address AS addresses_email_address "
+ "FROM users LEFT OUTER JOIN addresses "
+ "ON users.id = addresses.user_id AND "
+ "addresses.email_address NOT LIKE :email_address_1 "
+ "ORDER BY users.id",
+ )
+ eq_(
+ q.all(),
+ [
+ (User(id=7), Address(id=1)),
+ (User(id=8), None), # three addresses not here
+ (User(id=9), Address(id=5)),
+ (User(id=10), None),
+ ],
+ )
+
+ def test_caching_and_binds_lambda(self, mixin_fixture):
+ HasFoob, UserWFoob = mixin_fixture
+
+ statement = select(UserWFoob).filter(UserWFoob.id < 10)
+
+ def go(value):
+ return statement.options(
+ with_loader_criteria(
+ HasFoob,
+ lambda cls: cls.name == value,
+ include_aliases=True,
+ )
+ )
+
+ s = Session(testing.db, future=True)
+
+ for i in range(10):
+ name = random.choice(["ed", "fred", "jack"])
+ stmt = go(name)
+
+ eq_(s.execute(stmt).scalars().all(), [UserWFoob(name=name)])
+
+
+class TemporalFixtureTest(testing.fixtures.DeclarativeMappedTest):
+ @classmethod
+ def setup_classes(cls):
+ class HasTemporal(object):
+ """Mixin that identifies a class as having a timestamp column"""
+
+ timestamp = Column(
+ DateTime, default=datetime.datetime.utcnow, nullable=False
+ )
+
+ cls.HasTemporal = HasTemporal
+
+ def temporal_range(range_lower, range_upper):
+ return with_loader_criteria(
+ HasTemporal,
+ lambda cls: cls.timestamp.between(range_lower, range_upper),
+ include_aliases=True,
+ )
+
+ cls.temporal_range = staticmethod(temporal_range)
+
+ class Parent(HasTemporal, cls.DeclarativeBasic):
+ __tablename__ = "parent"
+ id = Column(Integer, primary_key=True)
+ children = relationship("Child", order_by="Child.id")
+
+ class Child(HasTemporal, cls.DeclarativeBasic):
+ __tablename__ = "child"
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(
+ Integer, ForeignKey("parent.id"), nullable=False
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ Parent, Child = cls.classes("Parent", "Child")
+
+ sess = Session(connection)
+ c1, c2, c3, c4, c5 = [
+ Child(timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00)),
+ Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
+ Child(timestamp=datetime.datetime(2009, 10, 20, 12, 00, 00)),
+ Child(timestamp=datetime.datetime(2009, 10, 12, 12, 00, 00)),
+ Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
+ ]
+
+ p1 = Parent(
+ timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00),
+ children=[c1, c2, c3],
+ )
+ p2 = Parent(
+ timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00),
+ children=[c4, c5],
+ )
+
+ sess.add_all([p1, p2])
+ sess.commit()
+
+ @testing.combinations((True,), (False,), argnames="use_caching")
+ @testing.combinations(
+ (None,),
+ (orm.lazyload,),
+ (orm.joinedload,),
+ (orm.subqueryload,),
+ (orm.selectinload,),
+ argnames="loader_strategy",
+ )
+ def test_same_relatinship_load_different_range(
+ self, use_caching, loader_strategy
+ ):
+ """This is the first test that exercises lazy loading, which uses
+ a lambda select, which then needs to transform the select to have
+ different bound parameters if it's not cached (or generate a working
+ list of parameters if it is), which then calls into a
+ with_loader_crieria that itself has another lambda inside of it,
+ which means we have to traverse and replace that lambda's expression,
+ but we can't evaluate it until compile time, so the inner lambda
+ holds onto the "transform" function so it can run it as needed.
+ this makes use of a new feature in visitors that exports a
+ "run this traversal later" function.
+
+ All of these individual features, cloning lambdaelements,
+ running replacement traversals later, are very new and need a lot
+ of tests, most likely in test/sql/test_lambdas.py.
+
+ the test is from the "temporal_range" example which is the whole
+ use case this feature is designed for and it is a whopper.
+
+
+ """
+ Parent, Child = self.classes("Parent", "Child")
+ temporal_range = self.temporal_range
+
+ if use_caching:
+ Parent.children.property.bake_queries = True
+ eng = testing.db
+ else:
+ Parent.children.property.bake_queries = False
+ eng = testing.db.execution_options(compiled_cache=None)
+
+ sess = Session(eng, future=True)
+
+ if loader_strategy:
+ loader_options = (loader_strategy(Parent.children),)
+ else:
+ loader_options = ()
+
+ p1 = sess.execute(
+ select(Parent).filter(
+ Parent.timestamp == datetime.datetime(2009, 10, 15, 12, 00, 00)
+ )
+ ).scalar()
+ c1, c2 = p1.children[0:2]
+ c2_id = c2.id
+
+ p2 = sess.execute(
+ select(Parent).filter(
+ Parent.timestamp == datetime.datetime(2009, 10, 17, 12, 00, 00)
+ )
+ ).scalar()
+ c5 = p2.children[1]
+
+ parents = (
+ sess.execute(
+ select(Parent)
+ .execution_options(populate_existing=True)
+ .options(
+ temporal_range(
+ datetime.datetime(2009, 10, 16, 12, 00, 00),
+ datetime.datetime(2009, 10, 18, 12, 00, 00),
+ ),
+ *loader_options
+ )
+ )
+ .scalars()
+ .all()
+ )
+
+ assert parents[0] == p2
+ assert parents[0].children == [c5]
+
+ parents = (
+ sess.execute(
+ select(Parent)
+ .execution_options(populate_existing=True)
+ .join(Parent.children)
+ .filter(Child.id == c2_id)
+ .options(
+ temporal_range(
+ datetime.datetime(2009, 10, 15, 11, 00, 00),
+ datetime.datetime(2009, 10, 18, 12, 00, 00),
+ ),
+ *loader_options
+ )
+ )
+ .scalars()
+ .all()
+ )
+
+ assert parents[0] == p1
+ assert parents[0].children == [c1, c2]
+
+
+class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ @testing.fixture
+ def user_address_fixture(self):
+ users, Address, addresses, User = (
+ self.tables.users,
+ self.classes.Address,
+ self.tables.addresses,
+ self.classes.User,
+ )
+
+ mapper(
+ User,
+ users,
+ properties={
+ "addresses": relationship(
+ mapper(Address, addresses), order_by=Address.id
+ )
+ },
+ )
+ return User, Address
+
+ def test_joinedload_local_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ s = Session(testing.db, future=True)
+
+ stmt = select(User).options(
+ joinedload(User.addresses.and_(Address.email_address != "email")),
+ )
+
+ with self.sql_execution_asserter() as asserter:
+
+ s.execute(stmt)
+
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT users.id, users.name, addresses_1.id AS id_1, "
+ "addresses_1.user_id, addresses_1.email_address FROM "
+ "users LEFT OUTER JOIN addresses AS addresses_1 "
+ "ON users.id = addresses_1.user_id "
+ "AND addresses_1.email_address != :email_address_1 "
+ "ORDER BY addresses_1.id",
+ [{"email_address_1": "email"}],
+ ),
+ )
+
+ def test_query_join_local_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ s = Session(testing.db)
+
+ q = s.query(User).join(
+ User.addresses.and_(Address.email_address != "email")
+ )
+
+ self.assert_compile(
+ q,
+ "SELECT users.id AS users_id, users.name AS users_name "
+ "FROM users JOIN addresses ON users.id = addresses.user_id "
+ "AND addresses.email_address != :email_address_1",
+ )
+
+ def test_select_join_local_criteria(self, user_address_fixture):
+ User, Address = user_address_fixture
+
+ stmt = select(User).join(
+ User.addresses.and_(Address.email_address != "email")
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name FROM users JOIN addresses "
+ "ON users.id = addresses.user_id "
+ "AND addresses.email_address != :email_address_1",
+ )
+
+ def test_select_joinm2m_local_criteria(self, order_item_fixture):
+ Order, Item = order_item_fixture
+
+ stmt = select(Order).join(
+ Order.items.and_(Item.description != "description")
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT orders.id, orders.user_id, orders.address_id, "
+ "orders.description, orders.isopen "
+ "FROM orders JOIN order_items AS order_items_1 "
+ "ON orders.id = order_items_1.order_id "
+ "JOIN items ON items.id = order_items_1.item_id "
+ "AND items.description != :description_1",
+ )
+
+ def test_select_joinm2m_aliased_local_criteria(self, order_item_fixture):
+ Order, Item = order_item_fixture
+
+ i1 = aliased(Item)
+ stmt = select(Order).join(
+ Order.items.of_type(i1).and_(i1.description != "description")
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT orders.id, orders.user_id, orders.address_id, "
+ "orders.description, orders.isopen "
+ "FROM orders JOIN order_items AS order_items_1 "
+ "ON orders.id = order_items_1.order_id "
+ "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
+ "AND items_1.description != :description_1",
+ )
diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py
index b573accbd..7aad2cab8 100644
--- a/test/sql/test_compare.py
+++ b/test/sql/test_compare.py
@@ -1512,3 +1512,23 @@ class CompareClausesTest(fixtures.TestBase):
is_true(x_p_a.compare(x_p))
is_true(x_p.compare(x_p_a))
is_false(x_p_a.compare(x_a))
+
+
+class ExecutableFlagsTest(fixtures.TestBase):
+ @testing.combinations(
+ (select(column("a")),),
+ (table("q", column("a")).insert(),),
+ (table("q", column("a")).update(),),
+ (table("q", column("a")).delete(),),
+ (lambda_stmt(lambda: select(column("a"))),),
+ )
+ def test_is_select(self, case):
+ if isinstance(case, LambdaElement):
+ resolved_case = case._resolved
+ else:
+ resolved_case = case
+
+ if isinstance(resolved_case, Select):
+ is_true(case.is_select)
+ else:
+ is_false(case.is_select)