summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2011-11-21 22:00:50 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2011-11-21 22:00:50 -0500
commitef79d1ae3b404780d17e8615426eeb39be1ac670 (patch)
treefcb916d2e1bf7123a87c93b7a426186f1d5e63e4 /lib/sqlalchemy/sql/compiler.py
parent0c3a53d433d0adddfd16831380f8aea5d1fad176 (diff)
downloadsqlalchemy-ef79d1ae3b404780d17e8615426eeb39be1ac670.tar.gz
passes for all three, includes multi col system with mysql
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py38
1 files changed, 32 insertions, 6 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b77591912..92c0c7b38 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -994,8 +994,7 @@ class SQLCompiler(engine.Compiled):
def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms)
- def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
- return False
+ render_table_with_column_in_update = False
def visit_update(self, update_stmt, **kw):
self.stack.append({'from': set([update_stmt.table])})
@@ -1014,9 +1013,12 @@ class SQLCompiler(engine.Compiled):
# if hasattr(c[1], '_from_objects'):
# extra_froms.update(c[1]._from_objects)
- text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+ text = "UPDATE " + self.update_tables_clause(
+ update_stmt,
+ update_stmt.table,
+ extra_froms, **kw)
- if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms):
+ if extra_froms and self.render_table_with_column_in_update:
text += ' SET ' + \
', '.join(
self.visit_column(c[0]) +
@@ -1038,7 +1040,10 @@ class SQLCompiler(engine.Compiled):
update_stmt, update_stmt._returning)
if extra_froms:
- extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+ extra_from_text = self.update_from_clause(
+ update_stmt,
+ update_stmt.table,
+ extra_froms, **kw)
if extra_from_text:
text += " " + extra_from_text
@@ -1104,6 +1109,7 @@ class SQLCompiler(engine.Compiled):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(sql._column_as_key(k), v)
+
# create a list of column assignment clauses as tuples
values = []
@@ -1117,11 +1123,31 @@ class SQLCompiler(engine.Compiled):
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
+ check_columns = {}
+ if extra_tables and stmt.parameters:
+ for t in extra_tables:
+ for c in t.c:
+ if c in stmt.parameters:
+ check_columns[c.key] = c
+
+ for c in check_columns.values():
+ value = stmt.parameters[c]
+ if sql._is_literal(value):
+ value = self._create_crud_bind_param(
+ c, value, required=value is required)
+ elif c.primary_key and implicit_returning:
+ self.returning.append(c)
+ value = self.process(value.self_group())
+ else:
+ self.postfetch.append(c)
+ value = self.process(value.self_group())
+ values.append((c, value))
+
# iterating through columns at the top to maintain ordering.
# otherwise we might iterate through individual sets of
# "defaults", "primary key cols", etc.
for c in stmt.table.columns:
- if c.key in parameters:
+ if c.key in parameters and c.key not in check_columns:
value = parameters[c.key]
if sql._is_literal(value):
value = self._create_crud_bind_param(