summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py110
1 files changed, 80 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5c5bfad55..4448f7c7b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -28,6 +28,7 @@ from . import schema, sqltypes, operators, functions, \
from .. import util, exc
import decimal
import itertools
+import operator
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -1771,7 +1772,7 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, extra_froms, **kw)
+ colparams = self._get_colparams(update_stmt, **kw)
if update_stmt._hints:
dialect_hints = dict([
@@ -1840,7 +1841,40 @@ class SQLCompiler(Compiled):
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
- def _get_colparams(self, stmt, extra_tables=None, **kw):
+ @util.memoized_property
+ def _key_getters_for_crud_column(self):
+ if self.isupdate and self.statement._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(self.statement._extra_froms)
+ def _column_as_key(key):
+ str_key = elements._column_as_key(key)
+ if hasattr(key, 'table') and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = elements._column_as_key
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+ def _get_colparams(self, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -1869,12 +1903,18 @@ class SQLCompiler(Compiled):
else:
stmt_parameters = stmt.parameters
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ _column_as_key, _getattr_col_key, _col_bind_name = \
+ self._key_getters_for_crud_column
+
# if we have statement parameters - set defaults in the
# compiled params
if self.column_keys is None:
parameters = {}
else:
- parameters = dict((elements._column_as_key(key), REQUIRED)
+ parameters = dict((_column_as_key(key), REQUIRED)
for key in self.column_keys
if not stmt_parameters or
key not in stmt_parameters)
@@ -1884,7 +1924,7 @@ class SQLCompiler(Compiled):
if stmt_parameters is not None:
for k, v in stmt_parameters.items():
- colkey = elements._column_as_key(k)
+ colkey = _column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
else:
@@ -1892,7 +1932,9 @@ class SQLCompiler(Compiled):
# add it to values() in an "as-is" state,
# coercing right side to bound param
if elements._is_literal(v):
- v = self.process(elements.BindParameter(None, v, type_=k.type), **kw)
+ v = self.process(
+ elements.BindParameter(None, v, type_=k.type),
+ **kw)
else:
v = self.process(v.self_group(), **kw)
@@ -1922,24 +1964,25 @@ class SQLCompiler(Compiled):
postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
check_columns = {}
+
# special logic that only occurs for multi-table UPDATE
# statements
- if extra_tables and stmt_parameters:
+ if self.isupdate and stmt._extra_froms and stmt_parameters:
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
for c, param in stmt_parameters.items()
)
- assert self.isupdate
affected_tables = set()
- for t in extra_tables:
+ for t in stmt._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
- check_columns[c.key] = c
+ check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
if elements._is_literal(value):
value = self._create_crud_bind_param(
- c, value, required=value is REQUIRED)
+ c, value, required=value is REQUIRED,
+ name=_col_bind_name(c))
else:
self.postfetch.append(c)
value = self.process(value.self_group(), **kw)
@@ -1954,12 +1997,18 @@ class SQLCompiler(Compiled):
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group(), **kw))
+ (c, self.process(
+ c.onupdate.arg.self_group(),
+ **kw)
+ )
)
self.postfetch.append(c)
else:
values.append(
- (c, self._create_crud_bind_param(c, None))
+ (c, self._create_crud_bind_param(
+ c, None, name=_col_bind_name(c)
+ )
+ )
)
self.prefetch.append(c)
elif c.server_onupdate is not None:
@@ -1968,7 +2017,7 @@ class SQLCompiler(Compiled):
if self.isinsert and stmt.select_names:
# for an insert from select, we can only use names that
# are given, so only select for those names.
- cols = (stmt.table.c[elements._column_as_key(name)]
+ cols = (stmt.table.c[_column_as_key(name)]
for name in stmt.select_names)
else:
# iterate through all table columns to maintain
@@ -1976,14 +2025,15 @@ class SQLCompiler(Compiled):
cols = stmt.table.columns
for c in cols:
- if c.key in parameters and c.key not in check_columns:
- value = parameters.pop(c.key)
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ value = parameters.pop(col_key)
if elements._is_literal(value):
value = self._create_crud_bind_param(
c, value, required=value is REQUIRED,
- name=c.key
+ name=_col_bind_name(c)
if not stmt._has_multi_parameters
- else "%s_0" % c.key
+ else "%s_0" % _col_bind_name(c)
)
else:
if isinstance(value, elements.BindParameter) and \
@@ -2119,12 +2169,12 @@ class SQLCompiler(Compiled):
if parameters and stmt_parameters:
check = set(parameters).intersection(
- elements._column_as_key(k) for k in stmt.parameters
+ _column_as_key(k) for k in stmt.parameters
).difference(check_columns)
if check:
raise exc.CompileError(
"Unconsumed column names: %s" %
- (", ".join(check))
+ (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
@@ -2133,17 +2183,17 @@ class SQLCompiler(Compiled):
values.extend(
[
- (
- c,
- self._create_crud_bind_param(
- c, row[c.key],
- name="%s_%d" % (c.key, i + 1)
- )
- if c.key in row else param
- )
- for (c, param) in values_0
- ]
- for i, row in enumerate(stmt.parameters[1:])
+ (
+ c,
+ self._create_crud_bind_param(
+ c, row[c.key],
+ name="%s_%d" % (c.key, i + 1)
+ )
+ if c.key in row else param
+ )
+ for (c, param) in values_0
+ ]
+ for i, row in enumerate(stmt.parameters[1:])
)
return values