diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-15 15:20:21 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-10-16 08:47:47 -0400 |
commit | 2b966de4196c8271934769337780f7d504d431cf (patch) | |
tree | 608cf4c6400faf6dccefbaefbcdd2e0db1e9bdae /lib/sqlalchemy/sql/crud.py | |
parent | e8da50ce0f0474bc89cee15603931760cb6c55ce (diff) | |
download | sqlalchemy-2b966de4196c8271934769337780f7d504d431cf.tar.gz |
accommodate arbitrary embedded params in insertmanyvalues
Fixed bug in new "insertmanyvalues" feature where INSERT that included a
subquery with :func:`_sql.bindparam` inside of it would fail to render
correctly in "insertmanyvalues" format. This affected psycopg2 most
directly as "insertmanyvalues" is used unconditionally with this driver.
Fixes: #8639
Change-Id: I67903fa86afe208899d4f23f940e0727d1be2ce3
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 129 |
1 files changed, 98 insertions, 31 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 22fffb73a..31d127c2c 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -18,12 +18,14 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Iterable from typing import List from typing import MutableMapping from typing import NamedTuple from typing import Optional from typing import overload from typing import Sequence +from typing import Set from typing import Tuple from typing import TYPE_CHECKING from typing import Union @@ -49,6 +51,7 @@ if TYPE_CHECKING: from .dml import DMLState from .dml import ValuesBase from .elements import ColumnElement + from .elements import KeyedColumnElement from .schema import _SQLExprDefault REQUIRED = util.symbol( @@ -74,18 +77,32 @@ def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: return c -_CrudParamSequence = Sequence[ - Tuple[ - "ColumnElement[Any]", - str, - Optional[Union[str, "_SQLExprDefault"]], - ] +_CrudParamElement = Tuple[ + "ColumnElement[Any]", + str, + Optional[Union[str, "_SQLExprDefault"]], + Iterable[str], +] +_CrudParamElementStr = Tuple[ + "KeyedColumnElement[Any]", + str, + str, + Iterable[str], ] +_CrudParamElementSQLExpr = Tuple[ + "ColumnClause[Any]", + str, + "_SQLExprDefault", + Iterable[str], +] + +_CrudParamSequence = List[_CrudParamElement] class _CrudParams(NamedTuple): single_params: _CrudParamSequence - all_multi_params: List[Sequence[Tuple[ColumnClause[Any], str, str]]] + + all_multi_params: List[Sequence[_CrudParamElementStr]] def _get_crud_params( @@ -175,6 +192,7 @@ def _get_crud_params( c, compiler.preparer.format_column(c), _create_bind_param(compiler, c, None, required=True), + (c.key,), ) for c in stmt.table.columns ], @@ -220,9 +238,7 @@ def _get_crud_params( ) # create a list of column assignment clauses as tuples - values: List[ - Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]] - ] = [] + values: List[_CrudParamElement] = [] if stmt_parameter_tuples is not None: _get_stmt_parameter_tuples_params( @@ -307,7 +323,10 @@ def _get_crud_params( compiler, stmt, compile_state, - cast("Sequence[Tuple[ColumnClause[Any], str, str]]", values), + cast( + "Sequence[_CrudParamElementStr]", + values, + ), cast("Callable[..., str]", _column_as_key), kw, ) @@ -326,6 +345,7 @@ def _get_crud_params( _as_dml_column(stmt.table.columns[0]), compiler.preparer.format_column(stmt.table.columns[0]), compiler.dialect.default_metavalue_token, + (), ) ] @@ -488,7 +508,7 @@ def _scan_insert_from_select_cols( compiler.stack[-1]["insert_from_select"] = stmt.select - add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = [] + add_select_cols: List[_CrudParamElementSQLExpr] = [] if stmt.include_insert_from_select_defaults: col_set = set(cols) for col in stmt.table.columns: @@ -499,7 +519,7 @@ def _scan_insert_from_select_cols( col_key = _getattr_col_key(c) if col_key in parameters and col_key not in check_columns: parameters.pop(col_key) - values.append((c, compiler.preparer.format_column(c), None)) + values.append((c, compiler.preparer.format_column(c), None, ())) else: _append_param_insert_select_hasdefault( compiler, stmt, c, add_select_cols, kw @@ -513,7 +533,7 @@ def _scan_insert_from_select_cols( f"Can't extend statement for INSERT..FROM SELECT to include " f"additional default-holding column(s) " f"""{ - ', '.join(repr(key) for _, key, _ in add_select_cols) + ', '.join(repr(key) for _, key, _, _ in add_select_cols) }. Convert the selectable to a subquery() first, or pass """ "include_defaults=False to Insert.from_select() to skip these " "columns." @@ -521,7 +541,7 @@ def _scan_insert_from_select_cols( ins_from_select = ins_from_select._generate() # copy raw_columns ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [ - expr for _, _, expr in add_select_cols + expr for _, _, expr, _ in add_select_cols ] compiler.stack[-1]["insert_from_select"] = ins_from_select @@ -748,6 +768,8 @@ def _append_param_parameter( c, use_table=compile_state.include_table_with_column_exprs ) + accumulated_bind_names: Set[str] = set() + if coercions._is_literal(value): if ( @@ -772,6 +794,7 @@ def _append_param_parameter( if not _compile_state_isinsert(compile_state) or not compile_state._has_multi_parameters else "%s_m0" % _col_bind_name(c), + accumulate_bind_names=accumulated_bind_names, **kw, ) elif value._is_bind_parameter: @@ -796,11 +819,16 @@ def _append_param_parameter( if not _compile_state_isinsert(compile_state) or not compile_state._has_multi_parameters else "%s_m0" % _col_bind_name(c), + accumulate_bind_names=accumulated_bind_names, **kw, ) else: # value is a SQL expression - value = compiler.process(value.self_group(), **kw) + value = compiler.process( + value.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) if compile_state.isupdate: if implicit_return_defaults and c in implicit_return_defaults: @@ -828,7 +856,7 @@ def _append_param_parameter( compiler.postfetch.append(c) - values.append((c, col_value, value)) + values.append((c, col_value, value, accumulated_bind_names)) def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): @@ -843,20 +871,32 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): not c.default.optional or not compiler.dialect.sequences_optional ): + accumulated_bind_names: Set[str] = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default, **kw), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) compiler.implicit_returning.append(c) elif c.default.is_clause_element: + accumulated_bind_names = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default.arg.self_group(), **kw), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) compiler.implicit_returning.append(c) @@ -869,6 +909,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): c, compiler.preparer.format_column(c), _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) elif c is stmt.table._autoincrement_column or c.server_default is not None: @@ -936,6 +977,7 @@ def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw): c, compiler.preparer.format_column(c), _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) elif ( @@ -962,11 +1004,17 @@ def _append_param_insert_hasdefault( if compiler.dialect.supports_sequences and ( not c.default.optional or not compiler.dialect.sequences_optional ): + accumulated_bind_names: Set[str] = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default, **kw), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) if implicit_return_defaults and c in implicit_return_defaults: @@ -974,11 +1022,17 @@ def _append_param_insert_hasdefault( elif not c.primary_key: compiler.postfetch.append(c) elif c.default.is_clause_element: + accumulated_bind_names = set() values.append( ( c, compiler.preparer.format_column(c), - compiler.process(c.default.arg.self_group(), **kw), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, ) ) @@ -993,6 +1047,7 @@ def _append_param_insert_hasdefault( c, compiler.preparer.format_column(c), _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) @@ -1001,7 +1056,7 @@ def _append_param_insert_select_hasdefault( compiler: SQLCompiler, stmt: ValuesBase, c: ColumnClause[Any], - values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]], + values: List[_CrudParamElementSQLExpr], kw: Dict[str, Any], ) -> None: @@ -1014,6 +1069,7 @@ def _append_param_insert_select_hasdefault( c, compiler.preparer.format_column(c), c.default.next_value(), + (), ) ) elif default_is_clause_element(c.default): @@ -1022,6 +1078,7 @@ def _append_param_insert_select_hasdefault( c, compiler.preparer.format_column(c), c.default.arg.self_group(), + (), ) ) else: @@ -1032,6 +1089,7 @@ def _append_param_insert_select_hasdefault( _create_insert_prefetch_bind_param( compiler, c, process=False, **kw ), + (c.key,), ) ) @@ -1051,6 +1109,7 @@ def _append_param_update( use_table=include_table, ), compiler.process(c.onupdate.arg.self_group(), **kw), + (), ) ) if implicit_return_defaults and c in implicit_return_defaults: @@ -1066,6 +1125,7 @@ def _append_param_update( use_table=include_table, ), _create_update_prefetch_bind_param(compiler, c, **kw), + (c.key,), ) ) elif c.server_onupdate is not None: @@ -1177,7 +1237,7 @@ class _multiparam_column(elements.ColumnElement[Any]): def _process_multiparam_default_bind( compiler: SQLCompiler, stmt: ValuesBase, - c: ColumnClause[Any], + c: KeyedColumnElement[Any], index: int, kw: Dict[str, Any], ) -> str: @@ -1243,14 +1303,18 @@ def _get_update_multitable_params( name=_col_bind_name(c), **kw, # TODO: no test coverage for literal binds here ) + accumulated_bind_names: Iterable[str] = (c.key,) elif value._is_bind_parameter: + cbn = _col_bind_name(c) value = _handle_values_anonymous_param( - compiler, c, value, name=_col_bind_name(c), **kw + compiler, c, value, name=cbn, **kw ) + accumulated_bind_names = (cbn,) else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) - values.append((c, col_value, value)) + accumulated_bind_names = () + values.append((c, col_value, value, accumulated_bind_names)) # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: @@ -1266,6 +1330,7 @@ def _get_update_multitable_params( compiler.process( c.onupdate.arg.self_group(), **kw ), + (), ) ) compiler.postfetch.append(c) @@ -1277,6 +1342,7 @@ def _get_update_multitable_params( _create_update_prefetch_bind_param( compiler, c, name=_col_bind_name(c), **kw ), + (c.key,), ) ) elif c.server_onupdate is not None: @@ -1287,21 +1353,21 @@ def _extend_values_for_multiparams( compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState, - initial_values: Sequence[Tuple[ColumnClause[Any], str, str]], + initial_values: Sequence[_CrudParamElementStr], _column_as_key: Callable[..., str], kw: Dict[str, Any], -) -> List[Sequence[Tuple[ColumnClause[Any], str, str]]]: +) -> List[Sequence[_CrudParamElementStr]]: values_0 = initial_values values = [initial_values] mp = compile_state._multi_parameters assert mp is not None for i, row in enumerate(mp[1:]): - extension: List[Tuple[ColumnClause[Any], str, str]] = [] + extension: List[_CrudParamElementStr] = [] row = {_column_as_key(key): v for key, v in row.items()} - for (col, col_expr, param) in values_0: + for (col, col_expr, param, accumulated_names) in values_0: if col.key in row: key = col.key @@ -1320,7 +1386,7 @@ def _extend_values_for_multiparams( compiler, stmt, col, i, kw ) - extension.append((col, col_expr, new_param)) + extension.append((col, col_expr, new_param, accumulated_names)) values.append(extension) @@ -1366,7 +1432,8 @@ def _get_stmt_parameter_tuples_params( v = compiler.process(v.self_group(), **kw) - values.append((k, col_expr, v)) + # TODO: not sure if accumulated_bind_names applies here + values.append((k, col_expr, v, ())) def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): |