summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/loading.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/loading.py')
-rw-r--r--lib/sqlalchemy/orm/loading.py62
1 files changed, 50 insertions, 12 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 1a5ea5fe6..5d78a5580 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -89,7 +89,13 @@ def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]:
"""
context.runid = _new_runid()
- context.post_load_paths = {}
+
+ if context.top_level_context:
+ is_top_level = False
+ context.post_load_paths = context.top_level_context.post_load_paths
+ else:
+ is_top_level = True
+ context.post_load_paths = {}
compile_state = context.compile_state
filtered = compile_state._has_mapper_entities
@@ -190,8 +196,28 @@ def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]:
tuple([proc(row) for proc in process]) for row in fetch
]
- for path, post_load in context.post_load_paths.items():
- post_load.invoke(context, path)
+ # if we are the originating load from a query, meaning we
+ # aren't being called as a result of a nested "post load",
+ # iterate through all the collected post loaders and fire them
+ # off. Previously this used to work recursively, however that
+ # prevented deeply nested structures from being loadable
+ if is_top_level:
+ if yield_per:
+ # if using yield per, memoize the state of the
+ # collection so that it can be restored
+ top_level_post_loads = list(
+ context.post_load_paths.items()
+ )
+
+ while context.post_load_paths:
+ post_loads = list(context.post_load_paths.items())
+ context.post_load_paths.clear()
+ for path, post_load in post_loads:
+ post_load.invoke(context, path)
+
+ if yield_per:
+ context.post_load_paths.clear()
+ context.post_load_paths.update(top_level_post_loads)
yield rows
@@ -747,7 +773,6 @@ def _instance_processor(
"quick": [],
"deferred": [],
"expire": [],
- "delayed": [],
"existing": [],
"eager": [],
}
@@ -1180,8 +1205,7 @@ def _populate_full(
for key, populator in populators["new"]:
populator(state, dict_, row)
- for key, populator in populators["delayed"]:
- populator(state, dict_, row)
+
elif load_path != state.load_path:
# new load path, e.g. object is present in more than one
# column position in a series of rows
@@ -1233,9 +1257,7 @@ def _populate_partial(
for key, populator in populators["new"]:
if key in to_load:
populator(state, dict_, row)
- for key, populator in populators["delayed"]:
- if key in to_load:
- populator(state, dict_, row)
+
for key, populator in populators["eager"]:
if key not in unloaded:
populator(state, dict_, row)
@@ -1371,14 +1393,23 @@ class PostLoad:
if not self.states:
return
path = path_registry.PathRegistry.coerce(path)
- for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
+ for (
+ effective_context,
+ token,
+ limit_to_mapper,
+ loader,
+ arg,
+ kw,
+ ) in self.loaders.values():
states = [
(state, overwrite)
for state, overwrite in self.states.items()
if state.manager.mapper.isa(limit_to_mapper)
]
if states:
- loader(context, path, states, self.load_keys, *arg, **kw)
+ loader(
+ effective_context, path, states, self.load_keys, *arg, **kw
+ )
self.states.clear()
@classmethod
@@ -1403,7 +1434,14 @@ class PostLoad:
pl = context.post_load_paths[path.path]
else:
pl = context.post_load_paths[path.path] = PostLoad()
- pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw)
+ pl.loaders[token] = (
+ context,
+ token,
+ limit_to_mapper,
+ loader_callable,
+ arg,
+ kw,
+ )
def load_scalar_attributes(mapper, state, attribute_names, passive):