diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-11-21 20:40:31 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2011-11-21 20:40:31 -0500 |
commit | 0c3a53d433d0adddfd16831380f8aea5d1fad176 (patch) | |
tree | 0c8dad33b2d636ef119a6feb4ffa0c0f3afd8e61 /lib/sqlalchemy/sql/compiler.py | |
parent | c0c42af4e0ef8acd651cc66e84ec636c14ab53a5 (diff) | |
download | sqlalchemy-0c3a53d433d0adddfd16831380f8aea5d1fad176.tar.gz |
sort of muscling this out, mysql a PITA
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 54 |
1 files changed, 47 insertions, 7 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8d7f2aab9..b77591912 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -985,15 +985,46 @@ class SQLCompiler(engine.Compiled): return text - def visit_update(self, update_stmt): + def update_limit_clause(self, update_stmt): + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return self.preparer.format_table(from_table) + + def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms) + + def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms): + return False + + def visit_update(self, update_stmt, **kw): self.stack.append({'from': set([update_stmt.table])}) self.isupdate = True - colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + if update_stmt._whereclause is not None: + extra_froms = set(update_stmt._whereclause._from_objects).\ + difference([update_stmt.table]) + else: + extra_froms = set() + + colparams = self._get_colparams(update_stmt, extra_froms) + + #for c in colparams: + # if hasattr(c[1], '_from_objects'): + # extra_froms.update(c[1]._from_objects) - text += ' SET ' + \ + text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) + + if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms): + text += ' SET ' + \ + ', '.join( + self.visit_column(c[0]) + + '=' + c[1] + for c in colparams + ) + else: + text += ' SET ' + \ ', '.join( self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] @@ -1006,9 +1037,18 @@ class SQLCompiler(engine.Compiled): text += " " + self.returning_clause( update_stmt, update_stmt._returning) + if extra_froms: + extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw) + if extra_from_text: + text += " " + extra_from_text + if update_stmt._whereclause is not None: text += " WHERE " + self.process(update_stmt._whereclause) + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, update_stmt._returning) @@ -1024,7 +1064,7 @@ class SQLCompiler(engine.Compiled): return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt): + def _get_colparams(self, stmt, extra_tables=None): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1100,7 +1140,7 @@ class SQLCompiler(engine.Compiled): ( implicit_returning or not postfetch_lastrowid or - c is not stmt.table._autoincrement_column + c is not t._autoincrement_column ): if implicit_returning: @@ -1127,7 +1167,7 @@ class SQLCompiler(engine.Compiled): self.returning.append(c) else: if c.default is not None or \ - c is stmt.table._autoincrement_column and ( + c is t._autoincrement_column and ( self.dialect.supports_sequences or self.dialect.preexecute_autoincrement_sequences ): |