summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/strategy_options.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/strategy_options.py')
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py108
1 files changed, 97 insertions, 11 deletions
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
index df13f05db..d3f456969 100644
--- a/lib/sqlalchemy/orm/strategy_options.py
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -13,7 +13,7 @@ from .attributes import QueryableAttribute
from .. import util
from ..sql.base import _generative, Generative
from .. import exc as sa_exc, inspect
-from .base import _is_aliased_class, _class_to_mapper
+from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class
from . import util as orm_util
from .path_registry import PathRegistry, TokenRegistry, \
_WILDCARD_TOKEN, _DEFAULT_TOKEN
@@ -63,6 +63,7 @@ class Load(Generative, MapperOption):
self.context = util.OrderedDict()
self.local_opts = {}
self._of_type = None
+ self.is_class_strategy = False
@classmethod
def for_existing_path(cls, path):
@@ -127,6 +128,7 @@ class Load(Generative, MapperOption):
return cloned
is_opts_only = False
+ is_class_strategy = False
strategy = None
propagate_to_loaders = False
@@ -148,6 +150,7 @@ class Load(Generative, MapperOption):
def _generate_path(self, path, attr, wildcard_key, raiseerr=True):
self._of_type = None
+
if raiseerr and not path.has_entity:
if isinstance(path, TokenRegistry):
raise sa_exc.ArgumentError(
@@ -187,6 +190,14 @@ class Load(Generative, MapperOption):
attr = attr.property
path = path[attr]
+ elif _is_mapped_class(attr):
+ if not attr.common_parent(path.mapper):
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Attribute '%s' does not "
+ "link from element '%s'" % (attr, path.entity))
+ else:
+ return None
else:
prop = attr.property
@@ -246,6 +257,7 @@ class Load(Generative, MapperOption):
self, attr, strategy, propagate_to_loaders=True):
strategy = self._coerce_strat(strategy)
+ self.is_class_strategy = False
self.propagate_to_loaders = propagate_to_loaders
# if the path is a wildcard, this will set propagate_to_loaders=False
self._generate_path(self.path, attr, "relationship")
@@ -257,6 +269,7 @@ class Load(Generative, MapperOption):
def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False):
strategy = self._coerce_strat(strategy)
+ self.is_class_strategy = False
for attr in attrs:
cloned = self._generate()
cloned.strategy = strategy
@@ -267,6 +280,31 @@ class Load(Generative, MapperOption):
if opts_only:
cloned.is_opts_only = True
cloned._set_path_strategy()
+ self.is_class_strategy = False
+
+ @_generative
+ def set_generic_strategy(self, attrs, strategy):
+ strategy = self._coerce_strat(strategy)
+
+ for attr in attrs:
+ path = self._generate_path(self.path, attr, None)
+ cloned = self._generate()
+ cloned.strategy = strategy
+ cloned.path = path
+ cloned.propagate_to_loaders = True
+ cloned._set_path_strategy()
+
+ @_generative
+ def set_class_strategy(self, strategy, opts):
+ strategy = self._coerce_strat(strategy)
+ cloned = self._generate()
+ cloned.is_class_strategy = True
+ path = cloned._generate_path(self.path, None, None)
+ cloned.strategy = strategy
+ cloned.path = path
+ cloned.propagate_to_loaders = True
+ cloned._set_path_strategy()
+ cloned.local_opts.update(opts)
def _set_for_path(self, context, path, replace=True, merge_opts=False):
if merge_opts or not replace:
@@ -284,7 +322,7 @@ class Load(Generative, MapperOption):
self.local_opts.update(existing.local_opts)
def _set_path_strategy(self):
- if self.path.has_entity:
+ if not self.is_class_strategy and self.path.has_entity:
effective_path = self.path.parent
else:
effective_path = self.path
@@ -367,7 +405,10 @@ class _UnboundLoad(Load):
if attr == _DEFAULT_TOKEN:
self.propagate_to_loaders = False
attr = "%s:%s" % (wildcard_key, attr)
- path = path + (attr, )
+ if path and _is_mapped_class(path[-1]) and not self.is_class_strategy:
+ path = path[0:-1]
+ if attr:
+ path = path + (attr, )
self.path = path
return path
@@ -502,7 +543,12 @@ class _UnboundLoad(Load):
(User, User.orders.property, Order, Order.items.property))
"""
+
start_path = self.path
+
+ if self.is_class_strategy and current_path:
+ start_path += (entities[0], )
+
# _current_path implies we're in a
# secondary load with an existing path
@@ -517,7 +563,8 @@ class _UnboundLoad(Load):
token = start_path[0]
if isinstance(token, util.string_types):
- entity = self._find_entity_basestring(entities, token, raiseerr)
+ entity = self._find_entity_basestring(
+ entities, token, raiseerr)
elif isinstance(token, PropComparator):
prop = token.property
entity = self._find_entity_prop_comparator(
@@ -525,7 +572,10 @@ class _UnboundLoad(Load):
prop.key,
token._parententity,
raiseerr)
-
+ elif self.is_class_strategy and _is_mapped_class(token):
+ entity = inspect(token)
+ if entity not in entities:
+ entity = None
else:
raise sa_exc.ArgumentError(
"mapper option expects "
@@ -541,7 +591,6 @@ class _UnboundLoad(Load):
# we just located, then go through the rest of our path
# tokens and populate into the Load().
loader = Load(path_element)
-
if context is not None:
loader.context = context
else:
@@ -549,16 +598,19 @@ class _UnboundLoad(Load):
loader.strategy = self.strategy
loader.is_opts_only = self.is_opts_only
+ loader.is_class_strategy = self.is_class_strategy
path = loader.path
- for token in start_path:
- if not loader._generate_path(
- loader.path, token, None, raiseerr):
- return
+
+ if not loader.is_class_strategy:
+ for token in start_path:
+ if not loader._generate_path(
+ loader.path, token, None, raiseerr):
+ return
loader.local_opts.update(self.local_opts)
- if loader.path.has_entity:
+ if not loader.is_class_strategy and loader.path.has_entity:
effective_path = loader.path.parent
else:
effective_path = loader.path
@@ -1289,3 +1341,37 @@ def undefer_group(loadopt, name):
@undefer_group._add_unbound_fn
def undefer_group(name):
return _UnboundLoad().undefer_group(name)
+
+
+@loader_option()
+def selectin_polymorphic(loadopt, classes):
+ """Indicate an eager load should take place for all attributes
+ specific to a subclass.
+
+ This uses an additional SELECT with IN against all matched primary
+ key values, and is the per-query analogue to the ``"selectin"``
+ setting on the :paramref:`.mapper.polymorphic_load` parameter.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`inheritance_polymorphic_load`
+
+ """
+ loadopt.set_class_strategy(
+ {"selectinload_polymorphic": True},
+ opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))}
+ )
+ return loadopt
+
+
+@selectin_polymorphic._add_unbound_fn
+def selectin_polymorphic(base_cls, classes):
+ ul = _UnboundLoad()
+ ul.is_class_strategy = True
+ ul.path = (inspect(base_cls), )
+ ul.selectin_polymorphic(
+ classes
+ )
+ return ul