summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/context.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-08-05 21:47:43 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-08-05 22:13:11 -0400
commitc7b489b25802f7a25ef78d0731411295c611cc1c (patch)
treef5e3b66ab8eb8bb7398c0195fa2b2f1de8ab91c4 /lib/sqlalchemy/orm/context.py
parent71a3ccbdef0d88e9231b7de9c51e4ed60b3b7181 (diff)
downloadsqlalchemy-c7b489b25802f7a25ef78d0731411295c611cc1c.tar.gz
Implement relationship AND criteria; global loader criteria
Added the ability to add arbitrary criteria to the ON clause generated by a relationship attribute in a query, which applies to methods such as :meth:`_query.Query.join` as well as loader options like :func:`_orm.joinedload`. Additionally, a "global" version of the option allows limiting criteria to be applied to particular entities in a query globally. Documentation is minimal at this point, new examples will be coming in a subsequent commit. Some adjustments to execution options in how they are represented in the ORMExecuteState as well as well as a few ORM tests that forgot to get merged in a preceding commit. Fixes: #4472 Change-Id: I2b8fc57092dedf35ebd16f6343ad0f0d7d332beb
Diffstat (limited to 'lib/sqlalchemy/orm/context.py')
-rw-r--r--lib/sqlalchemy/orm/context.py107
1 files changed, 79 insertions, 28 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index 96725e55b..a35b2f9fd 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -11,9 +11,9 @@ from .base import _is_aliased_class
from .interfaces import ORMColumnsClauseRole
from .path_registry import PathRegistry
from .util import _entity_corresponds_to
+from .util import _ORMJoin
from .util import aliased
from .util import Bundle
-from .util import join as orm_join
from .util import ORMAdapter
from .. import exc as sa_exc
from .. import future
@@ -78,7 +78,6 @@ class QueryContext(object):
_yield_per = None
_refresh_state = None
_lazy_loaded_from = None
- _params = _EMPTY_DICT
def __init__(
self,
@@ -308,6 +307,9 @@ class ORMFromStatementCompileState(ORMCompileState):
multi_row_eager_loaders = False
compound_eager_adapter = None
+ extra_criteria_entities = _EMPTY_DICT
+ eager_joins = _EMPTY_DICT
+
@classmethod
def create_for_statement(cls, statement_container, compiler, **kw):
@@ -338,6 +340,7 @@ class ORMFromStatementCompileState(ORMCompileState):
if toplevel and statement_container._with_options:
self.attributes = {"_unbound_load_dedupes": set()}
+ self.global_attributes = compiler._global_attributes
for opt in statement_container._with_options:
if opt._is_compile_state:
@@ -345,6 +348,7 @@ class ORMFromStatementCompileState(ORMCompileState):
else:
self.attributes = {}
+ self.global_attributes = compiler._global_attributes
if statement_container._with_context_options:
for fn, key in statement_container._with_context_options:
@@ -352,8 +356,6 @@ class ORMFromStatementCompileState(ORMCompileState):
self.primary_columns = []
self.secondary_columns = []
- self.eager_joins = {}
- self.single_inh_entities = {}
self.create_eager_joins = []
self._fallback_from_clauses = []
@@ -423,11 +425,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def create_for_statement(cls, statement, compiler, **kw):
"""compiler hook, we arrive here from compiler.visit_select() only."""
+ self = cls.__new__(cls)
+
if compiler is not None:
toplevel = not compiler.stack
compiler._rewrites_selected_columns = True
+ self.global_attributes = compiler._global_attributes
else:
toplevel = True
+ self.global_attributes = {}
select_statement = statement
@@ -437,8 +443,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement._compile_options
)
- self = cls.__new__(cls)
-
if select_statement._execution_options:
# execution options should not impact the compilation of a
# query, and at the moment subqueryloader is putting some things
@@ -516,7 +520,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.primary_columns = []
self.secondary_columns = []
self.eager_joins = {}
- self.single_inh_entities = {}
+ self.extra_criteria_entities = {}
self.create_eager_joins = []
self._fallback_from_clauses = []
@@ -634,7 +638,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if self.compile_options._enable_single_crit:
- self._adjust_for_single_inheritance()
+ self._adjust_for_extra_criteria()
if not self.primary_columns:
if self.compile_options._only_load_props:
@@ -1408,6 +1412,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left, right, onclause, prop, create_aliases, aliased_generation
)
+ if not r_info.is_selectable:
+ extra_criteria = self._get_extra_criteria(r_info)
+ else:
+ extra_criteria = ()
+
if replace_from_obj_index is not None:
# splice into an existing element in the
# self._from_obj list
@@ -1416,12 +1425,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
self.from_clauses = (
self.from_clauses[:replace_from_obj_index]
+ [
- orm_join(
+ _ORMJoin(
left_clause,
right,
onclause,
isouter=outerjoin,
full=full,
+ _extra_criteria=extra_criteria,
)
]
+ self.from_clauses[replace_from_obj_index + 1 :]
@@ -1440,8 +1450,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left_clause = left
self.from_clauses = self.from_clauses + [
- orm_join(
- left_clause, r_info, onclause, isouter=outerjoin, full=full
+ _ORMJoin(
+ left_clause,
+ r_info,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
)
]
@@ -1848,8 +1863,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
or kwargs.get("group_by", False)
)
- def _adjust_for_single_inheritance(self):
- """Apply single-table-inheritance filtering.
+ def _get_extra_criteria(self, ext_info):
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in self.global_attributes:
+ return tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in self.global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+ else:
+ return ()
+
+ def _adjust_for_extra_criteria(self):
+ """Apply extra criteria filtering.
For all distinct single-table-inheritance mappers represented in
the columns clause of this query, as well as the "select from entity",
@@ -1857,38 +1887,50 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
clause of the given QueryContext such that only the appropriate
subtypes are selected from the total results.
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ associated with the global context.
+
"""
for fromclause in self.from_clauses:
ext_info = fromclause._annotations.get("parententity", None)
if (
ext_info
- and ext_info.mapper._single_table_criterion is not None
- and ext_info not in self.single_inh_entities
+ and (
+ ext_info.mapper._single_table_criterion is not None
+ or ("additional_entity_criteria", ext_info.mapper)
+ in self.global_attributes
+ )
+ and ext_info not in self.extra_criteria_entities
):
- self.single_inh_entities[ext_info] = (
+ self.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
)
- search = set(self.single_inh_entities.values())
+ search = set(self.extra_criteria_entities.values())
for (ext_info, adapter) in search:
if ext_info in self._join_entities:
continue
+
single_crit = ext_info.mapper._single_table_criterion
+
+ additional_entity_criteria = self._get_extra_criteria(ext_info)
+
if single_crit is not None:
+ additional_entity_criteria += (single_crit,)
+
+ current_adapter = self._get_current_adapter()
+ for crit in additional_entity_criteria:
if adapter:
- single_crit = adapter.traverse(single_crit)
+ crit = adapter.traverse(crit)
- current_adapter = self._get_current_adapter()
if current_adapter:
- single_crit = sql_util._deep_annotate(
- single_crit, {"_orm_adapt": True}
- )
- single_crit = current_adapter(single_crit, False)
- self._where_criteria += (single_crit,)
+ crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+ crit = current_adapter(crit, False)
+ self._where_criteria += (crit,)
def _column_descriptions(query_or_select_stmt, compile_state=None):
@@ -2205,9 +2247,13 @@ class _MapperEntity(_QueryEntity):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
- if single_table_crit is not None:
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
ext_info = self.entity_zero
- compile_state.single_inh_entities[ext_info] = (
+ compile_state.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
)
@@ -2528,8 +2574,13 @@ class _ORMColumnEntity(_ColumnEntity):
ezero = self.entity_zero
single_table_crit = self.mapper._single_table_criterion
- if single_table_crit is not None:
- compile_state.single_inh_entities[ezero] = (
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+
+ compile_state.extra_criteria_entities[ezero] = (
ezero,
ezero._adapter if ezero.is_aliased_class else None,
)