summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/loading.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/loading.py')
-rw-r--r--lib/sqlalchemy/orm/loading.py381
1 files changed, 241 insertions, 140 deletions
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
index 0a6f8023a..96eddcb32 100644
--- a/lib/sqlalchemy/orm/loading.py
+++ b/lib/sqlalchemy/orm/loading.py
@@ -37,32 +37,35 @@ def instances(query, cursor, context):
filtered = query._has_mapper_entities
- single_entity = not query._only_return_tuples and \
- len(query._entities) == 1 and \
- query._entities[0].supports_single_entity
+ single_entity = (
+ not query._only_return_tuples
+ and len(query._entities) == 1
+ and query._entities[0].supports_single_entity
+ )
if filtered:
if single_entity:
filter_fn = id
else:
+
def filter_fn(row):
return tuple(
- id(item)
- if ent.use_id_for_hash
- else item
+ id(item) if ent.use_id_for_hash else item
for ent, item in zip(query._entities, row)
)
try:
- (process, labels) = \
- list(zip(*[
- query_entity.row_processor(query,
- context, cursor)
- for query_entity in query._entities
- ]))
+ (process, labels) = list(
+ zip(
+ *[
+ query_entity.row_processor(query, context, cursor)
+ for query_entity in query._entities
+ ]
+ )
+ )
if not single_entity:
- keyed_tuple = util.lightweight_named_tuple('result', labels)
+ keyed_tuple = util.lightweight_named_tuple("result", labels)
while True:
context.partials = {}
@@ -78,11 +81,12 @@ def instances(query, cursor, context):
proc = process[0]
rows = [proc(row) for row in fetch]
else:
- rows = [keyed_tuple([proc(row) for proc in process])
- for row in fetch]
+ rows = [
+ keyed_tuple([proc(row) for proc in process])
+ for row in fetch
+ ]
- for path, post_load in \
- context.post_load_paths.items():
+ for path, post_load in context.post_load_paths.items():
post_load.invoke(context, path)
if filtered:
@@ -113,19 +117,27 @@ def merge_result(querylib, query, iterator, load=True):
single_entity = len(query._entities) == 1
if single_entity:
if isinstance(query._entities[0], querylib._MapperEntity):
- result = [session._merge(
- attributes.instance_state(instance),
- attributes.instance_dict(instance),
- load=load, _recursive={}, _resolve_conflict_map={})
- for instance in iterator]
+ result = [
+ session._merge(
+ attributes.instance_state(instance),
+ attributes.instance_dict(instance),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+ for instance in iterator
+ ]
else:
result = list(iterator)
else:
- mapped_entities = [i for i, e in enumerate(query._entities)
- if isinstance(e, querylib._MapperEntity)]
+ mapped_entities = [
+ i
+ for i, e in enumerate(query._entities)
+ if isinstance(e, querylib._MapperEntity)
+ ]
result = []
keys = [ent._label_name for ent in query._entities]
- keyed_tuple = util.lightweight_named_tuple('result', keys)
+ keyed_tuple = util.lightweight_named_tuple("result", keys)
for row in iterator:
newrow = list(row)
for i in mapped_entities:
@@ -133,7 +145,10 @@ def merge_result(querylib, query, iterator, load=True):
newrow[i] = session._merge(
attributes.instance_state(newrow[i]),
attributes.instance_dict(newrow[i]),
- load=load, _recursive={}, _resolve_conflict_map={})
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
result.append(keyed_tuple(newrow))
return iter(result)
@@ -170,9 +185,9 @@ def get_from_identity(session, key, passive):
return None
-def load_on_ident(query, key,
- refresh_state=None, with_for_update=None,
- only_load_props=None):
+def load_on_ident(
+ query, key, refresh_state=None, with_for_update=None, only_load_props=None
+):
"""Load the given identity key from the database."""
if key is not None:
@@ -182,16 +197,23 @@ def load_on_ident(query, key,
ident = identity_token = None
return load_on_pk_identity(
- query, ident, refresh_state=refresh_state,
+ query,
+ ident,
+ refresh_state=refresh_state,
with_for_update=with_for_update,
only_load_props=only_load_props,
- identity_token=identity_token
+ identity_token=identity_token,
)
-def load_on_pk_identity(query, primary_key_identity,
- refresh_state=None, with_for_update=None,
- only_load_props=None, identity_token=None):
+def load_on_pk_identity(
+ query,
+ primary_key_identity,
+ refresh_state=None,
+ with_for_update=None,
+ only_load_props=None,
+ identity_token=None,
+):
"""Load the given primary key identity from the database."""
@@ -209,22 +231,28 @@ def load_on_pk_identity(query, primary_key_identity,
# None present in ident - turn those comparisons
# into "IS NULL"
if None in primary_key_identity:
- nones = set([
- _get_params[col].key for col, value in
- zip(mapper.primary_key, primary_key_identity)
- if value is None
- ])
- _get_clause = sql_util.adapt_criterion_to_null(
- _get_clause, nones)
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+ _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones)
_get_clause = q._adapt_clause(_get_clause, True, False)
q._criterion = _get_clause
- params = dict([
- (_get_params[primary_key].key, id_val)
- for id_val, primary_key
- in zip(primary_key_identity, mapper.primary_key)
- ])
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
q._params = params
@@ -243,7 +271,8 @@ def load_on_pk_identity(query, primary_key_identity,
version_check=version_check,
only_load_props=only_load_props,
refresh_state=refresh_state,
- identity_token=identity_token)
+ identity_token=identity_token,
+ )
q._order_by = None
try:
@@ -253,27 +282,31 @@ def load_on_pk_identity(query, primary_key_identity,
def _setup_entity_query(
- context, mapper, query_entity,
- path, adapter, column_collection,
- with_polymorphic=None, only_load_props=None,
- polymorphic_discriminator=None, **kw):
+ context,
+ mapper,
+ query_entity,
+ path,
+ adapter,
+ column_collection,
+ with_polymorphic=None,
+ only_load_props=None,
+ polymorphic_discriminator=None,
+ **kw
+):
if with_polymorphic:
poly_properties = mapper._iterate_polymorphic_properties(
- with_polymorphic)
+ with_polymorphic
+ )
else:
poly_properties = mapper._polymorphic_properties
quick_populators = {}
- path.set(
- context.attributes,
- "memoized_setups",
- quick_populators)
+ path.set(context.attributes, "memoized_setups", quick_populators)
for value in poly_properties:
- if only_load_props and \
- value.key not in only_load_props:
+ if only_load_props and value.key not in only_load_props:
continue
value.setup(
context,
@@ -286,9 +319,10 @@ def _setup_entity_query(
**kw
)
- if polymorphic_discriminator is not None and \
- polymorphic_discriminator \
- is not mapper.polymorphic_on:
+ if (
+ polymorphic_discriminator is not None
+ and polymorphic_discriminator is not mapper.polymorphic_on
+ ):
if adapter:
pd = adapter.columns[polymorphic_discriminator]
@@ -298,10 +332,16 @@ def _setup_entity_query(
def _instance_processor(
- mapper, context, result, path, adapter,
- only_load_props=None, refresh_state=None,
- polymorphic_discriminator=None,
- _polymorphic_from=None):
+ mapper,
+ context,
+ result,
+ path,
+ adapter,
+ only_load_props=None,
+ refresh_state=None,
+ polymorphic_discriminator=None,
+ _polymorphic_from=None,
+):
"""Produce a mapper level row processor callable
which processes rows into mapped instances."""
@@ -322,11 +362,11 @@ def _instance_processor(
props = mapper._prop_set
if only_load_props is not None:
- props = props.intersection(
- mapper._props[k] for k in only_load_props)
+ props = props.intersection(mapper._props[k] for k in only_load_props)
quick_populators = path.get(
- context.attributes, "memoized_setups", _none_set)
+ context.attributes, "memoized_setups", _none_set
+ )
for prop in props:
if prop in quick_populators:
@@ -334,7 +374,8 @@ def _instance_processor(
col = quick_populators[prop]
if col is _DEFER_FOR_STATE:
populators["new"].append(
- (prop.key, prop._deferred_column_loader))
+ (prop.key, prop._deferred_column_loader)
+ )
elif col is _SET_DEFERRED_EXPIRED:
# note that in this path, we are no longer
# searching in the result to see if the column might
@@ -366,14 +407,19 @@ def _instance_processor(
# will iterate through all of its columns
# to see if one fits
prop.create_row_processor(
- context, path, mapper, result, adapter, populators)
+ context, path, mapper, result, adapter, populators
+ )
else:
prop.create_row_processor(
- context, path, mapper, result, adapter, populators)
+ context, path, mapper, result, adapter, populators
+ )
propagate_options = context.propagate_options
- load_path = context.query._current_path + path \
- if context.query._current_path.path else path
+ load_path = (
+ context.query._current_path + path
+ if context.query._current_path.path
+ else path
+ )
session_identity_map = context.session.identity_map
@@ -391,18 +437,18 @@ def _instance_processor(
identity_token = context.identity_token
if not refresh_state and _polymorphic_from is not None:
- key = ('loader', path.path)
- if (
- key in context.attributes and
- context.attributes[key].strategy ==
- (('selectinload_polymorphic', True), )
+ key = ("loader", path.path)
+ if key in context.attributes and context.attributes[key].strategy == (
+ ("selectinload_polymorphic", True),
):
selectin_load_via = mapper._should_selectin_load(
- context.attributes[key].local_opts['entities'],
- _polymorphic_from)
+ context.attributes[key].local_opts["entities"],
+ _polymorphic_from,
+ )
else:
selectin_load_via = mapper._should_selectin_load(
- None, _polymorphic_from)
+ None, _polymorphic_from
+ )
if selectin_load_via and selectin_load_via is not _polymorphic_from:
# only_load_props goes w/ refresh_state only, and in a refresh
@@ -413,9 +459,13 @@ def _instance_processor(
callable_ = _load_subclass_via_in(context, path, selectin_load_via)
PostLoad.callable_for_path(
- context, load_path, selectin_load_via.mapper,
+ context,
+ load_path,
+ selectin_load_via.mapper,
+ selectin_load_via,
+ callable_,
selectin_load_via,
- callable_, selectin_load_via)
+ )
post_load = PostLoad.for_context(context, load_path, only_load_props)
@@ -425,8 +475,9 @@ def _instance_processor(
# super-rare condition; a refresh is being called
# on a non-instance-key instance; this is meant to only
# occur within a flush()
- refresh_identity_key = \
- mapper._identity_key_from_state(refresh_state)
+ refresh_identity_key = mapper._identity_key_from_state(
+ refresh_state
+ )
else:
refresh_identity_key = None
@@ -452,7 +503,7 @@ def _instance_processor(
identitykey = (
identity_class,
tuple([row[column] for column in pk_cols]),
- identity_token
+ identity_token,
)
instance = session_identity_map.get(identitykey)
@@ -507,8 +558,16 @@ def _instance_processor(
state.load_path = load_path
_populate_full(
- context, row, state, dict_, isnew, load_path,
- loaded_instance, populate_existing, populators)
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ populate_existing,
+ populators,
+ )
if isnew:
if loaded_instance:
@@ -518,7 +577,8 @@ def _instance_processor(
loaded_as_persistent(context.session, state.obj())
elif refresh_evt:
state.manager.dispatch.refresh(
- state, context, only_load_props)
+ state, context, only_load_props
+ )
if populate_existing or state.modified:
if refresh_state and only_load_props:
@@ -542,13 +602,19 @@ def _instance_processor(
# and add to the "context.partials" collection.
to_load = _populate_partial(
- context, row, state, dict_, isnew, load_path,
- unloaded, populators)
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ unloaded,
+ populators,
+ )
if isnew:
if refresh_evt:
- state.manager.dispatch.refresh(
- state, context, to_load)
+ state.manager.dispatch.refresh(state, context, to_load)
state._commit(dict_, to_load)
@@ -561,8 +627,14 @@ def _instance_processor(
# if we are doing polymorphic, dispatch to a different _instance()
# method specific to the subclass mapper
_instance = _decorate_polymorphic_switch(
- _instance, context, mapper, result, path,
- polymorphic_discriminator, adapter)
+ _instance,
+ context,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+ )
return _instance
@@ -581,14 +653,13 @@ def _load_subclass_via_in(context, path, entity):
orig_query = context.query
q2 = q._with_lazyload_options(
- (enable_opt, ) + orig_query._with_options + (disable_opt, ),
- path.parent, cache_path=path
+ (enable_opt,) + orig_query._with_options + (disable_opt,),
+ path.parent,
+ cache_path=path,
)
if orig_query._populate_existing:
- q2.add_criteria(
- lambda q: q.populate_existing()
- )
+ q2.add_criteria(lambda q: q.populate_existing())
q2(context.session).params(
primary_keys=[
@@ -601,8 +672,16 @@ def _load_subclass_via_in(context, path, entity):
def _populate_full(
- context, row, state, dict_, isnew, load_path,
- loaded_instance, populate_existing, populators):
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ populate_existing,
+ populators,
+):
if isnew:
# first time we are seeing a row with this identity.
state.runid = context.runid
@@ -650,8 +729,8 @@ def _populate_full(
def _populate_partial(
- context, row, state, dict_, isnew, load_path,
- unloaded, populators):
+ context, row, state, dict_, isnew, load_path, unloaded, populators
+):
if not isnew:
to_load = context.partials[state]
@@ -693,19 +772,32 @@ def _validate_version_id(mapper, state, dict_, row, adapter):
if adapter:
version_id_col = adapter.columns[version_id_col]
- if mapper._get_state_attr_by_column(
- state, dict_, mapper.version_id_col) != row[version_id_col]:
+ if (
+ mapper._get_state_attr_by_column(state, dict_, mapper.version_id_col)
+ != row[version_id_col]
+ ):
raise orm_exc.StaleDataError(
"Instance '%s' has version id '%s' which "
"does not match database-loaded version id '%s'."
- % (state_str(state), mapper._get_state_attr_by_column(
- state, dict_, mapper.version_id_col),
- row[version_id_col]))
+ % (
+ state_str(state),
+ mapper._get_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ ),
+ row[version_id_col],
+ )
+ )
def _decorate_polymorphic_switch(
- instance_fn, context, mapper, result, path,
- polymorphic_discriminator, adapter):
+ instance_fn,
+ context,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+):
if polymorphic_discriminator is not None:
polymorphic_on = polymorphic_discriminator
else:
@@ -721,19 +813,22 @@ def _decorate_polymorphic_switch(
sub_mapper = mapper.polymorphic_map[discriminator]
except KeyError:
raise AssertionError(
- "No such polymorphic_identity %r is defined" %
- discriminator)
+ "No such polymorphic_identity %r is defined" % discriminator
+ )
else:
if sub_mapper is mapper:
return None
return _instance_processor(
- sub_mapper, context, result,
- path, adapter, _polymorphic_from=mapper)
+ sub_mapper,
+ context,
+ result,
+ path,
+ adapter,
+ _polymorphic_from=mapper,
+ )
- polymorphic_instances = util.PopulateDict(
- configure_subclass_mapper
- )
+ polymorphic_instances = util.PopulateDict(configure_subclass_mapper)
def polymorphic_instance(row):
discriminator = row[polymorphic_on]
@@ -742,6 +837,7 @@ def _decorate_polymorphic_switch(
if _instance:
return _instance(row)
return instance_fn(row)
+
return polymorphic_instance
@@ -749,7 +845,8 @@ class PostLoad(object):
"""Track loaders and states for "post load" operations.
"""
- __slots__ = 'loaders', 'states', 'load_keys'
+
+ __slots__ = "loaders", "states", "load_keys"
def __init__(self):
self.loaders = {}
@@ -770,8 +867,7 @@ class PostLoad(object):
for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
states = [
(state, overwrite)
- for state, overwrite
- in self.states.items()
+ for state, overwrite in self.states.items()
if state.manager.mapper.isa(limit_to_mapper)
]
if states:
@@ -787,13 +883,15 @@ class PostLoad(object):
@classmethod
def path_exists(self, context, path, key):
- return path.path in context.post_load_paths and \
- key in context.post_load_paths[path.path].loaders
+ return (
+ path.path in context.post_load_paths
+ and key in context.post_load_paths[path.path].loaders
+ )
@classmethod
def callable_for_path(
- cls, context, path, limit_to_mapper, token,
- loader_callable, *arg, **kw):
+ cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw
+ ):
if path.path in context.post_load_paths:
pl = context.post_load_paths[path.path]
else:
@@ -809,8 +907,8 @@ def load_scalar_attributes(mapper, state, attribute_names):
if not session:
raise orm_exc.DetachedInstanceError(
"Instance %s is not bound to a Session; "
- "attribute refresh operation cannot proceed" %
- (state_str(state)))
+ "attribute refresh operation cannot proceed" % (state_str(state))
+ )
has_key = bool(state.key)
@@ -833,13 +931,12 @@ def load_scalar_attributes(mapper, state, attribute_names):
statement = mapper._optimized_get_statement(state, attribute_names)
if statement is not None:
result = load_on_ident(
- session.query(mapper).
- options(
- strategy_options.Load(mapper).undefer("*")
- ).from_statement(statement),
+ session.query(mapper)
+ .options(strategy_options.Load(mapper).undefer("*"))
+ .from_statement(statement),
None,
only_load_props=attribute_names,
- refresh_state=state
+ refresh_state=state,
)
if result is False:
@@ -850,30 +947,34 @@ def load_scalar_attributes(mapper, state, attribute_names):
# object is becoming persistent but hasn't yet been assigned
# an identity_key.
# check here to ensure we have the attrs we need.
- pk_attrs = [mapper._columntoproperty[col].key
- for col in mapper.primary_key]
+ pk_attrs = [
+ mapper._columntoproperty[col].key for col in mapper.primary_key
+ ]
if state.expired_attributes.intersection(pk_attrs):
raise sa_exc.InvalidRequestError(
"Instance %s cannot be refreshed - it's not "
" persistent and does not "
- "contain a full primary key." % state_str(state))
+ "contain a full primary key." % state_str(state)
+ )
identity_key = mapper._identity_key_from_state(state)
- if (_none_set.issubset(identity_key) and
- not mapper.allow_partial_pks) or \
- _none_set.issuperset(identity_key):
+ if (
+ _none_set.issubset(identity_key) and not mapper.allow_partial_pks
+ ) or _none_set.issuperset(identity_key):
util.warn_limited(
"Instance %s to be refreshed doesn't "
"contain a full primary key - can't be refreshed "
"(and shouldn't be expired, either).",
- state_str(state))
+ state_str(state),
+ )
return
result = load_on_ident(
session.query(mapper),
identity_key,
refresh_state=state,
- only_load_props=attribute_names)
+ only_load_props=attribute_names,
+ )
# if instance is pending, a refresh operation
# may not complete (even if PK attributes are assigned)