summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2022-11-19 20:39:10 +0100
committerFederico Caselli <cfederico87@gmail.com>2022-12-01 23:50:30 +0100
commit0f2baae6bf72353f785bad394684f2d6fa53e0ef (patch)
tree4d7c2cd6e8a73106aa4f95105968cf6e3fded813 /lib/sqlalchemy/sql/compiler.py
parentc440c920aecd6593974e5a0d37cdb9069e5d3e57 (diff)
downloadsqlalchemy-0f2baae6bf72353f785bad394684f2d6fa53e0ef.tar.gz
Fix positional compiling bugs
Fixed a series of issues regarding positionally rendered bound parameters, such as those used for SQLite, asyncpg, MySQL and others. Some compiled forms would not maintain the order of parameters correctly, such as the PostgreSQL ``regexp_replace()`` function as well as within the "nesting" feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. Fixes: #8827 Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py115
1 files changed, 91 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 50cf9b477..7ac279ee2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -236,8 +236,8 @@ BIND_TEMPLATES = {
}
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
OPERATORS = {
# binary
@@ -973,6 +973,7 @@ class SQLCompiler(Compiled):
debugging use cases.
"""
+ positiontup_level: Optional[Dict[str, int]] = None
inline: bool = False
@@ -988,6 +989,8 @@ class SQLCompiler(Compiled):
ctes_recursive: bool
cte_positional: Dict[CTE, List[str]]
+ cte_level: Dict[CTE, int]
+ cte_order: Dict[Optional[CTE], List[CTE]]
def __init__(
self,
@@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled):
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
+ self.positiontup_level = {}
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
@@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled):
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
+ self.cte_level = {}
+ self.cte_order = collections.defaultdict(list)
return ctes
@@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled):
text = self.process(taf.element, **kw)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled):
)
def visit_over(self, over, **kwargs):
+ text = over.element._compiler_dispatch(self, **kwargs)
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
over.range_, **kwargs
@@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled):
range_ = None
return "%s OVER (%s)" % (
- over.element._compiler_dispatch(self, **kwargs),
+ text,
" ".join(
[
"%s BY %s"
@@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ text
)
@@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled):
positional_names.append(name)
else:
self.positiontup.append(name) # type: ignore[union-attr]
- elif not escaped_from:
+ self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501
+ if not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
# not quite the translate use case as we want to
@@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = new_level_name + (
cte_opts,
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
else:
cte_level = len(self.stack) if nesting else 1
@@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = cte_level_name + (
cte_opts,
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes:
- # In compound query, CTEs are shared at the compound level
- if not is_embedded_select:
- nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(nesting_level=nesting_level) + text
+ # In compound query, CTEs are shared at the compound level
+ if self.ctes and (not is_embedded_select or toplevel):
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ + text
+ )
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
@@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled):
self,
nesting_level=None,
include_following_stack=False,
+ visiting_cte=None,
):
"""
include_following_stack
@@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled):
if not ctes:
return ""
-
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
- assert self.positiontup is not None
- self.positiontup = (
- list(
- itertools.chain.from_iterable(
- self.cte_positional[cte] for cte in ctes
- )
+ self.cte_order[visiting_cte].extend(ctes)
+
+ if visiting_cte is None and self.cte_order:
+ assert self.positiontup is not None
+
+ def get_nested_positional(cte):
+ if cte in self.cte_order:
+ children = self.cte_order.pop(cte)
+ to_add = list(
+ itertools.chain.from_iterable(
+ get_nested_positional(child_cte)
+ for child_cte in children
+ )
+ )
+ if cte in self.cte_positional:
+ return reorder_positional(
+ self.cte_positional[cte],
+ to_add,
+ self.cte_level[children[0]],
+ )
+ else:
+ return to_add
+ else:
+ return self.cte_positional.get(cte, [])
+
+ def reorder_positional(pos, to_add, level):
+ if not level:
+ return to_add + pos
+ index = 0
+ for index, name in enumerate(reversed(pos)):
+ if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
+ break
+ return pos[:-index] + to_add + pos[-index:]
+
+ to_add = get_nested_positional(None)
+ self.positiontup = reorder_positional(
+ self.positiontup, to_add, nesting_level
)
- + self.positiontup
- )
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
@@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
),
select_text,
)
@@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
)
+ text
)
@@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)