summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-11-21 20:40:31 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2011-11-21 20:40:31 -0500
commit0c3a53d433d0adddfd16831380f8aea5d1fad176 (patch)
tree0c8dad33b2d636ef119a6feb4ffa0c0f3afd8e61 /lib/sqlalchemy/sql/compiler.py
parentc0c42af4e0ef8acd651cc66e84ec636c14ab53a5 (diff)
downloadsqlalchemy-0c3a53d433d0adddfd16831380f8aea5d1fad176.tar.gz
sort of muscling this out, mysql a PITA
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py54
1 files changed, 47 insertions, 7 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8d7f2aab9..b77591912 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -985,15 +985,46 @@ class SQLCompiler(engine.Compiled):
return text
- def visit_update(self, update_stmt):
+ def update_limit_clause(self, update_stmt):
+ return None
+
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return self.preparer.format_table(from_table)
+
+ def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms)
+
+ def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
+ return False
+
+ def visit_update(self, update_stmt, **kw):
self.stack.append({'from': set([update_stmt.table])})
self.isupdate = True
- colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+ if update_stmt._whereclause is not None:
+ extra_froms = set(update_stmt._whereclause._from_objects).\
+ difference([update_stmt.table])
+ else:
+ extra_froms = set()
+
+ colparams = self._get_colparams(update_stmt, extra_froms)
+
+ #for c in colparams:
+ # if hasattr(c[1], '_from_objects'):
+ # extra_froms.update(c[1]._from_objects)
- text += ' SET ' + \
+ text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+
+ if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms):
+ text += ' SET ' + \
+ ', '.join(
+ self.visit_column(c[0]) +
+ '=' + c[1]
+ for c in colparams
+ )
+ else:
+ text += ' SET ' + \
', '.join(
self.preparer.quote(c[0].name, c[0].quote) +
'=' + c[1]
@@ -1006,9 +1037,18 @@ class SQLCompiler(engine.Compiled):
text += " " + self.returning_clause(
update_stmt, update_stmt._returning)
+ if extra_froms:
+ extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+ if extra_from_text:
+ text += " " + extra_from_text
+
if update_stmt._whereclause is not None:
text += " WHERE " + self.process(update_stmt._whereclause)
+ limit_clause = self.update_limit_clause(update_stmt)
+ if limit_clause:
+ text += " " + limit_clause
+
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, update_stmt._returning)
@@ -1024,7 +1064,7 @@ class SQLCompiler(engine.Compiled):
return bindparam._compiler_dispatch(self)
- def _get_colparams(self, stmt):
+ def _get_colparams(self, stmt, extra_tables=None):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -1100,7 +1140,7 @@ class SQLCompiler(engine.Compiled):
(
implicit_returning or
not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
+ c is not t._autoincrement_column
):
if implicit_returning:
@@ -1127,7 +1167,7 @@ class SQLCompiler(engine.Compiled):
self.returning.append(c)
else:
if c.default is not None or \
- c is stmt.table._autoincrement_column and (
+ c is t._autoincrement_column and (
self.dialect.supports_sequences or
self.dialect.preexecute_autoincrement_sequences
):