summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py93
1 files changed, 53 insertions, 40 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 81151a26b..b13377a59 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -32,6 +32,7 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .dml import isinsert as _compile_state_isinsert
from .elements import ColumnClause
from .schema import default_is_clause_element
from .schema import default_is_sequence
@@ -73,19 +74,18 @@ def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
return c
-class _CrudParams(NamedTuple):
- single_params: Sequence[
- Tuple[ColumnElement[Any], str, Optional[Union[str, _SQLExprDefault]]]
- ]
- all_multi_params: List[
- Sequence[
- Tuple[
- ColumnClause[Any],
- str,
- str,
- ]
- ]
+_CrudParamSequence = Sequence[
+ Tuple[
+ "ColumnElement[Any]",
+ str,
+ Optional[Union[str, "_SQLExprDefault"]],
]
+]
+
+
+class _CrudParams(NamedTuple):
+ single_params: _CrudParamSequence
+ all_multi_params: List[Sequence[Tuple[ColumnClause[Any], str, str]]]
def _get_crud_params(
@@ -144,6 +144,12 @@ def _get_crud_params(
compiler._get_bind_name_for_col = _col_bind_name
+ if stmt._returning and stmt._return_defaults:
+ raise exc.CompileError(
+ "Can't compile statement that includes returning() and "
+ "return_defaults() simultaneously"
+ )
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
@@ -164,7 +170,10 @@ def _get_crud_params(
]
spd: Optional[MutableMapping[_DMLColumnElement, Any]]
- if compile_state._has_multi_parameters:
+ if (
+ _compile_state_isinsert(compile_state)
+ and compile_state._has_multi_parameters
+ ):
mp = compile_state._multi_parameters
assert mp is not None
spd = mp[0]
@@ -227,7 +236,7 @@ def _get_crud_params(
kw,
)
- if compile_state.isinsert and stmt._select_names:
+ if _compile_state_isinsert(compile_state) and stmt._select_names:
# is an insert from select, is not a multiparams
assert not compile_state._has_multi_parameters
@@ -272,7 +281,10 @@ def _get_crud_params(
% (", ".join("%s" % (c,) for c in check))
)
- if compile_state._has_multi_parameters:
+ if (
+ _compile_state_isinsert(compile_state)
+ and compile_state._has_multi_parameters
+ ):
# is a multiparams, is not an insert from a select
assert not stmt._select_names
multi_extended_values = _extend_values_for_multiparams(
@@ -297,7 +309,7 @@ def _get_crud_params(
(
_as_dml_column(stmt.table.columns[0]),
compiler.preparer.format_column(stmt.table.columns[0]),
- "DEFAULT",
+ compiler.dialect.default_metavalue_token,
)
]
@@ -500,7 +512,7 @@ def _scan_insert_from_select_cols(
ins_from_select = ins_from_select._generate()
# copy raw_columns
ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [
- expr for col, col_expr, expr in add_select_cols
+ expr for _, _, expr in add_select_cols
]
compiler.stack[-1]["insert_from_select"] = ins_from_select
@@ -539,7 +551,8 @@ def _scan_cols(
else:
cols = stmt.table.columns
- if compile_state.isinsert and not compile_state._has_multi_parameters:
+ isinsert = _compile_state_isinsert(compile_state)
+ if isinsert and not compile_state._has_multi_parameters:
# new rules for #7998. fetch lastrowid or implicit returning
# for autoincrement column even if parameter is NULL, for DBs that
# override NULL param for primary key (sqlite, mysql/mariadb)
@@ -575,7 +588,7 @@ def _scan_cols(
kw,
)
- elif compile_state.isinsert:
+ elif isinsert:
# no parameter is present and it's an insert.
if c.primary_key and need_pks:
@@ -683,7 +696,8 @@ def _append_param_parameter(
value,
required=value is REQUIRED,
name=_col_bind_name(c)
- if not compile_state._has_multi_parameters
+ if not _compile_state_isinsert(compile_state)
+ or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw,
)
@@ -706,7 +720,8 @@ def _append_param_parameter(
c,
value,
name=_col_bind_name(c)
- if not compile_state._has_multi_parameters
+ if not _compile_state_isinsert(compile_state)
+ or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw,
)
@@ -922,11 +937,19 @@ def _append_param_insert_select_hasdefault(
not c.default.optional or not compiler.dialect.sequences_optional
):
values.append(
- (c, compiler.preparer.format_column(c), c.default.next_value())
+ (
+ c,
+ compiler.preparer.format_column(c),
+ c.default.next_value(),
+ )
)
elif default_is_clause_element(c.default):
values.append(
- (c, compiler.preparer.format_column(c), c.default.arg.self_group())
+ (
+ c,
+ compiler.preparer.format_column(c),
+ c.default.arg.self_group(),
+ )
)
else:
values.append(
@@ -1105,14 +1128,10 @@ def _process_multiparam_default_bind(
return compiler.process(c.default, **kw)
else:
col = _multiparam_column(c, index)
- if isinstance(stmt, dml.Insert):
- return _create_insert_prefetch_bind_param(
- compiler, col, process=True, **kw
- )
- else:
- return _create_update_prefetch_bind_param(
- compiler, col, process=True, **kw
- )
+ assert isinstance(stmt, dml.Insert)
+ return _create_insert_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
def _get_update_multitable_params(
@@ -1205,13 +1224,7 @@ def _extend_values_for_multiparams(
mp = compile_state._multi_parameters
assert mp is not None
for i, row in enumerate(mp[1:]):
- extension: List[
- Tuple[
- ColumnClause[Any],
- str,
- str,
- ]
- ] = []
+ extension: List[Tuple[ColumnClause[Any], str, str]] = []
row = {_column_as_key(key): v for key, v in row.items()}
@@ -1292,7 +1305,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
"""
need_pks = (
toplevel
- and compile_state.isinsert
+ and _compile_state_isinsert(compile_state)
and not stmt._inline
and (
not compiler.for_executemany
@@ -1348,7 +1361,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
if implicit_returning:
postfetch_lastrowid = False
- if compile_state.isinsert:
+ if _compile_state_isinsert(compile_state):
implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compile_state.isupdate:
implicit_return_defaults = (