diff options
Diffstat (limited to 'lib/sqlalchemy/orm/loading.py')
| -rw-r--r-- | lib/sqlalchemy/orm/loading.py | 62 |
1 files changed, 50 insertions, 12 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 1a5ea5fe6..5d78a5580 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -89,7 +89,13 @@ def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: """ context.runid = _new_runid() - context.post_load_paths = {} + + if context.top_level_context: + is_top_level = False + context.post_load_paths = context.top_level_context.post_load_paths + else: + is_top_level = True + context.post_load_paths = {} compile_state = context.compile_state filtered = compile_state._has_mapper_entities @@ -190,8 +196,28 @@ def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: tuple([proc(row) for proc in process]) for row in fetch ] - for path, post_load in context.post_load_paths.items(): - post_load.invoke(context, path) + # if we are the originating load from a query, meaning we + # aren't being called as a result of a nested "post load", + # iterate through all the collected post loaders and fire them + # off. Previously this used to work recursively, however that + # prevented deeply nested structures from being loadable + if is_top_level: + if yield_per: + # if using yield per, memoize the state of the + # collection so that it can be restored + top_level_post_loads = list( + context.post_load_paths.items() + ) + + while context.post_load_paths: + post_loads = list(context.post_load_paths.items()) + context.post_load_paths.clear() + for path, post_load in post_loads: + post_load.invoke(context, path) + + if yield_per: + context.post_load_paths.clear() + context.post_load_paths.update(top_level_post_loads) yield rows @@ -747,7 +773,6 @@ def _instance_processor( "quick": [], "deferred": [], "expire": [], - "delayed": [], "existing": [], "eager": [], } @@ -1180,8 +1205,7 @@ def _populate_full( for key, populator in populators["new"]: populator(state, dict_, row) - for key, populator in populators["delayed"]: - populator(state, dict_, row) + elif load_path != state.load_path: # new load path, e.g. object is present in more than one # column position in a series of rows @@ -1233,9 +1257,7 @@ def _populate_partial( for key, populator in populators["new"]: if key in to_load: populator(state, dict_, row) - for key, populator in populators["delayed"]: - if key in to_load: - populator(state, dict_, row) + for key, populator in populators["eager"]: if key not in unloaded: populator(state, dict_, row) @@ -1371,14 +1393,23 @@ class PostLoad: if not self.states: return path = path_registry.PathRegistry.coerce(path) - for token, limit_to_mapper, loader, arg, kw in self.loaders.values(): + for ( + effective_context, + token, + limit_to_mapper, + loader, + arg, + kw, + ) in self.loaders.values(): states = [ (state, overwrite) for state, overwrite in self.states.items() if state.manager.mapper.isa(limit_to_mapper) ] if states: - loader(context, path, states, self.load_keys, *arg, **kw) + loader( + effective_context, path, states, self.load_keys, *arg, **kw + ) self.states.clear() @classmethod @@ -1403,7 +1434,14 @@ class PostLoad: pl = context.post_load_paths[path.path] else: pl = context.post_load_paths[path.path] = PostLoad() - pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw) + pl.loaders[token] = ( + context, + token, + limit_to_mapper, + loader_callable, + arg, + kw, + ) def load_scalar_attributes(mapper, state, attribute_names, passive): |
