summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py470
1 files changed, 304 insertions, 166 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 7ac279ee2..d7358ad3b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -227,11 +227,13 @@ FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
+_pyformat_template = "%%(%(name)s)s"
BIND_TEMPLATES = {
- "pyformat": "%%(%(name)s)s",
+ "pyformat": _pyformat_template,
"qmark": "?",
"format": "%%s",
"numeric": ":[_POSITION]",
+ "numeric_dollar": "$[_POSITION]",
"named": ":%(name)s",
}
@@ -420,6 +422,22 @@ class _InsertManyValues(NamedTuple):
num_positional_params_counted: int
+class CompilerState(IntEnum):
+ COMPILING = 0
+ """statement is present, compilation phase in progress"""
+
+ STRING_APPLIED = 1
+ """statement is present, string form of the statement has been applied.
+
+ Additional processors by subclasses may still be pending.
+
+ """
+
+ NO_STATEMENT = 2
+ """compiler does not have a statement to compile, is used
+ for method access"""
+
+
class Linting(IntEnum):
"""represent preferences for the 'SQL linting' feature.
@@ -527,6 +545,14 @@ class Compiled:
defaults.
"""
+ statement: Optional[ClauseElement] = None
+ "The statement to compile."
+ string: str = ""
+ "The string representation of the ``statement``"
+
+ state: CompilerState
+ """description of the compiler's state"""
+
is_sql = False
is_ddl = False
@@ -618,7 +644,6 @@ class Compiled:
"""
-
self.dialect = dialect
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
@@ -628,6 +653,7 @@ class Compiled:
)
if statement is not None:
+ self.state = CompilerState.COMPILING
self.statement = statement
self.can_execute = statement.supports_execution
self._annotations = statement._annotations
@@ -641,6 +667,11 @@ class Compiled:
self.string = self.preparer._render_schema_translates(
self.string, schema_translate_map
)
+
+ self.state = CompilerState.STRING_APPLIED
+ else:
+ self.state = CompilerState.NO_STATEMENT
+
self._gen_time = perf_counter()
def _execute_on_connection(
@@ -672,7 +703,10 @@ class Compiled:
def __str__(self) -> str:
"""Return the string text of the generated SQL or DDL."""
- return self.string or ""
+ if self.state is CompilerState.STRING_APPLIED:
+ return self.string
+ else:
+ return ""
def construct_params(
self,
@@ -859,6 +893,19 @@ class SQLCompiler(Compiled):
driver/DB enforces this
"""
+ bindtemplate: str
+ """template to render bound parameters based on paramstyle."""
+
+ compilation_bindtemplate: str
+ """template used by compiler to render parameters before positional
+ paramstyle application"""
+
+ _numeric_binds_identifier_char: str
+ """Character that's used to as the identifier of a numerical bind param.
+ For example if this char is set to ``$``, numerical binds will be rendered
+ in the form ``$1, $2, $3``.
+ """
+
_result_columns: List[ResultColumnsEntry]
"""relates label names in the final SQL to a tuple of local
column/label name, ColumnElement object (if any) and
@@ -967,13 +1014,17 @@ class SQLCompiler(Compiled):
and is combined with the :attr:`_sql.Compiled.params` dictionary to
render parameters.
+ This sequence always contains the unescaped name of the parameters.
+
.. seealso::
:ref:`faq_sql_expression_string` - includes a usage example for
debugging use cases.
"""
- positiontup_level: Optional[Dict[str, int]] = None
+ _values_bindparam: Optional[List[str]] = None
+
+ _visited_bindparam: Optional[List[str]] = None
inline: bool = False
@@ -988,9 +1039,12 @@ class SQLCompiler(Compiled):
level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
ctes_recursive: bool
- cte_positional: Dict[CTE, List[str]]
- cte_level: Dict[CTE, int]
- cte_order: Dict[Optional[CTE], List[CTE]]
+
+ _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
+ _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
+ _positional_pattern = re.compile(
+ f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
+ )
def __init__(
self,
@@ -1055,10 +1109,15 @@ class SQLCompiler(Compiled):
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
- self.positiontup_level = {}
- self.positiontup = []
- self._numeric_binds = dialect.paramstyle == "numeric"
- self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+ self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
+ if nb:
+ self._numeric_binds_identifier_char = (
+ "$" if dialect.paramstyle == "numeric_dollar" else ":"
+ )
+
+ self.compilation_bindtemplate = _pyformat_template
+ else:
+ self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes = None
@@ -1095,11 +1154,17 @@ class SQLCompiler(Compiled):
):
self.inline = True
- if self.positional and self._numeric_binds:
- self._apply_numbered_params()
+ self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+
+ if self.state is CompilerState.STRING_APPLIED:
+ if self.positional:
+ if self._numeric_binds:
+ self._process_numeric()
+ else:
+ self._process_positional()
- if self._render_postcompile:
- self._process_parameters_for_postcompile(_populate_self=True)
+ if self._render_postcompile:
+ self._process_parameters_for_postcompile(_populate_self=True)
@property
def insert_single_values_expr(self) -> Optional[str]:
@@ -1135,7 +1200,7 @@ class SQLCompiler(Compiled):
"""
if self.implicit_returning:
return self.implicit_returning
- elif is_dml(self.statement):
+ elif self.statement is not None and is_dml(self.statement):
return [
c
for c in self.statement._all_selected_columns
@@ -1217,10 +1282,6 @@ class SQLCompiler(Compiled):
self.level_name_by_cte = {}
self.ctes_recursive = False
- if self.positional:
- self.cte_positional = {}
- self.cte_level = {}
- self.cte_order = collections.defaultdict(list)
return ctes
@@ -1248,12 +1309,145 @@ class SQLCompiler(Compiled):
ordered_columns,
)
- def _apply_numbered_params(self):
- poscount = itertools.count(1)
+ def _process_positional(self):
+ assert not self.positiontup
+ assert self.state is CompilerState.STRING_APPLIED
+ assert not self._numeric_binds
+
+ if self.dialect.paramstyle == "format":
+ placeholder = "%s"
+ else:
+ assert self.dialect.paramstyle == "qmark"
+ placeholder = "?"
+
+ positions = []
+
+ def find_position(m: re.Match[str]) -> str:
+ normal_bind = m.group(1)
+ if normal_bind:
+ positions.append(normal_bind)
+ return placeholder
+ else:
+ # this a post-compile bind
+ positions.append(m.group(2))
+ return m.group(0)
+
self.string = re.sub(
- r"\[_POSITION\]", lambda m: str(next(poscount)), self.string
+ self._positional_pattern, find_position, self.string
)
+ if self.escaped_bind_names:
+ reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
+ assert len(self.escaped_bind_names) == len(reverse_escape)
+ self.positiontup = [
+ reverse_escape.get(name, name) for name in positions
+ ]
+ else:
+ self.positiontup = positions
+
+ if self._insertmanyvalues:
+ positions = []
+ single_values_expr = re.sub(
+ self._positional_pattern,
+ find_position,
+ self._insertmanyvalues.single_values_expr,
+ )
+ insert_crud_params = [
+ (
+ v[0],
+ v[1],
+ re.sub(self._positional_pattern, find_position, v[2]),
+ v[3],
+ )
+ for v in self._insertmanyvalues.insert_crud_params
+ ]
+
+ self._insertmanyvalues = _InsertManyValues(
+ is_default_expr=self._insertmanyvalues.is_default_expr,
+ single_values_expr=single_values_expr,
+ insert_crud_params=insert_crud_params,
+ num_positional_params_counted=(
+ self._insertmanyvalues.num_positional_params_counted
+ ),
+ )
+
+ def _process_numeric(self):
+ assert self._numeric_binds
+ assert self.state is CompilerState.STRING_APPLIED
+
+ num = 1
+ param_pos: Dict[str, str] = {}
+ order: Iterable[str]
+ if self._insertmanyvalues and self._values_bindparam is not None:
+ # bindparams that are not in values are always placed first.
+ # this avoids the need of changing them when using executemany
+ # values () ()
+ order = itertools.chain(
+ (
+ name
+ for name in self.bind_names.values()
+ if name not in self._values_bindparam
+ ),
+ self.bind_names.values(),
+ )
+ else:
+ order = self.bind_names.values()
+
+ for bind_name in order:
+ if bind_name in param_pos:
+ continue
+ bind = self.binds[bind_name]
+ if (
+ bind in self.post_compile_params
+ or bind in self.literal_execute_params
+ ):
+ # set to None to just mark the in positiontup, it will not
+ # be replaced below.
+ param_pos[bind_name] = None # type: ignore
+ else:
+ ph = f"{self._numeric_binds_identifier_char}{num}"
+ num += 1
+ param_pos[bind_name] = ph
+
+ self.next_numeric_pos = num
+
+ self.positiontup = list(param_pos)
+ if self.escaped_bind_names:
+ reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
+ assert len(self.escaped_bind_names) == len(reverse_escape)
+ param_pos = {
+ self.escaped_bind_names.get(name, name): pos
+ for name, pos in param_pos.items()
+ }
+
+ # Can't use format here since % chars are not escaped.
+ self.string = self._pyformat_pattern.sub(
+ lambda m: param_pos[m.group(1)], self.string
+ )
+
+ if self._insertmanyvalues:
+ single_values_expr = (
+ # format is ok here since single_values_expr includes only
+ # place-holders
+ self._insertmanyvalues.single_values_expr
+ % param_pos
+ )
+ insert_crud_params = [
+ (v[0], v[1], "%s", v[3])
+ for v in self._insertmanyvalues.insert_crud_params
+ ]
+
+ self._insertmanyvalues = _InsertManyValues(
+ is_default_expr=self._insertmanyvalues.is_default_expr,
+ # This has the numbers (:1, :2)
+ single_values_expr=single_values_expr,
+ # The single binds are instead %s so they can be formatted
+ insert_crud_params=insert_crud_params,
+ num_positional_params_counted=(
+ self._insertmanyvalues.num_positional_params_counted
+ ),
+ )
+
@util.memoized_property
def _bind_processors(
self,
@@ -1492,39 +1686,30 @@ class SQLCompiler(Compiled):
new_processors: Dict[str, _BindProcessorType[Any]] = {}
- if self.positional and self._numeric_binds:
- # I'm not familiar with any DBAPI that uses 'numeric'.
- # strategy would likely be to make use of numbers greater than
- # the highest number present; then for expanding parameters,
- # append them to the end of the parameter list. that way
- # we avoid having to renumber all the existing parameters.
- raise NotImplementedError(
- "'post-compile' bind parameters are not supported with "
- "the 'numeric' paramstyle at this time."
- )
-
replacement_expressions: Dict[str, Any] = {}
to_update_sets: Dict[str, Any] = {}
# notes:
# *unescaped* parameter names in:
- # self.bind_names, self.binds, self._bind_processors
+ # self.bind_names, self.binds, self._bind_processors, self.positiontup
#
# *escaped* parameter names in:
# construct_params(), replacement_expressions
+ numeric_positiontup: Optional[List[str]] = None
+
if self.positional and self.positiontup is not None:
names: Iterable[str] = self.positiontup
+ if self._numeric_binds:
+ numeric_positiontup = []
else:
names = self.bind_names.values()
+ ebn = self.escaped_bind_names
for name in names:
- escaped_name = (
- self.escaped_bind_names.get(name, name)
- if self.escaped_bind_names
- else name
- )
+ escaped_name = ebn.get(name, name) if ebn else name
parameter = self.binds[name]
+
if parameter in self.literal_execute_params:
if escaped_name not in replacement_expressions:
value = parameters.pop(escaped_name)
@@ -1555,10 +1740,10 @@ class SQLCompiler(Compiled):
# in the escaped_bind_names dictionary.
values = parameters.pop(name)
- leep = self._literal_execute_expanding_parameter
- to_update, replacement_expr = leep(
+ leep_res = self._literal_execute_expanding_parameter(
escaped_name, parameter, values
)
+ (to_update, replacement_expr) = leep_res
to_update_sets[escaped_name] = to_update
replacement_expressions[escaped_name] = replacement_expr
@@ -1583,7 +1768,14 @@ class SQLCompiler(Compiled):
for key, _ in to_update
if name in single_processors
)
- if positiontup is not None:
+ if numeric_positiontup is not None:
+ numeric_positiontup.extend(
+ name for name, _ in to_update
+ )
+ elif positiontup is not None:
+ # to_update has escaped names, but that's ok since
+ # these are new names, that aren't in the
+ # escaped_bind_names dict.
positiontup.extend(name for name, _ in to_update)
expanded_parameters[name] = [
expand_key for expand_key, _ in to_update
@@ -1607,11 +1799,23 @@ class SQLCompiler(Compiled):
return expr
statement = re.sub(
- r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
- process_expanding,
- self.string,
+ self._post_compile_pattern, process_expanding, self.string
)
+ if numeric_positiontup is not None:
+ assert positiontup is not None
+ param_pos = {
+ key: f"{self._numeric_binds_identifier_char}{num}"
+ for num, key in enumerate(
+ numeric_positiontup, self.next_numeric_pos
+ )
+ }
+ # Can't use format here since % chars are not escaped.
+ statement = self._pyformat_pattern.sub(
+ lambda m: param_pos[m.group(1)], statement
+ )
+ positiontup.extend(numeric_positiontup)
+
expanded_state = ExpandedState(
statement,
parameters,
@@ -2109,13 +2313,7 @@ class SQLCompiler(Compiled):
text = self.process(taf.element, **kw)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kw.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)
@@ -2411,7 +2609,6 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
- visiting_cte=kwargs.get("visiting_cte"),
)
+ text
)
@@ -2625,6 +2822,11 @@ class SQLCompiler(Compiled):
dialect = self.dialect
typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
+ if self._numeric_binds:
+ bind_template = self.compilation_bindtemplate
+ else:
+ bind_template = self.bindtemplate
+
if (
self.dialect._bind_typing_render_casts
and typ_dialect_impl.render_bind_cast
@@ -2634,13 +2836,13 @@ class SQLCompiler(Compiled):
return self.render_bind_cast(
parameter.type,
typ_dialect_impl,
- self.bindtemplate % {"name": name},
+ bind_template % {"name": name},
)
else:
def _render_bindtemplate(name):
- return self.bindtemplate % {"name": name}
+ return bind_template % {"name": name}
if not values:
to_update = []
@@ -3224,7 +3426,6 @@ class SQLCompiler(Compiled):
def bindparam_string(
self,
name: str,
- positional_names: Optional[List[str]] = None,
post_compile: bool = False,
expanding: bool = False,
escaped_from: Optional[str] = None,
@@ -3232,12 +3433,9 @@ class SQLCompiler(Compiled):
**kw: Any,
) -> str:
- if self.positional:
- if positional_names is not None:
- positional_names.append(name)
- else:
- self.positiontup.append(name) # type: ignore[union-attr]
- self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501
+ if self._visited_bindparam is not None:
+ self._visited_bindparam.append(name)
+
if not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
@@ -3271,6 +3469,8 @@ class SQLCompiler(Compiled):
if type_impl.render_literal_cast:
ret = self.render_bind_cast(bindparam_type, type_impl, ret)
return ret
+ elif self.state is CompilerState.COMPILING:
+ ret = self.compilation_bindtemplate % {"name": name}
else:
ret = self.bindtemplate % {"name": name}
@@ -3349,8 +3549,6 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = new_level_name + (
cte_opts,
)
- if self.positional:
- self.cte_level[cte] = cte_level
else:
cte_level = len(self.stack) if nesting else 1
@@ -3414,8 +3612,6 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = cte_level_name + (
cte_opts,
)
- if self.positional:
- self.cte_level[cte] = cte_level
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -3455,9 +3651,6 @@ class SQLCompiler(Compiled):
)
)
- if self.positional:
- kwargs["positional_names"] = self.cte_positional[cte] = []
-
assert kwargs.get("subquery", False) is False
if not self.stack:
@@ -4152,13 +4345,7 @@ class SQLCompiler(Compiled):
# In compound query, CTEs are shared at the compound level
if self.ctes and (not is_embedded_select or toplevel):
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kwargs.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
@@ -4332,7 +4519,6 @@ class SQLCompiler(Compiled):
self,
nesting_level=None,
include_following_stack=False,
- visiting_cte=None,
):
"""
include_following_stack
@@ -4367,46 +4553,6 @@ class SQLCompiler(Compiled):
return ""
ctes_recursive = any([cte.recursive for cte in ctes])
- if self.positional:
- self.cte_order[visiting_cte].extend(ctes)
-
- if visiting_cte is None and self.cte_order:
- assert self.positiontup is not None
-
- def get_nested_positional(cte):
- if cte in self.cte_order:
- children = self.cte_order.pop(cte)
- to_add = list(
- itertools.chain.from_iterable(
- get_nested_positional(child_cte)
- for child_cte in children
- )
- )
- if cte in self.cte_positional:
- return reorder_positional(
- self.cte_positional[cte],
- to_add,
- self.cte_level[children[0]],
- )
- else:
- return to_add
- else:
- return self.cte_positional.get(cte, [])
-
- def reorder_positional(pos, to_add, level):
- if not level:
- return to_add + pos
- index = 0
- for index, name in enumerate(reversed(pos)):
- if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
- break
- return pos[:-index] + to_add + pos[-index:]
-
- to_add = get_nested_positional(None)
- self.positiontup = reorder_positional(
- self.positiontup, to_add, nesting_level
- )
-
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
@@ -4762,6 +4908,11 @@ class SQLCompiler(Compiled):
keys_to_replace = set()
base_parameters = {}
executemany_values_w_comma = f"({imv.single_values_expr}), "
+ if self._numeric_binds:
+ escaped = re.escape(self._numeric_binds_identifier_char)
+ executemany_values_w_comma = re.sub(
+ rf"{escaped}\d+", "%s", executemany_values_w_comma
+ )
while batches:
batch = batches[0:batch_size]
@@ -4794,25 +4945,37 @@ class SQLCompiler(Compiled):
num_ins_params = imv.num_positional_params_counted
+ batch_iterator: Iterable[Tuple[Any, ...]]
if num_ins_params == len(batch[0]):
extra_params = ()
- batch_iterator: Iterable[Tuple[Any, ...]] = batch
- elif self.returning_precedes_values:
+ batch_iterator = batch
+ elif self.returning_precedes_values or self._numeric_binds:
extra_params = batch[0][:-num_ins_params]
batch_iterator = (b[-num_ins_params:] for b in batch)
else:
extra_params = batch[0][num_ins_params:]
batch_iterator = (b[:num_ins_params] for b in batch)
+ values_string = (executemany_values_w_comma * len(batch))[:-2]
+ if self._numeric_binds and num_ins_params > 0:
+ # need to format here, since statement may contain
+ # unescaped %, while values_string contains just (%s, %s)
+ start = len(extra_params) + 1
+ end = num_ins_params * len(batch) + start
+ positions = tuple(
+ f"{self._numeric_binds_identifier_char}{i}"
+ for i in range(start, end)
+ )
+ values_string = values_string % positions
+
replaced_statement = statement.replace(
- "__EXECMANY_TOKEN__",
- (executemany_values_w_comma * len(batch))[:-2],
+ "__EXECMANY_TOKEN__", values_string
)
replaced_parameters = tuple(
itertools.chain.from_iterable(batch_iterator)
)
- if self.returning_precedes_values:
+ if self.returning_precedes_values or self._numeric_binds:
replaced_parameters = extra_params + replaced_parameters
else:
replaced_parameters = replaced_parameters + extra_params
@@ -4869,23 +5032,30 @@ class SQLCompiler(Compiled):
}
)
- positiontup_before = positiontup_after = 0
+ counted_bindparam = 0
# for positional, insertmanyvalues needs to know how many
# bound parameters are in the VALUES sequence; there's no simple
# rule because default expressions etc. can have zero or more
# params inside them. After multiple attempts to figure this out,
- # this very simplistic "count before, then count after" works and is
+ # this very simplistic "count after" works and is
# likely the least amount of callcounts, though looks clumsy
- if self.positiontup:
- positiontup_before = len(self.positiontup)
+ if self.positional:
+ self._visited_bindparam = []
crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, toplevel, **kw
)
- if self.positiontup:
- positiontup_after = len(self.positiontup)
+ if self.positional:
+ assert self._visited_bindparam is not None
+ counted_bindparam = len(self._visited_bindparam)
+ if self._numeric_binds:
+ if self._values_bindparam is not None:
+ self._values_bindparam += self._visited_bindparam
+ else:
+ self._values_bindparam = self._visited_bindparam
+ self._visited_bindparam = None
crud_params_single = crud_params_struct.single_params
@@ -4940,31 +5110,13 @@ class SQLCompiler(Compiled):
if self.implicit_returning or insert_stmt._returning:
- # if returning clause is rendered first, capture bound parameters
- # while visiting and place them prior to the VALUES oriented
- # bound parameters, when using positional parameter scheme
- rpv = self.returning_precedes_values
- flip_pt = rpv and self.positional
- if flip_pt:
- pt: Optional[List[str]] = self.positiontup
- temp_pt: Optional[List[str]]
- self.positiontup = temp_pt = []
- else:
- temp_pt = pt = None
-
returning_clause = self.returning_clause(
insert_stmt,
self.implicit_returning or insert_stmt._returning,
populate_result_map=toplevel,
)
- if flip_pt:
- if TYPE_CHECKING:
- assert temp_pt is not None
- assert pt is not None
- self.positiontup = temp_pt + pt
-
- if rpv:
+ if self.returning_precedes_values:
text += " " + returning_clause
else:
@@ -4982,7 +5134,6 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
- visiting_cte=kw.get("visiting_cte"),
),
select_text,
)
@@ -4999,7 +5150,7 @@ class SQLCompiler(Compiled):
cast(
"List[crud._CrudParamElementStr]", crud_params_single
),
- (positiontup_after - positiontup_before),
+ counted_bindparam,
)
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
@@ -5033,7 +5184,7 @@ class SQLCompiler(Compiled):
"List[crud._CrudParamElementStr]",
crud_params_single,
),
- positiontup_after - positiontup_before,
+ counted_bindparam,
)
if insert_stmt._post_values_clause is not None:
@@ -5052,7 +5203,6 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
- visiting_cte=kw.get("visiting_cte"),
)
+ text
)
@@ -5201,13 +5351,7 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kw.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)
@@ -5321,13 +5465,7 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(
- nesting_level=nesting_level,
- visiting_cte=kw.get("visiting_cte"),
- )
- + text
- )
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)