summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py115
-rw-r--r--lib/sqlalchemy/sql/selectable.py8
2 files changed, 94 insertions, 29 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 50cf9b477..7ac279ee2 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -236,8 +236,8 @@ BIND_TEMPLATES = {
}
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
OPERATORS = {
# binary
@@ -973,6 +973,7 @@ class SQLCompiler(Compiled):
debugging use cases.
"""
+ positiontup_level: Optional[Dict[str, int]] = None
inline: bool = False
@@ -988,6 +989,8 @@ class SQLCompiler(Compiled):
ctes_recursive: bool
cte_positional: Dict[CTE, List[str]]
+ cte_level: Dict[CTE, int]
+ cte_order: Dict[Optional[CTE], List[CTE]]
def __init__(
self,
@@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled):
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
+ self.positiontup_level = {}
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
@@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled):
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
+ self.cte_level = {}
+ self.cte_order = collections.defaultdict(list)
return ctes
@@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled):
text = self.process(taf.element, **kw)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled):
)
def visit_over(self, over, **kwargs):
+ text = over.element._compiler_dispatch(self, **kwargs)
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
over.range_, **kwargs
@@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled):
range_ = None
return "%s OVER (%s)" % (
- over.element._compiler_dispatch(self, **kwargs),
+ text,
" ".join(
[
"%s BY %s"
@@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ text
)
@@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled):
positional_names.append(name)
else:
self.positiontup.append(name) # type: ignore[union-attr]
- elif not escaped_from:
+ self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501
+ if not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
# not quite the translate use case as we want to
@@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = new_level_name + (
cte_opts,
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
else:
cte_level = len(self.stack) if nesting else 1
@@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled):
self.level_name_by_cte[_reference_cte] = cte_level_name + (
cte_opts,
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes:
- # In compound query, CTEs are shared at the compound level
- if not is_embedded_select:
- nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(nesting_level=nesting_level) + text
+ # In compound query, CTEs are shared at the compound level
+ if self.ctes and (not is_embedded_select or toplevel):
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ + text
+ )
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
@@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled):
self,
nesting_level=None,
include_following_stack=False,
+ visiting_cte=None,
):
"""
include_following_stack
@@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled):
if not ctes:
return ""
-
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
- assert self.positiontup is not None
- self.positiontup = (
- list(
- itertools.chain.from_iterable(
- self.cte_positional[cte] for cte in ctes
- )
+ self.cte_order[visiting_cte].extend(ctes)
+
+ if visiting_cte is None and self.cte_order:
+ assert self.positiontup is not None
+
+ def get_nested_positional(cte):
+ if cte in self.cte_order:
+ children = self.cte_order.pop(cte)
+ to_add = list(
+ itertools.chain.from_iterable(
+ get_nested_positional(child_cte)
+ for child_cte in children
+ )
+ )
+ if cte in self.cte_positional:
+ return reorder_positional(
+ self.cte_positional[cte],
+ to_add,
+ self.cte_level[children[0]],
+ )
+ else:
+ return to_add
+ else:
+ return self.cte_positional.get(cte, [])
+
+ def reorder_positional(pos, to_add, level):
+ if not level:
+ return to_add + pos
+ index = 0
+ for index, name in enumerate(reversed(pos)):
+ if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
+ break
+ return pos[:-index] + to_add + pos[-index:]
+
+ to_add = get_nested_positional(None)
+ self.positiontup = reorder_positional(
+ self.positiontup, to_add, nesting_level
)
- + self.positiontup
- )
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
@@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled):
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
),
select_text,
)
@@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled):
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
)
+ text
)
@@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled):
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 97336d416..2dcc611fa 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -2052,9 +2052,7 @@ class CTE(
else:
self.element._generate_fromclause_column_proxies(self)
- def alias(
- self, name: Optional[str] = None, flat: bool = False
- ) -> NamedFromClause:
+ def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE:
"""Return an :class:`_expression.Alias` of this
:class:`_expression.CTE`.
@@ -2078,7 +2076,7 @@ class CTE(
_suffixes=self._suffixes,
)
- def union(self, *other):
+ def union(self, *other: _SelectStatementForCompoundArgument) -> CTE:
r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
of the original CTE against the given selectables provided
as positional arguments.
@@ -2107,7 +2105,7 @@ class CTE(
_suffixes=self._suffixes,
)
- def union_all(self, *other):
+ def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE:
r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
of the original CTE against the given selectables provided
as positional arguments.