diff options
author | Eric Masseran <eric.masseran@gmail.com> | 2021-09-13 13:45:57 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-09-17 11:55:48 -0400 |
commit | a3884f36f691df81fb5a1c795fe7ecc0c83507b6 (patch) | |
tree | 4a512ccdf713a87cc1e4a0b2a79208d712daa118 /lib/sqlalchemy/sql/compiler.py | |
parent | f85dd7b9f1ca4ba30f58d939b2ae003feaa34c8f (diff) | |
download | sqlalchemy-a3884f36f691df81fb5a1c795fe7ecc0c83507b6.tar.gz |
Implement nesting CTE
Added new parameter :meth:`_sql.HasCte.cte.nesting` to the
:class:`_sql.CTE` constructor and :meth:`_sql.HasCTE.cte` method, which
flags the CTE as one which should remain nested within an enclosing CTE,
rather than being moved to the top level of the outermost SELECT. While in
the vast majority of cases there is no difference in SQL functionality,
users have identified various edge-cases where true nesting of CTE
constructs is desirable. Much thanks to Eric Masseran for lots of work on
this intricate feature.
Fixes: #4123
Closes: #6709
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/6709
Pull-request-sha: 64ab2f6ea269f2dcf37376a13ea38c48c5226fb6
Change-Id: Ic4dc25ab763af96d96632369e01527d48a654149
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 129 |
1 files changed, 105 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5153f54d1..333ed36f4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -841,7 +841,9 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT self.ctes = util.OrderedDict() + # Detect same CTE references self.ctes_by_name = {} + self.level_by_ctes = {} self.ctes_recursive = False if self.positional: self.cte_positional = {} @@ -1830,8 +1832,14 @@ class SQLCompiler(Compiled): if cs._has_row_limiting_clause: text += self._row_limit_clause(cs, **kwargs) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) self.stack.pop(-1) return text @@ -2507,17 +2515,24 @@ class SQLCompiler(Compiled): ): self._init_cte_state() + cte_level = len(self.stack) if cte.nesting else 1 + kwargs["visiting_cte"] = cte - if isinstance(cte.name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte.name) - else: - cte_name = cte.name + + cte_name = cte.name + + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) is_new_cte = True embedded_in_current_named_cte = False - if cte_name in self.ctes_by_name: - existing_cte = self.ctes_by_name[cte_name] + if cte in self.level_by_ctes: + cte_level = self.level_by_ctes[cte] + + cte_level_name = (cte_level, cte_name) + if cte_level_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, @@ -2529,6 +2544,7 @@ class SQLCompiler(Compiled): # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] + del self.level_by_ctes[existing_cte] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " @@ -2548,7 +2564,7 @@ class SQLCompiler(Compiled): cte_pre_alias_name = None if is_new_cte: - self.ctes_by_name[cte_name] = cte + self.ctes_by_name[cte_level_name] = cte if ( "autocommit" in cte.element._execution_options @@ -2633,6 +2649,7 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text + self.level_by_ctes[cte] = cte_level if asfrom: if from_linter: @@ -3084,6 +3101,7 @@ class SQLCompiler(Compiled): self, select_stmt, asfrom=False, + insert_into=False, fromhints=None, compound_index=None, select_wraps_for=None, @@ -3112,6 +3130,8 @@ class SQLCompiler(Compiled): if toplevel and not self.compile_state: self.compile_state = compile_state + is_embedded_select = compound_index is not None or insert_into + # translate step for Oracle, SQL Server which often need to # restructure the SELECT to allow for LIMIT/OFFSET and possibly # other conditions @@ -3260,8 +3280,13 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + # In compound query, CTEs are shared at the compound level + if not is_embedded_select: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause(nesting_level=nesting_level) + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3433,14 +3458,55 @@ class SQLCompiler(Compiled): clause += " " return clause - def _render_cte_clause(self): + def _render_cte_clause( + self, + nesting_level=None, + include_following_stack=False, + ): + """ + include_following_stack + Also render the nesting CTEs on the next stack. Useful for + SQL structures like UNION or INSERT that can wrap SELECT + statements containing nesting CTEs. + """ + if not self.ctes: + return "" + + if nesting_level and nesting_level > 1: + ctes = util.OrderedDict() + for cte in list(self.ctes.keys()): + cte_level = self.level_by_ctes[cte] + is_rendered_level = cte_level == nesting_level or ( + include_following_stack and cte_level == nesting_level + 1 + ) + if not (cte.nesting and is_rendered_level): + continue + + ctes[cte] = self.ctes[cte] + + del self.ctes[cte] + del self.level_by_ctes[cte] + + cte_name = cte.name + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) + + del self.ctes_by_name[(cte_level, cte_name)] + else: + ctes = self.ctes + + if not ctes: + return "" + + ctes_recursive = any([cte.recursive for cte in ctes]) + if self.positional: self.positiontup = ( - sum([self.cte_positional[cte] for cte in self.ctes], []) + sum([self.cte_positional[cte] for cte in ctes], []) + self.positiontup ) - cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join([txt for txt in self.ctes.values()]) + cte_text = self.get_cte_preamble(ctes_recursive) + " " + cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " return cte_text @@ -3689,11 +3755,18 @@ class SQLCompiler(Compiled): if insert_stmt.select is not None: # placed here by crud.py select_text = self.process( - self.stack[-1]["insert_from_select"], **kw + self.stack[-1]["insert_from_select"], insert_into=True, **kw ) - if self.ctes and toplevel and self.dialect.cte_follows_insert: - text += " %s%s" % (self._render_cte_clause(), select_text) + if self.ctes and self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text += " %s%s" % ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ), + select_text, + ) else: text += " %s" % select_text elif not crud_params and supports_default_values: @@ -3731,8 +3804,14 @@ class SQLCompiler(Compiled): if returning_clause and not self.returning_precedes_values: text += " " + returning_clause - if self.ctes and toplevel and not self.dialect.cte_follows_insert: - text = self._render_cte_clause() + text + if self.ctes and not self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) self.stack.pop(-1) @@ -3865,8 +3944,9 @@ class SQLCompiler(Compiled): update_stmt, self.returning or update_stmt._returning ) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -3968,8 +4048,9 @@ class SQLCompiler(Compiled): delete_stmt, delete_stmt._returning ) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) |