summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/orm/properties.py36
-rw-r--r--lib/sqlalchemy/orm/session.py6
-rw-r--r--test/orm/merge.py10
3 files changed, 33 insertions, 19 deletions
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index db43a8e27..4fd9a3e9b 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -38,7 +38,7 @@ class SynonymProperty(MapperProperty):
return s
return getattr(obj, self.name)
setattr(self.parent.class_, self.key, SynonymProp())
- def merge(self, session, source, dest):
+ def merge(self, session, source, dest, _recursive):
pass
class ColumnProperty(StrategizedProperty):
@@ -61,7 +61,7 @@ class ColumnProperty(StrategizedProperty):
setattr(object, self.key, value)
def get_history(self, obj, passive=False):
return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive)
- def merge(self, session, source, dest):
+ def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
def compare(self, value):
return self.columns[0] == value
@@ -127,20 +127,26 @@ class PropertyLoader(StrategizedProperty):
def __str__(self):
return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper)
- def merge(self, session, source, dest):
- if not "merge" in self.cascade:
+ def merge(self, session, source, dest, _recursive):
+ if not "merge" in self.cascade or source in _recursive:
return
- childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
- if childlist is None:
- return
- if self.uselist:
- # sets a blank list according to the correct list class
- dest_list = getattr(self.parent.class_, self.key).initialize(dest)
- for current in list(childlist):
- dest_list.append(session.merge(current))
- else:
- setattr(dest, self.key, session.merge(current))
-
+ _recursive.add(source)
+ try:
+ childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True)
+ if childlist is None:
+ return
+ if self.uselist:
+ # sets a blank list according to the correct list class
+ dest_list = getattr(self.parent.class_, self.key).initialize(dest)
+ for current in list(childlist):
+ dest_list.append(session.merge(current, _recursive=_recursive))
+ else:
+ current = list(childlist)[0]
+ if current is not None:
+ setattr(dest, self.key, session.merge(current, _recursive=_recursive))
+ finally:
+ _recursive.remove(source)
+
def cascade_iterator(self, type, object, recursive, halt_on=None):
if not type in self.cascade:
return
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 829220688..f2a718177 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -323,7 +323,7 @@ class Session(object):
for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)):
self.uow.register_deleted(c)
- def merge(self, object, entity_name=None):
+ def merge(self, object, entity_name=None, _recursive=None):
"""copy the state of the given object onto the persistent object with the same identifier.
If there is no persistent instance currently associated with the session, it will be loaded.
@@ -331,6 +331,8 @@ class Session(object):
a newly persistent instance. The given instance does not become associated with the session.
This operation cascades to associated instances if the association is mapped with cascade="merge".
"""
+ if _recursive is None:
+ _recursive = util.Set()
mapper = _object_mapper(object)
key = getattr(object, '_instance_key', None)
if key is None:
@@ -341,7 +343,7 @@ class Session(object):
else:
merged = self.get(mapper.class_, key[1])
for prop in mapper.props.values():
- prop.merge(self, object, merged)
+ prop.merge(self, object, merged, _recursive)
if key is None:
self.save(merged)
return merged
diff --git a/test/orm/merge.py b/test/orm/merge.py
index 7a62b147c..cb36cc3b5 100644
--- a/test/orm/merge.py
+++ b/test/orm/merge.py
@@ -60,7 +60,7 @@ class MergeTest(AssertMixin):
def test_saved_cascade(self):
"""test merge of a persistent entity with two child persistent entities."""
mapper(User, users, properties={
- 'addresses':relation(mapper(Address, addresses))
+ 'addresses':relation(mapper(Address, addresses), backref='user')
})
sess = create_session()
@@ -108,7 +108,7 @@ class MergeTest(AssertMixin):
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses)),
- 'orders':relation(Order)
+ 'orders':relation(Order, backref='customer')
})
sess = create_session()
@@ -132,6 +132,12 @@ class MergeTest(AssertMixin):
u.orders[0].items[1].item_name = 'item 2 modified'
sess2.merge(u)
assert u2.orders[0].items[1].item_name == 'item 2 modified'
+
+ sess2 = create_session()
+ o2 = sess2.query(Order).get(o.order_id)
+ o.customer.user_name = 'also fred'
+ sess2.merge(o)
+ assert o2.customer.user_name == 'also fred'
if __name__ == "__main__":