summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py119
1 files changed, 101 insertions, 18 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index b13377a59..22fffb73a 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -150,6 +150,22 @@ def _get_crud_params(
"return_defaults() simultaneously"
)
+ if compile_state.isdelete:
+ _setup_delete_return_defaults(
+ compiler,
+ stmt,
+ compile_state,
+ (),
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ (),
+ (),
+ toplevel,
+ kw,
+ )
+ return _CrudParams([], [])
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
@@ -466,13 +482,6 @@ def _scan_insert_from_select_cols(
kw,
):
- (
- need_pks,
- implicit_returning,
- implicit_return_defaults,
- postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
-
cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
assert compiler.stack[-1]["selectable"] is stmt
@@ -537,6 +546,8 @@ def _scan_cols(
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
+ assert compile_state.isupdate or compile_state.isinsert
+
if compile_state._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in compile_state._parameter_ordering
@@ -563,6 +574,13 @@ def _scan_cols(
else:
autoincrement_col = insert_null_pk_still_autoincrements = None
+ if stmt._supplemental_returning:
+ supplemental_returning = set(stmt._supplemental_returning)
+ else:
+ supplemental_returning = set()
+
+ compiler_implicit_returning = compiler.implicit_returning
+
for c in cols:
# scan through every column in the target table
@@ -627,11 +645,13 @@ def _scan_cols(
# column has a DDL-level default, and is either not a pk
# column or we don't need the pk.
if implicit_return_defaults and c in implicit_return_defaults:
- compiler.implicit_returning.append(c)
+ compiler_implicit_returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
+
elif implicit_return_defaults and c in implicit_return_defaults:
- compiler.implicit_returning.append(c)
+ compiler_implicit_returning.append(c)
+
elif (
c.primary_key
and c is not stmt.table._autoincrement_column
@@ -652,6 +672,59 @@ def _scan_cols(
kw,
)
+ # adding supplemental cols to implicit_returning in table
+ # order so that order is maintained between multiple INSERT
+ # statements which may have different parameters included, but all
+ # have the same RETURNING clause
+ if (
+ c in supplemental_returning
+ and c not in compiler_implicit_returning
+ ):
+ compiler_implicit_returning.append(c)
+
+ if supplemental_returning:
+ # we should have gotten every col into implicit_returning,
+ # however supplemental returning can also have SQL functions etc.
+ # in it
+ remaining_supplemental = supplemental_returning.difference(
+ compiler_implicit_returning
+ )
+ compiler_implicit_returning.extend(
+ c
+ for c in stmt._supplemental_returning
+ if c in remaining_supplemental
+ )
+
+
+def _setup_delete_return_defaults(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ toplevel,
+ kw,
+):
+ (_, _, implicit_return_defaults, _) = _get_returning_modifiers(
+ compiler, stmt, compile_state, toplevel
+ )
+
+ if not implicit_return_defaults:
+ return
+
+ if stmt._return_defaults_columns:
+ compiler.implicit_returning.extend(implicit_return_defaults)
+
+ if stmt._supplemental_returning:
+ ir_set = set(compiler.implicit_returning)
+ compiler.implicit_returning.extend(
+ c for c in stmt._supplemental_returning if c not in ir_set
+ )
+
def _append_param_parameter(
compiler,
@@ -743,7 +816,7 @@ def _append_param_parameter(
elif compiler.dialect.postfetch_lastrowid:
compiler.postfetch_lastrowid = True
- elif implicit_return_defaults and c in implicit_return_defaults:
+ elif implicit_return_defaults and (c in implicit_return_defaults):
compiler.implicit_returning.append(c)
else:
@@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
INSERT or UPDATE statement after it's invoked.
"""
+
need_pks = (
toplevel
and _compile_state_isinsert(compile_state)
@@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
)
)
and not stmt._returning
+ # and (not stmt._returning or stmt._return_defaults)
and not compile_state._has_multi_parameters
)
@@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
or stmt._return_defaults
)
)
-
if implicit_returning:
postfetch_lastrowid = False
if _compile_state_isinsert(compile_state):
- implicit_return_defaults = implicit_returning and stmt._return_defaults
+ should_implicit_return_defaults = (
+ implicit_returning and stmt._return_defaults
+ )
elif compile_state.isupdate:
- implicit_return_defaults = (
+ should_implicit_return_defaults = (
stmt._return_defaults
and compile_state._primary_table.implicit_returning
and compile_state._supports_implicit_returning
and compiler.dialect.update_returning
)
+ elif compile_state.isdelete:
+ should_implicit_return_defaults = (
+ stmt._return_defaults
+ and compile_state._primary_table.implicit_returning
+ and compile_state._supports_implicit_returning
+ and compiler.dialect.delete_returning
+ )
else:
- # this line is unused, currently we are always
- # isinsert or isupdate
- implicit_return_defaults = False # pragma: no cover
+ should_implicit_return_defaults = False # pragma: no cover
- if implicit_return_defaults:
+ if should_implicit_return_defaults:
if not stmt._return_defaults_columns:
implicit_return_defaults = set(stmt.table.c)
else:
implicit_return_defaults = set(stmt._return_defaults_columns)
+ else:
+ implicit_return_defaults = None
return (
need_pks,
- implicit_returning,
+ implicit_returning or should_implicit_return_defaults,
implicit_return_defaults,
postfetch_lastrowid,
)