summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/orm/session.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 9b0c31e6a..faa9e5a83 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -209,6 +209,7 @@ class SessionTransaction(object):
self._new = self._parent._new
self._deleted = self._parent._deleted
self._dirty = self._parent._dirty
+ self._key_switches = self._parent._key_switches
return
if not self.session._flushing:
@@ -217,6 +218,7 @@ class SessionTransaction(object):
self._new = weakref.WeakKeyDictionary()
self._deleted = weakref.WeakKeyDictionary()
self._dirty = weakref.WeakKeyDictionary()
+ self._key_switches = weakref.WeakKeyDictionary()
def _restore_snapshot(self, dirty_only=False):
assert self._is_transaction_boundary
@@ -226,11 +228,16 @@ class SessionTransaction(object):
if s.key:
del s.key
+ for s, (oldkey, newkey) in self._key_switches.items():
+ self.session.identity_map.discard(s)
+ s.key = oldkey
+ self.session.identity_map.replace(s)
+
for s in set(self._deleted).union(self.session._deleted):
if s.deleted:
#assert s in self._deleted
del s.deleted
- self.session._update_impl(s)
+ self.session._update_impl(s, discard_existing=True)
assert not self.session._deleted
@@ -1280,6 +1287,11 @@ class Session(_SessionClassMethods):
# state has already replaced this one in the identity
# map (see test/orm/test_naturalpks.py ReversePKsTest)
self.identity_map.discard(state)
+ if state in self.transaction._key_switches:
+ orig_key = self.transaction._key_switches[0]
+ else:
+ orig_key = state.key
+ self.transaction._key_switches[state] = (orig_key, instance_key)
state.key = instance_key
self.identity_map.replace(state)
@@ -1558,7 +1570,7 @@ class Session(_SessionClassMethods):
state.insert_order = len(self._new)
self._attach(state)
- def _update_impl(self, state):
+ def _update_impl(self, state, discard_existing=False):
if (self.identity_map.contains_state(state) and
state not in self._deleted):
return
@@ -1576,7 +1588,10 @@ class Session(_SessionClassMethods):
)
self._before_attach(state)
self._deleted.pop(state, None)
- self.identity_map.add(state)
+ if discard_existing:
+ self.identity_map.replace(state)
+ else:
+ self.identity_map.add(state)
self._attach(state)
def _save_or_update_impl(self, state):