diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 350 |
1 files changed, 328 insertions, 22 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1d13ffa9a..201324a2a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -94,6 +94,7 @@ if typing.TYPE_CHECKING: from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import Label from .functions import Function from .selectable import AliasedReturnsRows @@ -390,6 +391,19 @@ class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): class ExpandedState(NamedTuple): + """represents state to use when producing "expanded" and + "post compile" bound parameters for a statement. + + "expanded" parameters are parameters that are generated at + statement execution time to suit a number of parameters passed, the most + prominent example being the individual elements inside of an IN expression. + + "post compile" parameters are parameters where the SQL literal value + will be rendered into the SQL statement at execution time, rather than + being passed as separate parameters to the driver. + + """ + statement: str additional_parameters: _CoreSingleExecuteParams processors: Mapping[str, _BindProcessorType[Any]] @@ -397,7 +411,23 @@ class ExpandedState(NamedTuple): parameter_expansion: Mapping[str, List[str]] +class _InsertManyValues(NamedTuple): + """represents state to use for executing an "insertmanyvalues" statement""" + + is_default_expr: bool + single_values_expr: str + insert_crud_params: List[Tuple[KeyedColumnElement[Any], str, str]] + num_positional_params_counted: int + + class Linting(IntEnum): + """represent preferences for the 'SQL linting' feature. + + this feature currently includes support for flagging cartesian products + in SQL statements. + + """ + NO_LINTING = 0 "Disable all linting." @@ -419,6 +449,9 @@ NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple( class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + """represents current state for the "cartesian product" detection + feature.""" + def lint(self, start=None): froms = self.froms if not froms: @@ -762,8 +795,6 @@ class SQLCompiler(Compiled): is_sql = True - _result_columns: List[ResultColumnsEntry] - compound_keywords = COMPOUND_KEYWORDS isdelete: bool = False @@ -810,12 +841,6 @@ class SQLCompiler(Compiled): """major statements such as SELECT, INSERT, UPDATE, DELETE are tracked in this stack using an entry format.""" - result_columns: List[ResultColumnsEntry] - """relates label names in the final SQL to a tuple of local - column/label name, ColumnElement object (if any) and - TypeEngine. CursorResult uses this for type processing and - column targeting""" - returning_precedes_values: bool = False """set to True classwide to generate RETURNING clauses before the VALUES or WHERE clause (i.e. MSSQL) @@ -835,6 +860,12 @@ class SQLCompiler(Compiled): driver/DB enforces this """ + _result_columns: List[ResultColumnsEntry] + """relates label names in the final SQL to a tuple of local + column/label name, ColumnElement object (if any) and + TypeEngine. CursorResult uses this for type processing and + column targeting""" + _textual_ordered_columns: bool = False """tell the result object that the column names as rendered are important, but they are also "ordered" vs. what is in the compiled object here. @@ -881,14 +912,9 @@ class SQLCompiler(Compiled): """ - insert_single_values_expr: Optional[str] = None - """When an INSERT is compiled with a single set of parameters inside - a VALUES expression, the string is assigned here, where it can be - used for insert batching schemes to rewrite the VALUES expression. + _insertmanyvalues: Optional[_InsertManyValues] = None - .. versionadded:: 1.3.8 - - """ + _insert_crud_params: Optional[crud._CrudParamSequence] = None literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() """bindparameter objects that are rendered as literal values at statement @@ -1072,6 +1098,25 @@ class SQLCompiler(Compiled): if self._render_postcompile: self._process_parameters_for_postcompile(_populate_self=True) + @property + def insert_single_values_expr(self) -> Optional[str]: + """When an INSERT is compiled with a single set of parameters inside + a VALUES expression, the string is assigned here, where it can be + used for insert batching schemes to rewrite the VALUES expression. + + .. versionadded:: 1.3.8 + + .. versionchanged:: 2.0 This collection is no longer used by + SQLAlchemy's built-in dialects, in favor of the currently + internal ``_insertmanyvalues`` collection that is used only by + :class:`.SQLCompiler`. + + """ + if self._insertmanyvalues is None: + return None + else: + return self._insertmanyvalues.single_values_expr + @util.ro_memoized_property def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]: """The effective "returning" columns for INSERT, UPDATE or DELETE. @@ -1620,10 +1665,13 @@ class SQLCompiler(Compiled): param_key_getter = self._within_exec_param_key_getter + assert self.compile_state is not None + statement = self.compile_state.statement + if TYPE_CHECKING: - assert isinstance(self.statement, Insert) + assert isinstance(statement, Insert) - table = self.statement.table + table = statement.table getters = [ (operator.methodcaller("get", param_key_getter(col), None), col) @@ -1697,11 +1745,14 @@ class SQLCompiler(Compiled): else: result = util.preloaded.engine_result + assert self.compile_state is not None + statement = self.compile_state.statement + if TYPE_CHECKING: - assert isinstance(self.statement, Insert) + assert isinstance(statement, Insert) param_key_getter = self._within_exec_param_key_getter - table = self.statement.table + table = statement.table returning = self.implicit_returning assert returning is not None @@ -4506,7 +4557,202 @@ class SQLCompiler(Compiled): ) return dialect_hints, table_text + def _insert_stmt_should_use_insertmanyvalues(self, statement): + return ( + self.dialect.supports_multivalues_insert + and self.dialect.use_insertmanyvalues + # note self.implicit_returning or self._result_columns + # implies self.dialect.insert_returning capability + and ( + self.dialect.use_insertmanyvalues_wo_returning + or self.implicit_returning + or self._result_columns + ) + ) + + def _deliver_insertmanyvalues_batches( + self, statement, parameters, generic_setinputsizes, batch_size + ): + imv = self._insertmanyvalues + assert imv is not None + + executemany_values = f"({imv.single_values_expr})" + + lenparams = len(parameters) + if imv.is_default_expr and not self.dialect.supports_default_metavalue: + # backend doesn't support + # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ... + # at the moment this is basically SQL Server due to + # not being able to use DEFAULT for identity column + # just yield out that many single statements! still + # faster than a whole connection.execute() call ;) + # + # note we still are taking advantage of the fact that we know + # we are using RETURNING. The generalized approach of fetching + # cursor.lastrowid etc. still goes through the more heavyweight + # "ExecutionContext per statement" system as it isn't usable + # as a generic "RETURNING" approach + for batchnum, param in enumerate(parameters, 1): + yield ( + statement, + param, + generic_setinputsizes, + batchnum, + lenparams, + ) + return + else: + statement = statement.replace( + executemany_values, "__EXECMANY_TOKEN__" + ) + + # Use optional insertmanyvalues_max_parameters + # to further shrink the batch size so that there are no more than + # insertmanyvalues_max_parameters params. + # Currently used by SQL Server, which limits statements to 2100 bound + # parameters (actually 2099). + max_params = self.dialect.insertmanyvalues_max_parameters + if max_params: + total_num_of_params = len(self.bind_names) + num_params_per_batch = len(imv.insert_crud_params) + num_params_outside_of_batch = ( + total_num_of_params - num_params_per_batch + ) + batch_size = min( + batch_size, + ( + (max_params - num_params_outside_of_batch) + // num_params_per_batch + ), + ) + + batches = list(parameters) + + processed_setinputsizes = None + batchnum = 1 + total_batches = lenparams // batch_size + ( + 1 if lenparams % batch_size else 0 + ) + + insert_crud_params = imv.insert_crud_params + assert insert_crud_params is not None + + escaped_bind_names: Mapping[str, str] + if not self.positional: + if self.escaped_bind_names: + escaped_bind_names = self.escaped_bind_names + else: + escaped_bind_names = {} + + all_keys = set(parameters[0]) + + escaped_insert_crud_params: Sequence[Any] = [ + (escaped_bind_names.get(col.key, col.key), formatted) + for col, _, formatted in insert_crud_params + ] + + keys_to_replace = all_keys.intersection( + key for key, _ in escaped_insert_crud_params + ) + base_parameters = { + key: parameters[0][key] + for key in all_keys.difference(keys_to_replace) + } + executemany_values_w_comma = "" + else: + escaped_insert_crud_params = () + keys_to_replace = set() + base_parameters = {} + executemany_values_w_comma = f"({imv.single_values_expr}), " + + while batches: + batch = batches[0:batch_size] + batches[0:batch_size] = [] + + if generic_setinputsizes: + # if setinputsizes is present, expand this collection to + # suit the batch length as well + # currently this will be mssql+pyodbc for internal dialects + processed_setinputsizes = [ + (new_key, len_, typ) + for new_key, len_, typ in ( + (f"{key}_{index}", len_, typ) + for index in range(len(batch)) + for key, len_, typ in generic_setinputsizes + ) + ] + + replaced_parameters: Any + if self.positional: + # the assumption here is that any parameters that are not + # in the VALUES clause are expected to be parameterized + # expressions in the RETURNING (or maybe ON CONFLICT) clause. + # So based on + # which sequence comes first in the compiler's INSERT + # statement tells us where to expand the parameters. + + # otherwise we probably shouldn't be doing insertmanyvalues + # on the statement. + + num_ins_params = imv.num_positional_params_counted + + if num_ins_params == len(batch[0]): + extra_params = () + batch_iterator: Iterable[Tuple[Any, ...]] = batch + elif self.returning_precedes_values: + extra_params = batch[0][:-num_ins_params] + batch_iterator = (b[-num_ins_params:] for b in batch) + else: + extra_params = batch[0][num_ins_params:] + batch_iterator = (b[:num_ins_params] for b in batch) + + replaced_statement = statement.replace( + "__EXECMANY_TOKEN__", + (executemany_values_w_comma * len(batch))[:-2], + ) + + replaced_parameters = tuple( + itertools.chain.from_iterable(batch_iterator) + ) + if self.returning_precedes_values: + replaced_parameters = extra_params + replaced_parameters + else: + replaced_parameters = replaced_parameters + extra_params + else: + replaced_values_clauses = [] + replaced_parameters = base_parameters.copy() + + for i, param in enumerate(batch): + new_tokens = [ + formatted.replace(key, f"{key}__{i}") + if key in param + else formatted + for key, formatted in escaped_insert_crud_params + ] + replaced_values_clauses.append( + f"({', '.join(new_tokens)})" + ) + + replaced_parameters.update( + {f"{key}__{i}": param[key] for key in keys_to_replace} + ) + + replaced_statement = statement.replace( + "__EXECMANY_TOKEN__", + ", ".join(replaced_values_clauses), + ) + + yield ( + replaced_statement, + replaced_parameters, + processed_setinputsizes, + batchnum, + total_batches, + ) + batchnum += 1 + def visit_insert(self, insert_stmt, **kw): + compile_state = insert_stmt._compile_state_factory( insert_stmt, self, **kw ) @@ -4529,9 +4775,24 @@ class SQLCompiler(Compiled): } ) + positiontup_before = positiontup_after = 0 + + # for positional, insertmanyvalues needs to know how many + # bound parameters are in the VALUES sequence; there's no simple + # rule because default expressions etc. can have zero or more + # params inside them. After multiple attempts to figure this out, + # this very simplistic "count before, then count after" works and is + # likely the least amount of callcounts, though looks clumsy + if self.positiontup: + positiontup_before = len(self.positiontup) + crud_params_struct = crud._get_crud_params( self, insert_stmt, compile_state, toplevel, **kw ) + + if self.positiontup: + positiontup_after = len(self.positiontup) + crud_params_single = crud_params_struct.single_params if ( @@ -4584,14 +4845,34 @@ class SQLCompiler(Compiled): ) if self.implicit_returning or insert_stmt._returning: + + # if returning clause is rendered first, capture bound parameters + # while visiting and place them prior to the VALUES oriented + # bound parameters, when using positional parameter scheme + rpv = self.returning_precedes_values + flip_pt = rpv and self.positional + if flip_pt: + pt: Optional[List[str]] = self.positiontup + temp_pt: Optional[List[str]] + self.positiontup = temp_pt = [] + else: + temp_pt = pt = None + returning_clause = self.returning_clause( insert_stmt, self.implicit_returning or insert_stmt._returning, populate_result_map=toplevel, ) - if self.returning_precedes_values: + if flip_pt: + if TYPE_CHECKING: + assert temp_pt is not None + assert pt is not None + self.positiontup = temp_pt + pt + + if rpv: text += " " + returning_clause + else: returning_clause = None @@ -4614,6 +4895,18 @@ class SQLCompiler(Compiled): text += " %s" % select_text elif not crud_params_single and supports_default_values: text += " DEFAULT VALUES" + if toplevel and self._insert_stmt_should_use_insertmanyvalues( + insert_stmt + ): + self._insertmanyvalues = _InsertManyValues( + True, + self.dialect.default_metavalue_token, + cast( + "List[Tuple[KeyedColumnElement[Any], str, str]]", + crud_params_single, + ), + (positiontup_after - positiontup_before), + ) elif compile_state._has_multi_parameters: text += " VALUES %s" % ( ", ".join( @@ -4623,6 +4916,8 @@ class SQLCompiler(Compiled): ) ) else: + # TODO: why is third element of crud_params_single not str + # already? insert_single_values_expr = ", ".join( [ value @@ -4631,9 +4926,20 @@ class SQLCompiler(Compiled): ) ] ) + text += " VALUES (%s)" % insert_single_values_expr - if toplevel: - self.insert_single_values_expr = insert_single_values_expr + if toplevel and self._insert_stmt_should_use_insertmanyvalues( + insert_stmt + ): + self._insertmanyvalues = _InsertManyValues( + False, + insert_single_values_expr, + cast( + "List[Tuple[KeyedColumnElement[Any], str, str]]", + crud_params_single, + ), + positiontup_after - positiontup_before, + ) if insert_stmt._post_values_clause is not None: post_values_clause = self.process( |