summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-10-15 15:20:21 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2022-10-16 08:47:47 -0400
commit2b966de4196c8271934769337780f7d504d431cf (patch)
tree608cf4c6400faf6dccefbaefbcdd2e0db1e9bdae /lib/sqlalchemy/sql/crud.py
parente8da50ce0f0474bc89cee15603931760cb6c55ce (diff)
downloadsqlalchemy-2b966de4196c8271934769337780f7d504d431cf.tar.gz
accommodate arbitrary embedded params in insertmanyvalues
Fixed bug in new "insertmanyvalues" feature where INSERT that included a subquery with :func:`_sql.bindparam` inside of it would fail to render correctly in "insertmanyvalues" format. This affected psycopg2 most directly as "insertmanyvalues" is used unconditionally with this driver. Fixes: #8639 Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py129
1 files changed, 98 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 22fffb73a..31d127c2c 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -18,12 +18,14 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
+from typing import Iterable
from typing import List
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
from typing import Sequence
+from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
@@ -49,6 +51,7 @@ if TYPE_CHECKING:
from .dml import DMLState
from .dml import ValuesBase
from .elements import ColumnElement
+ from .elements import KeyedColumnElement
from .schema import _SQLExprDefault
REQUIRED = util.symbol(
@@ -74,18 +77,32 @@ def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]:
return c
-_CrudParamSequence = Sequence[
- Tuple[
- "ColumnElement[Any]",
- str,
- Optional[Union[str, "_SQLExprDefault"]],
- ]
+_CrudParamElement = Tuple[
+ "ColumnElement[Any]",
+ str,
+ Optional[Union[str, "_SQLExprDefault"]],
+ Iterable[str],
+]
+_CrudParamElementStr = Tuple[
+ "KeyedColumnElement[Any]",
+ str,
+ str,
+ Iterable[str],
]
+_CrudParamElementSQLExpr = Tuple[
+ "ColumnClause[Any]",
+ str,
+ "_SQLExprDefault",
+ Iterable[str],
+]
+
+_CrudParamSequence = List[_CrudParamElement]
class _CrudParams(NamedTuple):
single_params: _CrudParamSequence
- all_multi_params: List[Sequence[Tuple[ColumnClause[Any], str, str]]]
+
+ all_multi_params: List[Sequence[_CrudParamElementStr]]
def _get_crud_params(
@@ -175,6 +192,7 @@ def _get_crud_params(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
+ (c.key,),
)
for c in stmt.table.columns
],
@@ -220,9 +238,7 @@ def _get_crud_params(
)
# create a list of column assignment clauses as tuples
- values: List[
- Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
- ] = []
+ values: List[_CrudParamElement] = []
if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
@@ -307,7 +323,10 @@ def _get_crud_params(
compiler,
stmt,
compile_state,
- cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values),
+ cast(
+ "Sequence[_CrudParamElementStr]",
+ values,
+ ),
cast("Callable[..., str]", _column_as_key),
kw,
)
@@ -326,6 +345,7 @@ def _get_crud_params(
_as_dml_column(stmt.table.columns[0]),
compiler.preparer.format_column(stmt.table.columns[0]),
compiler.dialect.default_metavalue_token,
+ (),
)
]
@@ -488,7 +508,7 @@ def _scan_insert_from_select_cols(
compiler.stack[-1]["insert_from_select"] = stmt.select
- add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
+ add_select_cols: List[_CrudParamElementSQLExpr] = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
@@ -499,7 +519,7 @@ def _scan_insert_from_select_cols(
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
parameters.pop(col_key)
- values.append((c, compiler.preparer.format_column(c), None))
+ values.append((c, compiler.preparer.format_column(c), None, ()))
else:
_append_param_insert_select_hasdefault(
compiler, stmt, c, add_select_cols, kw
@@ -513,7 +533,7 @@ def _scan_insert_from_select_cols(
f"Can't extend statement for INSERT..FROM SELECT to include "
f"additional default-holding column(s) "
f"""{
- ', '.join(repr(key) for _, key, _ in add_select_cols)
+ ', '.join(repr(key) for _, key, _, _ in add_select_cols)
}. Convert the selectable to a subquery() first, or pass """
"include_defaults=False to Insert.from_select() to skip these "
"columns."
@@ -521,7 +541,7 @@ def _scan_insert_from_select_cols(
ins_from_select = ins_from_select._generate()
# copy raw_columns
ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [
- expr for _, _, expr in add_select_cols
+ expr for _, _, expr, _ in add_select_cols
]
compiler.stack[-1]["insert_from_select"] = ins_from_select
@@ -748,6 +768,8 @@ def _append_param_parameter(
c, use_table=compile_state.include_table_with_column_exprs
)
+ accumulated_bind_names: Set[str] = set()
+
if coercions._is_literal(value):
if (
@@ -772,6 +794,7 @@ def _append_param_parameter(
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
+ accumulate_bind_names=accumulated_bind_names,
**kw,
)
elif value._is_bind_parameter:
@@ -796,11 +819,16 @@ def _append_param_parameter(
if not _compile_state_isinsert(compile_state)
or not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
+ accumulate_bind_names=accumulated_bind_names,
**kw,
)
else:
# value is a SQL expression
- value = compiler.process(value.self_group(), **kw)
+ value = compiler.process(
+ value.self_group(),
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ )
if compile_state.isupdate:
if implicit_return_defaults and c in implicit_return_defaults:
@@ -828,7 +856,7 @@ def _append_param_parameter(
compiler.postfetch.append(c)
- values.append((c, col_value, value))
+ values.append((c, col_value, value, accumulated_bind_names))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
@@ -843,20 +871,32 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
not c.default.optional
or not compiler.dialect.sequences_optional
):
+ accumulated_bind_names: Set[str] = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default, **kw),
+ compiler.process(
+ c.default,
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
compiler.implicit_returning.append(c)
elif c.default.is_clause_element:
+ accumulated_bind_names = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default.arg.self_group(), **kw),
+ compiler.process(
+ c.default.arg.self_group(),
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
compiler.implicit_returning.append(c)
@@ -869,6 +909,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
elif c is stmt.table._autoincrement_column or c.server_default is not None:
@@ -936,6 +977,7 @@ def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw):
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
elif (
@@ -962,11 +1004,17 @@ def _append_param_insert_hasdefault(
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
+ accumulated_bind_names: Set[str] = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default, **kw),
+ compiler.process(
+ c.default,
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
if implicit_return_defaults and c in implicit_return_defaults:
@@ -974,11 +1022,17 @@ def _append_param_insert_hasdefault(
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
+ accumulated_bind_names = set()
values.append(
(
c,
compiler.preparer.format_column(c),
- compiler.process(c.default.arg.self_group(), **kw),
+ compiler.process(
+ c.default.arg.self_group(),
+ accumulate_bind_names=accumulated_bind_names,
+ **kw,
+ ),
+ accumulated_bind_names,
)
)
@@ -993,6 +1047,7 @@ def _append_param_insert_hasdefault(
c,
compiler.preparer.format_column(c),
_create_insert_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
@@ -1001,7 +1056,7 @@ def _append_param_insert_select_hasdefault(
compiler: SQLCompiler,
stmt: ValuesBase,
c: ColumnClause[Any],
- values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+ values: List[_CrudParamElementSQLExpr],
kw: Dict[str, Any],
) -> None:
@@ -1014,6 +1069,7 @@ def _append_param_insert_select_hasdefault(
c,
compiler.preparer.format_column(c),
c.default.next_value(),
+ (),
)
)
elif default_is_clause_element(c.default):
@@ -1022,6 +1078,7 @@ def _append_param_insert_select_hasdefault(
c,
compiler.preparer.format_column(c),
c.default.arg.self_group(),
+ (),
)
)
else:
@@ -1032,6 +1089,7 @@ def _append_param_insert_select_hasdefault(
_create_insert_prefetch_bind_param(
compiler, c, process=False, **kw
),
+ (c.key,),
)
)
@@ -1051,6 +1109,7 @@ def _append_param_update(
use_table=include_table,
),
compiler.process(c.onupdate.arg.self_group(), **kw),
+ (),
)
)
if implicit_return_defaults and c in implicit_return_defaults:
@@ -1066,6 +1125,7 @@ def _append_param_update(
use_table=include_table,
),
_create_update_prefetch_bind_param(compiler, c, **kw),
+ (c.key,),
)
)
elif c.server_onupdate is not None:
@@ -1177,7 +1237,7 @@ class _multiparam_column(elements.ColumnElement[Any]):
def _process_multiparam_default_bind(
compiler: SQLCompiler,
stmt: ValuesBase,
- c: ColumnClause[Any],
+ c: KeyedColumnElement[Any],
index: int,
kw: Dict[str, Any],
) -> str:
@@ -1243,14 +1303,18 @@ def _get_update_multitable_params(
name=_col_bind_name(c),
**kw, # TODO: no test coverage for literal binds here
)
+ accumulated_bind_names: Iterable[str] = (c.key,)
elif value._is_bind_parameter:
+ cbn = _col_bind_name(c)
value = _handle_values_anonymous_param(
- compiler, c, value, name=_col_bind_name(c), **kw
+ compiler, c, value, name=cbn, **kw
)
+ accumulated_bind_names = (cbn,)
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
- values.append((c, col_value, value))
+ accumulated_bind_names = ()
+ values.append((c, col_value, value, accumulated_bind_names))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
@@ -1266,6 +1330,7 @@ def _get_update_multitable_params(
compiler.process(
c.onupdate.arg.self_group(), **kw
),
+ (),
)
)
compiler.postfetch.append(c)
@@ -1277,6 +1342,7 @@ def _get_update_multitable_params(
_create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c), **kw
),
+ (c.key,),
)
)
elif c.server_onupdate is not None:
@@ -1287,21 +1353,21 @@ def _extend_values_for_multiparams(
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
- initial_values: Sequence[Tuple[ColumnClause[Any], str, str]],
+ initial_values: Sequence[_CrudParamElementStr],
_column_as_key: Callable[..., str],
kw: Dict[str, Any],
-) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]:
+) -> List[Sequence[_CrudParamElementStr]]:
values_0 = initial_values
values = [initial_values]
mp = compile_state._multi_parameters
assert mp is not None
for i, row in enumerate(mp[1:]):
- extension: List[Tuple[ColumnClause[Any], str, str]] = []
+ extension: List[_CrudParamElementStr] = []
row = {_column_as_key(key): v for key, v in row.items()}
- for (col, col_expr, param) in values_0:
+ for (col, col_expr, param, accumulated_names) in values_0:
if col.key in row:
key = col.key
@@ -1320,7 +1386,7 @@ def _extend_values_for_multiparams(
compiler, stmt, col, i, kw
)
- extension.append((col, col_expr, new_param))
+ extension.append((col, col_expr, new_param, accumulated_names))
values.append(extension)
@@ -1366,7 +1432,8 @@ def _get_stmt_parameter_tuples_params(
v = compiler.process(v.self_group(), **kw)
- values.append((k, col_expr, v))
+ # TODO: not sure if accumulated_bind_names applies here
+ values.append((k, col_expr, v, ()))
def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):