summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/session.py')
-rw-r--r--lib/sqlalchemy/orm/session.py54
1 files changed, 31 insertions, 23 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 00a7d55e5..cbfb0c1d6 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -299,14 +299,14 @@ class SessionTransaction(object):
self.session._expunge_state(s)
for s in self.session.identity_map.all_states():
- _expire_state(s, None)
+ _expire_state(s, None, instance_dict=self.session.identity_map)
def _remove_snapshot(self):
assert self._is_transaction_boundary
if not self.nested and self.session.expire_on_commit:
for s in self.session.identity_map.all_states():
- _expire_state(s, None)
+ _expire_state(s, None, instance_dict=self.session.identity_map)
def _connection_for_bind(self, bind):
self._assert_is_active()
@@ -900,7 +900,7 @@ class Session(object):
def _finalize_loaded(self, states):
for state, dict_ in states.items():
- state.commit_all(dict_)
+ state.commit_all(dict_, self.identity_map)
def refresh(self, instance, attribute_names=None):
"""Refresh the attributes on the given instance.
@@ -935,7 +935,7 @@ class Session(object):
"""Expires all persistent instances within this Session."""
for state in self.identity_map.all_states():
- _expire_state(state, None)
+ _expire_state(state, None, instance_dict=self.identity_map)
def expire(self, instance, attribute_names=None):
"""Expire the attributes on an instance.
@@ -956,14 +956,14 @@ class Session(object):
raise exc.UnmappedInstanceError(instance)
self._validate_persistent(state)
if attribute_names:
- _expire_state(state, attribute_names=attribute_names)
+ _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map)
else:
# pre-fetch the full cascade since the expire is going to
# remove associations
cascaded = list(_cascade_state_iterator('refresh-expire', state))
- _expire_state(state, None)
+ _expire_state(state, None, instance_dict=self.identity_map)
for (state, m, o) in cascaded:
- _expire_state(state, None)
+ _expire_state(state, None, instance_dict=self.identity_map)
def prune(self):
"""Remove unreferenced instances cached in the identity map.
@@ -1022,8 +1022,8 @@ class Session(object):
state.key = instance_key
self.identity_map.replace(state)
- state.commit_all(state.dict)
-
+ state.commit_all(state.dict, self.identity_map)
+
# remove from new last, might be the last strong ref
if state in self._new:
if self._enable_transaction_accounting and self.transaction:
@@ -1211,7 +1211,7 @@ class Session(object):
prop.merge(self, instance, merged, dont_load, _recursive)
if dont_load:
- attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history
+ attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map) # remove any history
if new_instance:
merged_state._run_on_load(merged)
@@ -1360,10 +1360,9 @@ class Session(object):
not self._deleted and not self._new):
return
-
dirty = self._dirty_states
if not dirty and not self._deleted and not self._new:
- self.identity_map.modified = False
+ self.identity_map._modified.clear()
return
flush_context = UOWTransaction(self)
@@ -1389,15 +1388,19 @@ class Session(object):
raise exc.UnmappedInstanceError(o)
objset.add(state)
else:
- # or just everything
- objset = set(self.identity_map.all_states()).union(new)
+ objset = None
# store objects whose fate has been decided
processed = set()
# put all saves/updates into the flush context. detect top-level
# orphans and throw them into deleted.
- for state in new.union(dirty).intersection(objset).difference(deleted):
+ if objset:
+ proc = new.union(dirty).intersection(objset).difference(deleted)
+ else:
+ proc = new.union(dirty).difference(deleted)
+
+ for state in proc:
is_orphan = _state_mapper(state)._is_orphan(state)
if is_orphan and not _state_has_identity(state):
path = ", nor ".join(
@@ -1413,7 +1416,11 @@ class Session(object):
processed.add(state)
# put all remaining deletes into the flush context.
- for state in deleted.intersection(objset).difference(processed):
+ if objset:
+ proc = deleted.intersection(objset).difference(processed)
+ else:
+ proc = deleted.difference(processed)
+ for state in proc:
flush_context.register_object(state, isdelete=True)
if len(flush_context.tasks) == 0:
@@ -1433,9 +1440,13 @@ class Session(object):
flush_context.finalize_flush_changes()
- if not objects:
- self.identity_map.modified = False
-
+ # useful assertions:
+ #if not objects:
+ # assert not self.identity_map._modified
+ #else:
+ # assert self.identity_map._modified == self.identity_map._modified.difference(objects)
+ #self.identity_map._modified.clear()
+
for ext in self.extensions:
ext.after_flush_postexec(self, flush_context)
@@ -1484,10 +1495,7 @@ class Session(object):
those that were possibly deleted.
"""
- return util.IdentitySet(
- [state
- for state in self.identity_map.all_states()
- if state.modified])
+ return self.identity_map._dirty_states()
@property
def dirty(self):