diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 129 |
1 files changed, 104 insertions, 25 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 04b62d1ff..563f61c04 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -34,6 +34,7 @@ from . import coercions from . import dml from . import elements from . import roles +from .base import _DefaultDescriptionTuple from .dml import isinsert as _compile_state_isinsert from .elements import ColumnClause from .schema import default_is_clause_element @@ -53,6 +54,7 @@ if TYPE_CHECKING: from .elements import ColumnElement from .elements import KeyedColumnElement from .schema import _SQLExprDefault + from .schema import Column REQUIRED = util.symbol( "REQUIRED", @@ -79,20 +81,22 @@ def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: _CrudParamElement = Tuple[ "ColumnElement[Any]", - str, - Optional[Union[str, "_SQLExprDefault"]], + str, # column name + Optional[ + Union[str, "_SQLExprDefault"] + ], # bound parameter string or SQL expression to apply Iterable[str], ] _CrudParamElementStr = Tuple[ "KeyedColumnElement[Any]", str, # column name - str, # placeholder + str, # bound parameter string Iterable[str], ] _CrudParamElementSQLExpr = Tuple[ "ColumnClause[Any]", str, - "_SQLExprDefault", + "_SQLExprDefault", # SQL expression to apply Iterable[str], ] @@ -101,8 +105,10 @@ _CrudParamSequence = List[_CrudParamElement] class _CrudParams(NamedTuple): single_params: _CrudParamSequence - all_multi_params: List[Sequence[_CrudParamElementStr]] + is_default_metavalue_only: bool = False + use_insertmanyvalues: bool = False + use_sentinel_columns: Optional[Sequence[Column[Any]]] = None def _get_crud_params( @@ -206,6 +212,7 @@ def _get_crud_params( (c.key,), ) for c in stmt.table.columns + if not c._omit_from_statements ], [], ) @@ -301,8 +308,10 @@ def _get_crud_params( toplevel, kw, ) + use_insertmanyvalues = False + use_sentinel_columns = None else: - _scan_cols( + use_insertmanyvalues, use_sentinel_columns = _scan_cols( compiler, stmt, compile_state, @@ -328,6 +337,8 @@ def _get_crud_params( % (", ".join("%s" % (c,) for c in check)) ) + is_default_metavalue_only = False + if ( _compile_state_isinsert(compile_state) and compile_state._has_multi_parameters @@ -363,8 +374,15 @@ def _get_crud_params( (), ) ] - - return _CrudParams(values, []) + is_default_metavalue_only = True + + return _CrudParams( + values, + [], + is_default_metavalue_only=is_default_metavalue_only, + use_insertmanyvalues=use_insertmanyvalues, + use_sentinel_columns=use_sentinel_columns, + ) @overload @@ -527,7 +545,19 @@ def _scan_insert_from_select_cols( if stmt.include_insert_from_select_defaults: col_set = set(cols) for col in stmt.table.columns: - if col not in col_set and col.default: + # omit columns that were not in the SELECT statement. + # this will omit columns marked as omit_from_statements naturally, + # as long as that col was not explicit in the SELECT. + # if an omit_from_statements col has a "default" on it, then + # we need to include it, as these defaults should still fire off. + # but, if it has that default and it's the "sentinel" default, + # we don't do sentinel default operations for insert_from_select + # here so we again omit it. + if ( + col not in col_set + and col.default + and not col.default.is_sentinel + ): cols.append(col) for c in cols: @@ -579,6 +609,8 @@ def _scan_cols( implicit_returning, implicit_return_defaults, postfetch_lastrowid, + use_insertmanyvalues, + use_sentinel_columns, ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) assert compile_state.isupdate or compile_state.isinsert @@ -672,9 +704,12 @@ def _scan_cols( elif c.default is not None: # column has a default, but it's not a pk column, or it is but # we don't need to get the pk back. - _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, values, kw - ) + if not c.default.is_sentinel or ( + use_sentinel_columns is not None + ): + _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw + ) elif c.server_default is not None: # column has a DDL-level default, and is either not a pk @@ -730,6 +765,8 @@ def _scan_cols( if c in remaining_supplemental ) + return (use_insertmanyvalues, use_sentinel_columns) + def _setup_delete_return_defaults( compiler, @@ -744,7 +781,7 @@ def _setup_delete_return_defaults( toplevel, kw, ): - (_, _, implicit_return_defaults, _) = _get_returning_modifiers( + (_, _, implicit_return_defaults, *_) = _get_returning_modifiers( compiler, stmt, compile_state, toplevel ) @@ -1248,6 +1285,18 @@ class _multiparam_column(elements.ColumnElement[Any]): and other.original == self.original ) + @util.memoized_property + def _default_description_tuple(self) -> _DefaultDescriptionTuple: + """used by default.py -> _process_execute_defaults()""" + + return _DefaultDescriptionTuple._from_column_default(self.default) + + @util.memoized_property + def _onupdate_description_tuple(self) -> _DefaultDescriptionTuple: + """used by default.py -> _process_execute_defaults()""" + + return _DefaultDescriptionTuple._from_column_default(self.onupdate) + def _process_multiparam_default_bind( compiler: SQLCompiler, @@ -1459,16 +1508,15 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): """ + dialect = compiler.dialect + need_pks = ( toplevel and _compile_state_isinsert(compile_state) and not stmt._inline and ( not compiler.for_executemany - or ( - compiler.dialect.insert_executemany_returning - and stmt._return_defaults - ) + or (dialect.insert_executemany_returning and stmt._return_defaults) ) and not stmt._returning # and (not stmt._returning or stmt._return_defaults) @@ -1479,7 +1527,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): # after the INSERT if that's all we need. postfetch_lastrowid = ( need_pks - and compiler.dialect.postfetch_lastrowid + and dialect.postfetch_lastrowid and stmt.table._autoincrement_column is not None ) @@ -1491,7 +1539,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): need_pks # the dialect can veto it if it just doesnt support RETURNING # with INSERT - and compiler.dialect.insert_returning + and dialect.insert_returning # user-defined implicit_returning on Table can veto it and compile_state._primary_table.implicit_returning # the compile_state can veto it (SQlite uses this to disable @@ -1506,10 +1554,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): # and a lot of weird use cases are supported by it. # SQLite lastrowid times 3x faster than returning, # Mariadb lastrowid 2x faster than returning - ( - not postfetch_lastrowid - or compiler.dialect.favor_returning_over_lastrowid - ) + (not postfetch_lastrowid or dialect.favor_returning_over_lastrowid) or compile_state._has_multi_parameters or stmt._return_defaults ) @@ -1521,25 +1566,57 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): should_implicit_return_defaults = ( implicit_returning and stmt._return_defaults ) + explicit_returning = should_implicit_return_defaults or stmt._returning + use_insertmanyvalues = ( + toplevel + and compiler.for_executemany + and dialect.use_insertmanyvalues + and ( + explicit_returning or dialect.use_insertmanyvalues_wo_returning + ) + ) + + use_sentinel_columns = None + if ( + use_insertmanyvalues + and explicit_returning + and stmt._sort_by_parameter_order + ): + use_sentinel_columns = compiler._get_sentinel_column_for_table( + stmt.table + ) + elif compile_state.isupdate: should_implicit_return_defaults = ( stmt._return_defaults and compile_state._primary_table.implicit_returning and compile_state._supports_implicit_returning - and compiler.dialect.update_returning + and dialect.update_returning ) + use_insertmanyvalues = False + use_sentinel_columns = None elif compile_state.isdelete: should_implicit_return_defaults = ( stmt._return_defaults and compile_state._primary_table.implicit_returning and compile_state._supports_implicit_returning - and compiler.dialect.delete_returning + and dialect.delete_returning ) + use_insertmanyvalues = False + use_sentinel_columns = None else: should_implicit_return_defaults = False # pragma: no cover + use_insertmanyvalues = False + use_sentinel_columns = None if should_implicit_return_defaults: if not stmt._return_defaults_columns: + # TODO: this is weird. See #9685 where we have to + # take an extra step to prevent this from happening. why + # would this ever be *all* columns? but if we set to blank, then + # that seems to break things also in the ORM. So we should + # try to clean this up and figure out what return_defaults + # needs to do w/ the ORM etc. here implicit_return_defaults = set(stmt.table.c) else: implicit_return_defaults = set(stmt._return_defaults_columns) @@ -1551,6 +1628,8 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): implicit_returning or should_implicit_return_defaults, implicit_return_defaults, postfetch_lastrowid, + use_insertmanyvalues, + use_sentinel_columns, ) |