summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/sqlalchemy/sql/compiler.py19
-rw-r--r--lib/sqlalchemy/sql/crud.py193
-rw-r--r--lib/sqlalchemy/sql/dml.py10
-rw-r--r--test/sql/test_update.py68
4 files changed, 234 insertions, 56 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 8e273f67c..542bf58ac 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -3286,7 +3286,7 @@ class SQLCompiler(Compiled):
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
- [preparer.format_column(c[0]) for c in crud_params_single]
+ [expr for c, expr, value in crud_params_single]
)
if self.returning or insert_stmt._returning:
@@ -3311,12 +3311,15 @@ class SQLCompiler(Compiled):
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
- "(%s)" % (", ".join(c[1] for c in crud_param_set))
+ "(%s)"
+ % (", ".join(value for c, expr, value in crud_param_set))
for crud_param_set in crud_params
)
)
else:
- insert_single_values_expr = ", ".join([c[1] for c in crud_params])
+ insert_single_values_expr = ", ".join(
+ [value for c, expr, value in crud_params]
+ )
text += " VALUES (%s)" % insert_single_values_expr
if toplevel:
self.insert_single_values_expr = insert_single_values_expr
@@ -3424,15 +3427,7 @@ class SQLCompiler(Compiled):
text += table_text
text += " SET "
- include_table = (
- is_multitable and self.render_table_with_column_in_update_from
- )
- text += ", ".join(
- c[0]._compiler_dispatch(self, include_table=include_table)
- + "="
- + c[1]
- for c in crud_params
- )
+ text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 85112f850..7d0616da7 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
return [
- (c, _create_bind_param(compiler, c, None, required=True))
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_bind_param(compiler, c, None, required=True),
+ )
for c in stmt.table.columns
]
@@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
if stmt_parameters is not None:
_get_stmt_parameters_params(
- compiler, parameters, stmt_parameters, _column_as_key, values, kw
+ compiler,
+ compile_state,
+ parameters,
+ stmt_parameters,
+ _column_as_key,
+ values,
+ kw,
)
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# statements
- if (
- compile_state.isupdate
- and compile_state._extra_froms
- and stmt_parameters
- ):
+ if compile_state.isupdate and compile_state.is_multitable:
_get_multitable_params(
compiler,
stmt,
@@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
# into an in-place multi values. This supports
# insert_executemany_returning mode :)
- values = [(stmt.table.columns[0], "DEFAULT")]
+ values = [
+ (
+ stmt.table.columns[0],
+ compiler.preparer.format_column(stmt.table.columns[0]),
+ "DEFAULT",
+ )
+ ]
return values
@@ -286,7 +298,7 @@ def _scan_insert_from_select_cols(
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
parameters.pop(col_key)
- values.append((c, None))
+ values.append((c, compiler.preparer.format_column(c), None))
else:
_append_param_insert_select_hasdefault(
compiler, stmt, c, add_select_cols, kw
@@ -297,7 +309,7 @@ def _scan_insert_from_select_cols(
compiler._insert_from_select = compiler._insert_from_select._generate()
compiler._insert_from_select._raw_columns = tuple(
compiler._insert_from_select._raw_columns
- ) + tuple(expr for col, expr in add_select_cols)
+ ) + tuple(expr for col, col_expr, expr in add_select_cols)
def _scan_cols(
@@ -390,7 +402,13 @@ def _scan_cols(
elif compile_state.isupdate:
_append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw
+ compiler,
+ compile_state,
+ stmt,
+ c,
+ implicit_return_defaults,
+ values,
+ kw,
)
@@ -410,6 +428,10 @@ def _append_param_parameter(
value = parameters.pop(col_key)
+ col_value = compiler.preparer.format_column(
+ c, use_table=compile_state.include_table_with_column_exprs
+ )
+
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -446,7 +468,7 @@ def _append_param_parameter(
if not c.primary_key:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
- values.append((c, value))
+ values.append((c, col_value, value))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
@@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
not c.default.optional
or not compiler.dialect.sequences_optional
):
- proc = compiler.process(c.default, **kw)
- values.append((c, proc))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default, **kw),
+ )
+ )
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
- (c, compiler.process(c.default.arg.self_group(), **kw))
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default.arg.self_group(), **kw),
+ )
)
compiler.returning.append(c)
else:
- values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
@@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
_warn_pk_with_no_anticipated_value(c)
-def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
- param = _create_bind_param(compiler, c, None, process=process, name=name)
+def _create_insert_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
compiler.insert_prefetch.append(c)
return param
-def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
- param = _create_bind_param(compiler, c, None, process=process, name=name)
+def _create_update_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
compiler.update_prefetch.append(c)
return param
@@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
- return _create_insert_prefetch_bind_param(compiler, col)
+ return _create_insert_prefetch_bind_param(compiler, col, **kw)
else:
- return _create_update_prefetch_bind_param(compiler, col)
+ return _create_update_prefetch_bind_param(compiler, col, **kw)
def _append_param_insert_pk(compiler, stmt, c, values, kw):
@@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
or compiler.dialect.preexecute_autoincrement_sequences
)
):
- values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
@@ -597,15 +648,25 @@ def _append_param_insert_hasdefault(
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
- proc = compiler.process(c.default, **kw)
- values.append((c, proc))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default, **kw),
+ )
+ )
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
- proc = compiler.process(c.default.arg.self_group(), **kw)
- values.append((c, proc))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default.arg.self_group(), **kw),
+ )
+ )
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
@@ -613,7 +674,13 @@ def _append_param_insert_hasdefault(
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
- values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
@@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
- proc = c.default
- values.append((c, proc.next_value()))
+ values.append(
+ (c, compiler.preparer.format_column(c), c.default.next_value())
+ )
elif c.default.is_clause_element:
- proc = c.default.arg.self_group()
- values.append((c, proc))
+ values.append(
+ (c, compiler.preparer.format_column(c), c.default.arg.self_group())
+ )
else:
values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c, process=False))
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(
+ compiler, c, process=False, **kw
+ ),
+ )
)
def _append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw
+ compiler, compile_state, stmt, c, implicit_return_defaults, values, kw
):
+ include_table = compile_state.include_table_with_column_exprs
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(c.onupdate.arg.self_group(), **kw))
+ (
+ c,
+ compiler.preparer.format_column(
+ c, use_table=include_table,
+ ),
+ compiler.process(c.onupdate.arg.self_group(), **kw),
+ )
)
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
- values.append((c, _create_update_prefetch_bind_param(compiler, c)))
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(
+ c, use_table=include_table,
+ ),
+ _create_update_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
elif c.server_onupdate is not None:
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
@@ -676,6 +766,9 @@ def _get_multitable_params(
(coercions.expect(roles.DMLColumnRole, c), param)
for c, param in stmt_parameters.items()
)
+
+ include_table = compile_state.include_table_with_column_exprs
+
affected_tables = set()
for t in compile_state._extra_froms:
for c in t.c:
@@ -683,6 +776,8 @@ def _get_multitable_params(
affected_tables.add(t)
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
+
+ col_value = compiler.process(c, include_table=include_table)
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -699,7 +794,7 @@ def _get_multitable_params(
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
- values.append((c, value))
+ values.append((c, col_value, value))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
@@ -711,6 +806,7 @@ def _get_multitable_params(
values.append(
(
c,
+ compiler.process(c, include_table=include_table),
compiler.process(
c.onupdate.arg.self_group(), **kw
),
@@ -721,8 +817,9 @@ def _get_multitable_params(
values.append(
(
c,
+ compiler.process(c, include_table=include_table),
_create_update_prefetch_bind_param(
- compiler, c, name=_col_bind_name(c)
+ compiler, c, name=_col_bind_name(c), **kw
),
)
)
@@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
for i, row in enumerate(compile_state._multi_parameters[1:]):
extension = []
- for (col, param) in values_0:
+ for (col, col_expr, param) in values_0:
if col in row or col.key in row:
key = col if col in row else col.key
@@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
compiler, stmt, col, i, kw
)
- extension.append((col, new_param))
+ extension.append((col, col_expr, new_param))
values.append(extension)
@@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
def _get_stmt_parameters_params(
- compiler, parameters, stmt_parameters, _column_as_key, values, kw
+ compiler,
+ compile_state,
+ parameters,
+ stmt_parameters,
+ _column_as_key,
+ values,
+ kw,
):
+
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
@@ -773,6 +877,11 @@ def _get_stmt_parameters_params(
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
+
+ col_expr = compiler.process(
+ k, include_table=compile_state.include_table_with_column_exprs
+ )
+
if coercions._is_literal(v):
v = compiler.process(
elements.BindParameter(None, v, type_=k.type), **kw
@@ -780,7 +889,7 @@ def _get_stmt_parameters_params(
else:
v = compiler.process(v.self_group(), **kw)
- values.append((k, v))
+ values.append((k, col_expr, v))
def _get_returning_modifiers(compiler, stmt, compile_state):
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index 21476c1f9..930865898 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -133,6 +133,8 @@ class DMLState(CompileState):
class InsertDMLState(DMLState):
isinsert = True
+ include_table_with_column_exprs = False
+
def __init__(self, statement, compiler, **kw):
self.statement = statement
@@ -149,6 +151,8 @@ class InsertDMLState(DMLState):
class UpdateDMLState(DMLState):
isupdate = True
+ include_table_with_column_exprs = False
+
def __init__(self, statement, compiler, **kw):
self.statement = statement
self.isupdate = True
@@ -159,7 +163,11 @@ class UpdateDMLState(DMLState):
self._process_values(statement)
elif statement._multi_values:
self._process_multi_values(statement)
- self._extra_froms = self._make_extra_froms(statement)
+ self._extra_froms = ef = self._make_extra_froms(statement)
+ self.is_multitable = mt = ef and self._dict_parameters
+ self.include_table_with_column_exprs = (
+ mt and compiler.render_table_with_column_in_update_from
+ )
@CompileState.plugin_for("default", "delete")
diff --git a/test/sql/test_update.py b/test/sql/test_update.py
index 18e9da654..5db5fed11 100644
--- a/test/sql/test_update.py
+++ b/test/sql/test_update.py
@@ -840,7 +840,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
dialect=mysql.dialect(),
)
- def test_update_to_expression(self):
+ def test_update_to_expression_one(self):
"""test update from an expression.
this logic is triggered currently by a left side that doesn't
@@ -856,6 +856,72 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
"UPDATE mytable SET foo(myid)=:param_1",
)
+ def test_update_to_expression_two(self):
+ """test update from an expression.
+
+ this logic is triggered currently by a left side that doesn't
+ have a key. The current supported use case is updating the index
+ of a PostgreSQL ARRAY type.
+
+ """
+
+ from sqlalchemy import ARRAY
+
+ t = table(
+ "foo",
+ column("data1", ARRAY(Integer)),
+ column("data2", ARRAY(Integer)),
+ )
+
+ stmt = t.update().values({t.c.data1[5]: 7, t.c.data2[10]: 18})
+ dialect = default.StrCompileDialect()
+ dialect.paramstyle = "qmark"
+ dialect.positional = True
+ self.assert_compile(
+ stmt,
+ "UPDATE foo SET data1[?]=?, data2[?]=?",
+ dialect=dialect,
+ checkpositional=(5, 7, 10, 18),
+ )
+
+ def test_update_to_expression_three(self):
+ # this test is from test_defaults but exercises a particular
+ # parameter ordering issue
+ metadata = MetaData()
+
+ q = Table(
+ "q",
+ metadata,
+ Column("x", Integer, default=2),
+ Column("y", Integer, onupdate=5),
+ Column("z", Integer),
+ )
+
+ p = Table(
+ "p",
+ metadata,
+ Column("s", Integer),
+ Column("t", Integer),
+ Column("u", Integer, onupdate=1),
+ )
+
+ cte = (
+ q.update().where(q.c.z == 1).values(x=7).returning(q.c.z).cte("c")
+ )
+ stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z)
+
+ dialect = default.StrCompileDialect()
+ dialect.paramstyle = "qmark"
+ dialect.positional = True
+
+ self.assert_compile(
+ stmt,
+ "WITH c AS (UPDATE q SET x=?, y=? WHERE q.z = ? RETURNING q.z) "
+ "SELECT p.s, c.z FROM p, c WHERE p.s = c.z",
+ checkpositional=(7, None, 1),
+ dialect=dialect,
+ )
+
def test_update_bound_ordering(self):
"""test that bound parameters between the UPDATE and FROM clauses
order correctly in different SQL compilation scenarios.