diff options
-rw-r--r-- | CHANGES | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 10 | ||||
-rw-r--r-- | test/orm/test_update_delete.py | 13 |
3 files changed, 23 insertions, 7 deletions
@@ -142,10 +142,9 @@ underneath "0.7.xx". and if the parent table is referenced in the WHERE clause, the compiler will call upon UPDATE..FROM syntax as allowed by the dialect - to satisfy the WHERE clause. Target columns - must still be in the target table i.e. - does not support MySQL's multi-table update - feature (even though this is in Core). + to satisfy the WHERE clause. MySQL's multi-table + update feature is also supported if columns + are specified by object in the "values" dicitionary. PG's DELETE..USING is also not available in Core yet. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c56b7fc37..fd9718f1f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1446,14 +1446,18 @@ class SQLCompiler(engine.Compiled): # special logic that only occurs for multi-table UPDATE # statements if extra_tables and stmt.parameters: + normalized_params = dict( + (sql._clause_element_as_expr(c), param) + for c, param in stmt.parameters.items() + ) assert self.isupdate affected_tables = set() for t in extra_tables: for c in t.c: - if c in stmt.parameters: + if c in normalized_params: affected_tables.add(t) check_columns[c.key] = c - value = stmt.parameters[c] + value = normalized_params[c] if sql._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is required) @@ -1466,7 +1470,7 @@ class SQLCompiler(engine.Compiled): # server_onupdate for these for t in affected_tables: for c in t.c: - if c in stmt.parameters: + if c in normalized_params: continue elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index e6a429c90..e259c5229 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -642,3 +642,16 @@ class InheritTest(fixtures.DeclarativeMappedTest): set([('e1', 'e1', ), ('e2', 'e5')]) ) + @testing.only_on('mysql', 'Multi table update') + def test_update_from_multitable(self): + Engineer = self.classes.Engineer + Person = self.classes.Person + s = Session(testing.db) + s.query(Engineer).filter(Engineer.id == Person.id).\ + filter(Person.name == 'e2').update({Person.name: 'e22', + Engineer.engineer_name: 'e55'}) + + eq_( + set(s.query(Person.name, Engineer.engineer_name)), + set([('e1', 'e1', ), ('e22', 'e55')]) + ) |