diff options
Diffstat (limited to 'lib/sqlalchemy/orm/strategy_options.py')
-rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 108 |
1 files changed, 97 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index df13f05db..d3f456969 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -13,7 +13,7 @@ from .attributes import QueryableAttribute from .. import util from ..sql.base import _generative, Generative from .. import exc as sa_exc, inspect -from .base import _is_aliased_class, _class_to_mapper +from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class from . import util as orm_util from .path_registry import PathRegistry, TokenRegistry, \ _WILDCARD_TOKEN, _DEFAULT_TOKEN @@ -63,6 +63,7 @@ class Load(Generative, MapperOption): self.context = util.OrderedDict() self.local_opts = {} self._of_type = None + self.is_class_strategy = False @classmethod def for_existing_path(cls, path): @@ -127,6 +128,7 @@ class Load(Generative, MapperOption): return cloned is_opts_only = False + is_class_strategy = False strategy = None propagate_to_loaders = False @@ -148,6 +150,7 @@ class Load(Generative, MapperOption): def _generate_path(self, path, attr, wildcard_key, raiseerr=True): self._of_type = None + if raiseerr and not path.has_entity: if isinstance(path, TokenRegistry): raise sa_exc.ArgumentError( @@ -187,6 +190,14 @@ class Load(Generative, MapperOption): attr = attr.property path = path[attr] + elif _is_mapped_class(attr): + if not attr.common_parent(path.mapper): + if raiseerr: + raise sa_exc.ArgumentError( + "Attribute '%s' does not " + "link from element '%s'" % (attr, path.entity)) + else: + return None else: prop = attr.property @@ -246,6 +257,7 @@ class Load(Generative, MapperOption): self, attr, strategy, propagate_to_loaders=True): strategy = self._coerce_strat(strategy) + self.is_class_strategy = False self.propagate_to_loaders = propagate_to_loaders # if the path is a wildcard, this will set propagate_to_loaders=False self._generate_path(self.path, attr, "relationship") @@ -257,6 +269,7 @@ class Load(Generative, MapperOption): def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False): strategy = self._coerce_strat(strategy) + self.is_class_strategy = False for attr in attrs: cloned = self._generate() cloned.strategy = strategy @@ -267,6 +280,31 @@ class Load(Generative, MapperOption): if opts_only: cloned.is_opts_only = True cloned._set_path_strategy() + self.is_class_strategy = False + + @_generative + def set_generic_strategy(self, attrs, strategy): + strategy = self._coerce_strat(strategy) + + for attr in attrs: + path = self._generate_path(self.path, attr, None) + cloned = self._generate() + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + cloned._set_path_strategy() + + @_generative + def set_class_strategy(self, strategy, opts): + strategy = self._coerce_strat(strategy) + cloned = self._generate() + cloned.is_class_strategy = True + path = cloned._generate_path(self.path, None, None) + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + cloned._set_path_strategy() + cloned.local_opts.update(opts) def _set_for_path(self, context, path, replace=True, merge_opts=False): if merge_opts or not replace: @@ -284,7 +322,7 @@ class Load(Generative, MapperOption): self.local_opts.update(existing.local_opts) def _set_path_strategy(self): - if self.path.has_entity: + if not self.is_class_strategy and self.path.has_entity: effective_path = self.path.parent else: effective_path = self.path @@ -367,7 +405,10 @@ class _UnboundLoad(Load): if attr == _DEFAULT_TOKEN: self.propagate_to_loaders = False attr = "%s:%s" % (wildcard_key, attr) - path = path + (attr, ) + if path and _is_mapped_class(path[-1]) and not self.is_class_strategy: + path = path[0:-1] + if attr: + path = path + (attr, ) self.path = path return path @@ -502,7 +543,12 @@ class _UnboundLoad(Load): (User, User.orders.property, Order, Order.items.property)) """ + start_path = self.path + + if self.is_class_strategy and current_path: + start_path += (entities[0], ) + # _current_path implies we're in a # secondary load with an existing path @@ -517,7 +563,8 @@ class _UnboundLoad(Load): token = start_path[0] if isinstance(token, util.string_types): - entity = self._find_entity_basestring(entities, token, raiseerr) + entity = self._find_entity_basestring( + entities, token, raiseerr) elif isinstance(token, PropComparator): prop = token.property entity = self._find_entity_prop_comparator( @@ -525,7 +572,10 @@ class _UnboundLoad(Load): prop.key, token._parententity, raiseerr) - + elif self.is_class_strategy and _is_mapped_class(token): + entity = inspect(token) + if entity not in entities: + entity = None else: raise sa_exc.ArgumentError( "mapper option expects " @@ -541,7 +591,6 @@ class _UnboundLoad(Load): # we just located, then go through the rest of our path # tokens and populate into the Load(). loader = Load(path_element) - if context is not None: loader.context = context else: @@ -549,16 +598,19 @@ class _UnboundLoad(Load): loader.strategy = self.strategy loader.is_opts_only = self.is_opts_only + loader.is_class_strategy = self.is_class_strategy path = loader.path - for token in start_path: - if not loader._generate_path( - loader.path, token, None, raiseerr): - return + + if not loader.is_class_strategy: + for token in start_path: + if not loader._generate_path( + loader.path, token, None, raiseerr): + return loader.local_opts.update(self.local_opts) - if loader.path.has_entity: + if not loader.is_class_strategy and loader.path.has_entity: effective_path = loader.path.parent else: effective_path = loader.path @@ -1289,3 +1341,37 @@ def undefer_group(loadopt, name): @undefer_group._add_unbound_fn def undefer_group(name): return _UnboundLoad().undefer_group(name) + + +@loader_option() +def selectin_polymorphic(loadopt, classes): + """Indicate an eager load should take place for all attributes + specific to a subclass. + + This uses an additional SELECT with IN against all matched primary + key values, and is the per-query analogue to the ``"selectin"`` + setting on the :paramref:`.mapper.polymorphic_load` parameter. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`inheritance_polymorphic_load` + + """ + loadopt.set_class_strategy( + {"selectinload_polymorphic": True}, + opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))} + ) + return loadopt + + +@selectin_polymorphic._add_unbound_fn +def selectin_polymorphic(base_cls, classes): + ul = _UnboundLoad() + ul.is_class_strategy = True + ul.path = (inspect(base_cls), ) + ul.selectin_polymorphic( + classes + ) + return ul |