diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-11-21 22:00:50 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-11-21 22:00:50 -0500 |
commit | ef79d1ae3b404780d17e8615426eeb39be1ac670 (patch) | |
tree | fcb916d2e1bf7123a87c93b7a426186f1d5e63e4 /lib/sqlalchemy/sql/compiler.py | |
parent | 0c3a53d433d0adddfd16831380f8aea5d1fad176 (diff) | |
download | sqlalchemy-ef79d1ae3b404780d17e8615426eeb39be1ac670.tar.gz |
passes for all three, includes multi col system with mysql
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b77591912..92c0c7b38 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -994,8 +994,7 @@ class SQLCompiler(engine.Compiled): def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms) - def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms): - return False + render_table_with_column_in_update = False def visit_update(self, update_stmt, **kw): self.stack.append({'from': set([update_stmt.table])}) @@ -1014,9 +1013,12 @@ class SQLCompiler(engine.Compiled): # if hasattr(c[1], '_from_objects'): # extra_froms.update(c[1]._from_objects) - text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) + text = "UPDATE " + self.update_tables_clause( + update_stmt, + update_stmt.table, + extra_froms, **kw) - if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms): + if extra_froms and self.render_table_with_column_in_update: text += ' SET ' + \ ', '.join( self.visit_column(c[0]) + @@ -1038,7 +1040,10 @@ class SQLCompiler(engine.Compiled): update_stmt, update_stmt._returning) if extra_froms: - extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw) + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + extra_froms, **kw) if extra_from_text: text += " " + extra_from_text @@ -1104,6 +1109,7 @@ class SQLCompiler(engine.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(sql._column_as_key(k), v) + # create a list of column assignment clauses as tuples values = [] @@ -1117,11 +1123,31 @@ class SQLCompiler(engine.Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid + check_columns = {} + if extra_tables and stmt.parameters: + for t in extra_tables: + for c in t.c: + if c in stmt.parameters: + check_columns[c.key] = c + + for c in check_columns.values(): + value = stmt.parameters[c] + if sql._is_literal(value): + value = self._create_crud_bind_param( + c, value, required=value is required) + elif c.primary_key and implicit_returning: + self.returning.append(c) + value = self.process(value.self_group()) + else: + self.postfetch.append(c) + value = self.process(value.self_group()) + values.append((c, value)) + # iterating through columns at the top to maintain ordering. # otherwise we might iterate through individual sets of # "defaults", "primary key cols", etc. for c in stmt.table.columns: - if c.key in parameters: + if c.key in parameters and c.key not in check_columns: value = parameters[c.key] if sql._is_literal(value): value = self._create_crud_bind_param( |