diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 43 |
1 files changed, 41 insertions, 2 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index d3329c391..6a5785043 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -528,6 +528,17 @@ def _scan_cols( else: cols = stmt.table.columns + if compile_state.isinsert and not compile_state._has_multi_parameters: + # new rules for #7998. fetch lastrowid or implicit returning + # for autoincrement column even if parameter is NULL, for DBs that + # override NULL param for primary key (sqlite, mysql/mariadb) + autoincrement_col = stmt.table._autoincrement_column + insert_null_pk_still_autoincrements = ( + compiler.dialect.insert_null_pk_still_autoincrements + ) + else: + autoincrement_col = insert_null_pk_still_autoincrements = None + for c in cols: # scan through every column in the target table @@ -547,6 +558,8 @@ def _scan_cols( implicit_returning, implicit_return_defaults, values, + autoincrement_col, + insert_null_pk_still_autoincrements, kw, ) @@ -626,6 +639,8 @@ def _append_param_parameter( implicit_returning, implicit_return_defaults, values, + autoincrement_col, + insert_null_pk_still_autoincrements, kw, ): value = parameters.pop(col_key) @@ -635,6 +650,19 @@ def _append_param_parameter( ) if coercions._is_literal(value): + + if ( + insert_null_pk_still_autoincrements + and c.primary_key + and c is autoincrement_col + ): + # support use case for #7998, fetch autoincrement cols + # even if value was given + if implicit_returning: + compiler.implicit_returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + value = _create_bind_param( compiler, c, @@ -646,6 +674,19 @@ def _append_param_parameter( **kw, ) elif value._is_bind_parameter: + if ( + insert_null_pk_still_autoincrements + and value.value is None + and c.primary_key + and c is autoincrement_col + ): + # support use case for #7998, fetch autoincrement cols + # even if value was given + if implicit_returning: + compiler.implicit_returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + value = _handle_values_anonymous_param( compiler, c, @@ -1244,7 +1285,6 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): and not stmt._returning and not compile_state._has_multi_parameters ) - implicit_returning = ( need_pks and compiler.dialect.implicit_returning @@ -1271,7 +1311,6 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): implicit_return_defaults = set(stmt._return_defaults_columns) postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid - return ( need_pks, implicit_returning, |