diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 181 |
1 files changed, 103 insertions, 78 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e474952ce..2827a5817 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -16,6 +16,7 @@ from . import coercions from . import dml from . import elements from . import roles +from .elements import ClauseElement from .. import exc from .. import util @@ -33,45 +34,8 @@ values present. """, ) -ISINSERT = util.symbol("ISINSERT") -ISUPDATE = util.symbol("ISUPDATE") -ISDELETE = util.symbol("ISDELETE") - -def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): - restore_isinsert = compiler.isinsert - restore_isupdate = compiler.isupdate - restore_isdelete = compiler.isdelete - - should_restore = ( - (restore_isinsert or restore_isupdate or restore_isdelete) - or len(compiler.stack) > 1 - or "visiting_cte" in kw - ) - - if local_stmt_type is ISINSERT: - compiler.isupdate = False - compiler.isinsert = True - elif local_stmt_type is ISUPDATE: - compiler.isupdate = True - compiler.isinsert = False - elif local_stmt_type is ISDELETE: - if not should_restore: - compiler.isdelete = True - else: - assert False, "ISINSERT, ISUPDATE, or ISDELETE expected" - - try: - if local_stmt_type in (ISINSERT, ISUPDATE): - return _get_crud_params(compiler, stmt, **kw) - finally: - if should_restore: - compiler.isinsert = restore_isinsert - compiler.isupdate = restore_isupdate - compiler.isdelete = restore_isdelete - - -def _get_crud_params(compiler, stmt, **kw): +def _get_crud_params(compiler, stmt, compile_state, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -87,27 +51,29 @@ def _get_crud_params(compiler, stmt, **kw): compiler.update_prefetch = [] compiler.returning = [] + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._key_getters_for_crud_column = getters + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns - if compiler.column_keys is None and stmt.parameters is None: + if compiler.column_keys is None and compile_state._no_parameters: return [ (c, _create_bind_param(compiler, c, None, required=True)) for c in stmt.table.columns ] - if stmt._has_multi_parameters: - stmt_parameters = stmt.parameters[0] + if compile_state._has_multi_parameters: + stmt_parameters = compile_state._multi_parameters[0] else: - stmt_parameters = stmt.parameters - - # getters - these are normally just column.key, - # but in the case of mysql multi-table update, the rules for - # .key must conditionally take tablename into account - ( - _column_as_key, - _getattr_col_key, - _col_bind_name, - ) = _key_getters_for_crud_column(compiler, stmt) + stmt_parameters = compile_state._dict_parameters # if we have statement parameters - set defaults in the # compiled params @@ -132,10 +98,15 @@ def _get_crud_params(compiler, stmt, **kw): # special logic that only occurs for multi-table UPDATE # statements - if compiler.isupdate and stmt._extra_froms and stmt_parameters: + if ( + compile_state.isupdate + and compile_state._extra_froms + and stmt_parameters + ): _get_multitable_params( compiler, stmt, + compile_state, stmt_parameters, check_columns, _col_bind_name, @@ -144,10 +115,11 @@ def _get_crud_params(compiler, stmt, **kw): kw, ) - if compiler.isinsert and stmt.select_names: + if compile_state.isinsert and stmt._select_names: _scan_insert_from_select_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -160,6 +132,7 @@ def _get_crud_params(compiler, stmt, **kw): _scan_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -181,8 +154,10 @@ def _get_crud_params(compiler, stmt, **kw): % (", ".join("%s" % (c,) for c in check)) ) - if stmt._has_multi_parameters: - values = _extend_values_for_multiparams(compiler, stmt, values, kw) + if compile_state._has_multi_parameters: + values = _extend_values_for_multiparams( + compiler, stmt, compile_state, values, kw + ) return values @@ -201,15 +176,46 @@ def _create_bind_param( return bindparam -def _key_getters_for_crud_column(compiler, stmt): - if compiler.isupdate and stmt._extra_froms: +def _handle_values_anonymous_param(compiler, col, value, name, **kw): + # the insert() and update() constructs as of 1.4 will now produce anonymous + # bindparam() objects in the values() collections up front when given plain + # literal values. This is so that cache key behaviors, which need to + # produce bound parameters in deterministic order without invoking any + # compilation here, can be applied to these constructs when they include + # values() (but not yet multi-values, which are not included in caching + # right now). + # + # in order to produce the desired "crud" style name for these parameters, + # which will also be targetable in engine/default.py through the usual + # conventions, apply our desired name to these unique parameters by + # populating the compiler truncated names cache with the desired name, + # rather than having + # compiler.visit_bindparam()->compiler._truncated_identifier make up a + # name. Saves on call counts also. + if value.unique and isinstance(value.key, elements._truncated_label): + compiler.truncated_names[("bindparam", value.key)] = name + + if value.type._isnull: + # either unique parameter, or other bound parameters that were + # passed in directly + # clone using base ClauseElement to retain unique key + value = ClauseElement._clone(value) + + # set type to that of the column unconditionally + value.type = col.type + + return value._compiler_dispatch(compiler, **kw) + + +def _key_getters_for_crud_column(compiler, stmt, compile_state): + if compile_state.isupdate and compile_state._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. # the "main" table of the statement remains unqualified, # allowing the most compatibility with a non-multi-table # statement. - _et = set(stmt._extra_froms) + _et = set(compile_state._extra_froms) c_key_role = functools.partial( coercions.expect_as_key, roles.DMLColumnRole @@ -246,6 +252,7 @@ def _key_getters_for_crud_column(compiler, stmt): def _scan_insert_from_select_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -260,9 +267,9 @@ def _scan_insert_from_select_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt) + ) = _get_returning_modifiers(compiler, stmt, compile_state) - cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names] + cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] compiler._insert_from_select = stmt.select @@ -294,6 +301,7 @@ def _scan_insert_from_select_cols( def _scan_cols( compiler, stmt, + compile_state, parameters, _getattr_col_key, _column_as_key, @@ -308,11 +316,11 @@ def _scan_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt) + ) = _get_returning_modifiers(compiler, stmt, compile_state) - if stmt._parameter_ordering: + if compile_state._parameter_ordering: parameter_ordering = [ - _column_as_key(key) for key in stmt._parameter_ordering + _column_as_key(key) for key in compile_state._parameter_ordering ] ordered_keys = set(parameter_ordering) cols = [stmt.table.c[key] for key in parameter_ordering] + [ @@ -329,6 +337,7 @@ def _scan_cols( _append_param_parameter( compiler, stmt, + compile_state, c, col_key, parameters, @@ -339,7 +348,7 @@ def _scan_cols( kw, ) - elif compiler.isinsert: + elif compile_state.isinsert: if ( c.primary_key and need_pks @@ -377,7 +386,7 @@ def _scan_cols( ): _warn_pk_with_no_anticipated_value(c) - elif compiler.isupdate: + elif compile_state.isupdate: _append_param_update( compiler, stmt, c, implicit_return_defaults, values, kw ) @@ -386,6 +395,7 @@ def _scan_cols( def _append_param_parameter( compiler, stmt, + compile_state, c, col_key, parameters, @@ -395,7 +405,9 @@ def _append_param_parameter( values, kw, ): + value = parameters.pop(col_key) + if coercions._is_literal(value): value = _create_bind_param( compiler, @@ -403,15 +415,21 @@ def _append_param_parameter( value, required=value is REQUIRED, name=_col_bind_name(c) - if not stmt._has_multi_parameters + if not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c), + **kw + ) + elif value._is_bind_parameter: + value = _handle_values_anonymous_param( + compiler, + c, + value, + name=_col_bind_name(c) + if not compile_state._has_multi_parameters else "%s_m0" % _col_bind_name(c), **kw ) else: - if isinstance(value, elements.BindParameter) and value.type._isnull: - value = value._clone() - value.type = c.type - if c.primary_key and implicit_returning: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) @@ -644,6 +662,7 @@ def _append_param_update( def _get_multitable_params( compiler, stmt, + compile_state, stmt_parameters, check_columns, _col_bind_name, @@ -656,7 +675,7 @@ def _get_multitable_params( for c, param in stmt_parameters.items() ) affected_tables = set() - for t in stmt._extra_froms: + for t in compile_state._extra_froms: for c in t.c: if c in normalized_params: affected_tables.add(t) @@ -669,6 +688,11 @@ def _get_multitable_params( value, required=value is REQUIRED, name=_col_bind_name(c), + **kw # TODO: no test coverage for literal binds here + ) + elif value._is_bind_parameter: + value = _handle_values_anonymous_param( + compiler, c, value, name=_col_bind_name(c), **kw ) else: compiler.postfetch.append(c) @@ -704,11 +728,11 @@ def _get_multitable_params( compiler.postfetch.append(c) -def _extend_values_for_multiparams(compiler, stmt, values, kw): +def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw): values_0 = values values = [values] - for i, row in enumerate(stmt.parameters[1:]): + for i, row in enumerate(compile_state._multi_parameters[1:]): extension = [] for (col, param) in values_0: if col in row or col.key in row: @@ -757,12 +781,13 @@ def _get_stmt_parameters_params( values.append((k, v)) -def _get_returning_modifiers(compiler, stmt): +def _get_returning_modifiers(compiler, stmt, compile_state): + need_pks = ( - compiler.isinsert + compile_state.isinsert and not compiler.inline and not stmt._returning - and not stmt._has_multi_parameters + and not compile_state._has_multi_parameters ) implicit_returning = ( @@ -771,9 +796,9 @@ def _get_returning_modifiers(compiler, stmt): and stmt.table.implicit_returning ) - if compiler.isinsert: + if compile_state.isinsert: implicit_return_defaults = implicit_returning and stmt._return_defaults - elif compiler.isupdate: + elif compile_state.isupdate: implicit_return_defaults = ( compiler.dialect.implicit_returning and stmt.table.implicit_returning |