diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-08 13:03:17 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-08-08 13:34:27 -0400 |
commit | c0685e5f419e203acd5e46f25c90e851e30e6f03 (patch) | |
tree | 42b73a6205a8b03e2ba232c21e43d5affb8657d2 /lib/sqlalchemy/sql/crud.py | |
parent | 302e8dee82718df6c3a46de4c5283bdae51a650a (diff) | |
download | sqlalchemy-c0685e5f419e203acd5e46f25c90e851e30e6f03.tar.gz |
render INSERT/UPDATE column expressions up front; pass state
Fixes related to rendering of complex UPDATE DML
which was not correctly preserving positional parameter
order in conjunction with DML features that are only known
to work on the PostgreSQL database. Both pg8000
and asyncpg use positional parameters which is why these
issues are suddenly apparent.
crud.py now takes on the task of rendering the column
expressions for SET or VALUES so that for the very unusual
case that the column expression is a compound expression
that includes a bound parameter (namely an array index),
the bound parameter order is preserved.
Additionally, crud.py passes through the positional_names
keyword argument into bindparam_string() which is necessary
when CTEs are being rendered, as PG supports complex
CTE / INSERT / UPDATE scenarios.
Change-Id: I7f03920500e19b721636b84594de78a5bfdcbc82
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 193 |
1 files changed, 151 insertions, 42 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 85112f850..7d0616da7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: return [ - (c, _create_bind_param(compiler, c, None, required=True)) + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) for c in stmt.table.columns ] @@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ) check_columns = {} # special logic that only occurs for multi-table UPDATE # statements - if ( - compile_state.isupdate - and compile_state._extra_froms - and stmt_parameters - ): + if compile_state.isupdate and compile_state.is_multitable: _get_multitable_params( compiler, stmt, @@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # into INSERT (firstcol) VALUES (DEFAULT) which can be turned # into an in-place multi values. This supports # insert_executemany_returning mode :) - values = [(stmt.table.columns[0], "DEFAULT")] + values = [ + ( + stmt.table.columns[0], + compiler.preparer.format_column(stmt.table.columns[0]), + "DEFAULT", + ) + ] return values @@ -286,7 +298,7 @@ def _scan_insert_from_select_cols( col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: parameters.pop(col_key) - values.append((c, None)) + values.append((c, compiler.preparer.format_column(c), None)) else: _append_param_insert_select_hasdefault( compiler, stmt, c, add_select_cols, kw @@ -297,7 +309,7 @@ def _scan_insert_from_select_cols( compiler._insert_from_select = compiler._insert_from_select._generate() compiler._insert_from_select._raw_columns = tuple( compiler._insert_from_select._raw_columns - ) + tuple(expr for col, expr in add_select_cols) + ) + tuple(expr for col, col_expr, expr in add_select_cols) def _scan_cols( @@ -390,7 +402,13 @@ def _scan_cols( elif compile_state.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, ) @@ -410,6 +428,10 @@ def _append_param_parameter( value = parameters.pop(col_key) + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -446,7 +468,7 @@ def _append_param_parameter( if not c.primary_key: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): @@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process(c.default.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) ) compiler.returning.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): _warn_pk_with_no_anticipated_value(c) -def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_insert_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.insert_prefetch.append(c) return param -def _create_update_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_update_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.update_prefetch.append(c) return param @@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): else: col = _multiparam_column(c, index) if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col) + return _create_insert_prefetch_bind_param(compiler, col, **kw) else: - return _create_update_prefetch_bind_param(compiler, col) + return _create_update_prefetch_bind_param(compiler, col, **kw) def _append_param_insert_pk(compiler, stmt, c, values, kw): @@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): or compiler.dialect.preexecute_autoincrement_sequences ) ): - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -597,15 +648,25 @@ def _append_param_insert_hasdefault( if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: - proc = compiler.process(c.default.arg.self_group(), **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -613,7 +674,13 @@ def _append_param_insert_hasdefault( # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): @@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = c.default - values.append((c, proc.next_value())) + values.append( + (c, compiler.preparer.format_column(c), c.default.next_value()) + ) elif c.default.is_clause_element: - proc = c.default.arg.self_group() - values.append((c, proc)) + values.append( + (c, compiler.preparer.format_column(c), c.default.arg.self_group()) + ) else: values.append( - (c, _create_insert_prefetch_bind_param(compiler, c, process=False)) + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + ) ) def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw ): + include_table = compile_state.include_table_with_column_exprs if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process(c.onupdate.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + ) ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append((c, _create_update_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -676,6 +766,9 @@ def _get_multitable_params( (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) + + include_table = compile_state.include_table_with_column_exprs + affected_tables = set() for t in compile_state._extra_froms: for c in t.c: @@ -683,6 +776,8 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -699,7 +794,7 @@ def _get_multitable_params( else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: @@ -711,6 +806,7 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), compiler.process( c.onupdate.arg.self_group(), **kw ), @@ -721,8 +817,9 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c) + compiler, c, name=_col_bind_name(c), **kw ), ) ) @@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] - for (col, param) in values_0: + for (col, col_expr, param) in values_0: if col in row or col.key in row: key = col if col in row else col.key @@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): compiler, stmt, col, i, kw ) - extension.append((col, new_param)) + extension.append((col, col_expr, new_param)) values.append(extension) @@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ): + for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -773,6 +877,11 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw @@ -780,7 +889,7 @@ def _get_stmt_parameters_params( else: v = compiler.process(v.self_group(), **kw) - values.append((k, v)) + values.append((k, col_expr, v)) def _get_returning_modifiers(compiler, stmt, compile_state): |