summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2008-04-01 17:13:09 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2008-04-01 17:13:09 +0000
commitad231da3b83bcdad4446690fa37fbe03408a40d6 (patch)
treeaf2efdd02ba000e2b994b7e191f964bed336a841
parent1e0a91fe81cd8cf38603c7147ebf2e79301be6f5 (diff)
downloadsqlalchemy-ad231da3b83bcdad4446690fa37fbe03408a40d6.tar.gz
- merge() may actually work now, though we've heard that before...
- merge() uses the priamry key attributes on the object if _instance_key not present. so merging works for instances that dont have an instnace_key, will still issue UPDATE for existing rows. - improved collection behavior for merge() - will remove elements from a destination collection that are not in the source. - fixed naive set-mutation issue in Select._get_display_froms - simplified fixtures.Base a bit
-rw-r--r--CHANGES27
-rw-r--r--lib/sqlalchemy/orm/attributes.py2
-rw-r--r--lib/sqlalchemy/orm/properties.py14
-rw-r--r--lib/sqlalchemy/orm/session.py22
-rw-r--r--lib/sqlalchemy/sql/expression.py6
-rw-r--r--test/orm/merge.py178
-rw-r--r--test/testlib/fixtures.py16
7 files changed, 197 insertions, 68 deletions
diff --git a/CHANGES b/CHANGES
index 1fc3cda0c..3450bb906 100644
--- a/CHANGES
+++ b/CHANGES
@@ -5,6 +5,25 @@ CHANGES
0.4.5
=====
- orm
+ - a small change in behavior to session.merge() - existing
+ objects are checked for based on primary key attributes,
+ not necessarily _instance_key. So the widely requested
+ capability, that:
+
+ x = MyObject(id=1)
+ x = sess.merge(x)
+
+ will in fact load MyObject with id #1 from the database
+ if present, is now available. merge() still
+ copies the state of the given object to the persistent
+ one, so an example like the above would typically have
+ copied "None" from all attributes of "x" onto the persistent
+ copy. These can be reverted using session.expire(x).
+
+ - also fixed behavior in merge() whereby collection elements
+ present on the destination but not the merged collection
+ were not being removed from the destination.
+
- Added a more aggressive check for "uncompiled mappers",
helps particularly with declarative layer [ticket:995]
@@ -158,7 +177,13 @@ CHANGES
- random() is now a generic sql function and will compile to
the database's random implementation, if any.
-
+
+ - fixed an issue in select() regarding its generation of
+ FROM clauses, in rare circumstances two clauses could
+ be produced when one was intended to cancel out the
+ other. Some ORM queries with lots of eager loads
+ might have seen this symptom.
+
- declarative extension
- The "synonym" function is now directly usable with
"declarative". Pass in the decorated property using the
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
index a511c9bbb..f57298d7c 100644
--- a/lib/sqlalchemy/orm/attributes.py
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -569,7 +569,7 @@ class CollectionAttributeImpl(AttributeImpl):
self.fire_remove_event(state, value, initiator)
else:
collection.remove_with_event(value, initiator)
-
+
def set(self, state, value, initiator):
"""Set a value on the given object.
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
index 970e49ea4..d050f40d7 100644
--- a/lib/sqlalchemy/orm/properties.py
+++ b/lib/sqlalchemy/orm/properties.py
@@ -416,7 +416,6 @@ class PropertyLoader(StrategizedProperty):
return
if not "merge" in self.cascade:
- # TODO: lazy callable should merge to the new instance
dest._state.expire_attributes([self.key])
return
@@ -425,15 +424,18 @@ class PropertyLoader(StrategizedProperty):
return
if self.uselist:
- dest_list = attributes.init_collection(dest, self.key)
+ dest_list = []
for current in instances:
_recursive[(current, self)] = True
obj = session.merge(current, entity_name=self.mapper.entity_name, dont_load=dont_load, _recursive=_recursive)
if obj is not None:
- if dont_load:
- dest_list.append_without_event(obj)
- else:
- dest_list.append_with_event(obj)
+ dest_list.append(obj)
+ if dont_load:
+ coll = attributes.init_collection(dest, self.key)
+ for c in dest_list:
+ coll.append_without_event(c)
+ else:
+ getattr(dest.__class__, self.key).impl._set_iterable(dest._state, dest_list)
else:
current = instances[0]
if current is not None:
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 391bc925b..b7a4aa911 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -955,8 +955,10 @@ class Session(object):
if key is None:
if dont_load:
raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True.")
- merged = attributes.new_instance(mapper.class_)
- else:
+ key = mapper.identity_key_from_instance(instance)
+
+ merged = None
+ if key:
if key in self.identity_map:
merged = self.identity_map[key]
elif dont_load:
@@ -969,15 +971,19 @@ class Session(object):
self._update_impl(merged, entity_name=mapper.entity_name)
else:
merged = self.get(mapper.class_, key[1])
- if merged is None:
- raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(instance))
+
+ if merged is None:
+ merged = attributes.new_instance(mapper.class_)
+ self.save(merged, entity_name=mapper.entity_name)
+
_recursive[instance] = merged
+
for prop in mapper.iterate_properties:
prop.merge(self, instance, merged, dont_load, _recursive)
- if key is None:
- self.save(merged, entity_name=mapper.entity_name)
- elif dont_load:
- merged._state.commit_all()
+
+ if dont_load:
+ merged._state.commit_all() # remove any history
+
return merged
def identity_key(cls, *args, **kwargs):
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 2cd10720a..758f75ebe 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -3096,9 +3096,9 @@ class Select(_SelectBaseMixin, FromClause):
if self._froms:
froms.update(self._froms)
-
- for f in froms:
- froms.difference_update(f._hide_froms)
+
+ toremove = itertools.chain(*[f._hide_froms for f in froms])
+ froms.difference_update(toremove)
if len(froms) > 1 or self.__correlate:
if self.__correlate:
diff --git a/test/orm/merge.py b/test/orm/merge.py
index b9a247289..fd61ccc28 100644
--- a/test/orm/merge.py
+++ b/test/orm/merge.py
@@ -3,6 +3,7 @@ from sqlalchemy import *
from sqlalchemy import exceptions
from sqlalchemy.orm import *
from sqlalchemy.orm import mapperlib
+from sqlalchemy.util import OrderedSet
from testlib import *
from testlib import fixtures
from testlib.tables import *
@@ -12,31 +13,139 @@ class MergeTest(TestBase, AssertsExecutionResults):
"""tests session.merge() functionality"""
def setUpAll(self):
tables.create()
+
def tearDownAll(self):
tables.drop()
+
def tearDown(self):
clear_mappers()
tables.delete()
- def setUp(self):
- pass
- def test_unsaved(self):
- """test merge of a single transient entity."""
+ def test_transient_to_pending(self):
+ class User(fixtures.Base):
+ pass
mapper(User, users)
sess = create_session()
- u = User()
- u.user_id = 7
- u.user_name = "fred"
+ u = User(user_id=7, user_name='fred')
u2 = sess.merge(u)
assert u2 in sess
- assert u2.user_id == 7
- assert u2.user_name == 'fred'
+ self.assertEquals(u2, User(user_id=7, user_name='fred'))
sess.flush()
sess.clear()
- u2 = sess.query(User).get(7)
- assert u2.user_name == 'fred'
+ self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred'))
+
+ def test_transient_to_pending_collection(self):
+ class User(fixtures.Base):
+ pass
+ class Address(fixtures.Base):
+ pass
+ mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+ mapper(Address, addresses)
+ u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+ Address(address_id=1, email_address='fred1'),
+ Address(address_id=2, email_address='fred2'),
+ ]))
+ sess = create_session()
+ sess.merge(u)
+ sess.flush()
+ sess.clear()
+
+ self.assertEquals(sess.query(User).one(),
+ User(user_id=7, user_name='fred', addresses=OrderedSet([
+ Address(address_id=1, email_address='fred1'),
+ Address(address_id=2, email_address='fred2'),
+ ]))
+ )
+
+ def test_transient_to_persistent(self):
+ class User(fixtures.Base):
+ pass
+ mapper(User, users)
+ sess = create_session()
+ u = User(user_id=7, user_name='fred')
+ sess.save(u)
+ sess.flush()
+ sess.clear()
+
+ u2 = User(user_id=7, user_name='fred jones')
+ u2 = sess.merge(u2)
+ sess.flush()
+ sess.clear()
+ self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred jones'))
+
+ def test_transient_to_persistent_collection(self):
+ class User(fixtures.Base):
+ pass
+ class Address(fixtures.Base):
+ pass
+ mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+ mapper(Address, addresses)
+
+ u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+ Address(address_id=1, email_address='fred1'),
+ Address(address_id=2, email_address='fred2'),
+ ]))
+ sess = create_session()
+ sess.save(u)
+ sess.flush()
+ sess.clear()
+
+ u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+ Address(address_id=3, email_address='fred3'),
+ Address(address_id=4, email_address='fred4'),
+ ]))
+
+ u = sess.merge(u)
+ self.assertEquals(u,
+ User(user_id=7, user_name='fred', addresses=OrderedSet([
+ Address(address_id=3, email_address='fred3'),
+ Address(address_id=4, email_address='fred4'),
+ ]))
+ )
+ sess.flush()
+ sess.clear()
+ self.assertEquals(sess.query(User).one(),
+ User(user_id=7, user_name='fred', addresses=OrderedSet([
+ Address(address_id=3, email_address='fred3'),
+ Address(address_id=4, email_address='fred4'),
+ ]))
+ )
+
+ def test_detached_to_persistent_collection(self):
+ class User(fixtures.Base):
+ pass
+ class Address(fixtures.Base):
+ pass
+ mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+ mapper(Address, addresses)
+
+ a = Address(address_id=1, email_address='fred1')
+ u = User(user_id=7, user_name='fred', addresses=OrderedSet([
+ a,
+ Address(address_id=2, email_address='fred2'),
+ ]))
+ sess = create_session()
+ sess.save(u)
+ sess.flush()
+ sess.clear()
+
+ u.user_name='fred jones'
+ u.addresses.add(Address(address_id=3, email_address='fred3'))
+ u.addresses.remove(a)
+
+ u = sess.merge(u)
+ sess.flush()
+ sess.clear()
+
+ self.assertEquals(sess.query(User).first(),
+ User(user_id=7, user_name='fred jones', addresses=OrderedSet([
+ Address(address_id=2, email_address='fred2'),
+ Address(address_id=3, email_address='fred3'),
+ ]))
+ )
+
def test_unsaved_cascade(self):
"""test merge of a transient entity with two child transient entities, with a bidirectional relation."""
@@ -63,18 +172,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
u2 = sess.query(User).get(7)
self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
- def test_transient_dontload(self):
- mapper(User, users)
-
- sess = create_session()
- u = User()
- try:
- u2 = sess.merge(u, dont_load=True)
- assert False
- except exceptions.InvalidRequestError, err:
- assert str(err) == "merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True."
-
- def test_saved_cascade(self):
+ def test_attribute_cascade(self):
"""test merge of a persistent entity with two child persistent entities."""
class User(fixtures.Base):
@@ -132,7 +230,6 @@ class MergeTest(TestBase, AssertsExecutionResults):
# test with "dontload" merge
sess5 = create_session()
- print "------------------"
u = sess5.merge(u, dont_load=True)
assert len(u.addresses)
for a in u.addresses:
@@ -158,7 +255,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
assert u2.user_name == 'fred2'
assert u2.addresses[1].email_address == 'afafds'
- def test_saved_cascade_2(self):
+ def test_one_to_many_cascade(self):
mapper(Order, orders, properties={
'items':relation(mapper(Item, orderitems))
@@ -197,8 +294,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess2.merge(o)
assert o2.customer.user_name == 'also fred'
- def test_saved_cascade_3(self):
- """test merge of a persistent entity with one_to_one relationship"""
+ def test_one_to_one_cascade(self):
mapper(User, users, properties={
'address':relation(mapper(Address, addresses),uselist = False)
@@ -221,6 +317,14 @@ class MergeTest(TestBase, AssertsExecutionResults):
u3 = sess.merge(u2)
+ def test_transient_dontload(self):
+ mapper(User, users)
+
+ sess = create_session()
+ u = User()
+ self.assertRaisesMessage(exceptions.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+
+
def test_dontload_with_backrefs(self):
"""test that dontload populates relations in both directions without requiring a load"""
@@ -254,8 +358,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
self.assertEquals(u.addresses[1].user, User(user_id=7, user_name='fred'))
- def test_noload_with_eager(self):
- """this test illustrates that with noload=True, we can't just
+ def test_dontload_with_eager(self):
+ """this test illustrates that with dont_load=True, we can't just
copy the committed_state of the merged instance over; since it references collection objects
which themselves are to be merged. This committed_state would instead need to be piecemeal
'converted' to represent the correct objects.
@@ -286,8 +390,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess3.flush()
self.assert_sql_count(testing.db, go, 0)
- def test_noload_disallows_dirty(self):
- """noload doesnt support 'dirty' objects right now (see test_noload_with_eager()).
+ def test_dont_load_disallows_dirty(self):
+ """dont_load doesnt support 'dirty' objects right now (see test_dont_load_with_eager()).
Therefore lets assert it."""
mapper(User, users)
@@ -315,8 +419,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
sess3.flush()
self.assert_sql_count(testing.db, go, 0)
- def test_noload_sets_entityname(self):
- """test that a noload-merged entity has entity_name set, has_mapper() passes, and lazyloads work"""
+ def test_dont_load_sets_entityname(self):
+ """test that a dont_load-merged entity has entity_name set, has_mapper() passes, and lazyloads work"""
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses),uselist = True)
})
@@ -346,7 +450,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
assert len(u2.addresses) == 1
self.assert_sql_count(testing.db, go, 1)
- def test_noload_sets_backrefs(self):
+ def test_dont_load_sets_backrefs(self):
mapper(User, users, properties={
'addresses':relation(mapper(Address, addresses),backref='user')
})
@@ -370,10 +474,10 @@ class MergeTest(TestBase, AssertsExecutionResults):
assert u2.addresses[0].user is u2
self.assert_sql_count(testing.db, go, 0)
- def test_noload_preserves_parents(self):
- """test that merge with noload does not trigger a 'delete-orphan' operation.
+ def test_dont_load_preserves_parents(self):
+ """test that merge with dont_load does not trigger a 'delete-orphan' operation.
- merge with noload sets attributes without using events. this means the
+ merge with dont_load sets attributes without using events. this means the
'hasparent' flag is not propagated to the newly merged instance. in fact this
works out OK, because the '_state.parents' collection on the newly
merged instance is empty; since the mapper doesn't see an active 'False' setting
diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py
index bbd27a39f..a1aa717e9 100644
--- a/test/testlib/fixtures.py
+++ b/test/testlib/fixtures.py
@@ -54,19 +54,11 @@ class Base(object):
except AttributeError:
#print "b class does not have attribute named '%s'" % attr
return False
- #print "other:", battr
- if not hasattr(value, '__len__'):
- value = list(iter(value))
- battr = list(iter(battr))
- if len(value) != len(battr):
- #print "Length of collection '%s' does not match that of b" % attr
- return False
- for (us, them) in zip(value, battr):
- if us != them:
- #print "1. Attribute named '%s' does not match b" % attr
- return False
- else:
+
+ if list(value) == list(battr):
continue
+ else:
+ return False
else:
if value is not None:
if value != getattr(b, attr, None):