diff options
Diffstat (limited to 'lib/sqlalchemy/orm/loading.py')
-rw-r--r-- | lib/sqlalchemy/orm/loading.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index a23cafac2..8a20bf0dd 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -394,7 +394,8 @@ def _instance_processor( callable_ = _load_subclass_via_in(context, path, selectin_load_via) PostLoad.callable_for_path( - context, load_path, selectin_load_via, + context, load_path, selectin_load_via.mapper, + selectin_load_via, callable_, selectin_load_via) post_load = PostLoad.for_context(context, load_path, only_load_props) @@ -574,7 +575,6 @@ def _load_subclass_via_in(context, path, entity): primary_keys=[ state.key[1][0] if zero_idx else state.key[1] for state, load_attrs in states - if state.mapper.isa(mapper) ] ).all() @@ -738,16 +738,25 @@ class PostLoad(object): self.load_keys = None def add_state(self, state, overwrite): + # the states for a polymorphic load here are all shared + # within a single PostLoad object among multiple subtypes. + # Filtering of callables on a per-subclass basis needs to be done at + # the invocation level self.states[state] = overwrite def invoke(self, context, path): if not self.states: return path = path_registry.PathRegistry.coerce(path) - for key, loader, arg, kw in self.loaders.values(): + for token, limit_to_mapper, loader, arg, kw in self.loaders.values(): + states = [ + (state, overwrite) + for state, overwrite + in self.states.items() + if state.manager.mapper.isa(limit_to_mapper) + ] loader( - context, path, self.states.items(), - self.load_keys, *arg, **kw) + context, path, states, self.load_keys, *arg, **kw) self.states.clear() @classmethod @@ -764,12 +773,13 @@ class PostLoad(object): @classmethod def callable_for_path( - cls, context, path, attr_key, loader_callable, *arg, **kw): + cls, context, path, limit_to_mapper, token, + loader_callable, *arg, **kw): if path.path in context.post_load_paths: pl = context.post_load_paths[path.path] else: pl = context.post_load_paths[path.path] = PostLoad() - pl.loaders[attr_key] = (attr_key, loader_callable, arg, kw) + pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw) def load_scalar_attributes(mapper, state, attribute_names): |