diff options
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 54 |
1 files changed, 31 insertions, 23 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 00a7d55e5..cbfb0c1d6 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -299,14 +299,14 @@ class SessionTransaction(object): self.session._expunge_state(s) for s in self.session.identity_map.all_states(): - _expire_state(s, None) + _expire_state(s, None, instance_dict=self.session.identity_map) def _remove_snapshot(self): assert self._is_transaction_boundary if not self.nested and self.session.expire_on_commit: for s in self.session.identity_map.all_states(): - _expire_state(s, None) + _expire_state(s, None, instance_dict=self.session.identity_map) def _connection_for_bind(self, bind): self._assert_is_active() @@ -900,7 +900,7 @@ class Session(object): def _finalize_loaded(self, states): for state, dict_ in states.items(): - state.commit_all(dict_) + state.commit_all(dict_, self.identity_map) def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. @@ -935,7 +935,7 @@ class Session(object): """Expires all persistent instances within this Session.""" for state in self.identity_map.all_states(): - _expire_state(state, None) + _expire_state(state, None, instance_dict=self.identity_map) def expire(self, instance, attribute_names=None): """Expire the attributes on an instance. @@ -956,14 +956,14 @@ class Session(object): raise exc.UnmappedInstanceError(instance) self._validate_persistent(state) if attribute_names: - _expire_state(state, attribute_names=attribute_names) + _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map) else: # pre-fetch the full cascade since the expire is going to # remove associations cascaded = list(_cascade_state_iterator('refresh-expire', state)) - _expire_state(state, None) + _expire_state(state, None, instance_dict=self.identity_map) for (state, m, o) in cascaded: - _expire_state(state, None) + _expire_state(state, None, instance_dict=self.identity_map) def prune(self): """Remove unreferenced instances cached in the identity map. @@ -1022,8 +1022,8 @@ class Session(object): state.key = instance_key self.identity_map.replace(state) - state.commit_all(state.dict) - + state.commit_all(state.dict, self.identity_map) + # remove from new last, might be the last strong ref if state in self._new: if self._enable_transaction_accounting and self.transaction: @@ -1211,7 +1211,7 @@ class Session(object): prop.merge(self, instance, merged, dont_load, _recursive) if dont_load: - attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history + attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map) # remove any history if new_instance: merged_state._run_on_load(merged) @@ -1360,10 +1360,9 @@ class Session(object): not self._deleted and not self._new): return - dirty = self._dirty_states if not dirty and not self._deleted and not self._new: - self.identity_map.modified = False + self.identity_map._modified.clear() return flush_context = UOWTransaction(self) @@ -1389,15 +1388,19 @@ class Session(object): raise exc.UnmappedInstanceError(o) objset.add(state) else: - # or just everything - objset = set(self.identity_map.all_states()).union(new) + objset = None # store objects whose fate has been decided processed = set() # put all saves/updates into the flush context. detect top-level # orphans and throw them into deleted. - for state in new.union(dirty).intersection(objset).difference(deleted): + if objset: + proc = new.union(dirty).intersection(objset).difference(deleted) + else: + proc = new.union(dirty).difference(deleted) + + for state in proc: is_orphan = _state_mapper(state)._is_orphan(state) if is_orphan and not _state_has_identity(state): path = ", nor ".join( @@ -1413,7 +1416,11 @@ class Session(object): processed.add(state) # put all remaining deletes into the flush context. - for state in deleted.intersection(objset).difference(processed): + if objset: + proc = deleted.intersection(objset).difference(processed) + else: + proc = deleted.difference(processed) + for state in proc: flush_context.register_object(state, isdelete=True) if len(flush_context.tasks) == 0: @@ -1433,9 +1440,13 @@ class Session(object): flush_context.finalize_flush_changes() - if not objects: - self.identity_map.modified = False - + # useful assertions: + #if not objects: + # assert not self.identity_map._modified + #else: + # assert self.identity_map._modified == self.identity_map._modified.difference(objects) + #self.identity_map._modified.clear() + for ext in self.extensions: ext.after_flush_postexec(self, flush_context) @@ -1484,10 +1495,7 @@ class Session(object): those that were possibly deleted. """ - return util.IdentitySet( - [state - for state in self.identity_map.all_states() - if state.modified]) + return self.identity_map._dirty_states() @property def dirty(self): |