diff options
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 115 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 8 |
2 files changed, 94 insertions, 29 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 50cf9b477..7ac279ee2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -236,8 +236,8 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) OPERATORS = { # binary @@ -973,6 +973,7 @@ class SQLCompiler(Compiled): debugging use cases. """ + positiontup_level: Optional[Dict[str, int]] = None inline: bool = False @@ -988,6 +989,8 @@ class SQLCompiler(Compiled): ctes_recursive: bool cte_positional: Dict[CTE, List[str]] + cte_level: Dict[CTE, int] + cte_order: Dict[Optional[CTE], List[CTE]] def __init__( self, @@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: + self.positiontup_level = {} self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] @@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled): self.ctes_recursive = False if self.positional: self.cte_positional = {} + self.cte_level = {} + self.cte_order = collections.defaultdict(list) return ctes @@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled): ) def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs @@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled): range_ = None return "%s OVER (%s)" % ( - over.element._compiler_dispatch(self, **kwargs), + text, " ".join( [ "%s BY %s" @@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) # type: ignore[union-attr] - elif not escaped_from: + self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): # not quite the translate use case as we want to @@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes: - # In compound query, CTEs are shared at the compound level - if not is_embedded_select: - nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause(nesting_level=nesting_level) + text + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kwargs.get("visiting_cte"), ) + + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, + visiting_cte=None, ): """ include_following_stack @@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled): if not ctes: return "" - ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: - assert self.positiontup is not None - self.positiontup = ( - list( - itertools.chain.from_iterable( - self.cte_positional[cte] for cte in ctes - ) + self.cte_order[visiting_cte].extend(ctes) + + if visiting_cte is None and self.cte_order: + assert self.positiontup is not None + + def get_nested_positional(cte): + if cte in self.cte_order: + children = self.cte_order.pop(cte) + to_add = list( + itertools.chain.from_iterable( + get_nested_positional(child_cte) + for child_cte in children + ) + ) + if cte in self.cte_positional: + return reorder_positional( + self.cte_positional[cte], + to_add, + self.cte_level[children[0]], + ) + else: + return to_add + else: + return self.cte_positional.get(cte, []) + + def reorder_positional(pos, to_add, level): + if not level: + return to_add + pos + index = 0 + for index, name in enumerate(reversed(pos)): + if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 + break + return pos[:-index] + to_add + pos[-index:] + + to_add = get_nested_positional(None) + self.positiontup = reorder_positional( + self.positiontup, to_add, nesting_level ) - + self.positiontup - ) cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) @@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 97336d416..2dcc611fa 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2052,9 +2052,7 @@ class CTE( else: self.element._generate_fromclause_column_proxies(self) - def alias( - self, name: Optional[str] = None, flat: bool = False - ) -> NamedFromClause: + def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -2078,7 +2076,7 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, *other): + def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2107,7 +2105,7 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, *other): + def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. |