diff options
Diffstat (limited to 'lib/sqlalchemy/orm/state.py')
-rw-r--r-- | lib/sqlalchemy/orm/state.py | 233 |
1 files changed, 129 insertions, 104 deletions
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 944dc8177..c36d8817b 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -18,8 +18,16 @@ from .. import inspection from .. import exc as sa_exc from . import exc as orm_exc, interfaces from .path_registry import PathRegistry -from .base import PASSIVE_NO_RESULT, SQL_OK, NEVER_SET, ATTR_WAS_SET, \ - NO_VALUE, PASSIVE_NO_INITIALIZE, INIT_OK, PASSIVE_OFF +from .base import ( + PASSIVE_NO_RESULT, + SQL_OK, + NEVER_SET, + ATTR_WAS_SET, + NO_VALUE, + PASSIVE_NO_INITIALIZE, + INIT_OK, + PASSIVE_OFF, +) from . import base @@ -106,10 +114,7 @@ class InstanceState(interfaces.InspectionAttrInfo): """ return util.ImmutableProperties( - dict( - (key, AttributeState(self, key)) - for key in self.manager - ) + dict((key, AttributeState(self, key)) for key in self.manager) ) @property @@ -121,8 +126,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is None and \ - not self._attached + return self.key is None and not self._attached @property def pending(self): @@ -134,8 +138,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is None and \ - self._attached + return self.key is None and self._attached @property def deleted(self): @@ -164,8 +167,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is not None and \ - self._attached and self._deleted + return self.key is not None and self._attached and self._deleted @property def was_deleted(self): @@ -210,8 +212,7 @@ class InstanceState(interfaces.InspectionAttrInfo): :ref:`session_object_states` """ - return self.key is not None and \ - self._attached and not self._deleted + return self.key is not None and self._attached and not self._deleted @property def detached(self): @@ -227,8 +228,10 @@ class InstanceState(interfaces.InspectionAttrInfo): @property @util.dependencies("sqlalchemy.orm.session") def _attached(self, sessionlib): - return self.session_id is not None and \ - self.session_id in sessionlib._sessions + return ( + self.session_id is not None + and self.session_id in sessionlib._sessions + ) def _track_last_known_value(self, key): """Track the last known value of a particular key after expiration @@ -323,14 +326,14 @@ class InstanceState(interfaces.InspectionAttrInfo): @classmethod def _detach_states(self, states, session, to_transient=False): - persistent_to_detached = \ + persistent_to_detached = ( session.dispatch.persistent_to_detached or None - deleted_to_detached = \ - session.dispatch.deleted_to_detached or None - pending_to_transient = \ - session.dispatch.pending_to_transient or None - persistent_to_transient = \ + ) + deleted_to_detached = session.dispatch.deleted_to_detached or None + pending_to_transient = session.dispatch.pending_to_transient or None + persistent_to_transient = ( session.dispatch.persistent_to_transient or None + ) for state in states: deleted = state._deleted @@ -448,23 +451,33 @@ class InstanceState(interfaces.InspectionAttrInfo): return self._pending_mutations[key] def __getstate__(self): - state_dict = {'instance': self.obj()} + state_dict = {"instance": self.obj()} state_dict.update( - (k, self.__dict__[k]) for k in ( - 'committed_state', '_pending_mutations', 'modified', - 'expired', 'callables', 'key', 'parents', 'load_options', - 'class_', 'expired_attributes', 'info' - ) if k in self.__dict__ + (k, self.__dict__[k]) + for k in ( + "committed_state", + "_pending_mutations", + "modified", + "expired", + "callables", + "key", + "parents", + "load_options", + "class_", + "expired_attributes", + "info", + ) + if k in self.__dict__ ) if self.load_path: - state_dict['load_path'] = self.load_path.serialize() + state_dict["load_path"] = self.load_path.serialize() - state_dict['manager'] = self.manager._serialize(self, state_dict) + state_dict["manager"] = self.manager._serialize(self, state_dict) return state_dict def __setstate__(self, state_dict): - inst = state_dict['instance'] + inst = state_dict["instance"] if inst is not None: self.obj = weakref.ref(inst, self._cleanup) self.class_ = inst.__class__ @@ -473,20 +486,20 @@ class InstanceState(interfaces.InspectionAttrInfo): # due to storage of state in "parents". "class_" # also new. self.obj = None - self.class_ = state_dict['class_'] - - self.committed_state = state_dict.get('committed_state', {}) - self._pending_mutations = state_dict.get('_pending_mutations', {}) - self.parents = state_dict.get('parents', {}) - self.modified = state_dict.get('modified', False) - self.expired = state_dict.get('expired', False) - if 'info' in state_dict: - self.info.update(state_dict['info']) - if 'callables' in state_dict: - self.callables = state_dict['callables'] + self.class_ = state_dict["class_"] + + self.committed_state = state_dict.get("committed_state", {}) + self._pending_mutations = state_dict.get("_pending_mutations", {}) + self.parents = state_dict.get("parents", {}) + self.modified = state_dict.get("modified", False) + self.expired = state_dict.get("expired", False) + if "info" in state_dict: + self.info.update(state_dict["info"]) + if "callables" in state_dict: + self.callables = state_dict["callables"] try: - self.expired_attributes = state_dict['expired_attributes'] + self.expired_attributes = state_dict["expired_attributes"] except KeyError: self.expired_attributes = set() # 0.9 and earlier compat @@ -495,30 +508,31 @@ class InstanceState(interfaces.InspectionAttrInfo): self.expired_attributes.add(k) del self.callables[k] else: - if 'expired_attributes' in state_dict: - self.expired_attributes = state_dict['expired_attributes'] + if "expired_attributes" in state_dict: + self.expired_attributes = state_dict["expired_attributes"] else: self.expired_attributes = set() - self.__dict__.update([ - (k, state_dict[k]) for k in ( - 'key', 'load_options' - ) if k in state_dict - ]) + self.__dict__.update( + [ + (k, state_dict[k]) + for k in ("key", "load_options") + if k in state_dict + ] + ) if self.key: try: self.identity_token = self.key[2] except IndexError: # 1.1 and earlier compat before identity_token assert len(self.key) == 2 - self.key = self.key + (None, ) + self.key = self.key + (None,) self.identity_token = None - if 'load_path' in state_dict: - self.load_path = PathRegistry.\ - deserialize(state_dict['load_path']) + if "load_path" in state_dict: + self.load_path = PathRegistry.deserialize(state_dict["load_path"]) - state_dict['manager'](self, inst, state_dict) + state_dict["manager"](self, inst, state_dict) def _reset(self, dict_, key): """Remove the given attribute and any @@ -532,25 +546,29 @@ class InstanceState(interfaces.InspectionAttrInfo): self.callables.pop(key, None) def _copy_callables(self, from_): - if 'callables' in from_.__dict__: + if "callables" in from_.__dict__: self.callables = dict(from_.callables) @classmethod def _instance_level_callable_processor(cls, manager, fn, key): impl = manager[key].impl if impl.collection: + def _set_callable(state, dict_, row): - if 'callables' not in state.__dict__: + if "callables" not in state.__dict__: state.callables = {} old = dict_.pop(key, None) if old is not None: impl._invalidate_collection(old) state.callables[key] = fn + else: + def _set_callable(state, dict_, row): - if 'callables' not in state.__dict__: + if "callables" not in state.__dict__: state.callables = {} state.callables[key] = fn + return _set_callable def _expire(self, dict_, modified_set): @@ -563,15 +581,18 @@ class InstanceState(interfaces.InspectionAttrInfo): self._strong_obj = None - if '_pending_mutations' in self.__dict__: - del self.__dict__['_pending_mutations'] + if "_pending_mutations" in self.__dict__: + del self.__dict__["_pending_mutations"] - if 'parents' in self.__dict__: - del self.__dict__['parents'] + if "parents" in self.__dict__: + del self.__dict__["parents"] self.expired_attributes.update( - [impl.key for impl in self.manager._scalar_loader_impls - if impl.expire_missing or impl.key in dict_] + [ + impl.key + for impl in self.manager._scalar_loader_impls + if impl.expire_missing or impl.key in dict_ + ] ) if self.callables: @@ -584,8 +605,7 @@ class InstanceState(interfaces.InspectionAttrInfo): if self._last_known_values: self._last_known_values.update( - (k, dict_[k]) for k in self._last_known_values - if k in dict_ + (k, dict_[k]) for k in self._last_known_values if k in dict_ ) for key in self.manager._all_key_set.intersection(dict_): @@ -594,17 +614,14 @@ class InstanceState(interfaces.InspectionAttrInfo): self.manager.dispatch.expire(self, None) def _expire_attributes(self, dict_, attribute_names, no_loader=False): - pending = self.__dict__.get('_pending_mutations', None) + pending = self.__dict__.get("_pending_mutations", None) callables = self.callables for key in attribute_names: impl = self.manager[key].impl if impl.accepts_scalar_loader: - if no_loader and ( - impl.callable_ or - key in callables - ): + if no_loader and (impl.callable_ or key in callables): continue self.expired_attributes.add(key) @@ -614,8 +631,11 @@ class InstanceState(interfaces.InspectionAttrInfo): if impl.collection and old is not NO_VALUE: impl._invalidate_collection(old) - if self._last_known_values and key in self._last_known_values \ - and old is not NO_VALUE: + if ( + self._last_known_values + and key in self._last_known_values + and old is not NO_VALUE + ): self._last_known_values[key] = old self.committed_state.pop(key, None) @@ -634,8 +654,7 @@ class InstanceState(interfaces.InspectionAttrInfo): if not passive & SQL_OK: return PASSIVE_NO_RESULT - toload = self.expired_attributes.\ - intersection(self.unmodified) + toload = self.expired_attributes.intersection(self.unmodified) self.manager.deferred_scalar_loader(self, toload) @@ -656,9 +675,11 @@ class InstanceState(interfaces.InspectionAttrInfo): def unmodified_intersection(self, keys): """Return self.unmodified.intersection(keys).""" - - return set(keys).intersection(self.manager).\ - difference(self.committed_state) + return ( + set(keys) + .intersection(self.manager) + .difference(self.committed_state) + ) @property def unloaded(self): @@ -668,9 +689,11 @@ class InstanceState(interfaces.InspectionAttrInfo): was never populated or modified. """ - return set(self.manager).\ - difference(self.committed_state).\ - difference(self.dict) + return ( + set(self.manager) + .difference(self.committed_state) + .difference(self.dict) + ) @property def unloaded_expirable(self): @@ -681,13 +704,16 @@ class InstanceState(interfaces.InspectionAttrInfo): """ return self.unloaded.intersection( - attr for attr in self.manager - if self.manager[attr].impl.expire_missing) + attr + for attr in self.manager + if self.manager[attr].impl.expire_missing + ) @property def _unloaded_non_object(self): return self.unloaded.intersection( - attr for attr in self.manager + attr + for attr in self.manager if self.manager[attr].impl.accepts_scalar_loader ) @@ -695,14 +721,16 @@ class InstanceState(interfaces.InspectionAttrInfo): return None def _modified_event( - self, dict_, attr, previous, collection=False, is_userland=False): + self, dict_, attr, previous, collection=False, is_userland=False + ): if attr: if not attr.send_modified_events: return if is_userland and attr.key not in dict_: raise sa_exc.InvalidRequestError( "Can't flag attribute '%s' modified; it's not present in " - "the object state" % attr.key) + "the object state" % attr.key + ) if attr.key not in self.committed_state or is_userland: if collection: if previous is NEVER_SET: @@ -718,8 +746,7 @@ class InstanceState(interfaces.InspectionAttrInfo): # assert self._strong_obj is None or self.modified - if (self.session_id and self._strong_obj is None) \ - or not self.modified: + if (self.session_id and self._strong_obj is None) or not self.modified: self.modified = True instance_dict = self._instance_dict() if instance_dict: @@ -737,10 +764,8 @@ class InstanceState(interfaces.InspectionAttrInfo): "Can't emit change event for attribute '%s' - " "parent object of type %s has been garbage " "collected." - % ( - self.manager[attr.key], - base.state_class_str(self) - )) + % (self.manager[attr.key], base.state_class_str(self)) + ) def _commit(self, dict_, keys): """Commit attributes. @@ -758,17 +783,18 @@ class InstanceState(interfaces.InspectionAttrInfo): self.expired = False self.expired_attributes.difference_update( - set(keys).intersection(dict_)) + set(keys).intersection(dict_) + ) # the per-keys commit removes object-level callables, # while that of commit_all does not. it's not clear # if this behavior has a clear rationale, however tests do # ensure this is what it does. if self.callables: - for key in set(self.callables).\ - intersection(keys).\ - intersection(dict_): - del self.callables[key] + for key in ( + set(self.callables).intersection(keys).intersection(dict_) + ): + del self.callables[key] def _commit_all(self, dict_, instance_dict=None): """commit all attributes unconditionally. @@ -797,8 +823,8 @@ class InstanceState(interfaces.InspectionAttrInfo): state.committed_state.clear() - if '_pending_mutations' in state_dict: - del state_dict['_pending_mutations'] + if "_pending_mutations" in state_dict: + del state_dict["_pending_mutations"] state.expired_attributes.difference_update(dict_) @@ -848,7 +874,8 @@ class AttributeState(object): """ return self.state.manager[self.key].__get__( - self.state.obj(), self.state.class_) + self.state.obj(), self.state.class_ + ) @property def history(self): @@ -866,8 +893,7 @@ class AttributeState(object): :func:`.attributes.get_history` - underlying function """ - return self.state.get_history(self.key, - PASSIVE_NO_INITIALIZE) + return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE) def load_history(self): """Return the current pre-flush change history for @@ -885,8 +911,7 @@ class AttributeState(object): .. versionadded:: 0.9.0 """ - return self.state.get_history(self.key, - PASSIVE_OFF ^ INIT_OK) + return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK) class PendingCollection(object): |