summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/context.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-06-15 15:13:34 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2021-06-17 09:48:52 -0400
commit5b3e887f46afdbee312d5efd2a14f7c9b7eeac65 (patch)
tree7c12dd2686dc3d26222383d39527b24613e49da3 /lib/sqlalchemy/orm/context.py
parent29fbbd9cebf5d4a4f21d01a74bcfb6dce923fe1b (diff)
downloadsqlalchemy-5b3e887f46afdbee312d5efd2a14f7c9b7eeac65.tar.gz
memoize current options and joins w with_entities/with_only_cols
Fixed further regressions in the same area as that of :ticket:`6052` where loader options as well as invocations of methods like :meth:`_orm.Query.join` would fail if the left side of the statement for which the option/join depends upon were replaced by using the :meth:`_orm.Query.with_entities` method, or when using 2.0 style queries when using the :meth:`_sql.Select.with_only_columns` method. A new set of state has been added to the objects which tracks the "left" entities that the options / join were made against which is memoized when the lead entities are changed. Fixes: #6503 Fixes: #6253 Change-Id: I211b2af98b0b20d1263fb15dc513884dcc5de6a4
Diffstat (limited to 'lib/sqlalchemy/orm/context.py')
-rw-r--r--lib/sqlalchemy/orm/context.py203
1 files changed, 162 insertions, 41 deletions
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
index e4448f953..321eeada0 100644
--- a/lib/sqlalchemy/orm/context.py
+++ b/lib/sqlalchemy/orm/context.py
@@ -322,10 +322,16 @@ class ORMCompileState(CompileState):
return loading.instances(result, querycontext)
@property
- def _mapper_entities(self):
- return (
+ def _lead_mapper_entities(self):
+ """return all _MapperEntity objects in the lead entities collection.
+
+ Does **not** include entities that have been replaced by
+ with_entities(), with_only_columns()
+
+ """
+ return [
ent for ent in self._entities if isinstance(ent, _MapperEntity)
- )
+ ]
def _create_with_polymorphic_adapter(self, ext_info, selectable):
if (
@@ -405,7 +411,9 @@ class ORMFromStatementCompileState(ORMCompileState):
self.use_legacy_query_style,
)
- _QueryEntity.to_compile_state(self, statement_container._raw_columns)
+ _QueryEntity.to_compile_state(
+ self, statement_container._raw_columns, self._entities
+ )
self.current_path = statement_container._compile_options._current_path
@@ -477,6 +485,8 @@ class ORMFromStatementCompileState(ORMCompileState):
class ORMSelectCompileState(ORMCompileState, SelectState):
_joinpath = _joinpoint = _EMPTY_DICT
+ _memoized_entities = _EMPTY_DICT
+
_from_obj_alias = None
_has_mapper_entities = False
@@ -572,15 +582,48 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
statement._label_style, self.use_legacy_query_style
)
- _QueryEntity.to_compile_state(self, select_statement._raw_columns)
+ if select_statement._memoized_select_entities:
+ self._memoized_entities = {
+ memoized_entities: _QueryEntity.to_compile_state(
+ self,
+ memoized_entities._raw_columns,
+ [],
+ )
+ for memoized_entities in (
+ select_statement._memoized_select_entities
+ )
+ }
+
+ _QueryEntity.to_compile_state(
+ self, select_statement._raw_columns, self._entities
+ )
self.current_path = select_statement._compile_options._current_path
self.eager_order_by = ()
- if toplevel and select_statement._with_options:
+ if toplevel and (
+ select_statement._with_options
+ or select_statement._memoized_select_entities
+ ):
self.attributes = {"_unbound_load_dedupes": set()}
+ for (
+ memoized_entities
+ ) in select_statement._memoized_select_entities:
+ for opt in memoized_entities._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state_replaced_entities(
+ self,
+ [
+ ent
+ for ent in self._memoized_entities[
+ memoized_entities
+ ]
+ if isinstance(ent, _MapperEntity)
+ ],
+ )
+
for opt in self.select_statement._with_options:
if opt._is_compile_state:
opt.process_compile_state(self)
@@ -626,11 +669,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
if self.compile_options._set_base_alias:
self._set_select_from_alias()
+ for memoized_entities in query._memoized_select_entities:
+ if memoized_entities._setup_joins:
+ self._join(
+ memoized_entities._setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+ if memoized_entities._legacy_setup_joins:
+ self._legacy_join(
+ memoized_entities._legacy_setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+
if query._setup_joins:
- self._join(query._setup_joins)
+ self._join(query._setup_joins, self._entities)
if query._legacy_setup_joins:
- self._legacy_join(query._legacy_setup_joins)
+ self._legacy_join(query._legacy_setup_joins, self._entities)
current_adapter = self._get_current_adapter()
@@ -782,7 +837,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# entities will also set up polymorphic adapters for mappers
# that have with_polymorphic configured
- _QueryEntity.to_compile_state(self, query._raw_columns)
+ _QueryEntity.to_compile_state(self, query._raw_columns, self._entities)
return self
@classmethod
@@ -921,7 +976,18 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _all_equivs(self):
equivs = {}
- for ent in self._mapper_entities:
+
+ for memoized_entities in self._memoized_entities.values():
+ for ent in [
+ ent
+ for ent in memoized_entities
+ if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+
+ for ent in [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]:
equivs.update(ent.mapper._equivalent_columns)
return equivs
@@ -1211,7 +1277,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
return _adapt_clause
- def _join(self, args):
+ def _join(self, args, entities_collection):
for (right, onclause, from_, flags) in args:
isouter = flags["isouter"]
full = flags["full"]
@@ -1316,6 +1382,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
+ entities_collection,
left,
right,
onclause,
@@ -1326,7 +1393,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
full,
)
- def _legacy_join(self, args):
+ def _legacy_join(self, args, entities_collection):
"""consumes arguments from join() or outerjoin(), places them into a
consistent format with which to form the actual JOIN constructs.
@@ -1474,6 +1541,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# figure out the final "left" and "right" sides and create an
# ORMJoin to add to our _from_obj tuple
self._join_left_to_right(
+ entities_collection,
left,
right,
onclause,
@@ -1489,6 +1557,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
def _join_left_to_right(
self,
+ entities_collection,
left,
right,
onclause,
@@ -1513,7 +1582,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
left,
replace_from_obj_index,
use_entity_index,
- ) = self._join_determine_implicit_left_side(left, right, onclause)
+ ) = self._join_determine_implicit_left_side(
+ entities_collection, left, right, onclause
+ )
else:
# left is given via a relationship/name, or as explicit left side.
# Determine where in our
@@ -1522,7 +1593,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
(
replace_from_obj_index,
use_entity_index,
- ) = self._join_place_explicit_left_side(left)
+ ) = self._join_place_explicit_left_side(entities_collection, left)
if left is right and not create_aliases:
raise sa_exc.InvalidRequestError(
@@ -1568,9 +1639,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# entity_zero.selectable, but if with_polymorphic() were used
# might be distinct
assert isinstance(
- self._entities[use_entity_index], _MapperEntity
+ entities_collection[use_entity_index], _MapperEntity
)
- left_clause = self._entities[use_entity_index].selectable
+ left_clause = entities_collection[use_entity_index].selectable
else:
left_clause = left
@@ -1585,7 +1656,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
)
]
- def _join_determine_implicit_left_side(self, left, right, onclause):
+ def _join_determine_implicit_left_side(
+ self, entities_collection, left, right, onclause
+ ):
"""When join conditions don't express the left side explicitly,
determine if an existing FROM or entity in this query
can serve as the left hand side.
@@ -1635,12 +1708,12 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
"to help resolve the ambiguity." % (right,)
)
- elif self._entities:
+ elif entities_collection:
# we have no explicit FROMs, so the implicit left has to
# come from our list of entities.
potential = {}
- for entity_index, ent in enumerate(self._entities):
+ for entity_index, ent in enumerate(entities_collection):
entity = ent.entity_zero_or_selectable
if entity is None:
continue
@@ -1689,7 +1762,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
return left, replace_from_obj_index, use_entity_index
- def _join_place_explicit_left_side(self, left):
+ def _join_place_explicit_left_side(self, entities_collection, left):
"""When join conditions express a left side explicitly, determine
where in our existing list of FROM clauses we should join towards,
or if we need to make a new join, and if so is it from one of our
@@ -1743,10 +1816,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
# aliasing / adaptation rules present on that entity if any
if (
replace_from_obj_index is None
- and self._entities
+ and entities_collection
and hasattr(l_info, "mapper")
):
- for idx, ent in enumerate(self._entities):
+ for idx, ent in enumerate(entities_collection):
# TODO: should we be checking for multiple mapper entities
# matching?
if isinstance(ent, _MapperEntity) and ent.corresponds_to(left):
@@ -2194,11 +2267,14 @@ class _QueryEntity(object):
__slots__ = ()
@classmethod
- def to_compile_state(cls, compile_state, entities):
+ def to_compile_state(cls, compile_state, entities, entities_collection):
+
for idx, entity in enumerate(entities):
if entity._is_lambda_element:
if entity._is_sequence:
- cls.to_compile_state(compile_state, entity._resolved)
+ cls.to_compile_state(
+ compile_state, entity._resolved, entities_collection
+ )
continue
else:
entity = entity._resolved
@@ -2206,26 +2282,38 @@ class _QueryEntity(object):
if entity.is_clause_element:
if entity.is_selectable:
if "parententity" in entity._annotations:
- _MapperEntity(compile_state, entity)
+ _MapperEntity(
+ compile_state, entity, entities_collection
+ )
else:
_ColumnEntity._for_columns(
- compile_state, entity._select_iterable, idx
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
)
else:
if entity._annotations.get("bundle", False):
- _BundleEntity(compile_state, entity)
+ _BundleEntity(
+ compile_state, entity, entities_collection
+ )
elif entity._is_clause_list:
# this is legacy only - test_composites.py
# test_query_cols_legacy
_ColumnEntity._for_columns(
- compile_state, entity._select_iterable, idx
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
)
else:
_ColumnEntity._for_columns(
- compile_state, [entity], idx
+ compile_state, [entity], entities_collection, idx
)
elif entity.is_bundle:
- _BundleEntity(compile_state, entity)
+ _BundleEntity(compile_state, entity, entities_collection)
+
+ return entities_collection
class _MapperEntity(_QueryEntity):
@@ -2244,8 +2332,8 @@ class _MapperEntity(_QueryEntity):
"_polymorphic_discriminator",
)
- def __init__(self, compile_state, entity):
- compile_state._entities.append(self)
+ def __init__(self, compile_state, entity, entities_collection):
+ entities_collection.append(self)
if compile_state._primary_entity is None:
compile_state._primary_entity = self
compile_state._has_mapper_entities = True
@@ -2418,7 +2506,12 @@ class _BundleEntity(_QueryEntity):
)
def __init__(
- self, compile_state, expr, setup_entities=True, parent_bundle=None
+ self,
+ compile_state,
+ expr,
+ entities_collection,
+ setup_entities=True,
+ parent_bundle=None,
):
compile_state._has_orm_entities = True
@@ -2426,7 +2519,7 @@ class _BundleEntity(_QueryEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
if isinstance(
expr, (attributes.QueryableAttribute, interfaces.PropComparator)
@@ -2443,12 +2536,26 @@ class _BundleEntity(_QueryEntity):
if setup_entities:
for expr in bundle.exprs:
if "bundle" in expr._annotations:
- _BundleEntity(compile_state, expr, parent_bundle=self)
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ parent_bundle=self,
+ )
elif isinstance(expr, Bundle):
- _BundleEntity(compile_state, expr, parent_bundle=self)
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ parent_bundle=self,
+ )
else:
_ORMColumnEntity._for_columns(
- compile_state, [expr], None, parent_bundle=self
+ compile_state,
+ [expr],
+ entities_collection,
+ None,
+ parent_bundle=self,
)
self.supports_single_entity = self.bundle.single_entity
@@ -2516,7 +2623,12 @@ class _ColumnEntity(_QueryEntity):
@classmethod
def _for_columns(
- cls, compile_state, columns, raw_column_index, parent_bundle=None
+ cls,
+ compile_state,
+ columns,
+ entities_collection,
+ raw_column_index,
+ parent_bundle=None,
):
for column in columns:
annotations = column._annotations
@@ -2532,6 +2644,7 @@ class _ColumnEntity(_QueryEntity):
_IdentityTokenEntity(
compile_state,
column,
+ entities_collection,
_entity,
raw_column_index,
parent_bundle=parent_bundle,
@@ -2540,6 +2653,7 @@ class _ColumnEntity(_QueryEntity):
_ORMColumnEntity(
compile_state,
column,
+ entities_collection,
_entity,
raw_column_index,
parent_bundle=parent_bundle,
@@ -2548,6 +2662,7 @@ class _ColumnEntity(_QueryEntity):
_RawColumnEntity(
compile_state,
column,
+ entities_collection,
raw_column_index,
parent_bundle=parent_bundle,
)
@@ -2630,7 +2745,12 @@ class _RawColumnEntity(_ColumnEntity):
)
def __init__(
- self, compile_state, column, raw_column_index, parent_bundle=None
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ parent_bundle=None,
):
self.expr = column
self.raw_column_index = raw_column_index
@@ -2643,7 +2763,7 @@ class _RawColumnEntity(_ColumnEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
self.column = column
self.entity_zero_or_selectable = (
@@ -2690,6 +2810,7 @@ class _ORMColumnEntity(_ColumnEntity):
self,
compile_state,
column,
+ entities_collection,
parententity,
raw_column_index,
parent_bundle=None,
@@ -2729,7 +2850,7 @@ class _ORMColumnEntity(_ColumnEntity):
if parent_bundle:
parent_bundle._entities.append(self)
else:
- compile_state._entities.append(self)
+ entities_collection.append(self)
compile_state._has_orm_entities = True