diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 119 |
1 files changed, 101 insertions, 18 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index b13377a59..22fffb73a 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -150,6 +150,22 @@ def _get_crud_params( "return_defaults() simultaneously" ) + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if compiler.column_keys is None and compile_state._no_parameters: @@ -466,13 +482,6 @@ def _scan_insert_from_select_cols( kw, ): - ( - need_pks, - implicit_returning, - implicit_return_defaults, - postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) - cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] assert compiler.stack[-1]["selectable"] is stmt @@ -537,6 +546,8 @@ def _scan_cols( postfetch_lastrowid, ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) + assert compile_state.isupdate or compile_state.isinsert + if compile_state._parameter_ordering: parameter_ordering = [ _column_as_key(key) for key in compile_state._parameter_ordering @@ -563,6 +574,13 @@ def _scan_cols( else: autoincrement_col = insert_null_pk_still_autoincrements = None + if stmt._supplemental_returning: + supplemental_returning = set(stmt._supplemental_returning) + else: + supplemental_returning = set() + + compiler_implicit_returning = compiler.implicit_returning + for c in cols: # scan through every column in the target table @@ -627,11 +645,13 @@ def _scan_cols( # column has a DDL-level default, and is either not a pk # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: - compiler.implicit_returning.append(c) + compiler_implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) + elif implicit_return_defaults and c in implicit_return_defaults: - compiler.implicit_returning.append(c) + compiler_implicit_returning.append(c) + elif ( c.primary_key and c is not stmt.table._autoincrement_column @@ -652,6 +672,59 @@ def _scan_cols( kw, ) + # adding supplemental cols to implicit_returning in table + # order so that order is maintained between multiple INSERT + # statements which may have different parameters included, but all + # have the same RETURNING clause + if ( + c in supplemental_returning + and c not in compiler_implicit_returning + ): + compiler_implicit_returning.append(c) + + if supplemental_returning: + # we should have gotten every col into implicit_returning, + # however supplemental returning can also have SQL functions etc. + # in it + remaining_supplemental = supplemental_returning.difference( + compiler_implicit_returning + ) + compiler_implicit_returning.extend( + c + for c in stmt._supplemental_returning + if c in remaining_supplemental + ) + + +def _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + (_, _, implicit_return_defaults, _) = _get_returning_modifiers( + compiler, stmt, compile_state, toplevel + ) + + if not implicit_return_defaults: + return + + if stmt._return_defaults_columns: + compiler.implicit_returning.extend(implicit_return_defaults) + + if stmt._supplemental_returning: + ir_set = set(compiler.implicit_returning) + compiler.implicit_returning.extend( + c for c in stmt._supplemental_returning if c not in ir_set + ) + def _append_param_parameter( compiler, @@ -743,7 +816,7 @@ def _append_param_parameter( elif compiler.dialect.postfetch_lastrowid: compiler.postfetch_lastrowid = True - elif implicit_return_defaults and c in implicit_return_defaults: + elif implicit_return_defaults and (c in implicit_return_defaults): compiler.implicit_returning.append(c) else: @@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): INSERT or UPDATE statement after it's invoked. """ + need_pks = ( toplevel and _compile_state_isinsert(compile_state) @@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): ) ) and not stmt._returning + # and (not stmt._returning or stmt._return_defaults) and not compile_state._has_multi_parameters ) @@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): or stmt._return_defaults ) ) - if implicit_returning: postfetch_lastrowid = False if _compile_state_isinsert(compile_state): - implicit_return_defaults = implicit_returning and stmt._return_defaults + should_implicit_return_defaults = ( + implicit_returning and stmt._return_defaults + ) elif compile_state.isupdate: - implicit_return_defaults = ( + should_implicit_return_defaults = ( stmt._return_defaults and compile_state._primary_table.implicit_returning and compile_state._supports_implicit_returning and compiler.dialect.update_returning ) + elif compile_state.isdelete: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and compiler.dialect.delete_returning + ) else: - # this line is unused, currently we are always - # isinsert or isupdate - implicit_return_defaults = False # pragma: no cover + should_implicit_return_defaults = False # pragma: no cover - if implicit_return_defaults: + if should_implicit_return_defaults: if not stmt._return_defaults_columns: implicit_return_defaults = set(stmt.table.c) else: implicit_return_defaults = set(stmt._return_defaults_columns) + else: + implicit_return_defaults = None return ( need_pks, - implicit_returning, + implicit_returning or should_implicit_return_defaults, implicit_return_defaults, postfetch_lastrowid, ) |