diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 110 |
1 files changed, 80 insertions, 30 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5c5bfad55..4448f7c7b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -28,6 +28,7 @@ from . import schema, sqltypes, operators, functions, \ from .. import util, exc import decimal import itertools +import operator RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -1771,7 +1772,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, extra_froms, **kw) + colparams = self._get_colparams(update_stmt, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1840,7 +1841,40 @@ class SQLCompiler(Compiled): bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None, **kw): + @util.memoized_property + def _key_getters_for_crud_column(self): + if self.isupdate and self.statement._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(self.statement._extra_froms) + def _column_as_key(key): + str_key = elements._column_as_key(key) + if hasattr(key, 'table') and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = elements._column_as_key + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + def _get_colparams(self, stmt, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1869,12 +1903,18 @@ class SQLCompiler(Compiled): else: stmt_parameters = stmt.parameters + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + _column_as_key, _getattr_col_key, _col_bind_name = \ + self._key_getters_for_crud_column + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: - parameters = dict((elements._column_as_key(key), REQUIRED) + parameters = dict((_column_as_key(key), REQUIRED) for key in self.column_keys if not stmt_parameters or key not in stmt_parameters) @@ -1884,7 +1924,7 @@ class SQLCompiler(Compiled): if stmt_parameters is not None: for k, v in stmt_parameters.items(): - colkey = elements._column_as_key(k) + colkey = _column_as_key(k) if colkey is not None: parameters.setdefault(colkey, v) else: @@ -1892,7 +1932,9 @@ class SQLCompiler(Compiled): # add it to values() in an "as-is" state, # coercing right side to bound param if elements._is_literal(v): - v = self.process(elements.BindParameter(None, v, type_=k.type), **kw) + v = self.process( + elements.BindParameter(None, v, type_=k.type), + **kw) else: v = self.process(v.self_group(), **kw) @@ -1922,24 +1964,25 @@ class SQLCompiler(Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} + # special logic that only occurs for multi-table UPDATE # statements - if extra_tables and stmt_parameters: + if self.isupdate and stmt._extra_froms and stmt_parameters: normalized_params = dict( (elements._clause_element_as_expr(c), param) for c, param in stmt_parameters.items() ) - assert self.isupdate affected_tables = set() - for t in extra_tables: + for t in stmt._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) - check_columns[c.key] = c + check_columns[_getattr_col_key(c)] = c value = normalized_params[c] if elements._is_literal(value): value = self._create_crud_bind_param( - c, value, required=value is REQUIRED) + c, value, required=value is REQUIRED, + name=_col_bind_name(c)) else: self.postfetch.append(c) value = self.process(value.self_group(), **kw) @@ -1954,12 +1997,18 @@ class SQLCompiler(Compiled): elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group(), **kw)) + (c, self.process( + c.onupdate.arg.self_group(), + **kw) + ) ) self.postfetch.append(c) else: values.append( - (c, self._create_crud_bind_param(c, None)) + (c, self._create_crud_bind_param( + c, None, name=_col_bind_name(c) + ) + ) ) self.prefetch.append(c) elif c.server_onupdate is not None: @@ -1968,7 +2017,7 @@ class SQLCompiler(Compiled): if self.isinsert and stmt.select_names: # for an insert from select, we can only use names that # are given, so only select for those names. - cols = (stmt.table.c[elements._column_as_key(name)] + cols = (stmt.table.c[_column_as_key(name)] for name in stmt.select_names) else: # iterate through all table columns to maintain @@ -1976,14 +2025,15 @@ class SQLCompiler(Compiled): cols = stmt.table.columns for c in cols: - if c.key in parameters and c.key not in check_columns: - value = parameters.pop(c.key) + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + value = parameters.pop(col_key) if elements._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is REQUIRED, - name=c.key + name=_col_bind_name(c) if not stmt._has_multi_parameters - else "%s_0" % c.key + else "%s_0" % _col_bind_name(c) ) else: if isinstance(value, elements.BindParameter) and \ @@ -2119,12 +2169,12 @@ class SQLCompiler(Compiled): if parameters and stmt_parameters: check = set(parameters).intersection( - elements._column_as_key(k) for k in stmt.parameters + _column_as_key(k) for k in stmt.parameters ).difference(check_columns) if check: raise exc.CompileError( "Unconsumed column names: %s" % - (", ".join(check)) + (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -2133,17 +2183,17 @@ class SQLCompiler(Compiled): values.extend( [ - ( - c, - self._create_crud_bind_param( - c, row[c.key], - name="%s_%d" % (c.key, i + 1) - ) - if c.key in row else param - ) - for (c, param) in values_0 - ] - for i, row in enumerate(stmt.parameters[1:]) + ( + c, + self._create_crud_bind_param( + c, row[c.key], + name="%s_%d" % (c.key, i + 1) + ) + if c.key in row else param + ) + for (c, param) in values_0 + ] + for i, row in enumerate(stmt.parameters[1:]) ) return values |