diff options
Diffstat (limited to 'lib/sqlalchemy/orm/loading.py')
-rw-r--r-- | lib/sqlalchemy/orm/loading.py | 381 |
1 files changed, 241 insertions, 140 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 0a6f8023a..96eddcb32 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -37,32 +37,35 @@ def instances(query, cursor, context): filtered = query._has_mapper_entities - single_entity = not query._only_return_tuples and \ - len(query._entities) == 1 and \ - query._entities[0].supports_single_entity + single_entity = ( + not query._only_return_tuples + and len(query._entities) == 1 + and query._entities[0].supports_single_entity + ) if filtered: if single_entity: filter_fn = id else: + def filter_fn(row): return tuple( - id(item) - if ent.use_id_for_hash - else item + id(item) if ent.use_id_for_hash else item for ent, item in zip(query._entities, row) ) try: - (process, labels) = \ - list(zip(*[ - query_entity.row_processor(query, - context, cursor) - for query_entity in query._entities - ])) + (process, labels) = list( + zip( + *[ + query_entity.row_processor(query, context, cursor) + for query_entity in query._entities + ] + ) + ) if not single_entity: - keyed_tuple = util.lightweight_named_tuple('result', labels) + keyed_tuple = util.lightweight_named_tuple("result", labels) while True: context.partials = {} @@ -78,11 +81,12 @@ def instances(query, cursor, context): proc = process[0] rows = [proc(row) for row in fetch] else: - rows = [keyed_tuple([proc(row) for proc in process]) - for row in fetch] + rows = [ + keyed_tuple([proc(row) for proc in process]) + for row in fetch + ] - for path, post_load in \ - context.post_load_paths.items(): + for path, post_load in context.post_load_paths.items(): post_load.invoke(context, path) if filtered: @@ -113,19 +117,27 @@ def merge_result(querylib, query, iterator, load=True): single_entity = len(query._entities) == 1 if single_entity: if isinstance(query._entities[0], querylib._MapperEntity): - result = [session._merge( - attributes.instance_state(instance), - attributes.instance_dict(instance), - load=load, _recursive={}, _resolve_conflict_map={}) - for instance in iterator] + result = [ + session._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + for instance in iterator + ] else: result = list(iterator) else: - mapped_entities = [i for i, e in enumerate(query._entities) - if isinstance(e, querylib._MapperEntity)] + mapped_entities = [ + i + for i, e in enumerate(query._entities) + if isinstance(e, querylib._MapperEntity) + ] result = [] keys = [ent._label_name for ent in query._entities] - keyed_tuple = util.lightweight_named_tuple('result', keys) + keyed_tuple = util.lightweight_named_tuple("result", keys) for row in iterator: newrow = list(row) for i in mapped_entities: @@ -133,7 +145,10 @@ def merge_result(querylib, query, iterator, load=True): newrow[i] = session._merge( attributes.instance_state(newrow[i]), attributes.instance_dict(newrow[i]), - load=load, _recursive={}, _resolve_conflict_map={}) + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) result.append(keyed_tuple(newrow)) return iter(result) @@ -170,9 +185,9 @@ def get_from_identity(session, key, passive): return None -def load_on_ident(query, key, - refresh_state=None, with_for_update=None, - only_load_props=None): +def load_on_ident( + query, key, refresh_state=None, with_for_update=None, only_load_props=None +): """Load the given identity key from the database.""" if key is not None: @@ -182,16 +197,23 @@ def load_on_ident(query, key, ident = identity_token = None return load_on_pk_identity( - query, ident, refresh_state=refresh_state, + query, + ident, + refresh_state=refresh_state, with_for_update=with_for_update, only_load_props=only_load_props, - identity_token=identity_token + identity_token=identity_token, ) -def load_on_pk_identity(query, primary_key_identity, - refresh_state=None, with_for_update=None, - only_load_props=None, identity_token=None): +def load_on_pk_identity( + query, + primary_key_identity, + refresh_state=None, + with_for_update=None, + only_load_props=None, + identity_token=None, +): """Load the given primary key identity from the database.""" @@ -209,22 +231,28 @@ def load_on_pk_identity(query, primary_key_identity, # None present in ident - turn those comparisons # into "IS NULL" if None in primary_key_identity: - nones = set([ - _get_params[col].key for col, value in - zip(mapper.primary_key, primary_key_identity) - if value is None - ]) - _get_clause = sql_util.adapt_criterion_to_null( - _get_clause, nones) + nones = set( + [ + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + ] + ) + _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) _get_clause = q._adapt_clause(_get_clause, True, False) q._criterion = _get_clause - params = dict([ - (_get_params[primary_key].key, id_val) - for id_val, primary_key - in zip(primary_key_identity, mapper.primary_key) - ]) + params = dict( + [ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + ] + ) q._params = params @@ -243,7 +271,8 @@ def load_on_pk_identity(query, primary_key_identity, version_check=version_check, only_load_props=only_load_props, refresh_state=refresh_state, - identity_token=identity_token) + identity_token=identity_token, + ) q._order_by = None try: @@ -253,27 +282,31 @@ def load_on_pk_identity(query, primary_key_identity, def _setup_entity_query( - context, mapper, query_entity, - path, adapter, column_collection, - with_polymorphic=None, only_load_props=None, - polymorphic_discriminator=None, **kw): + context, + mapper, + query_entity, + path, + adapter, + column_collection, + with_polymorphic=None, + only_load_props=None, + polymorphic_discriminator=None, + **kw +): if with_polymorphic: poly_properties = mapper._iterate_polymorphic_properties( - with_polymorphic) + with_polymorphic + ) else: poly_properties = mapper._polymorphic_properties quick_populators = {} - path.set( - context.attributes, - "memoized_setups", - quick_populators) + path.set(context.attributes, "memoized_setups", quick_populators) for value in poly_properties: - if only_load_props and \ - value.key not in only_load_props: + if only_load_props and value.key not in only_load_props: continue value.setup( context, @@ -286,9 +319,10 @@ def _setup_entity_query( **kw ) - if polymorphic_discriminator is not None and \ - polymorphic_discriminator \ - is not mapper.polymorphic_on: + if ( + polymorphic_discriminator is not None + and polymorphic_discriminator is not mapper.polymorphic_on + ): if adapter: pd = adapter.columns[polymorphic_discriminator] @@ -298,10 +332,16 @@ def _setup_entity_query( def _instance_processor( - mapper, context, result, path, adapter, - only_load_props=None, refresh_state=None, - polymorphic_discriminator=None, - _polymorphic_from=None): + mapper, + context, + result, + path, + adapter, + only_load_props=None, + refresh_state=None, + polymorphic_discriminator=None, + _polymorphic_from=None, +): """Produce a mapper level row processor callable which processes rows into mapped instances.""" @@ -322,11 +362,11 @@ def _instance_processor( props = mapper._prop_set if only_load_props is not None: - props = props.intersection( - mapper._props[k] for k in only_load_props) + props = props.intersection(mapper._props[k] for k in only_load_props) quick_populators = path.get( - context.attributes, "memoized_setups", _none_set) + context.attributes, "memoized_setups", _none_set + ) for prop in props: if prop in quick_populators: @@ -334,7 +374,8 @@ def _instance_processor( col = quick_populators[prop] if col is _DEFER_FOR_STATE: populators["new"].append( - (prop.key, prop._deferred_column_loader)) + (prop.key, prop._deferred_column_loader) + ) elif col is _SET_DEFERRED_EXPIRED: # note that in this path, we are no longer # searching in the result to see if the column might @@ -366,14 +407,19 @@ def _instance_processor( # will iterate through all of its columns # to see if one fits prop.create_row_processor( - context, path, mapper, result, adapter, populators) + context, path, mapper, result, adapter, populators + ) else: prop.create_row_processor( - context, path, mapper, result, adapter, populators) + context, path, mapper, result, adapter, populators + ) propagate_options = context.propagate_options - load_path = context.query._current_path + path \ - if context.query._current_path.path else path + load_path = ( + context.query._current_path + path + if context.query._current_path.path + else path + ) session_identity_map = context.session.identity_map @@ -391,18 +437,18 @@ def _instance_processor( identity_token = context.identity_token if not refresh_state and _polymorphic_from is not None: - key = ('loader', path.path) - if ( - key in context.attributes and - context.attributes[key].strategy == - (('selectinload_polymorphic', True), ) + key = ("loader", path.path) + if key in context.attributes and context.attributes[key].strategy == ( + ("selectinload_polymorphic", True), ): selectin_load_via = mapper._should_selectin_load( - context.attributes[key].local_opts['entities'], - _polymorphic_from) + context.attributes[key].local_opts["entities"], + _polymorphic_from, + ) else: selectin_load_via = mapper._should_selectin_load( - None, _polymorphic_from) + None, _polymorphic_from + ) if selectin_load_via and selectin_load_via is not _polymorphic_from: # only_load_props goes w/ refresh_state only, and in a refresh @@ -413,9 +459,13 @@ def _instance_processor( callable_ = _load_subclass_via_in(context, path, selectin_load_via) PostLoad.callable_for_path( - context, load_path, selectin_load_via.mapper, + context, + load_path, + selectin_load_via.mapper, + selectin_load_via, + callable_, selectin_load_via, - callable_, selectin_load_via) + ) post_load = PostLoad.for_context(context, load_path, only_load_props) @@ -425,8 +475,9 @@ def _instance_processor( # super-rare condition; a refresh is being called # on a non-instance-key instance; this is meant to only # occur within a flush() - refresh_identity_key = \ - mapper._identity_key_from_state(refresh_state) + refresh_identity_key = mapper._identity_key_from_state( + refresh_state + ) else: refresh_identity_key = None @@ -452,7 +503,7 @@ def _instance_processor( identitykey = ( identity_class, tuple([row[column] for column in pk_cols]), - identity_token + identity_token, ) instance = session_identity_map.get(identitykey) @@ -507,8 +558,16 @@ def _instance_processor( state.load_path = load_path _populate_full( - context, row, state, dict_, isnew, load_path, - loaded_instance, populate_existing, populators) + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + populate_existing, + populators, + ) if isnew: if loaded_instance: @@ -518,7 +577,8 @@ def _instance_processor( loaded_as_persistent(context.session, state.obj()) elif refresh_evt: state.manager.dispatch.refresh( - state, context, only_load_props) + state, context, only_load_props + ) if populate_existing or state.modified: if refresh_state and only_load_props: @@ -542,13 +602,19 @@ def _instance_processor( # and add to the "context.partials" collection. to_load = _populate_partial( - context, row, state, dict_, isnew, load_path, - unloaded, populators) + context, + row, + state, + dict_, + isnew, + load_path, + unloaded, + populators, + ) if isnew: if refresh_evt: - state.manager.dispatch.refresh( - state, context, to_load) + state.manager.dispatch.refresh(state, context, to_load) state._commit(dict_, to_load) @@ -561,8 +627,14 @@ def _instance_processor( # if we are doing polymorphic, dispatch to a different _instance() # method specific to the subclass mapper _instance = _decorate_polymorphic_switch( - _instance, context, mapper, result, path, - polymorphic_discriminator, adapter) + _instance, + context, + mapper, + result, + path, + polymorphic_discriminator, + adapter, + ) return _instance @@ -581,14 +653,13 @@ def _load_subclass_via_in(context, path, entity): orig_query = context.query q2 = q._with_lazyload_options( - (enable_opt, ) + orig_query._with_options + (disable_opt, ), - path.parent, cache_path=path + (enable_opt,) + orig_query._with_options + (disable_opt,), + path.parent, + cache_path=path, ) if orig_query._populate_existing: - q2.add_criteria( - lambda q: q.populate_existing() - ) + q2.add_criteria(lambda q: q.populate_existing()) q2(context.session).params( primary_keys=[ @@ -601,8 +672,16 @@ def _load_subclass_via_in(context, path, entity): def _populate_full( - context, row, state, dict_, isnew, load_path, - loaded_instance, populate_existing, populators): + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + populate_existing, + populators, +): if isnew: # first time we are seeing a row with this identity. state.runid = context.runid @@ -650,8 +729,8 @@ def _populate_full( def _populate_partial( - context, row, state, dict_, isnew, load_path, - unloaded, populators): + context, row, state, dict_, isnew, load_path, unloaded, populators +): if not isnew: to_load = context.partials[state] @@ -693,19 +772,32 @@ def _validate_version_id(mapper, state, dict_, row, adapter): if adapter: version_id_col = adapter.columns[version_id_col] - if mapper._get_state_attr_by_column( - state, dict_, mapper.version_id_col) != row[version_id_col]: + if ( + mapper._get_state_attr_by_column(state, dict_, mapper.version_id_col) + != row[version_id_col] + ): raise orm_exc.StaleDataError( "Instance '%s' has version id '%s' which " "does not match database-loaded version id '%s'." - % (state_str(state), mapper._get_state_attr_by_column( - state, dict_, mapper.version_id_col), - row[version_id_col])) + % ( + state_str(state), + mapper._get_state_attr_by_column( + state, dict_, mapper.version_id_col + ), + row[version_id_col], + ) + ) def _decorate_polymorphic_switch( - instance_fn, context, mapper, result, path, - polymorphic_discriminator, adapter): + instance_fn, + context, + mapper, + result, + path, + polymorphic_discriminator, + adapter, +): if polymorphic_discriminator is not None: polymorphic_on = polymorphic_discriminator else: @@ -721,19 +813,22 @@ def _decorate_polymorphic_switch( sub_mapper = mapper.polymorphic_map[discriminator] except KeyError: raise AssertionError( - "No such polymorphic_identity %r is defined" % - discriminator) + "No such polymorphic_identity %r is defined" % discriminator + ) else: if sub_mapper is mapper: return None return _instance_processor( - sub_mapper, context, result, - path, adapter, _polymorphic_from=mapper) + sub_mapper, + context, + result, + path, + adapter, + _polymorphic_from=mapper, + ) - polymorphic_instances = util.PopulateDict( - configure_subclass_mapper - ) + polymorphic_instances = util.PopulateDict(configure_subclass_mapper) def polymorphic_instance(row): discriminator = row[polymorphic_on] @@ -742,6 +837,7 @@ def _decorate_polymorphic_switch( if _instance: return _instance(row) return instance_fn(row) + return polymorphic_instance @@ -749,7 +845,8 @@ class PostLoad(object): """Track loaders and states for "post load" operations. """ - __slots__ = 'loaders', 'states', 'load_keys' + + __slots__ = "loaders", "states", "load_keys" def __init__(self): self.loaders = {} @@ -770,8 +867,7 @@ class PostLoad(object): for token, limit_to_mapper, loader, arg, kw in self.loaders.values(): states = [ (state, overwrite) - for state, overwrite - in self.states.items() + for state, overwrite in self.states.items() if state.manager.mapper.isa(limit_to_mapper) ] if states: @@ -787,13 +883,15 @@ class PostLoad(object): @classmethod def path_exists(self, context, path, key): - return path.path in context.post_load_paths and \ - key in context.post_load_paths[path.path].loaders + return ( + path.path in context.post_load_paths + and key in context.post_load_paths[path.path].loaders + ) @classmethod def callable_for_path( - cls, context, path, limit_to_mapper, token, - loader_callable, *arg, **kw): + cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw + ): if path.path in context.post_load_paths: pl = context.post_load_paths[path.path] else: @@ -809,8 +907,8 @@ def load_scalar_attributes(mapper, state, attribute_names): if not session: raise orm_exc.DetachedInstanceError( "Instance %s is not bound to a Session; " - "attribute refresh operation cannot proceed" % - (state_str(state))) + "attribute refresh operation cannot proceed" % (state_str(state)) + ) has_key = bool(state.key) @@ -833,13 +931,12 @@ def load_scalar_attributes(mapper, state, attribute_names): statement = mapper._optimized_get_statement(state, attribute_names) if statement is not None: result = load_on_ident( - session.query(mapper). - options( - strategy_options.Load(mapper).undefer("*") - ).from_statement(statement), + session.query(mapper) + .options(strategy_options.Load(mapper).undefer("*")) + .from_statement(statement), None, only_load_props=attribute_names, - refresh_state=state + refresh_state=state, ) if result is False: @@ -850,30 +947,34 @@ def load_scalar_attributes(mapper, state, attribute_names): # object is becoming persistent but hasn't yet been assigned # an identity_key. # check here to ensure we have the attrs we need. - pk_attrs = [mapper._columntoproperty[col].key - for col in mapper.primary_key] + pk_attrs = [ + mapper._columntoproperty[col].key for col in mapper.primary_key + ] if state.expired_attributes.intersection(pk_attrs): raise sa_exc.InvalidRequestError( "Instance %s cannot be refreshed - it's not " " persistent and does not " - "contain a full primary key." % state_str(state)) + "contain a full primary key." % state_str(state) + ) identity_key = mapper._identity_key_from_state(state) - if (_none_set.issubset(identity_key) and - not mapper.allow_partial_pks) or \ - _none_set.issuperset(identity_key): + if ( + _none_set.issubset(identity_key) and not mapper.allow_partial_pks + ) or _none_set.issuperset(identity_key): util.warn_limited( "Instance %s to be refreshed doesn't " "contain a full primary key - can't be refreshed " "(and shouldn't be expired, either).", - state_str(state)) + state_str(state), + ) return result = load_on_ident( session.query(mapper), identity_key, refresh_state=state, - only_load_props=attribute_names) + only_load_props=attribute_names, + ) # if instance is pending, a refresh operation # may not complete (even if PK attributes are assigned) |