summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES17
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py45
-rw-r--r--lib/sqlalchemy/orm/collections.py55
-rw-r--r--test/ext/associationproxy.py32
-rw-r--r--test/orm/collection.py162
5 files changed, 256 insertions, 55 deletions
diff --git a/CHANGES b/CHANGES
index b839ec7df..4deba3b49 100644
--- a/CHANGES
+++ b/CHANGES
@@ -4,22 +4,23 @@ CHANGES
0.4.3
-----
- orm
- - added very rudimentary yielding iterator behavior to Query. Call
- query.yield_per(<number of rows>) and evaluate the Query in an
+ - Added very rudimentary yielding iterator behavior to Query. Call
+ query.yield_per(<number of rows>) and evaluate the Query in an
iterative context; every collection of N rows will be packaged up
- and yielded. Use this method with extreme caution since it does
+ and yielded. Use this method with extreme caution since it does
not attempt to reconcile eagerly loaded collections across
result batch boundaries, nor will it behave nicely if the same
- instance occurs in more than one batch. This means that an eagerly
+ instance occurs in more than one batch. This means that an eagerly
loaded collection will get cleared out if it's referenced in more than
one batch, and in all cases attributes will be overwritten on instances
that occur in more than one batch.
-- dialects
+ - Fixed in-place set mutation operators for set collections and association
+ proxied sets. [ticket:920]
- - PostgreSQL
- - Fixed the missing call to subtype result processor for the PGArray
- type. [ticket:913]
+- dialects
+ - Fixed the missing call to subtype result processor for the PGArray
+ type. [ticket:913]
0.4.2
-----
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
index 472bd1b2c..c5a2b4d07 100644
--- a/lib/sqlalchemy/ext/associationproxy.py
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -176,8 +176,9 @@ class AssociationProxy(object):
self._scalar_set(target, values)
else:
proxy = self.__get__(obj, None)
- proxy.clear()
- self._set(proxy, values)
+ if proxy is not values:
+ proxy.clear()
+ self._set(proxy, values)
def __delete__(self, obj):
delattr(obj, self.key)
@@ -653,7 +654,12 @@ class _AssociationSet(object):
for value in other:
self.add(value)
- __ior__ = update
+ def __ior__(self, other):
+ if util.duck_type_collection(other) is not set:
+ return NotImplemented
+ for value in other:
+ self.add(value)
+ return self
def _set(self):
return util.Set(iter(self))
@@ -672,7 +678,12 @@ class _AssociationSet(object):
for value in other:
self.discard(value)
- __isub__ = difference_update
+ def __isub__(self, other):
+ if util.duck_type_collection(other) is not set:
+ return NotImplemented
+ for value in other:
+ self.discard(value)
+ return self
def intersection(self, other):
return util.Set(self).intersection(other)
@@ -689,7 +700,18 @@ class _AssociationSet(object):
for value in add:
self.add(value)
- __iand__ = intersection_update
+ def __iand__(self, other):
+ if util.duck_type_collection(other) is not set:
+ return NotImplemented
+ want, have = self.intersection(other), util.Set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
def symmetric_difference(self, other):
return util.Set(self).symmetric_difference(other)
@@ -706,7 +728,18 @@ class _AssociationSet(object):
for value in add:
self.add(value)
- __ixor__ = symmetric_difference_update
+ def __ixor__(self, other):
+ if util.duck_type_collection(other) is not set:
+ return NotImplemented
+ want, have = self.symmetric_difference(other), util.Set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
def issubset(self, other):
return util.Set(self).issubset(other)
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
index ddbf6f005..106601640 100644
--- a/lib/sqlalchemy/orm/collections.py
+++ b/lib/sqlalchemy/orm/collections.py
@@ -1138,7 +1138,17 @@ def _set_decorators():
self.add(item)
_tidy(update)
return update
- __ior__ = update
+
+ def __ior__(fn):
+ def __ior__(self, value):
+ if sautil.duck_type_collection(value) is not set:
+ return NotImplemented
+ for item in value:
+ if item not in self:
+ self.add(item)
+ return self
+ _tidy(__ior__)
+ return __ior__
def difference_update(fn):
def difference_update(self, value):
@@ -1146,7 +1156,16 @@ def _set_decorators():
self.discard(item)
_tidy(difference_update)
return difference_update
- __isub__ = difference_update
+
+ def __isub__(fn):
+ def __isub__(self, value):
+ if sautil.duck_type_collection(value) is not set:
+ return NotImplemented
+ for item in value:
+ self.discard(item)
+ return self
+ _tidy(__isub__)
+ return __isub__
def intersection_update(fn):
def intersection_update(self, other):
@@ -1159,7 +1178,21 @@ def _set_decorators():
self.add(item)
_tidy(intersection_update)
return intersection_update
- __iand__ = intersection_update
+
+ def __iand__(fn):
+ def __iand__(self, other):
+ if sautil.duck_type_collection(other) is not set:
+ return NotImplemented
+ want, have = self.intersection(other), sautil.Set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ return self
+ _tidy(__iand__)
+ return __iand__
def symmetric_difference_update(fn):
def symmetric_difference_update(self, other):
@@ -1172,7 +1205,21 @@ def _set_decorators():
self.add(item)
_tidy(symmetric_difference_update)
return symmetric_difference_update
- __ixor__ = symmetric_difference_update
+
+ def __ixor__(fn):
+ def __ixor__(self, other):
+ if sautil.duck_type_collection(other) is not set:
+ return NotImplemented
+ want, have = self.symmetric_difference(other), sautil.Set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ return self
+ _tidy(__ixor__)
+ return __ixor__
l = locals().copy()
l.pop('_tidy')
diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py
index fe8b40255..b3ce69a97 100644
--- a/test/ext/associationproxy.py
+++ b/test/ext/associationproxy.py
@@ -485,6 +485,38 @@ class SetTest(_CollectionOperations):
print 'got', repr(p.children)
raise
+ # in-place mutations
+ for op in ('|=', '-=', '&=', '^='):
+ for base in (['a', 'b', 'c'], []):
+ for other in (set(['a','b','c']), set(['a','b','c','d']),
+ set(['a']), set(['a','b']),
+ set(['c','d']), set(['e', 'f', 'g']),
+ set()):
+ p = Parent('p')
+ p.children = base[:]
+ control = set(base[:])
+
+ exec "p.children %s other" % op
+ exec "control %s other" % op
+
+ try:
+ self.assert_(p.children == control)
+ except:
+ print 'Test %s %s %s:' % (set(base), op, other)
+ print 'want', repr(control)
+ print 'got', repr(p.children)
+ raise
+
+ p = self.roundtrip(p)
+
+ try:
+ self.assert_(p.children == control)
+ except:
+ print 'Test %s %s %s:' % (base, op, other)
+ print 'want', repr(control)
+ print 'got', repr(p.children)
+ raise
+
class CustomSetTest(SetTest):
def __init__(self, *args, **kw):
diff --git a/test/orm/collection.py b/test/orm/collection.py
index 43b2f41e2..6e50a8512 100644
--- a/test/orm/collection.py
+++ b/test/orm/collection.py
@@ -74,12 +74,12 @@ class CollectionsTest(PersistTest):
adapter.append_with_event(e1)
assert_eq()
-
+
adapter.append_without_event(e2)
assert_ne()
canary.data.add(e2)
assert_eq()
-
+
adapter.remove_without_event(e2)
assert_ne()
canary.data.remove(e2)
@@ -91,7 +91,7 @@ class CollectionsTest(PersistTest):
def _test_list(self, typecallable, creator=entity_maker):
class Foo(object):
pass
-
+
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
@@ -106,7 +106,7 @@ class CollectionsTest(PersistTest):
self.assert_(set(direct) == canary.data)
self.assert_(set(adapter) == canary.data)
self.assert_(direct == control)
-
+
# assume append() is available for list tests
e = creator()
direct.append(e)
@@ -122,7 +122,7 @@ class CollectionsTest(PersistTest):
e = creator()
direct.append(e)
control.append(e)
-
+
e = creator()
direct[0] = e
control[0] = e
@@ -174,7 +174,7 @@ class CollectionsTest(PersistTest):
e = creator()
direct.append(e)
control.append(e)
-
+
direct.remove(e)
control.remove(e)
assert_eq()
@@ -204,7 +204,7 @@ class CollectionsTest(PersistTest):
direct[1::2] = values
control[1::2] = values
assert_eq()
-
+
if hasattr(direct, '__delslice__'):
for i in range(1, 4):
e = creator()
@@ -212,7 +212,7 @@ class CollectionsTest(PersistTest):
control.append(e)
del direct[-1:]
- del control[-1:]
+ del control[-1:]
assert_eq()
del direct[1:2]
@@ -321,7 +321,7 @@ class CollectionsTest(PersistTest):
return self.data == other
def __repr__(self):
return 'ListLike(%s)' % repr(self.data)
-
+
self._test_adapter(ListLike)
self._test_list(ListLike)
self._test_list_bulk(ListLike)
@@ -348,7 +348,7 @@ class CollectionsTest(PersistTest):
return self.data == other
def __repr__(self):
return 'ListIsh(%s)' % repr(self.data)
-
+
self._test_adapter(ListIsh)
self._test_list(ListIsh)
self._test_list_bulk(ListIsh)
@@ -382,7 +382,7 @@ class CollectionsTest(PersistTest):
for item in list(direct):
direct.remove(item)
control.clear()
-
+
# assume add() is available for list tests
addall(creator())
@@ -420,17 +420,35 @@ class CollectionsTest(PersistTest):
direct.discard(e)
self.assert_(e not in canary.removed)
assert_eq()
-
+
if hasattr(direct, 'update'):
+ zap()
e = creator()
addall(e)
-
+
values = set([e, creator(), creator()])
direct.update(values)
control.update(values)
assert_eq()
+ if hasattr(direct, '__ior__'):
+ zap()
+ e = creator()
+ addall(e)
+
+ values = set([e, creator(), creator()])
+
+ direct |= values
+ control |= values
+ assert_eq()
+
+ try:
+ direct |= [e, creator()]
+ assert False
+ except TypeError:
+ assert True
+
if hasattr(direct, 'clear'):
addall(creator(), creator())
direct.clear()
@@ -439,6 +457,7 @@ class CollectionsTest(PersistTest):
if hasattr(direct, 'difference_update'):
zap()
+ e = creator()
addall(creator(), creator())
values = set([creator()])
@@ -450,6 +469,26 @@ class CollectionsTest(PersistTest):
control.difference_update(values)
assert_eq()
+ if hasattr(direct, '__isub__'):
+ zap()
+ e = creator()
+ addall(creator(), creator())
+ values = set([creator()])
+
+ direct -= values
+ control -= values
+ assert_eq()
+ values.update(set([e, creator()]))
+ direct -= values
+ control -= values
+ assert_eq()
+
+ try:
+ direct -= [e, creator()]
+ assert False
+ except TypeError:
+ assert True
+
if hasattr(direct, 'intersection_update'):
zap()
e = creator()
@@ -465,6 +504,27 @@ class CollectionsTest(PersistTest):
control.intersection_update(values)
assert_eq()
+ if hasattr(direct, '__iand__'):
+ zap()
+ e = creator()
+ addall(e, creator(), creator())
+ values = set(control)
+
+ direct &= values
+ control &= values
+ assert_eq()
+
+ values.update(set([e, creator()]))
+ direct &= values
+ control &= values
+ assert_eq()
+
+ try:
+ direct &= [e, creator()]
+ assert False
+ except TypeError:
+ assert True
+
if hasattr(direct, 'symmetric_difference_update'):
zap()
e = creator()
@@ -487,6 +547,34 @@ class CollectionsTest(PersistTest):
control.symmetric_difference_update(values)
assert_eq()
+ if hasattr(direct, '__ixor__'):
+ zap()
+ e = creator()
+ addall(e, creator(), creator())
+
+ values = set([e, creator()])
+ direct ^= values
+ control ^= values
+ assert_eq()
+
+ e = creator()
+ addall(e)
+ values = set([e])
+ direct ^= values
+ control ^= values
+ assert_eq()
+
+ values = set()
+ direct ^= values
+ control ^= values
+ assert_eq()
+
+ try:
+ direct ^= [e, creator()]
+ assert False
+ except TypeError:
+ assert True
+
def _test_set_bulk(self, typecallable, creator=entity_maker):
class Foo(object):
pass
@@ -513,7 +601,7 @@ class CollectionsTest(PersistTest):
self.assert_(obj.attr == set([e2]))
self.assert_(e1 in canary.removed)
self.assert_(e2 in canary.added)
-
+
e3 = creator()
real_set = set([e3])
obj.attr = real_set
@@ -521,7 +609,7 @@ class CollectionsTest(PersistTest):
self.assert_(obj.attr == set([e3]))
self.assert_(e2 in canary.removed)
self.assert_(e3 in canary.added)
-
+
e4 = creator()
try:
obj.attr = [e4]
@@ -620,7 +708,7 @@ class CollectionsTest(PersistTest):
for item in list(adapter):
direct.remove(item)
control.clear()
-
+
# assume an 'set' method is available for tests
addall(creator())
@@ -655,7 +743,7 @@ class CollectionsTest(PersistTest):
direct.clear()
control.clear()
assert_eq()
-
+
direct.clear()
control.clear()
assert_eq()
@@ -678,7 +766,7 @@ class CollectionsTest(PersistTest):
zap()
e = creator()
addall(e)
-
+
direct.popitem()
control.popitem()
assert_eq()
@@ -907,7 +995,7 @@ class CollectionsTest(PersistTest):
def _test_object(self, typecallable, creator=entity_maker):
class Foo(object):
pass
-
+
canary = Canary()
attributes.register_class(Foo)
attributes.register_attribute(Foo, 'attr', True, extension=canary,
@@ -933,7 +1021,7 @@ class CollectionsTest(PersistTest):
direct.zark(e)
control.remove(e)
assert_eq()
-
+
e = creator()
direct.maybe_zark(e)
control.discard(e)
@@ -1035,7 +1123,7 @@ class CollectionsTest(PersistTest):
@collection.removes_return()
def pop(self, key):
return self.data.pop()
-
+
@collection.iterator
def __iter__(self):
return iter(self.data)
@@ -1136,14 +1224,14 @@ class CollectionsTest(PersistTest):
col1.append(e3)
self.assert_(e3 not in canary.data)
self.assert_(collections.collection_adapter(col1) is None)
-
+
obj.attr[0] = e3
self.assert_(e3 in canary.data)
class DictHelpersTest(ORMTest):
def define_tables(self, metadata):
global parents, children, Parent, Child
-
+
parents = Table('parents', metadata,
Column('id', Integer, primary_key=True),
Column('label', String))
@@ -1170,7 +1258,7 @@ class DictHelpersTest(ORMTest):
'children': relation(Child, collection_class=collection_class,
cascade="all, delete-orphan")
})
-
+
p = Parent()
p.children['foo'] = Child('foo', 'value')
p.children['bar'] = Child('bar', 'value')
@@ -1187,15 +1275,15 @@ class DictHelpersTest(ORMTest):
collections.collection_adapter(p.children).append_with_event(
Child('foo', 'newvalue'))
-
+
session.flush()
session.clear()
-
+
p = session.query(Parent).get(pid)
-
+
self.assert_(set(p.children.keys()) == set(['foo', 'bar']))
self.assert_(p.children['foo'].id != cid)
-
+
self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
session.flush()
session.clear()
@@ -1205,7 +1293,7 @@ class DictHelpersTest(ORMTest):
collections.collection_adapter(p.children).remove_with_event(
p.children['foo'])
-
+
self.assert_(len(list(collections.collection_adapter(p.children))) == 1)
session.flush()
session.clear()
@@ -1220,7 +1308,7 @@ class DictHelpersTest(ORMTest):
p = session.query(Parent).get(pid)
self.assert_(len(list(collections.collection_adapter(p.children))) == 0)
-
+
def _test_composite_mapped(self, collection_class):
mapper(Child, children)
@@ -1228,7 +1316,7 @@ class DictHelpersTest(ORMTest):
'children': relation(Child, collection_class=collection_class,
cascade="all, delete-orphan")
})
-
+
p = Parent()
p.children[('foo', '1')] = Child('foo', '1', 'value 1')
p.children[('foo', '2')] = Child('foo', '2', 'value 2')
@@ -1238,7 +1326,7 @@ class DictHelpersTest(ORMTest):
session.flush()
pid = p.id
session.clear()
-
+
p = session.query(Parent).get(pid)
self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
@@ -1246,17 +1334,17 @@ class DictHelpersTest(ORMTest):
collections.collection_adapter(p.children).append_with_event(
Child('foo', '1', 'newvalue'))
-
+
session.flush()
session.clear()
-
+
p = session.query(Parent).get(pid)
-
+
self.assert_(set(p.children.keys()) == set([('foo', '1'), ('foo', '2')]))
self.assert_(p.children[('foo', '1')].id != cid)
-
+
self.assert_(len(list(collections.collection_adapter(p.children))) == 2)
-
+
def test_mapped_collection(self):
collection_class = collections.mapped_collection(lambda c: c.a)
self._test_scalar_mapped(collection_class)