diff options
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 193 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 10 | ||||
-rw-r--r-- | test/sql/test_update.py | 68 |
4 files changed, 234 insertions, 56 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8e273f67c..542bf58ac 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3286,7 +3286,7 @@ class SQLCompiler(Compiled): if crud_params_single or not supports_default_values: text += " (%s)" % ", ".join( - [preparer.format_column(c[0]) for c in crud_params_single] + [expr for c, expr, value in crud_params_single] ) if self.returning or insert_stmt._returning: @@ -3311,12 +3311,15 @@ class SQLCompiler(Compiled): elif compile_state._has_multi_parameters: text += " VALUES %s" % ( ", ".join( - "(%s)" % (", ".join(c[1] for c in crud_param_set)) + "(%s)" + % (", ".join(value for c, expr, value in crud_param_set)) for crud_param_set in crud_params ) ) else: - insert_single_values_expr = ", ".join([c[1] for c in crud_params]) + insert_single_values_expr = ", ".join( + [value for c, expr, value in crud_params] + ) text += " VALUES (%s)" % insert_single_values_expr if toplevel: self.insert_single_values_expr = insert_single_values_expr @@ -3424,15 +3427,7 @@ class SQLCompiler(Compiled): text += table_text text += " SET " - include_table = ( - is_multitable and self.render_table_with_column_in_update_from - ) - text += ", ".join( - c[0]._compiler_dispatch(self, include_table=include_table) - + "=" - + c[1] - for c in crud_params - ) + text += ", ".join(expr + "=" + value for c, expr, value in crud_params) if self.returning or update_stmt._returning: if self.returning_precedes_values: diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 85112f850..7d0616da7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: return [ - (c, _create_bind_param(compiler, c, None, required=True)) + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) for c in stmt.table.columns ] @@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ) check_columns = {} # special logic that only occurs for multi-table UPDATE # statements - if ( - compile_state.isupdate - and compile_state._extra_froms - and stmt_parameters - ): + if compile_state.isupdate and compile_state.is_multitable: _get_multitable_params( compiler, stmt, @@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # into INSERT (firstcol) VALUES (DEFAULT) which can be turned # into an in-place multi values. This supports # insert_executemany_returning mode :) - values = [(stmt.table.columns[0], "DEFAULT")] + values = [ + ( + stmt.table.columns[0], + compiler.preparer.format_column(stmt.table.columns[0]), + "DEFAULT", + ) + ] return values @@ -286,7 +298,7 @@ def _scan_insert_from_select_cols( col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: parameters.pop(col_key) - values.append((c, None)) + values.append((c, compiler.preparer.format_column(c), None)) else: _append_param_insert_select_hasdefault( compiler, stmt, c, add_select_cols, kw @@ -297,7 +309,7 @@ def _scan_insert_from_select_cols( compiler._insert_from_select = compiler._insert_from_select._generate() compiler._insert_from_select._raw_columns = tuple( compiler._insert_from_select._raw_columns - ) + tuple(expr for col, expr in add_select_cols) + ) + tuple(expr for col, col_expr, expr in add_select_cols) def _scan_cols( @@ -390,7 +402,13 @@ def _scan_cols( elif compile_state.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, ) @@ -410,6 +428,10 @@ def _append_param_parameter( value = parameters.pop(col_key) + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -446,7 +468,7 @@ def _append_param_parameter( if not c.primary_key: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): @@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process(c.default.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) ) compiler.returning.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): _warn_pk_with_no_anticipated_value(c) -def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_insert_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.insert_prefetch.append(c) return param -def _create_update_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_update_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.update_prefetch.append(c) return param @@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): else: col = _multiparam_column(c, index) if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col) + return _create_insert_prefetch_bind_param(compiler, col, **kw) else: - return _create_update_prefetch_bind_param(compiler, col) + return _create_update_prefetch_bind_param(compiler, col, **kw) def _append_param_insert_pk(compiler, stmt, c, values, kw): @@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): or compiler.dialect.preexecute_autoincrement_sequences ) ): - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -597,15 +648,25 @@ def _append_param_insert_hasdefault( if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: - proc = compiler.process(c.default.arg.self_group(), **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -613,7 +674,13 @@ def _append_param_insert_hasdefault( # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): @@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = c.default - values.append((c, proc.next_value())) + values.append( + (c, compiler.preparer.format_column(c), c.default.next_value()) + ) elif c.default.is_clause_element: - proc = c.default.arg.self_group() - values.append((c, proc)) + values.append( + (c, compiler.preparer.format_column(c), c.default.arg.self_group()) + ) else: values.append( - (c, _create_insert_prefetch_bind_param(compiler, c, process=False)) + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + ) ) def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw ): + include_table = compile_state.include_table_with_column_exprs if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process(c.onupdate.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + ) ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append((c, _create_update_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -676,6 +766,9 @@ def _get_multitable_params( (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) + + include_table = compile_state.include_table_with_column_exprs + affected_tables = set() for t in compile_state._extra_froms: for c in t.c: @@ -683,6 +776,8 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -699,7 +794,7 @@ def _get_multitable_params( else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: @@ -711,6 +806,7 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), compiler.process( c.onupdate.arg.self_group(), **kw ), @@ -721,8 +817,9 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c) + compiler, c, name=_col_bind_name(c), **kw ), ) ) @@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] - for (col, param) in values_0: + for (col, col_expr, param) in values_0: if col in row or col.key in row: key = col if col in row else col.key @@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): compiler, stmt, col, i, kw ) - extension.append((col, new_param)) + extension.append((col, col_expr, new_param)) values.append(extension) @@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ): + for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -773,6 +877,11 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw @@ -780,7 +889,7 @@ def _get_stmt_parameters_params( else: v = compiler.process(v.self_group(), **kw) - values.append((k, v)) + values.append((k, col_expr, v)) def _get_returning_modifiers(compiler, stmt, compile_state): diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 21476c1f9..930865898 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -133,6 +133,8 @@ class DMLState(CompileState): class InsertDMLState(DMLState): isinsert = True + include_table_with_column_exprs = False + def __init__(self, statement, compiler, **kw): self.statement = statement @@ -149,6 +151,8 @@ class InsertDMLState(DMLState): class UpdateDMLState(DMLState): isupdate = True + include_table_with_column_exprs = False + def __init__(self, statement, compiler, **kw): self.statement = statement self.isupdate = True @@ -159,7 +163,11 @@ class UpdateDMLState(DMLState): self._process_values(statement) elif statement._multi_values: self._process_multi_values(statement) - self._extra_froms = self._make_extra_froms(statement) + self._extra_froms = ef = self._make_extra_froms(statement) + self.is_multitable = mt = ef and self._dict_parameters + self.include_table_with_column_exprs = ( + mt and compiler.render_table_with_column_in_update_from + ) @CompileState.plugin_for("default", "delete") diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 18e9da654..5db5fed11 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -840,7 +840,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=mysql.dialect(), ) - def test_update_to_expression(self): + def test_update_to_expression_one(self): """test update from an expression. this logic is triggered currently by a left side that doesn't @@ -856,6 +856,72 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): "UPDATE mytable SET foo(myid)=:param_1", ) + def test_update_to_expression_two(self): + """test update from an expression. + + this logic is triggered currently by a left side that doesn't + have a key. The current supported use case is updating the index + of a PostgreSQL ARRAY type. + + """ + + from sqlalchemy import ARRAY + + t = table( + "foo", + column("data1", ARRAY(Integer)), + column("data2", ARRAY(Integer)), + ) + + stmt = t.update().values({t.c.data1[5]: 7, t.c.data2[10]: 18}) + dialect = default.StrCompileDialect() + dialect.paramstyle = "qmark" + dialect.positional = True + self.assert_compile( + stmt, + "UPDATE foo SET data1[?]=?, data2[?]=?", + dialect=dialect, + checkpositional=(5, 7, 10, 18), + ) + + def test_update_to_expression_three(self): + # this test is from test_defaults but exercises a particular + # parameter ordering issue + metadata = MetaData() + + q = Table( + "q", + metadata, + Column("x", Integer, default=2), + Column("y", Integer, onupdate=5), + Column("z", Integer), + ) + + p = Table( + "p", + metadata, + Column("s", Integer), + Column("t", Integer), + Column("u", Integer, onupdate=1), + ) + + cte = ( + q.update().where(q.c.z == 1).values(x=7).returning(q.c.z).cte("c") + ) + stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z) + + dialect = default.StrCompileDialect() + dialect.paramstyle = "qmark" + dialect.positional = True + + self.assert_compile( + stmt, + "WITH c AS (UPDATE q SET x=?, y=? WHERE q.z = ? RETURNING q.z) " + "SELECT p.s, c.z FROM p, c WHERE p.s = c.z", + checkpositional=(7, None, 1), + dialect=dialect, + ) + def test_update_bound_ordering(self): """test that bound parameters between the UPDATE and FROM clauses order correctly in different SQL compilation scenarios. |