diff options
author | Federico Caselli <cfederico87@gmail.com> | 2022-11-19 20:39:10 +0100 |
---|---|---|
committer | Federico Caselli <cfederico87@gmail.com> | 2022-12-01 23:50:30 +0100 |
commit | 0f2baae6bf72353f785bad394684f2d6fa53e0ef (patch) | |
tree | 4d7c2cd6e8a73106aa4f95105968cf6e3fded813 /lib/sqlalchemy/sql/compiler.py | |
parent | c440c920aecd6593974e5a0d37cdb9069e5d3e57 (diff) | |
download | sqlalchemy-0f2baae6bf72353f785bad394684f2d6fa53e0ef.tar.gz |
Fix positional compiling bugs
Fixed a series of issues regarding positionally rendered bound parameters,
such as those used for SQLite, asyncpg, MySQL and others. Some compiled
forms would not maintain the order of parameters correctly, such as the
PostgreSQL ``regexp_replace()`` function as well as within the "nesting"
feature of the :class:`.CTE` construct first introduced in :ticket:`4123`.
Fixes: #8827
Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 115 |
1 files changed, 91 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 50cf9b477..7ac279ee2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -236,8 +236,8 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) OPERATORS = { # binary @@ -973,6 +973,7 @@ class SQLCompiler(Compiled): debugging use cases. """ + positiontup_level: Optional[Dict[str, int]] = None inline: bool = False @@ -988,6 +989,8 @@ class SQLCompiler(Compiled): ctes_recursive: bool cte_positional: Dict[CTE, List[str]] + cte_level: Dict[CTE, int] + cte_order: Dict[Optional[CTE], List[CTE]] def __init__( self, @@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: + self.positiontup_level = {} self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] @@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled): self.ctes_recursive = False if self.positional: self.cte_positional = {} + self.cte_level = {} + self.cte_order = collections.defaultdict(list) return ctes @@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled): ) def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs @@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled): range_ = None return "%s OVER (%s)" % ( - over.element._compiler_dispatch(self, **kwargs), + text, " ".join( [ "%s BY %s" @@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) # type: ignore[union-attr] - elif not escaped_from: + self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): # not quite the translate use case as we want to @@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes: - # In compound query, CTEs are shared at the compound level - if not is_embedded_select: - nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause(nesting_level=nesting_level) + text + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kwargs.get("visiting_cte"), ) + + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, + visiting_cte=None, ): """ include_following_stack @@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled): if not ctes: return "" - ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: - assert self.positiontup is not None - self.positiontup = ( - list( - itertools.chain.from_iterable( - self.cte_positional[cte] for cte in ctes - ) + self.cte_order[visiting_cte].extend(ctes) + + if visiting_cte is None and self.cte_order: + assert self.positiontup is not None + + def get_nested_positional(cte): + if cte in self.cte_order: + children = self.cte_order.pop(cte) + to_add = list( + itertools.chain.from_iterable( + get_nested_positional(child_cte) + for child_cte in children + ) + ) + if cte in self.cte_positional: + return reorder_positional( + self.cte_positional[cte], + to_add, + self.cte_level[children[0]], + ) + else: + return to_add + else: + return self.cte_positional.get(cte, []) + + def reorder_positional(pos, to_add, level): + if not level: + return to_add + pos + index = 0 + for index, name in enumerate(reversed(pos)): + if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 + break + return pos[:-index] + to_add + pos[-index:] + + to_add = get_nested_positional(None) + self.positiontup = reorder_positional( + self.positiontup, to_add, nesting_level ) - + self.positiontup - ) cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) @@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) |