summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py66
1 files changed, 56 insertions, 10 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 913e4d433..81151a26b 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -568,6 +568,7 @@ def _scan_cols(
_col_bind_name,
implicit_returning,
implicit_return_defaults,
+ postfetch_lastrowid,
values,
autoincrement_col,
insert_null_pk_still_autoincrements,
@@ -649,6 +650,7 @@ def _append_param_parameter(
_col_bind_name,
implicit_returning,
implicit_return_defaults,
+ postfetch_lastrowid,
values,
autoincrement_col,
insert_null_pk_still_autoincrements,
@@ -668,11 +670,12 @@ def _append_param_parameter(
and c is autoincrement_col
):
# support use case for #7998, fetch autoincrement cols
- # even if value was given
- if implicit_returning:
- compiler.implicit_returning.append(c)
- elif compiler.dialect.postfetch_lastrowid:
+ # even if value was given.
+
+ if postfetch_lastrowid:
compiler.postfetch_lastrowid = True
+ elif implicit_returning:
+ compiler.implicit_returning.append(c)
value = _create_bind_param(
compiler,
@@ -1281,7 +1284,12 @@ def _get_stmt_parameter_tuples_params(
def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
+ """determines RETURNING strategy, if any, for the statement.
+
+ This is where it's determined what we need to fetch from the
+ INSERT or UPDATE statement after it's invoked.
+ """
need_pks = (
toplevel
and compile_state.isinsert
@@ -1296,19 +1304,58 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
and not stmt._returning
and not compile_state._has_multi_parameters
)
+
+ # check if we have access to simple cursor.lastrowid. we can use that
+ # after the INSERT if that's all we need.
+ postfetch_lastrowid = (
+ need_pks
+ and compiler.dialect.postfetch_lastrowid
+ and stmt.table._autoincrement_column is not None
+ )
+
+ # see if we want to add RETURNING to an INSERT in order to get
+ # primary key columns back. This would be instead of postfetch_lastrowid
+ # if that's set.
implicit_returning = (
+ # statement itself can veto it
need_pks
- and compiler.dialect.implicit_returning
- and stmt.table.implicit_returning
+ # the dialect can veto it if it just doesnt support RETURNING
+ # with INSERT
+ and compiler.dialect.insert_returning
+ # user-defined implicit_returning on Table can veto it
+ and compile_state._primary_table.implicit_returning
+ # the compile_state can veto it (SQlite uses this to disable
+ # RETURNING for an ON CONFLICT insert, as SQLite does not return
+ # for rows that were updated, which is wrong)
+ and compile_state._supports_implicit_returning
+ and (
+ # since we support MariaDB and SQLite which also support lastrowid,
+ # decide if we should use lastrowid or RETURNING. for insert
+ # that didnt call return_defaults() and has just one set of
+ # parameters, we can use lastrowid. this is more "traditional"
+ # and a lot of weird use cases are supported by it.
+ # SQLite lastrowid times 3x faster than returning,
+ # Mariadb lastrowid 2x faster than returning
+ (
+ not postfetch_lastrowid
+ or compiler.dialect.favor_returning_over_lastrowid
+ )
+ or compile_state._has_multi_parameters
+ or stmt._return_defaults
+ )
)
+ if implicit_returning:
+ postfetch_lastrowid = False
+
if compile_state.isinsert:
implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compile_state.isupdate:
implicit_return_defaults = (
- compiler.dialect.implicit_returning
- and stmt.table.implicit_returning
- and stmt._return_defaults
+ stmt._return_defaults
+ and compile_state._primary_table.implicit_returning
+ and compile_state._supports_implicit_returning
+ and compiler.dialect.update_returning
)
else:
# this line is unused, currently we are always
@@ -1321,7 +1368,6 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
else:
implicit_return_defaults = set(stmt._return_defaults_columns)
- postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
return (
need_pks,
implicit_returning,