diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 129 |
1 files changed, 105 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5153f54d1..333ed36f4 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -841,7 +841,9 @@ class SQLCompiler(Compiled): """ # collect CTEs to tack on top of a SELECT self.ctes = util.OrderedDict() + # Detect same CTE references self.ctes_by_name = {} + self.level_by_ctes = {} self.ctes_recursive = False if self.positional: self.cte_positional = {} @@ -1830,8 +1832,14 @@ class SQLCompiler(Compiled): if cs._has_row_limiting_clause: text += self._row_limit_clause(cs, **kwargs) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) self.stack.pop(-1) return text @@ -2507,17 +2515,24 @@ class SQLCompiler(Compiled): ): self._init_cte_state() + cte_level = len(self.stack) if cte.nesting else 1 + kwargs["visiting_cte"] = cte - if isinstance(cte.name, elements._truncated_label): - cte_name = self._truncated_identifier("alias", cte.name) - else: - cte_name = cte.name + + cte_name = cte.name + + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) is_new_cte = True embedded_in_current_named_cte = False - if cte_name in self.ctes_by_name: - existing_cte = self.ctes_by_name[cte_name] + if cte in self.level_by_ctes: + cte_level = self.level_by_ctes[cte] + + cte_level_name = (cte_level, cte_name) + if cte_level_name in self.ctes_by_name: + existing_cte = self.ctes_by_name[cte_level_name] embedded_in_current_named_cte = visiting_cte is existing_cte # we've generated a same-named CTE that we are enclosed in, @@ -2529,6 +2544,7 @@ class SQLCompiler(Compiled): # enclosed in us - we take precedence, so # discard the text for the "inner". del self.ctes[existing_cte] + del self.level_by_ctes[existing_cte] else: raise exc.CompileError( "Multiple, unrelated CTEs found with " @@ -2548,7 +2564,7 @@ class SQLCompiler(Compiled): cte_pre_alias_name = None if is_new_cte: - self.ctes_by_name[cte_name] = cte + self.ctes_by_name[cte_level_name] = cte if ( "autocommit" in cte.element._execution_options @@ -2633,6 +2649,7 @@ class SQLCompiler(Compiled): ) self.ctes[cte] = text + self.level_by_ctes[cte] = cte_level if asfrom: if from_linter: @@ -3084,6 +3101,7 @@ class SQLCompiler(Compiled): self, select_stmt, asfrom=False, + insert_into=False, fromhints=None, compound_index=None, select_wraps_for=None, @@ -3112,6 +3130,8 @@ class SQLCompiler(Compiled): if toplevel and not self.compile_state: self.compile_state = compile_state + is_embedded_select = compound_index is not None or insert_into + # translate step for Oracle, SQL Server which often need to # restructure the SELECT to allow for LIMIT/OFFSET and possibly # other conditions @@ -3260,8 +3280,13 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + # In compound query, CTEs are shared at the compound level + if not is_embedded_select: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause(nesting_level=nesting_level) + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3433,14 +3458,55 @@ class SQLCompiler(Compiled): clause += " " return clause - def _render_cte_clause(self): + def _render_cte_clause( + self, + nesting_level=None, + include_following_stack=False, + ): + """ + include_following_stack + Also render the nesting CTEs on the next stack. Useful for + SQL structures like UNION or INSERT that can wrap SELECT + statements containing nesting CTEs. + """ + if not self.ctes: + return "" + + if nesting_level and nesting_level > 1: + ctes = util.OrderedDict() + for cte in list(self.ctes.keys()): + cte_level = self.level_by_ctes[cte] + is_rendered_level = cte_level == nesting_level or ( + include_following_stack and cte_level == nesting_level + 1 + ) + if not (cte.nesting and is_rendered_level): + continue + + ctes[cte] = self.ctes[cte] + + del self.ctes[cte] + del self.level_by_ctes[cte] + + cte_name = cte.name + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) + + del self.ctes_by_name[(cte_level, cte_name)] + else: + ctes = self.ctes + + if not ctes: + return "" + + ctes_recursive = any([cte.recursive for cte in ctes]) + if self.positional: self.positiontup = ( - sum([self.cte_positional[cte] for cte in self.ctes], []) + sum([self.cte_positional[cte] for cte in ctes], []) + self.positiontup ) - cte_text = self.get_cte_preamble(self.ctes_recursive) + " " - cte_text += ", \n".join([txt for txt in self.ctes.values()]) + cte_text = self.get_cte_preamble(ctes_recursive) + " " + cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " return cte_text @@ -3689,11 +3755,18 @@ class SQLCompiler(Compiled): if insert_stmt.select is not None: # placed here by crud.py select_text = self.process( - self.stack[-1]["insert_from_select"], **kw + self.stack[-1]["insert_from_select"], insert_into=True, **kw ) - if self.ctes and toplevel and self.dialect.cte_follows_insert: - text += " %s%s" % (self._render_cte_clause(), select_text) + if self.ctes and self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text += " %s%s" % ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ), + select_text, + ) else: text += " %s" % select_text elif not crud_params and supports_default_values: @@ -3731,8 +3804,14 @@ class SQLCompiler(Compiled): if returning_clause and not self.returning_precedes_values: text += " " + returning_clause - if self.ctes and toplevel and not self.dialect.cte_follows_insert: - text = self._render_cte_clause() + text + if self.ctes and not self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) self.stack.pop(-1) @@ -3865,8 +3944,9 @@ class SQLCompiler(Compiled): update_stmt, self.returning or update_stmt._returning ) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) @@ -3968,8 +4048,9 @@ class SQLCompiler(Compiled): delete_stmt, delete_stmt._returning ) - if self.ctes and toplevel: - text = self._render_cte_clause() + text + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text self.stack.pop(-1) |