diff options
author | Federico Caselli <cfederico87@gmail.com> | 2022-12-02 11:58:40 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-12-05 09:59:01 -0500 |
commit | 06c234d037bdab48e716d6c5f5dc200095269474 (patch) | |
tree | 8ed48e0627f0e4816b7e26f9e6294330f1ba19d6 /lib/sqlalchemy/sql | |
parent | 9058593e0b28cee0211251de6604e4601ff69a00 (diff) | |
download | sqlalchemy-06c234d037bdab48e716d6c5f5dc200095269474.tar.gz |
Rewrite positional handling, test for "numeric"
Changed how the positional compilation is performed. It's rendered by the compiler
the same as the pyformat compilation. The string is then processed to replace
the placeholders with the correct ones, and to obtain the correct order of the
parameters.
This vastly simplifies the computation of the order of the parameters, that in
case of nested CTE is very hard to compute correctly.
Reworked how numeric paramstyle behavers:
- added support for repeated parameter, without duplicating them like in normal
positional dialects
- implement insertmany support. This requires that the dialect supports out of
order placehoders, since all parameters that are not part of the VALUES clauses
are placed at the beginning of the parameter tuple
- support for different identifiers for a numeric parameter. It's for example
possible to use postgresql style placeholder $1, $2, etc
Added two new dialect based on sqlite to test "numeric" fully using
both :1 style and $1 style. Includes a workaround for SQLite's
not-really-correct numeric implementation.
Changed parmstyle of asyncpg dialect to use numeric, rendering with its native
$ identifiers
Fixes: #8926
Fixes: #8849
Change-Id: I7c640467d49adfe6d795cc84296fc7403dcad4d6
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 470 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 4 |
2 files changed, 306 insertions, 168 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7ac279ee2..d7358ad3b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -227,11 +227,13 @@ FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) +_pyformat_template = "%%(%(name)s)s" BIND_TEMPLATES = { - "pyformat": "%%(%(name)s)s", + "pyformat": _pyformat_template, "qmark": "?", "format": "%%s", "numeric": ":[_POSITION]", + "numeric_dollar": "$[_POSITION]", "named": ":%(name)s", } @@ -420,6 +422,22 @@ class _InsertManyValues(NamedTuple): num_positional_params_counted: int +class CompilerState(IntEnum): + COMPILING = 0 + """statement is present, compilation phase in progress""" + + STRING_APPLIED = 1 + """statement is present, string form of the statement has been applied. + + Additional processors by subclasses may still be pending. + + """ + + NO_STATEMENT = 2 + """compiler does not have a statement to compile, is used + for method access""" + + class Linting(IntEnum): """represent preferences for the 'SQL linting' feature. @@ -527,6 +545,14 @@ class Compiled: defaults. """ + statement: Optional[ClauseElement] = None + "The statement to compile." + string: str = "" + "The string representation of the ``statement``" + + state: CompilerState + """description of the compiler's state""" + is_sql = False is_ddl = False @@ -618,7 +644,6 @@ class Compiled: """ - self.dialect = dialect self.preparer = self.dialect.identifier_preparer if schema_translate_map: @@ -628,6 +653,7 @@ class Compiled: ) if statement is not None: + self.state = CompilerState.COMPILING self.statement = statement self.can_execute = statement.supports_execution self._annotations = statement._annotations @@ -641,6 +667,11 @@ class Compiled: self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) + + self.state = CompilerState.STRING_APPLIED + else: + self.state = CompilerState.NO_STATEMENT + self._gen_time = perf_counter() def _execute_on_connection( @@ -672,7 +703,10 @@ class Compiled: def __str__(self) -> str: """Return the string text of the generated SQL or DDL.""" - return self.string or "" + if self.state is CompilerState.STRING_APPLIED: + return self.string + else: + return "" def construct_params( self, @@ -859,6 +893,19 @@ class SQLCompiler(Compiled): driver/DB enforces this """ + bindtemplate: str + """template to render bound parameters based on paramstyle.""" + + compilation_bindtemplate: str + """template used by compiler to render parameters before positional + paramstyle application""" + + _numeric_binds_identifier_char: str + """Character that's used to as the identifier of a numerical bind param. + For example if this char is set to ``$``, numerical binds will be rendered + in the form ``$1, $2, $3``. + """ + _result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -967,13 +1014,17 @@ class SQLCompiler(Compiled): and is combined with the :attr:`_sql.Compiled.params` dictionary to render parameters. + This sequence always contains the unescaped name of the parameters. + .. seealso:: :ref:`faq_sql_expression_string` - includes a usage example for debugging use cases. """ - positiontup_level: Optional[Dict[str, int]] = None + _values_bindparam: Optional[List[str]] = None + + _visited_bindparam: Optional[List[str]] = None inline: bool = False @@ -988,9 +1039,12 @@ class SQLCompiler(Compiled): level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] ctes_recursive: bool - cte_positional: Dict[CTE, List[str]] - cte_level: Dict[CTE, int] - cte_order: Dict[Optional[CTE], List[CTE]] + + _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]") + _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s") + _positional_pattern = re.compile( + f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" + ) def __init__( self, @@ -1055,10 +1109,15 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: - self.positiontup_level = {} - self.positiontup = [] - self._numeric_binds = dialect.paramstyle == "numeric" - self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + self._numeric_binds = nb = dialect.paramstyle.startswith("numeric") + if nb: + self._numeric_binds_identifier_char = ( + "$" if dialect.paramstyle == "numeric_dollar" else ":" + ) + + self.compilation_bindtemplate = _pyformat_template + else: + self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.ctes = None @@ -1095,11 +1154,17 @@ class SQLCompiler(Compiled): ): self.inline = True - if self.positional and self._numeric_binds: - self._apply_numbered_params() + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + if self.state is CompilerState.STRING_APPLIED: + if self.positional: + if self._numeric_binds: + self._process_numeric() + else: + self._process_positional() - if self._render_postcompile: - self._process_parameters_for_postcompile(_populate_self=True) + if self._render_postcompile: + self._process_parameters_for_postcompile(_populate_self=True) @property def insert_single_values_expr(self) -> Optional[str]: @@ -1135,7 +1200,7 @@ class SQLCompiler(Compiled): """ if self.implicit_returning: return self.implicit_returning - elif is_dml(self.statement): + elif self.statement is not None and is_dml(self.statement): return [ c for c in self.statement._all_selected_columns @@ -1217,10 +1282,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte = {} self.ctes_recursive = False - if self.positional: - self.cte_positional = {} - self.cte_level = {} - self.cte_order = collections.defaultdict(list) return ctes @@ -1248,12 +1309,145 @@ class SQLCompiler(Compiled): ordered_columns, ) - def _apply_numbered_params(self): - poscount = itertools.count(1) + def _process_positional(self): + assert not self.positiontup + assert self.state is CompilerState.STRING_APPLIED + assert not self._numeric_binds + + if self.dialect.paramstyle == "format": + placeholder = "%s" + else: + assert self.dialect.paramstyle == "qmark" + placeholder = "?" + + positions = [] + + def find_position(m: re.Match[str]) -> str: + normal_bind = m.group(1) + if normal_bind: + positions.append(normal_bind) + return placeholder + else: + # this a post-compile bind + positions.append(m.group(2)) + return m.group(0) + self.string = re.sub( - r"\[_POSITION\]", lambda m: str(next(poscount)), self.string + self._positional_pattern, find_position, self.string ) + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + self.positiontup = [ + reverse_escape.get(name, name) for name in positions + ] + else: + self.positiontup = positions + + if self._insertmanyvalues: + positions = [] + single_values_expr = re.sub( + self._positional_pattern, + find_position, + self._insertmanyvalues.single_values_expr, + ) + insert_crud_params = [ + ( + v[0], + v[1], + re.sub(self._positional_pattern, find_position, v[2]), + v[3], + ) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + single_values_expr=single_values_expr, + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + + def _process_numeric(self): + assert self._numeric_binds + assert self.state is CompilerState.STRING_APPLIED + + num = 1 + param_pos: Dict[str, str] = {} + order: Iterable[str] + if self._insertmanyvalues and self._values_bindparam is not None: + # bindparams that are not in values are always placed first. + # this avoids the need of changing them when using executemany + # values () () + order = itertools.chain( + ( + name + for name in self.bind_names.values() + if name not in self._values_bindparam + ), + self.bind_names.values(), + ) + else: + order = self.bind_names.values() + + for bind_name in order: + if bind_name in param_pos: + continue + bind = self.binds[bind_name] + if ( + bind in self.post_compile_params + or bind in self.literal_execute_params + ): + # set to None to just mark the in positiontup, it will not + # be replaced below. + param_pos[bind_name] = None # type: ignore + else: + ph = f"{self._numeric_binds_identifier_char}{num}" + num += 1 + param_pos[bind_name] = ph + + self.next_numeric_pos = num + + self.positiontup = list(param_pos) + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + param_pos = { + self.escaped_bind_names.get(name, name): pos + for name, pos in param_pos.items() + } + + # Can't use format here since % chars are not escaped. + self.string = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], self.string + ) + + if self._insertmanyvalues: + single_values_expr = ( + # format is ok here since single_values_expr includes only + # place-holders + self._insertmanyvalues.single_values_expr + % param_pos + ) + insert_crud_params = [ + (v[0], v[1], "%s", v[3]) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + # This has the numbers (:1, :2) + single_values_expr=single_values_expr, + # The single binds are instead %s so they can be formatted + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + @util.memoized_property def _bind_processors( self, @@ -1492,39 +1686,30 @@ class SQLCompiler(Compiled): new_processors: Dict[str, _BindProcessorType[Any]] = {} - if self.positional and self._numeric_binds: - # I'm not familiar with any DBAPI that uses 'numeric'. - # strategy would likely be to make use of numbers greater than - # the highest number present; then for expanding parameters, - # append them to the end of the parameter list. that way - # we avoid having to renumber all the existing parameters. - raise NotImplementedError( - "'post-compile' bind parameters are not supported with " - "the 'numeric' paramstyle at this time." - ) - replacement_expressions: Dict[str, Any] = {} to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: - # self.bind_names, self.binds, self._bind_processors + # self.bind_names, self.binds, self._bind_processors, self.positiontup # # *escaped* parameter names in: # construct_params(), replacement_expressions + numeric_positiontup: Optional[List[str]] = None + if self.positional and self.positiontup is not None: names: Iterable[str] = self.positiontup + if self._numeric_binds: + numeric_positiontup = [] else: names = self.bind_names.values() + ebn = self.escaped_bind_names for name in names: - escaped_name = ( - self.escaped_bind_names.get(name, name) - if self.escaped_bind_names - else name - ) + escaped_name = ebn.get(name, name) if ebn else name parameter = self.binds[name] + if parameter in self.literal_execute_params: if escaped_name not in replacement_expressions: value = parameters.pop(escaped_name) @@ -1555,10 +1740,10 @@ class SQLCompiler(Compiled): # in the escaped_bind_names dictionary. values = parameters.pop(name) - leep = self._literal_execute_expanding_parameter - to_update, replacement_expr = leep( + leep_res = self._literal_execute_expanding_parameter( escaped_name, parameter, values ) + (to_update, replacement_expr) = leep_res to_update_sets[escaped_name] = to_update replacement_expressions[escaped_name] = replacement_expr @@ -1583,7 +1768,14 @@ class SQLCompiler(Compiled): for key, _ in to_update if name in single_processors ) - if positiontup is not None: + if numeric_positiontup is not None: + numeric_positiontup.extend( + name for name, _ in to_update + ) + elif positiontup is not None: + # to_update has escaped names, but that's ok since + # these are new names, that aren't in the + # escaped_bind_names dict. positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ expand_key for expand_key, _ in to_update @@ -1607,11 +1799,23 @@ class SQLCompiler(Compiled): return expr statement = re.sub( - r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", - process_expanding, - self.string, + self._post_compile_pattern, process_expanding, self.string ) + if numeric_positiontup is not None: + assert positiontup is not None + param_pos = { + key: f"{self._numeric_binds_identifier_char}{num}" + for num, key in enumerate( + numeric_positiontup, self.next_numeric_pos + ) + } + # Can't use format here since % chars are not escaped. + statement = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], statement + ) + positiontup.extend(numeric_positiontup) + expanded_state = ExpandedState( statement, parameters, @@ -2109,13 +2313,7 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -2411,7 +2609,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -2625,6 +2822,11 @@ class SQLCompiler(Compiled): dialect = self.dialect typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect) + if self._numeric_binds: + bind_template = self.compilation_bindtemplate + else: + bind_template = self.bindtemplate + if ( self.dialect._bind_typing_render_casts and typ_dialect_impl.render_bind_cast @@ -2634,13 +2836,13 @@ class SQLCompiler(Compiled): return self.render_bind_cast( parameter.type, typ_dialect_impl, - self.bindtemplate % {"name": name}, + bind_template % {"name": name}, ) else: def _render_bindtemplate(name): - return self.bindtemplate % {"name": name} + return bind_template % {"name": name} if not values: to_update = [] @@ -3224,7 +3426,6 @@ class SQLCompiler(Compiled): def bindparam_string( self, name: str, - positional_names: Optional[List[str]] = None, post_compile: bool = False, expanding: bool = False, escaped_from: Optional[str] = None, @@ -3232,12 +3433,9 @@ class SQLCompiler(Compiled): **kw: Any, ) -> str: - if self.positional: - if positional_names is not None: - positional_names.append(name) - else: - self.positiontup.append(name) # type: ignore[union-attr] - self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if self._visited_bindparam is not None: + self._visited_bindparam.append(name) + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -3271,6 +3469,8 @@ class SQLCompiler(Compiled): if type_impl.render_literal_cast: ret = self.render_bind_cast(bindparam_type, type_impl, ret) return ret + elif self.state is CompilerState.COMPILING: + ret = self.compilation_bindtemplate % {"name": name} else: ret = self.bindtemplate % {"name": name} @@ -3349,8 +3549,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) - if self.positional: - self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3414,8 +3612,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) - if self.positional: - self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3455,9 +3651,6 @@ class SQLCompiler(Compiled): ) ) - if self.positional: - kwargs["positional_names"] = self.cte_positional[cte] = [] - assert kwargs.get("subquery", False) is False if not self.stack: @@ -4152,13 +4345,7 @@ class SQLCompiler(Compiled): # In compound query, CTEs are shared at the compound level if self.ctes and (not is_embedded_select or toplevel): nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kwargs.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4332,7 +4519,6 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, - visiting_cte=None, ): """ include_following_stack @@ -4367,46 +4553,6 @@ class SQLCompiler(Compiled): return "" ctes_recursive = any([cte.recursive for cte in ctes]) - if self.positional: - self.cte_order[visiting_cte].extend(ctes) - - if visiting_cte is None and self.cte_order: - assert self.positiontup is not None - - def get_nested_positional(cte): - if cte in self.cte_order: - children = self.cte_order.pop(cte) - to_add = list( - itertools.chain.from_iterable( - get_nested_positional(child_cte) - for child_cte in children - ) - ) - if cte in self.cte_positional: - return reorder_positional( - self.cte_positional[cte], - to_add, - self.cte_level[children[0]], - ) - else: - return to_add - else: - return self.cte_positional.get(cte, []) - - def reorder_positional(pos, to_add, level): - if not level: - return to_add + pos - index = 0 - for index, name in enumerate(reversed(pos)): - if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 - break - return pos[:-index] + to_add + pos[-index:] - - to_add = get_nested_positional(None) - self.positiontup = reorder_positional( - self.positiontup, to_add, nesting_level - ) - cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4762,6 +4908,11 @@ class SQLCompiler(Compiled): keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " + if self._numeric_binds: + escaped = re.escape(self._numeric_binds_identifier_char) + executemany_values_w_comma = re.sub( + rf"{escaped}\d+", "%s", executemany_values_w_comma + ) while batches: batch = batches[0:batch_size] @@ -4794,25 +4945,37 @@ class SQLCompiler(Compiled): num_ins_params = imv.num_positional_params_counted + batch_iterator: Iterable[Tuple[Any, ...]] if num_ins_params == len(batch[0]): extra_params = () - batch_iterator: Iterable[Tuple[Any, ...]] = batch - elif self.returning_precedes_values: + batch_iterator = batch + elif self.returning_precedes_values or self._numeric_binds: extra_params = batch[0][:-num_ins_params] batch_iterator = (b[-num_ins_params:] for b in batch) else: extra_params = batch[0][num_ins_params:] batch_iterator = (b[:num_ins_params] for b in batch) + values_string = (executemany_values_w_comma * len(batch))[:-2] + if self._numeric_binds and num_ins_params > 0: + # need to format here, since statement may contain + # unescaped %, while values_string contains just (%s, %s) + start = len(extra_params) + 1 + end = num_ins_params * len(batch) + start + positions = tuple( + f"{self._numeric_binds_identifier_char}{i}" + for i in range(start, end) + ) + values_string = values_string % positions + replaced_statement = statement.replace( - "__EXECMANY_TOKEN__", - (executemany_values_w_comma * len(batch))[:-2], + "__EXECMANY_TOKEN__", values_string ) replaced_parameters = tuple( itertools.chain.from_iterable(batch_iterator) ) - if self.returning_precedes_values: + if self.returning_precedes_values or self._numeric_binds: replaced_parameters = extra_params + replaced_parameters else: replaced_parameters = replaced_parameters + extra_params @@ -4869,23 +5032,30 @@ class SQLCompiler(Compiled): } ) - positiontup_before = positiontup_after = 0 + counted_bindparam = 0 # for positional, insertmanyvalues needs to know how many # bound parameters are in the VALUES sequence; there's no simple # rule because default expressions etc. can have zero or more # params inside them. After multiple attempts to figure this out, - # this very simplistic "count before, then count after" works and is + # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy - if self.positiontup: - positiontup_before = len(self.positiontup) + if self.positional: + self._visited_bindparam = [] crud_params_struct = crud._get_crud_params( self, insert_stmt, compile_state, toplevel, **kw ) - if self.positiontup: - positiontup_after = len(self.positiontup) + if self.positional: + assert self._visited_bindparam is not None + counted_bindparam = len(self._visited_bindparam) + if self._numeric_binds: + if self._values_bindparam is not None: + self._values_bindparam += self._visited_bindparam + else: + self._values_bindparam = self._visited_bindparam + self._visited_bindparam = None crud_params_single = crud_params_struct.single_params @@ -4940,31 +5110,13 @@ class SQLCompiler(Compiled): if self.implicit_returning or insert_stmt._returning: - # if returning clause is rendered first, capture bound parameters - # while visiting and place them prior to the VALUES oriented - # bound parameters, when using positional parameter scheme - rpv = self.returning_precedes_values - flip_pt = rpv and self.positional - if flip_pt: - pt: Optional[List[str]] = self.positiontup - temp_pt: Optional[List[str]] - self.positiontup = temp_pt = [] - else: - temp_pt = pt = None - returning_clause = self.returning_clause( insert_stmt, self.implicit_returning or insert_stmt._returning, populate_result_map=toplevel, ) - if flip_pt: - if TYPE_CHECKING: - assert temp_pt is not None - assert pt is not None - self.positiontup = temp_pt + pt - - if rpv: + if self.returning_precedes_values: text += " " + returning_clause else: @@ -4982,7 +5134,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4999,7 +5150,7 @@ class SQLCompiler(Compiled): cast( "List[crud._CrudParamElementStr]", crud_params_single ), - (positiontup_after - positiontup_before), + counted_bindparam, ) elif compile_state._has_multi_parameters: text += " VALUES %s" % ( @@ -5033,7 +5184,7 @@ class SQLCompiler(Compiled): "List[crud._CrudParamElementStr]", crud_params_single, ), - positiontup_after - positiontup_before, + counted_bindparam, ) if insert_stmt._post_values_clause is not None: @@ -5052,7 +5203,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5201,13 +5351,7 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -5321,13 +5465,7 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 017ff7baa..ae1b032ae 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -85,8 +85,8 @@ _CrudParamElement = Tuple[ ] _CrudParamElementStr = Tuple[ "KeyedColumnElement[Any]", - str, - str, + str, # column name + str, # placeholder Iterable[str], ] _CrudParamElementSQLExpr = Tuple[ |