diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 193 |
1 files changed, 151 insertions, 42 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 85112f850..7d0616da7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -65,7 +65,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: return [ - (c, _create_bind_param(compiler, c, None, required=True)) + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) for c in stmt.table.columns ] @@ -90,18 +94,20 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ) check_columns = {} # special logic that only occurs for multi-table UPDATE # statements - if ( - compile_state.isupdate - and compile_state._extra_froms - and stmt_parameters - ): + if compile_state.isupdate and compile_state.is_multitable: _get_multitable_params( compiler, stmt, @@ -162,7 +168,13 @@ def _get_crud_params(compiler, stmt, compile_state, **kw): # into INSERT (firstcol) VALUES (DEFAULT) which can be turned # into an in-place multi values. This supports # insert_executemany_returning mode :) - values = [(stmt.table.columns[0], "DEFAULT")] + values = [ + ( + stmt.table.columns[0], + compiler.preparer.format_column(stmt.table.columns[0]), + "DEFAULT", + ) + ] return values @@ -286,7 +298,7 @@ def _scan_insert_from_select_cols( col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: parameters.pop(col_key) - values.append((c, None)) + values.append((c, compiler.preparer.format_column(c), None)) else: _append_param_insert_select_hasdefault( compiler, stmt, c, add_select_cols, kw @@ -297,7 +309,7 @@ def _scan_insert_from_select_cols( compiler._insert_from_select = compiler._insert_from_select._generate() compiler._insert_from_select._raw_columns = tuple( compiler._insert_from_select._raw_columns - ) + tuple(expr for col, expr in add_select_cols) + ) + tuple(expr for col, col_expr, expr in add_select_cols) def _scan_cols( @@ -390,7 +402,13 @@ def _scan_cols( elif compile_state.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, ) @@ -410,6 +428,10 @@ def _append_param_parameter( value = parameters.pop(col_key) + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -446,7 +468,7 @@ def _append_param_parameter( if not c.primary_key: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): @@ -472,16 +494,31 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process(c.default.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) ) compiler.returning.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -490,14 +527,22 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): _warn_pk_with_no_anticipated_value(c) -def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_insert_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.insert_prefetch.append(c) return param -def _create_update_prefetch_bind_param(compiler, c, process=True, name=None): - param = _create_bind_param(compiler, c, None, process=process, name=name) +def _create_update_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) compiler.update_prefetch.append(c) return param @@ -539,9 +584,9 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): else: col = _multiparam_column(c, index) if isinstance(stmt, dml.Insert): - return _create_insert_prefetch_bind_param(compiler, col) + return _create_insert_prefetch_bind_param(compiler, col, **kw) else: - return _create_update_prefetch_bind_param(compiler, col) + return _create_update_prefetch_bind_param(compiler, col, **kw) def _append_param_insert_pk(compiler, stmt, c, values, kw): @@ -582,7 +627,13 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): or compiler.dialect.preexecute_autoincrement_sequences ) ): - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -597,15 +648,25 @@ def _append_param_insert_hasdefault( if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = compiler.process(c.default, **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: - proc = compiler.process(c.default.arg.self_group(), **kw) - values.append((c, proc)) + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) + ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -613,7 +674,13 @@ def _append_param_insert_hasdefault( # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append((c, _create_insert_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): @@ -622,32 +689,55 @@ def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): - proc = c.default - values.append((c, proc.next_value())) + values.append( + (c, compiler.preparer.format_column(c), c.default.next_value()) + ) elif c.default.is_clause_element: - proc = c.default.arg.self_group() - values.append((c, proc)) + values.append( + (c, compiler.preparer.format_column(c), c.default.arg.self_group()) + ) else: values.append( - (c, _create_insert_prefetch_bind_param(compiler, c, process=False)) + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + ) ) def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw ): + include_table = compile_state.include_table_with_column_exprs if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process(c.onupdate.arg.self_group(), **kw)) + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + ) ) if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append((c, _create_update_prefetch_bind_param(compiler, c))) + values.append( + ( + c, + compiler.preparer.format_column( + c, use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + ) + ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) @@ -676,6 +766,9 @@ def _get_multitable_params( (coercions.expect(roles.DMLColumnRole, c), param) for c, param in stmt_parameters.items() ) + + include_table = compile_state.include_table_with_column_exprs + affected_tables = set() for t in compile_state._extra_froms: for c in t.c: @@ -683,6 +776,8 @@ def _get_multitable_params( affected_tables.add(t) check_columns[_getattr_col_key(c)] = c value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -699,7 +794,7 @@ def _get_multitable_params( else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, value)) + values.append((c, col_value, value)) # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: @@ -711,6 +806,7 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), compiler.process( c.onupdate.arg.self_group(), **kw ), @@ -721,8 +817,9 @@ def _get_multitable_params( values.append( ( c, + compiler.process(c, include_table=include_table), _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c) + compiler, c, name=_col_bind_name(c), **kw ), ) ) @@ -736,7 +833,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] - for (col, param) in values_0: + for (col, col_expr, param) in values_0: if col in row or col.key in row: key = col if col in row else col.key @@ -755,7 +852,7 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): compiler, stmt, col, i, kw ) - extension.append((col, new_param)) + extension.append((col, col_expr, new_param)) values.append(extension) @@ -763,8 +860,15 @@ def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw + compiler, + compile_state, + parameters, + stmt_parameters, + _column_as_key, + values, + kw, ): + for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -773,6 +877,11 @@ def _get_stmt_parameters_params( # a non-Column expression on the left side; # add it to values() in an "as-is" state, # coercing right side to bound param + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + if coercions._is_literal(v): v = compiler.process( elements.BindParameter(None, v, type_=k.type), **kw @@ -780,7 +889,7 @@ def _get_stmt_parameters_params( else: v = compiler.process(v.self_group(), **kw) - values.append((k, v)) + values.append((k, col_expr, v)) def _get_returning_modifiers(compiler, stmt, compile_state): |