diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/session.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 9b0c31e6a..faa9e5a83 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -209,6 +209,7 @@ class SessionTransaction(object): self._new = self._parent._new self._deleted = self._parent._deleted self._dirty = self._parent._dirty + self._key_switches = self._parent._key_switches return if not self.session._flushing: @@ -217,6 +218,7 @@ class SessionTransaction(object): self._new = weakref.WeakKeyDictionary() self._deleted = weakref.WeakKeyDictionary() self._dirty = weakref.WeakKeyDictionary() + self._key_switches = weakref.WeakKeyDictionary() def _restore_snapshot(self, dirty_only=False): assert self._is_transaction_boundary @@ -226,11 +228,16 @@ class SessionTransaction(object): if s.key: del s.key + for s, (oldkey, newkey) in self._key_switches.items(): + self.session.identity_map.discard(s) + s.key = oldkey + self.session.identity_map.replace(s) + for s in set(self._deleted).union(self.session._deleted): if s.deleted: #assert s in self._deleted del s.deleted - self.session._update_impl(s) + self.session._update_impl(s, discard_existing=True) assert not self.session._deleted @@ -1280,6 +1287,11 @@ class Session(_SessionClassMethods): # state has already replaced this one in the identity # map (see test/orm/test_naturalpks.py ReversePKsTest) self.identity_map.discard(state) + if state in self.transaction._key_switches: + orig_key = self.transaction._key_switches[0] + else: + orig_key = state.key + self.transaction._key_switches[state] = (orig_key, instance_key) state.key = instance_key self.identity_map.replace(state) @@ -1558,7 +1570,7 @@ class Session(_SessionClassMethods): state.insert_order = len(self._new) self._attach(state) - def _update_impl(self, state): + def _update_impl(self, state, discard_existing=False): if (self.identity_map.contains_state(state) and state not in self._deleted): return @@ -1576,7 +1588,10 @@ class Session(_SessionClassMethods): ) self._before_attach(state) self._deleted.pop(state, None) - self.identity_map.add(state) + if discard_existing: + self.identity_map.replace(state) + else: + self.identity_map.add(state) self._attach(state) def _save_or_update_impl(self, state): |