diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-03-28 11:00:37 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-06-05 11:27:00 -0400 |
commit | bb6a1f690d4a749df44a1ef329b66f71205968fe (patch) | |
tree | 90aac9e592df3a769f5397f84a14b911e4cb52f1 /lib | |
parent | 6bb97495baa640c6f03d1b50affd664cb903dee3 (diff) | |
download | sqlalchemy-bb6a1f690d4a749df44a1ef329b66f71205968fe.tar.gz |
selectin polymorphic loading
Added a new style of mapper-level inheritance loading
"polymorphic selectin". This style of loading
emits queries for each subclass in an inheritance
hierarchy subsequent to the load of the base
object type, using IN to specify the desired
primary key values.
Fixes: #3948
Change-Id: I59e071c6142354a3f95730046e3dcdfc0e2c4de5
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/ext/baked.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/__init__.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/loading.py | 53 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 111 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 18 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/strategy_options.py | 108 | ||||
-rw-r--r-- | lib/sqlalchemy/orm/util.py | 21 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 5 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertsql.py | 48 |
9 files changed, 340 insertions, 37 deletions
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py index ba3c2aed0..c0fe963ac 100644 --- a/lib/sqlalchemy/ext/baked.py +++ b/lib/sqlalchemy/ext/baked.py @@ -154,7 +154,7 @@ class BakedQuery(object): self._spoiled = True return self - def _add_lazyload_options(self, options, effective_path): + def _add_lazyload_options(self, options, effective_path, cache_path=None): """Used by per-state lazy loaders to add options to the "lazy load" query from a parent query. @@ -166,13 +166,16 @@ class BakedQuery(object): key = () - if effective_path.path[0].is_aliased_class: + if not cache_path: + cache_path = effective_path + + if cache_path.path[0].is_aliased_class: # paths that are against an AliasedClass are unsafe to cache # with since the AliasedClass is an ad-hoc object. self.spoil() else: for opt in options: - cache_key = opt._generate_cache_key(effective_path) + cache_key = opt._generate_cache_key(cache_path) if cache_key is False: self.spoil() elif cache_key is not None: @@ -181,7 +184,7 @@ class BakedQuery(object): self.add_criteria( lambda q: q._with_current_path(effective_path). _conditional_options(*options), - effective_path.path, key + cache_path.path, key ) def _retrieve_baked_query(self, session): diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index adfe2360a..7ecd5b67e 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -246,6 +246,7 @@ immediateload = strategy_options.immediateload._unbound_fn noload = strategy_options.noload._unbound_fn raiseload = strategy_options.raiseload._unbound_fn defaultload = strategy_options.defaultload._unbound_fn +selectin_polymorphic = strategy_options.selectin_polymorphic._unbound_fn from .strategy_options import Load @@ -268,6 +269,7 @@ def __go(lcls): from .. import util as sa_util from . import dynamic from . import events + from . import loading import inspect as _inspect __all__ = sorted(name for name, obj in lcls.items() diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 7feec660d..48c0db851 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -19,6 +19,7 @@ from . import attributes, exc as orm_exc from ..sql import util as sql_util from . import strategy_options from . import path_registry +from .. import sql from .util import _none_set, state_str from .base import _SET_DEFERRED_EXPIRED, _DEFER_FOR_STATE @@ -353,6 +354,27 @@ def _instance_processor( session_id = context.session.hash_key version_check = context.version_check runid = context.runid + + if not refresh_state and _polymorphic_from is not None: + key = ('loader', path.path) + if ( + key in context.attributes and + context.attributes[key].strategy == + (('selectinload_polymorphic', True), ) and + mapper in context.attributes[key].local_opts['mappers'] + ) or mapper.polymorphic_load == 'selectin': + + # only_load_props goes w/ refresh_state only, and in a refresh + # we are a single row query for the exact entity; polymorphic + # loading does not apply + assert only_load_props is None + + callable_ = _load_subclass_via_in(context, path, mapper) + + PostLoad.callable_for_path( + context, load_path, mapper, + callable_, mapper) + post_load = PostLoad.for_context(context, load_path, only_load_props) if refresh_state: @@ -501,6 +523,37 @@ def _instance_processor( return _instance +@util.dependencies("sqlalchemy.ext.baked") +def _load_subclass_via_in(baked, context, path, mapper): + + zero_idx = len(mapper.base_mapper.primary_key) == 1 + + q, enable_opt, disable_opt = mapper._subclass_load_via_in + + def do_load(context, path, states, load_only, effective_entity): + orig_query = context.query + + q._add_lazyload_options( + (enable_opt, ) + orig_query._with_options + (disable_opt, ), + path.parent, cache_path=path + ) + + if orig_query._populate_existing: + q.add_criteria( + lambda q: q.populate_existing() + ) + + q(context.session).params( + primary_keys=[ + state.key[1][0] if zero_idx else state.key[1] + for state, load_attrs in states + if state.mapper.isa(mapper) + ] + ).all() + + return do_load + + def _populate_full( context, row, state, dict_, isnew, load_path, loaded_instance, populate_existing, populators): diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 6bf86d0ef..1042442c0 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -106,6 +106,7 @@ class Mapper(InspectionAttr): polymorphic_identity=None, concrete=False, with_polymorphic=None, + polymorphic_load=None, allow_partial_pks=True, batch=True, column_prefix=None, @@ -381,6 +382,27 @@ class Mapper(InspectionAttr): :paramref:`.mapper.passive_deletes` - supporting ON DELETE CASCADE for joined-table inheritance mappers + :param polymorphic_load: Specifies "polymorphic loading" behavior + for a subclass in an inheritance hierarchy (joined and single + table inheritance only). Valid values are: + + * "'inline'" - specifies this class should be part of the + "with_polymorphic" mappers, e.g. its columns will be included + in a SELECT query against the base. + + * "'selectin'" - specifies that when instances of this class + are loaded, an additional SELECT will be emitted to retrieve + the columns specific to this subclass. The SELECT uses + IN to fetch multiple subclasses at once. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`with_polymorphic_mapper_config` + + :ref:`polymorphic_selectin` + :param polymorphic_on: Specifies the column, attribute, or SQL expression used to determine the target class for an incoming row, when inheriting classes are present. @@ -622,8 +644,6 @@ class Mapper(InspectionAttr): else: self.confirm_deleted_rows = confirm_deleted_rows - self._set_with_polymorphic(with_polymorphic) - if isinstance(self.local_table, expression.SelectBase): raise sa_exc.InvalidRequestError( "When mapping against a select() construct, map against " @@ -632,11 +652,8 @@ class Mapper(InspectionAttr): "SELECT from a subquery that does not have an alias." ) - if self.with_polymorphic and \ - isinstance(self.with_polymorphic[1], - expression.SelectBase): - self.with_polymorphic = (self.with_polymorphic[0], - self.with_polymorphic[1].alias()) + self._set_with_polymorphic(with_polymorphic) + self.polymorphic_load = polymorphic_load # our 'polymorphic identity', a string name that when located in a # result set row indicates this Mapper should be used to construct @@ -1037,6 +1054,19 @@ class Mapper(InspectionAttr): ) self.polymorphic_map[self.polymorphic_identity] = self + if self.polymorphic_load and self.concrete: + raise exc.ArgumentError( + "polymorphic_load is not currently supported " + "with concrete table inheritance") + if self.polymorphic_load == 'inline': + self.inherits._add_with_polymorphic_subclass(self) + elif self.polymorphic_load == 'selectin': + pass + elif self.polymorphic_load is not None: + raise sa_exc.ArgumentError( + "unknown argument for polymorphic_load: %r" % + self.polymorphic_load) + else: self._all_tables = set() self.base_mapper = self @@ -1077,9 +1107,22 @@ class Mapper(InspectionAttr): expression.SelectBase): self.with_polymorphic = (self.with_polymorphic[0], self.with_polymorphic[1].alias()) + if self.configured: self._expire_memoizations() + def _add_with_polymorphic_subclass(self, mapper): + subcl = mapper.class_ + if self.with_polymorphic is None: + self._set_with_polymorphic((subcl,)) + elif self.with_polymorphic[0] != '*': + self._set_with_polymorphic( + ( + self.with_polymorphic[0] + (subcl, ), + self.with_polymorphic[1] + ) + ) + def _set_concrete_base(self, mapper): """Set the given :class:`.Mapper` as the 'inherits' for this :class:`.Mapper`, assuming this :class:`.Mapper` is concrete @@ -2663,6 +2706,60 @@ class Mapper(InspectionAttr): cols.extend(props[key].columns) return sql.select(cols, cond, use_labels=True) + @_memoized_configured_property + @util.dependencies( + "sqlalchemy.ext.baked", + "sqlalchemy.orm.strategy_options") + def _subclass_load_via_in(self, baked, strategy_options): + """Assemble a BakedQuery that can load the columns local to + this subclass as a SELECT with IN. + + """ + assert self.inherits + + polymorphic_prop = self._columntoproperty[ + self.polymorphic_on] + keep_props = set( + [polymorphic_prop] + self._identity_key_props) + + disable_opt = strategy_options.Load(self) + enable_opt = strategy_options.Load(self) + + for prop in self.attrs: + if prop.parent is self or prop in keep_props: + # "enable" options, to turn on the properties that we want to + # load by default (subject to options from the query) + enable_opt.set_generic_strategy( + (prop.key, ), + dict(prop.strategy_key) + ) + else: + # "disable" options, to turn off the properties from the + # superclass that we *don't* want to load, applied after + # the options from the query to override them + disable_opt.set_generic_strategy( + (prop.key, ), + {"do_nothing": True} + ) + + if len(self.primary_key) > 1: + in_expr = sql.tuple_(*self.primary_key) + else: + in_expr = self.primary_key[0] + + q = baked.BakedQuery( + self._compiled_cache, + lambda session: session.query(self), + (self, ) + ) + q += lambda q: q.filter( + in_expr.in_( + sql.bindparam('primary_keys', expanding=True) + ) + ).order_by(*self.primary_key) + + return q, enable_opt, disable_opt + def cascade_iterator(self, type_, state, halt_on=None): """Iterate each element and its mapper in an object graph, for all relationships that meet the given cascade rule. diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index dc69ae99d..e48462d35 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -196,6 +196,7 @@ class ColumnLoader(LoaderStrategy): @log.class_logger @properties.ColumnProperty.strategy_for(deferred=True, instrument=True) +@properties.ColumnProperty.strategy_for(do_nothing=True) class DeferredColumnLoader(LoaderStrategy): """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" @@ -336,6 +337,18 @@ class AbstractRelationshipLoader(LoaderStrategy): @log.class_logger +@properties.RelationshipProperty.strategy_for(do_nothing=True) +class DoNothingLoader(LoaderStrategy): + """Relationship loader that makes no change to the object's state. + + Compared to NoLoader, this loader does not initialize the + collection/attribute to empty/none; the usual default LazyLoader will + take effect. + + """ + + +@log.class_logger @properties.RelationshipProperty.strategy_for(lazy="noload") @properties.RelationshipProperty.strategy_for(lazy=None) class NoLoader(AbstractRelationshipLoader): @@ -711,6 +724,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): self, context, path, loadopt, mapper, result, adapter, populators): key = self.key + if not self.is_class_level: # we are not the primary manager for this attribute # on this class - set up a @@ -1804,6 +1818,9 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): selectin_path = ( context.query._current_path or orm_util.PathRegistry.root) + path + if not orm_util._entity_isa(path[-1], self.parent): + return + if loading.PostLoad.path_exists(context, selectin_path, self.key): return @@ -1914,6 +1931,7 @@ class SelectInLoader(AbstractRelationshipLoader, util.MemoizedSlots): } for key, state, overwrite in chunk: + if not overwrite and self.key in state.dict: continue diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index df13f05db..d3f456969 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -13,7 +13,7 @@ from .attributes import QueryableAttribute from .. import util from ..sql.base import _generative, Generative from .. import exc as sa_exc, inspect -from .base import _is_aliased_class, _class_to_mapper +from .base import _is_aliased_class, _class_to_mapper, _is_mapped_class from . import util as orm_util from .path_registry import PathRegistry, TokenRegistry, \ _WILDCARD_TOKEN, _DEFAULT_TOKEN @@ -63,6 +63,7 @@ class Load(Generative, MapperOption): self.context = util.OrderedDict() self.local_opts = {} self._of_type = None + self.is_class_strategy = False @classmethod def for_existing_path(cls, path): @@ -127,6 +128,7 @@ class Load(Generative, MapperOption): return cloned is_opts_only = False + is_class_strategy = False strategy = None propagate_to_loaders = False @@ -148,6 +150,7 @@ class Load(Generative, MapperOption): def _generate_path(self, path, attr, wildcard_key, raiseerr=True): self._of_type = None + if raiseerr and not path.has_entity: if isinstance(path, TokenRegistry): raise sa_exc.ArgumentError( @@ -187,6 +190,14 @@ class Load(Generative, MapperOption): attr = attr.property path = path[attr] + elif _is_mapped_class(attr): + if not attr.common_parent(path.mapper): + if raiseerr: + raise sa_exc.ArgumentError( + "Attribute '%s' does not " + "link from element '%s'" % (attr, path.entity)) + else: + return None else: prop = attr.property @@ -246,6 +257,7 @@ class Load(Generative, MapperOption): self, attr, strategy, propagate_to_loaders=True): strategy = self._coerce_strat(strategy) + self.is_class_strategy = False self.propagate_to_loaders = propagate_to_loaders # if the path is a wildcard, this will set propagate_to_loaders=False self._generate_path(self.path, attr, "relationship") @@ -257,6 +269,7 @@ class Load(Generative, MapperOption): def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False): strategy = self._coerce_strat(strategy) + self.is_class_strategy = False for attr in attrs: cloned = self._generate() cloned.strategy = strategy @@ -267,6 +280,31 @@ class Load(Generative, MapperOption): if opts_only: cloned.is_opts_only = True cloned._set_path_strategy() + self.is_class_strategy = False + + @_generative + def set_generic_strategy(self, attrs, strategy): + strategy = self._coerce_strat(strategy) + + for attr in attrs: + path = self._generate_path(self.path, attr, None) + cloned = self._generate() + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + cloned._set_path_strategy() + + @_generative + def set_class_strategy(self, strategy, opts): + strategy = self._coerce_strat(strategy) + cloned = self._generate() + cloned.is_class_strategy = True + path = cloned._generate_path(self.path, None, None) + cloned.strategy = strategy + cloned.path = path + cloned.propagate_to_loaders = True + cloned._set_path_strategy() + cloned.local_opts.update(opts) def _set_for_path(self, context, path, replace=True, merge_opts=False): if merge_opts or not replace: @@ -284,7 +322,7 @@ class Load(Generative, MapperOption): self.local_opts.update(existing.local_opts) def _set_path_strategy(self): - if self.path.has_entity: + if not self.is_class_strategy and self.path.has_entity: effective_path = self.path.parent else: effective_path = self.path @@ -367,7 +405,10 @@ class _UnboundLoad(Load): if attr == _DEFAULT_TOKEN: self.propagate_to_loaders = False attr = "%s:%s" % (wildcard_key, attr) - path = path + (attr, ) + if path and _is_mapped_class(path[-1]) and not self.is_class_strategy: + path = path[0:-1] + if attr: + path = path + (attr, ) self.path = path return path @@ -502,7 +543,12 @@ class _UnboundLoad(Load): (User, User.orders.property, Order, Order.items.property)) """ + start_path = self.path + + if self.is_class_strategy and current_path: + start_path += (entities[0], ) + # _current_path implies we're in a # secondary load with an existing path @@ -517,7 +563,8 @@ class _UnboundLoad(Load): token = start_path[0] if isinstance(token, util.string_types): - entity = self._find_entity_basestring(entities, token, raiseerr) + entity = self._find_entity_basestring( + entities, token, raiseerr) elif isinstance(token, PropComparator): prop = token.property entity = self._find_entity_prop_comparator( @@ -525,7 +572,10 @@ class _UnboundLoad(Load): prop.key, token._parententity, raiseerr) - + elif self.is_class_strategy and _is_mapped_class(token): + entity = inspect(token) + if entity not in entities: + entity = None else: raise sa_exc.ArgumentError( "mapper option expects " @@ -541,7 +591,6 @@ class _UnboundLoad(Load): # we just located, then go through the rest of our path # tokens and populate into the Load(). loader = Load(path_element) - if context is not None: loader.context = context else: @@ -549,16 +598,19 @@ class _UnboundLoad(Load): loader.strategy = self.strategy loader.is_opts_only = self.is_opts_only + loader.is_class_strategy = self.is_class_strategy path = loader.path - for token in start_path: - if not loader._generate_path( - loader.path, token, None, raiseerr): - return + + if not loader.is_class_strategy: + for token in start_path: + if not loader._generate_path( + loader.path, token, None, raiseerr): + return loader.local_opts.update(self.local_opts) - if loader.path.has_entity: + if not loader.is_class_strategy and loader.path.has_entity: effective_path = loader.path.parent else: effective_path = loader.path @@ -1289,3 +1341,37 @@ def undefer_group(loadopt, name): @undefer_group._add_unbound_fn def undefer_group(name): return _UnboundLoad().undefer_group(name) + + +@loader_option() +def selectin_polymorphic(loadopt, classes): + """Indicate an eager load should take place for all attributes + specific to a subclass. + + This uses an additional SELECT with IN against all matched primary + key values, and is the per-query analogue to the ``"selectin"`` + setting on the :paramref:`.mapper.polymorphic_load` parameter. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`inheritance_polymorphic_load` + + """ + loadopt.set_class_strategy( + {"selectinload_polymorphic": True}, + opts={"mappers": tuple(sorted((inspect(cls) for cls in classes), key=id))} + ) + return loadopt + + +@selectin_polymorphic._add_unbound_fn +def selectin_polymorphic(base_cls, classes): + ul = _UnboundLoad() + ul.is_class_strategy = True + ul.path = (inspect(base_cls), ) + ul.selectin_polymorphic( + classes + ) + return ul diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 9a397ccf3..4267b79fb 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1043,7 +1043,13 @@ def was_deleted(object): state = attributes.instance_state(object) return state.was_deleted + def _entity_corresponds_to(given, entity): + """determine if 'given' corresponds to 'entity', in terms + of an entity passed to Query that would match the same entity + being referred to elsewhere in the query. + + """ if entity.is_aliased_class: if given.is_aliased_class: if entity._base_alias is given._base_alias: @@ -1057,6 +1063,21 @@ def _entity_corresponds_to(given, entity): return entity.common_parent(given) + +def _entity_isa(given, mapper): + """determine if 'given' "is a" mapper, in terms of the given + would load rows of type 'mapper'. + + """ + if given.is_aliased_class: + return mapper in given.with_polymorphic_mappers or \ + given.mapper.isa(mapper) + elif given.with_polymorphic_mappers: + return mapper in given.with_polymorphic_mappers + else: + return given.isa(mapper) + + def randomize_unitofwork(): """Use random-ordering sets within the unit of work in order to detect unit of work sorting issues. diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index dfea33dc7..c0854ea55 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -497,8 +497,9 @@ class AssertsExecutionResults(object): def assert_sql_execution(self, db, callable_, *rules): with self.sql_execution_asserter(db) as asserter: - callable_() + result = callable_() asserter.assert_(*rules) + return result def assert_sql(self, db, callable_, rules): @@ -512,7 +513,7 @@ class AssertsExecutionResults(object): newrule = assertsql.CompiledSQL(*rule) newrules.append(newrule) - self.assert_sql_execution(db, callable_, *newrules) + return self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): self.assert_sql_execution( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e39b6315d..86d850733 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -282,6 +282,32 @@ class AllOf(AssertRule): self.errormessage = list(self.rules)[0].errormessage +class EachOf(AssertRule): + + def __init__(self, *rules): + self.rules = list(rules) + + def process_statement(self, execute_observed): + while self.rules: + rule = self.rules[0] + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.pop(0) + elif rule.errormessage: + self.errormessage = rule.errormessage + if rule.consume_statement: + break + + if not self.rules: + self.is_consumed = True + + def no_more_statements(self): + if self.rules and not self.rules[0].is_consumed: + self.rules[0].no_more_statements() + elif self.rules: + super(EachOf, self).no_more_statements() + + class Or(AllOf): def process_statement(self, execute_observed): @@ -319,24 +345,20 @@ class SQLAsserter(object): del self.accumulated def assert_(self, *rules): - rules = list(rules) - observed = list(self._final) + rule = EachOf(*rules) - while observed and rules: - rule = rules[0] - rule.process_statement(observed[0]) + observed = list(self._final) + while observed: + statement = observed.pop(0) + rule.process_statement(statement) if rule.is_consumed: - rules.pop(0) + break elif rule.errormessage: assert False, rule.errormessage - - if rule.consume_statement: - observed.pop(0) - - if not observed and rules: - rules[0].no_more_statements() - elif not rules and observed: + if observed: assert False, "Additional SQL statements remain" + elif not rule.is_consumed: + rule.no_more_statements() @contextlib.contextmanager |