diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 65 |
1 files changed, 47 insertions, 18 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 91a3f70c9..f6db2c4b2 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -87,6 +87,7 @@ def _get_crud_params( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, + toplevel: bool, **kw: Any, ) -> _CrudParams: """create a set of tuples representing column/string pairs for use @@ -99,10 +100,33 @@ def _get_crud_params( """ + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + compiler.postfetch = [] compiler.insert_prefetch = [] compiler.update_prefetch = [] - compiler.returning = [] + compiler.implicit_returning = [] # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for @@ -213,6 +237,7 @@ def _get_crud_params( _col_bind_name, check_columns, values, + toplevel, kw, ) else: @@ -226,6 +251,7 @@ def _get_crud_params( _col_bind_name, check_columns, values, + toplevel, kw, ) @@ -419,6 +445,7 @@ def _scan_insert_from_select_cols( _col_bind_name, check_columns, values, + toplevel, kw, ): @@ -427,7 +454,7 @@ def _scan_insert_from_select_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state) + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] @@ -472,6 +499,7 @@ def _scan_cols( _col_bind_name, check_columns, values, + toplevel, kw, ): ( @@ -479,7 +507,7 @@ def _scan_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, - ) = _get_returning_modifiers(compiler, stmt, compile_state) + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) if compile_state._parameter_ordering: parameter_ordering = [ @@ -556,11 +584,11 @@ def _scan_cols( # column has a DDL-level default, and is either not a pk # column or we don't need the pk. if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif ( c.primary_key and c is not stmt.table._autoincrement_column @@ -628,7 +656,7 @@ def _append_param_parameter( if compile_state.isupdate: if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) @@ -636,12 +664,12 @@ def _append_param_parameter( if c.primary_key: if implicit_returning: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif compiler.dialect.postfetch_lastrowid: compiler.postfetch_lastrowid = True elif implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: # postfetch specifically means, "we can SELECT the row we just @@ -674,7 +702,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.process(c.default, **kw), ) ) - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif c.default.is_clause_element: values.append( ( @@ -683,7 +711,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): compiler.process(c.default.arg.self_group(), **kw), ) ) - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: # client side default. OK we can't use RETURNING, need to # do a "prefetch", which in fact fetches the default value @@ -696,7 +724,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): ) ) elif c is stmt.table._autoincrement_column or c.server_default is not None: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -794,7 +822,7 @@ def _append_param_insert_hasdefault( ) ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: @@ -807,7 +835,7 @@ def _append_param_insert_hasdefault( ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch compiler.postfetch.append(c) @@ -870,7 +898,7 @@ def _append_param_update( ) ) if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) else: @@ -886,7 +914,7 @@ def _append_param_update( ) elif c.server_onupdate is not None: if implicit_return_defaults and c in implicit_return_defaults: - compiler.returning.append(c) + compiler.implicit_returning.append(c) else: compiler.postfetch.append(c) elif ( @@ -894,7 +922,7 @@ def _append_param_update( and (stmt._return_defaults_columns or not stmt._return_defaults) and c in implicit_return_defaults ): - compiler.returning.append(c) + compiler.implicit_returning.append(c) @overload @@ -1195,10 +1223,11 @@ def _get_stmt_parameter_tuples_params( values.append((k, col_expr, v)) -def _get_returning_modifiers(compiler, stmt, compile_state): +def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): need_pks = ( - compile_state.isinsert + toplevel + and compile_state.isinsert and not stmt._inline and ( not compiler.for_executemany |