summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/mapper.py86
-rw-r--r--lib/sqlalchemy/util.py2
-rw-r--r--test/mapper.py28
3 files changed, 83 insertions, 33 deletions
diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py
index 40968412d..d117553f3 100644
--- a/lib/sqlalchemy/mapper.py
+++ b/lib/sqlalchemy/mapper.py
@@ -231,6 +231,9 @@ class Mapper(object):
def foo():
for table in self.tables:
params = {}
+ # TODO: prepare the insert() and update() - (1) within the code or
+ # (2) as a real prepared statement, just once, and put them somewhere for
+ # an external loop to grab onto them
for primary_key in table.primary_keys:
if self._getattrbycolumn(obj, primary_key) is None:
statement = table.insert()
@@ -430,37 +433,45 @@ class PropertyLoader(MapperProperty):
self.primaryjoin = match_primaries(parent.selectable, self.target)
def save(self, obj, traverse, refetch):
- # if a mapping table does not exist, save a row for all objects
- # in our list normally, setting their primary keys
- # else, determine the foreign key column in our table, set it to the parent
- # of all child objects before saving
- # if a mapping table exists, determine the two foreign key columns
- # in the mapping table, set the two values, and insert that row, for
- # each row in the list
- if self.secondary is None:
- setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, obj)
- childlist = getattr(obj, self.key)
- if not isinstance(childlist, util.HistoryArraySet):
- childlist = util.HistoryArraySet(childlist)
- clean_setattr(obj, self.key, childlist)
- for child in childlist.added_items():
- setter.child = child
- self.primaryjoin.accept_visitor(setter)
- child.dirty = True
- for child in childlist.deleted_items():
- setter.child = child
- setter.clearkeys = True
- self.primaryjoin.accept_visitor(setter)
- child.dirty = True
- self.mapper.save(child)
- for child in childlist:
- self.mapper.save(child)
- # TODO: if transaction fails state is invalid
- # use unit of work ?
- childlist.clear_history()
- else:
- raise "TODO"
+ # saves child objects
+
+ # TODO: put association table inserts/deletes into one batch
+ #if self.secondary is not None:
+ # secondary_delete = self.secondary.delete(sql.and_([c == bindparam(c.key) for c in setter.secondary.c]))
+
+ setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj)
+ childlist = getattr(obj, self.key)
+ if not isinstance(childlist, util.HistoryArraySet):
+ childlist = util.HistoryArraySet(childlist)
+ clean_setattr(obj, self.key, childlist)
+ for child in childlist.deleted_items():
+ setter.child = child
+ setter.clearkeys = True
+ self.primaryjoin.accept_visitor(setter)
+ child.dirty = True
+ self.mapper.save(child)
+ if self.secondary is not None:
+ self.secondaryjoin.accept_visitor(setter)
+ # TODO: prepare this above
+ statement = self.secondary.delete(sql.and_(*[c == setter.associationrow[c.key] for c in self.secondary.c]))
+ statement.echo = self.mapper.echo
+ statement.execute()
+ for child in childlist.added_items():
+ setter.child = child
+ self.primaryjoin.accept_visitor(setter)
+ child.dirty = True
self.mapper.save(child)
+ if self.secondary is not None:
+ self.secondaryjoin.accept_visitor(setter)
+ # TODO: prepare this above
+ statement = self.secondary.insert()
+ statement.echo = self.mapper.echo
+ statement.execute(**setter.associationrow)
+ for child in childlist.unchanged_items():
+ self.mapper.save(child)
+ # TODO: if transaction fails state is invalid
+ # use unit of work ?
+ childlist.clear_history()
def delete(self):
@@ -615,17 +626,20 @@ class TableFinder(sql.ClauseVisitor):
self.tables.append(table)
class ForeignKeySetter(sql.ClauseVisitor):
- def __init__(self, parentmapper, childmapper, primarytable, secondarytable, obj):
+ def __init__(self, parentmapper, childmapper, primarytable, secondarytable, associationtable, obj):
self.parentmapper = parentmapper
self.childmapper = childmapper
self.primarytable = primarytable
self.secondarytable = secondarytable
+ self.associationtable = associationtable
self.obj = obj
+ self.associationrow = {}
self.clearkeys = False
self.child = None
def visit_binary(self, binary):
if binary.operator == '=':
+ # TODO: this code is silly
if binary.left.table == self.primarytable and binary.right.table == self.secondarytable:
if self.clearkeys:
self.childmapper._setattrbycolumn(self.child, binary.right, None)
@@ -636,7 +650,15 @@ class ForeignKeySetter(sql.ClauseVisitor):
self.childmapper._setattrbycolumn(self.child, binary.left, None)
else:
self.childmapper._setattrbycolumn(self.child, binary.left, self.parentmapper._getattrbycolumn(self.obj, binary.right))
-
+ elif binary.right.table == self.associationtable and binary.left.table == self.primarytable:
+ self.associationrow[binary.right.key] = self.parentmapper._getattrbycolumn(self.obj, binary.left)
+ elif binary.left.table == self.associationtable and binary.right.table == self.primarytable:
+ self.associationrow[binary.left.key] = self.parentmapper._getattrbycolumn(self.obj, binary.right)
+ elif binary.right.table == self.associationtable and binary.left.table == self.secondarytable:
+ self.associationrow[binary.right.key] = self.childmapper._getattrbycolumn(self.child, binary.left)
+ elif binary.left.table == self.associationtable and binary.right.table == self.secondarytable:
+ self.associationrow[binary.left.key] = self.childmapper._getattrbycolumn(self.child, binary.right)
+
class LazyIzer(sql.ClauseVisitor):
"""converts an expression which refers to a table column into an
expression refers to a Bind Param, i.e. a specific value.
diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py
index 1c86bf8c3..dd6d03045 100644
--- a/lib/sqlalchemy/util.py
+++ b/lib/sqlalchemy/util.py
@@ -157,6 +157,8 @@ class HistoryArraySet(UserList.UserList):
return [key for key, value in self.records.iteritems() if value is True]
def deleted_items(self):
return [key for key, value in self.records.iteritems() if value is False]
+ def unchanged_items(self):
+ return [key for key, value in self.records.iteritems() if value is None]
def append_nohistory(self, item):
if not self.records.has_key(item):
self.records[item] = None
diff --git a/test/mapper.py b/test/mapper.py
index ccd2a9fe4..7f129cf94 100644
--- a/test/mapper.py
+++ b/test/mapper.py
@@ -375,6 +375,32 @@ class SaveTest(PersistTest):
addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(a.address_id, a2.address_id)).execute()).fetchall()
self.assert_(addresstable[0].row == (a.address_id, u.user_id, 'one2many@test.org'))
self.assert_(addresstable[1].row == (a2.address_id, None, 'lala@test.org'))
-
+
+ def testmanytomany(self):
+ items = orderitems
+
+ m = mapper(Item, items, properties = dict(
+ keywords = relation(Keyword, keywords, itemkeywords, lazy = False),
+ ), echo = True)
+
+ keywordmapper = mapper(Keyword, keywords)
+
+ item = Item()
+ item.item_name = 'item1'
+ item.keywords = []
+ k = Keyword()
+ k.name = 'purple'
+ item.keywords.append(k)
+ klist = keywordmapper.select(keywords.c.name.in_('blue', 'big', 'round'))
+ for k in klist:
+ item.keywords.append(k)
+ m.save(item)
+ print repr(m.select(items.c.item_id == item.item_id))
+
+ del item.keywords[2]
+ del item.keywords[2]
+ m.save(item)
+ print repr(m.select(items.c.item_id == item.item_id))
+
if __name__ == "__main__":
unittest.main()