diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 470 |
1 files changed, 304 insertions, 166 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 7ac279ee2..d7358ad3b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -227,11 +227,13 @@ FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE) BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE) +_pyformat_template = "%%(%(name)s)s" BIND_TEMPLATES = { - "pyformat": "%%(%(name)s)s", + "pyformat": _pyformat_template, "qmark": "?", "format": "%%s", "numeric": ":[_POSITION]", + "numeric_dollar": "$[_POSITION]", "named": ":%(name)s", } @@ -420,6 +422,22 @@ class _InsertManyValues(NamedTuple): num_positional_params_counted: int +class CompilerState(IntEnum): + COMPILING = 0 + """statement is present, compilation phase in progress""" + + STRING_APPLIED = 1 + """statement is present, string form of the statement has been applied. + + Additional processors by subclasses may still be pending. + + """ + + NO_STATEMENT = 2 + """compiler does not have a statement to compile, is used + for method access""" + + class Linting(IntEnum): """represent preferences for the 'SQL linting' feature. @@ -527,6 +545,14 @@ class Compiled: defaults. """ + statement: Optional[ClauseElement] = None + "The statement to compile." + string: str = "" + "The string representation of the ``statement``" + + state: CompilerState + """description of the compiler's state""" + is_sql = False is_ddl = False @@ -618,7 +644,6 @@ class Compiled: """ - self.dialect = dialect self.preparer = self.dialect.identifier_preparer if schema_translate_map: @@ -628,6 +653,7 @@ class Compiled: ) if statement is not None: + self.state = CompilerState.COMPILING self.statement = statement self.can_execute = statement.supports_execution self._annotations = statement._annotations @@ -641,6 +667,11 @@ class Compiled: self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) + + self.state = CompilerState.STRING_APPLIED + else: + self.state = CompilerState.NO_STATEMENT + self._gen_time = perf_counter() def _execute_on_connection( @@ -672,7 +703,10 @@ class Compiled: def __str__(self) -> str: """Return the string text of the generated SQL or DDL.""" - return self.string or "" + if self.state is CompilerState.STRING_APPLIED: + return self.string + else: + return "" def construct_params( self, @@ -859,6 +893,19 @@ class SQLCompiler(Compiled): driver/DB enforces this """ + bindtemplate: str + """template to render bound parameters based on paramstyle.""" + + compilation_bindtemplate: str + """template used by compiler to render parameters before positional + paramstyle application""" + + _numeric_binds_identifier_char: str + """Character that's used to as the identifier of a numerical bind param. + For example if this char is set to ``$``, numerical binds will be rendered + in the form ``$1, $2, $3``. + """ + _result_columns: List[ResultColumnsEntry] """relates label names in the final SQL to a tuple of local column/label name, ColumnElement object (if any) and @@ -967,13 +1014,17 @@ class SQLCompiler(Compiled): and is combined with the :attr:`_sql.Compiled.params` dictionary to render parameters. + This sequence always contains the unescaped name of the parameters. + .. seealso:: :ref:`faq_sql_expression_string` - includes a usage example for debugging use cases. """ - positiontup_level: Optional[Dict[str, int]] = None + _values_bindparam: Optional[List[str]] = None + + _visited_bindparam: Optional[List[str]] = None inline: bool = False @@ -988,9 +1039,12 @@ class SQLCompiler(Compiled): level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] ctes_recursive: bool - cte_positional: Dict[CTE, List[str]] - cte_level: Dict[CTE, int] - cte_order: Dict[Optional[CTE], List[CTE]] + + _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]") + _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s") + _positional_pattern = re.compile( + f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" + ) def __init__( self, @@ -1055,10 +1109,15 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: - self.positiontup_level = {} - self.positiontup = [] - self._numeric_binds = dialect.paramstyle == "numeric" - self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + self._numeric_binds = nb = dialect.paramstyle.startswith("numeric") + if nb: + self._numeric_binds_identifier_char = ( + "$" if dialect.paramstyle == "numeric_dollar" else ":" + ) + + self.compilation_bindtemplate = _pyformat_template + else: + self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.ctes = None @@ -1095,11 +1154,17 @@ class SQLCompiler(Compiled): ): self.inline = True - if self.positional and self._numeric_binds: - self._apply_numbered_params() + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + if self.state is CompilerState.STRING_APPLIED: + if self.positional: + if self._numeric_binds: + self._process_numeric() + else: + self._process_positional() - if self._render_postcompile: - self._process_parameters_for_postcompile(_populate_self=True) + if self._render_postcompile: + self._process_parameters_for_postcompile(_populate_self=True) @property def insert_single_values_expr(self) -> Optional[str]: @@ -1135,7 +1200,7 @@ class SQLCompiler(Compiled): """ if self.implicit_returning: return self.implicit_returning - elif is_dml(self.statement): + elif self.statement is not None and is_dml(self.statement): return [ c for c in self.statement._all_selected_columns @@ -1217,10 +1282,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte = {} self.ctes_recursive = False - if self.positional: - self.cte_positional = {} - self.cte_level = {} - self.cte_order = collections.defaultdict(list) return ctes @@ -1248,12 +1309,145 @@ class SQLCompiler(Compiled): ordered_columns, ) - def _apply_numbered_params(self): - poscount = itertools.count(1) + def _process_positional(self): + assert not self.positiontup + assert self.state is CompilerState.STRING_APPLIED + assert not self._numeric_binds + + if self.dialect.paramstyle == "format": + placeholder = "%s" + else: + assert self.dialect.paramstyle == "qmark" + placeholder = "?" + + positions = [] + + def find_position(m: re.Match[str]) -> str: + normal_bind = m.group(1) + if normal_bind: + positions.append(normal_bind) + return placeholder + else: + # this a post-compile bind + positions.append(m.group(2)) + return m.group(0) + self.string = re.sub( - r"\[_POSITION\]", lambda m: str(next(poscount)), self.string + self._positional_pattern, find_position, self.string ) + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + self.positiontup = [ + reverse_escape.get(name, name) for name in positions + ] + else: + self.positiontup = positions + + if self._insertmanyvalues: + positions = [] + single_values_expr = re.sub( + self._positional_pattern, + find_position, + self._insertmanyvalues.single_values_expr, + ) + insert_crud_params = [ + ( + v[0], + v[1], + re.sub(self._positional_pattern, find_position, v[2]), + v[3], + ) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + single_values_expr=single_values_expr, + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + + def _process_numeric(self): + assert self._numeric_binds + assert self.state is CompilerState.STRING_APPLIED + + num = 1 + param_pos: Dict[str, str] = {} + order: Iterable[str] + if self._insertmanyvalues and self._values_bindparam is not None: + # bindparams that are not in values are always placed first. + # this avoids the need of changing them when using executemany + # values () () + order = itertools.chain( + ( + name + for name in self.bind_names.values() + if name not in self._values_bindparam + ), + self.bind_names.values(), + ) + else: + order = self.bind_names.values() + + for bind_name in order: + if bind_name in param_pos: + continue + bind = self.binds[bind_name] + if ( + bind in self.post_compile_params + or bind in self.literal_execute_params + ): + # set to None to just mark the in positiontup, it will not + # be replaced below. + param_pos[bind_name] = None # type: ignore + else: + ph = f"{self._numeric_binds_identifier_char}{num}" + num += 1 + param_pos[bind_name] = ph + + self.next_numeric_pos = num + + self.positiontup = list(param_pos) + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + param_pos = { + self.escaped_bind_names.get(name, name): pos + for name, pos in param_pos.items() + } + + # Can't use format here since % chars are not escaped. + self.string = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], self.string + ) + + if self._insertmanyvalues: + single_values_expr = ( + # format is ok here since single_values_expr includes only + # place-holders + self._insertmanyvalues.single_values_expr + % param_pos + ) + insert_crud_params = [ + (v[0], v[1], "%s", v[3]) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=self._insertmanyvalues.is_default_expr, + # This has the numbers (:1, :2) + single_values_expr=single_values_expr, + # The single binds are instead %s so they can be formatted + insert_crud_params=insert_crud_params, + num_positional_params_counted=( + self._insertmanyvalues.num_positional_params_counted + ), + ) + @util.memoized_property def _bind_processors( self, @@ -1492,39 +1686,30 @@ class SQLCompiler(Compiled): new_processors: Dict[str, _BindProcessorType[Any]] = {} - if self.positional and self._numeric_binds: - # I'm not familiar with any DBAPI that uses 'numeric'. - # strategy would likely be to make use of numbers greater than - # the highest number present; then for expanding parameters, - # append them to the end of the parameter list. that way - # we avoid having to renumber all the existing parameters. - raise NotImplementedError( - "'post-compile' bind parameters are not supported with " - "the 'numeric' paramstyle at this time." - ) - replacement_expressions: Dict[str, Any] = {} to_update_sets: Dict[str, Any] = {} # notes: # *unescaped* parameter names in: - # self.bind_names, self.binds, self._bind_processors + # self.bind_names, self.binds, self._bind_processors, self.positiontup # # *escaped* parameter names in: # construct_params(), replacement_expressions + numeric_positiontup: Optional[List[str]] = None + if self.positional and self.positiontup is not None: names: Iterable[str] = self.positiontup + if self._numeric_binds: + numeric_positiontup = [] else: names = self.bind_names.values() + ebn = self.escaped_bind_names for name in names: - escaped_name = ( - self.escaped_bind_names.get(name, name) - if self.escaped_bind_names - else name - ) + escaped_name = ebn.get(name, name) if ebn else name parameter = self.binds[name] + if parameter in self.literal_execute_params: if escaped_name not in replacement_expressions: value = parameters.pop(escaped_name) @@ -1555,10 +1740,10 @@ class SQLCompiler(Compiled): # in the escaped_bind_names dictionary. values = parameters.pop(name) - leep = self._literal_execute_expanding_parameter - to_update, replacement_expr = leep( + leep_res = self._literal_execute_expanding_parameter( escaped_name, parameter, values ) + (to_update, replacement_expr) = leep_res to_update_sets[escaped_name] = to_update replacement_expressions[escaped_name] = replacement_expr @@ -1583,7 +1768,14 @@ class SQLCompiler(Compiled): for key, _ in to_update if name in single_processors ) - if positiontup is not None: + if numeric_positiontup is not None: + numeric_positiontup.extend( + name for name, _ in to_update + ) + elif positiontup is not None: + # to_update has escaped names, but that's ok since + # these are new names, that aren't in the + # escaped_bind_names dict. positiontup.extend(name for name, _ in to_update) expanded_parameters[name] = [ expand_key for expand_key, _ in to_update @@ -1607,11 +1799,23 @@ class SQLCompiler(Compiled): return expr statement = re.sub( - r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", - process_expanding, - self.string, + self._post_compile_pattern, process_expanding, self.string ) + if numeric_positiontup is not None: + assert positiontup is not None + param_pos = { + key: f"{self._numeric_binds_identifier_char}{num}" + for num, key in enumerate( + numeric_positiontup, self.next_numeric_pos + ) + } + # Can't use format here since % chars are not escaped. + statement = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], statement + ) + positiontup.extend(numeric_positiontup) + expanded_state = ExpandedState( statement, parameters, @@ -2109,13 +2313,7 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -2411,7 +2609,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -2625,6 +2822,11 @@ class SQLCompiler(Compiled): dialect = self.dialect typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect) + if self._numeric_binds: + bind_template = self.compilation_bindtemplate + else: + bind_template = self.bindtemplate + if ( self.dialect._bind_typing_render_casts and typ_dialect_impl.render_bind_cast @@ -2634,13 +2836,13 @@ class SQLCompiler(Compiled): return self.render_bind_cast( parameter.type, typ_dialect_impl, - self.bindtemplate % {"name": name}, + bind_template % {"name": name}, ) else: def _render_bindtemplate(name): - return self.bindtemplate % {"name": name} + return bind_template % {"name": name} if not values: to_update = [] @@ -3224,7 +3426,6 @@ class SQLCompiler(Compiled): def bindparam_string( self, name: str, - positional_names: Optional[List[str]] = None, post_compile: bool = False, expanding: bool = False, escaped_from: Optional[str] = None, @@ -3232,12 +3433,9 @@ class SQLCompiler(Compiled): **kw: Any, ) -> str: - if self.positional: - if positional_names is not None: - positional_names.append(name) - else: - self.positiontup.append(name) # type: ignore[union-attr] - self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if self._visited_bindparam is not None: + self._visited_bindparam.append(name) + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): @@ -3271,6 +3469,8 @@ class SQLCompiler(Compiled): if type_impl.render_literal_cast: ret = self.render_bind_cast(bindparam_type, type_impl, ret) return ret + elif self.state is CompilerState.COMPILING: + ret = self.compilation_bindtemplate % {"name": name} else: ret = self.bindtemplate % {"name": name} @@ -3349,8 +3549,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) - if self.positional: - self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3414,8 +3612,6 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) - if self.positional: - self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3455,9 +3651,6 @@ class SQLCompiler(Compiled): ) ) - if self.positional: - kwargs["positional_names"] = self.cte_positional[cte] = [] - assert kwargs.get("subquery", False) is False if not self.stack: @@ -4152,13 +4345,7 @@ class SQLCompiler(Compiled): # In compound query, CTEs are shared at the compound level if self.ctes and (not is_embedded_select or toplevel): nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kwargs.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4332,7 +4519,6 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, - visiting_cte=None, ): """ include_following_stack @@ -4367,46 +4553,6 @@ class SQLCompiler(Compiled): return "" ctes_recursive = any([cte.recursive for cte in ctes]) - if self.positional: - self.cte_order[visiting_cte].extend(ctes) - - if visiting_cte is None and self.cte_order: - assert self.positiontup is not None - - def get_nested_positional(cte): - if cte in self.cte_order: - children = self.cte_order.pop(cte) - to_add = list( - itertools.chain.from_iterable( - get_nested_positional(child_cte) - for child_cte in children - ) - ) - if cte in self.cte_positional: - return reorder_positional( - self.cte_positional[cte], - to_add, - self.cte_level[children[0]], - ) - else: - return to_add - else: - return self.cte_positional.get(cte, []) - - def reorder_positional(pos, to_add, level): - if not level: - return to_add + pos - index = 0 - for index, name in enumerate(reversed(pos)): - if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 - break - return pos[:-index] + to_add + pos[-index:] - - to_add = get_nested_positional(None) - self.positiontup = reorder_positional( - self.positiontup, to_add, nesting_level - ) - cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -4762,6 +4908,11 @@ class SQLCompiler(Compiled): keys_to_replace = set() base_parameters = {} executemany_values_w_comma = f"({imv.single_values_expr}), " + if self._numeric_binds: + escaped = re.escape(self._numeric_binds_identifier_char) + executemany_values_w_comma = re.sub( + rf"{escaped}\d+", "%s", executemany_values_w_comma + ) while batches: batch = batches[0:batch_size] @@ -4794,25 +4945,37 @@ class SQLCompiler(Compiled): num_ins_params = imv.num_positional_params_counted + batch_iterator: Iterable[Tuple[Any, ...]] if num_ins_params == len(batch[0]): extra_params = () - batch_iterator: Iterable[Tuple[Any, ...]] = batch - elif self.returning_precedes_values: + batch_iterator = batch + elif self.returning_precedes_values or self._numeric_binds: extra_params = batch[0][:-num_ins_params] batch_iterator = (b[-num_ins_params:] for b in batch) else: extra_params = batch[0][num_ins_params:] batch_iterator = (b[:num_ins_params] for b in batch) + values_string = (executemany_values_w_comma * len(batch))[:-2] + if self._numeric_binds and num_ins_params > 0: + # need to format here, since statement may contain + # unescaped %, while values_string contains just (%s, %s) + start = len(extra_params) + 1 + end = num_ins_params * len(batch) + start + positions = tuple( + f"{self._numeric_binds_identifier_char}{i}" + for i in range(start, end) + ) + values_string = values_string % positions + replaced_statement = statement.replace( - "__EXECMANY_TOKEN__", - (executemany_values_w_comma * len(batch))[:-2], + "__EXECMANY_TOKEN__", values_string ) replaced_parameters = tuple( itertools.chain.from_iterable(batch_iterator) ) - if self.returning_precedes_values: + if self.returning_precedes_values or self._numeric_binds: replaced_parameters = extra_params + replaced_parameters else: replaced_parameters = replaced_parameters + extra_params @@ -4869,23 +5032,30 @@ class SQLCompiler(Compiled): } ) - positiontup_before = positiontup_after = 0 + counted_bindparam = 0 # for positional, insertmanyvalues needs to know how many # bound parameters are in the VALUES sequence; there's no simple # rule because default expressions etc. can have zero or more # params inside them. After multiple attempts to figure this out, - # this very simplistic "count before, then count after" works and is + # this very simplistic "count after" works and is # likely the least amount of callcounts, though looks clumsy - if self.positiontup: - positiontup_before = len(self.positiontup) + if self.positional: + self._visited_bindparam = [] crud_params_struct = crud._get_crud_params( self, insert_stmt, compile_state, toplevel, **kw ) - if self.positiontup: - positiontup_after = len(self.positiontup) + if self.positional: + assert self._visited_bindparam is not None + counted_bindparam = len(self._visited_bindparam) + if self._numeric_binds: + if self._values_bindparam is not None: + self._values_bindparam += self._visited_bindparam + else: + self._values_bindparam = self._visited_bindparam + self._visited_bindparam = None crud_params_single = crud_params_struct.single_params @@ -4940,31 +5110,13 @@ class SQLCompiler(Compiled): if self.implicit_returning or insert_stmt._returning: - # if returning clause is rendered first, capture bound parameters - # while visiting and place them prior to the VALUES oriented - # bound parameters, when using positional parameter scheme - rpv = self.returning_precedes_values - flip_pt = rpv and self.positional - if flip_pt: - pt: Optional[List[str]] = self.positiontup - temp_pt: Optional[List[str]] - self.positiontup = temp_pt = [] - else: - temp_pt = pt = None - returning_clause = self.returning_clause( insert_stmt, self.implicit_returning or insert_stmt._returning, populate_result_map=toplevel, ) - if flip_pt: - if TYPE_CHECKING: - assert temp_pt is not None - assert pt is not None - self.positiontup = temp_pt + pt - - if rpv: + if self.returning_precedes_values: text += " " + returning_clause else: @@ -4982,7 +5134,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4999,7 +5150,7 @@ class SQLCompiler(Compiled): cast( "List[crud._CrudParamElementStr]", crud_params_single ), - (positiontup_after - positiontup_before), + counted_bindparam, ) elif compile_state._has_multi_parameters: text += " VALUES %s" % ( @@ -5033,7 +5184,7 @@ class SQLCompiler(Compiled): "List[crud._CrudParamElementStr]", crud_params_single, ), - positiontup_after - positiontup_before, + counted_bindparam, ) if insert_stmt._post_values_clause is not None: @@ -5052,7 +5203,6 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, - visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5201,13 +5351,7 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -5321,13 +5465,7 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause( - nesting_level=nesting_level, - visiting_cte=kw.get("visiting_cte"), - ) - + text - ) + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) |