summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py129
1 files changed, 105 insertions, 24 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5153f54d1..333ed36f4 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -841,7 +841,9 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
self.ctes = util.OrderedDict()
+ # Detect same CTE references
self.ctes_by_name = {}
+ self.level_by_ctes = {}
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
@@ -1830,8 +1832,14 @@ class SQLCompiler(Compiled):
if cs._has_row_limiting_clause:
text += self._row_limit_clause(cs, **kwargs)
- if self.ctes and toplevel:
- text = self._render_cte_clause() + text
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
self.stack.pop(-1)
return text
@@ -2507,17 +2515,24 @@ class SQLCompiler(Compiled):
):
self._init_cte_state()
+ cte_level = len(self.stack) if cte.nesting else 1
+
kwargs["visiting_cte"] = cte
- if isinstance(cte.name, elements._truncated_label):
- cte_name = self._truncated_identifier("alias", cte.name)
- else:
- cte_name = cte.name
+
+ cte_name = cte.name
+
+ if isinstance(cte_name, elements._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte_name)
is_new_cte = True
embedded_in_current_named_cte = False
- if cte_name in self.ctes_by_name:
- existing_cte = self.ctes_by_name[cte_name]
+ if cte in self.level_by_ctes:
+ cte_level = self.level_by_ctes[cte]
+
+ cte_level_name = (cte_level, cte_name)
+ if cte_level_name in self.ctes_by_name:
+ existing_cte = self.ctes_by_name[cte_level_name]
embedded_in_current_named_cte = visiting_cte is existing_cte
# we've generated a same-named CTE that we are enclosed in,
@@ -2529,6 +2544,7 @@ class SQLCompiler(Compiled):
# enclosed in us - we take precedence, so
# discard the text for the "inner".
del self.ctes[existing_cte]
+ del self.level_by_ctes[existing_cte]
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
@@ -2548,7 +2564,7 @@ class SQLCompiler(Compiled):
cte_pre_alias_name = None
if is_new_cte:
- self.ctes_by_name[cte_name] = cte
+ self.ctes_by_name[cte_level_name] = cte
if (
"autocommit" in cte.element._execution_options
@@ -2633,6 +2649,7 @@ class SQLCompiler(Compiled):
)
self.ctes[cte] = text
+ self.level_by_ctes[cte] = cte_level
if asfrom:
if from_linter:
@@ -3084,6 +3101,7 @@ class SQLCompiler(Compiled):
self,
select_stmt,
asfrom=False,
+ insert_into=False,
fromhints=None,
compound_index=None,
select_wraps_for=None,
@@ -3112,6 +3130,8 @@ class SQLCompiler(Compiled):
if toplevel and not self.compile_state:
self.compile_state = compile_state
+ is_embedded_select = compound_index is not None or insert_into
+
# translate step for Oracle, SQL Server which often need to
# restructure the SELECT to allow for LIMIT/OFFSET and possibly
# other conditions
@@ -3260,8 +3280,13 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes and toplevel:
- text = self._render_cte_clause() + text
+ if self.ctes:
+ # In compound query, CTEs are shared at the compound level
+ if not is_embedded_select:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(nesting_level=nesting_level) + text
+ )
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
@@ -3433,14 +3458,55 @@ class SQLCompiler(Compiled):
clause += " "
return clause
- def _render_cte_clause(self):
+ def _render_cte_clause(
+ self,
+ nesting_level=None,
+ include_following_stack=False,
+ ):
+ """
+ include_following_stack
+ Also render the nesting CTEs on the next stack. Useful for
+ SQL structures like UNION or INSERT that can wrap SELECT
+ statements containing nesting CTEs.
+ """
+ if not self.ctes:
+ return ""
+
+ if nesting_level and nesting_level > 1:
+ ctes = util.OrderedDict()
+ for cte in list(self.ctes.keys()):
+ cte_level = self.level_by_ctes[cte]
+ is_rendered_level = cte_level == nesting_level or (
+ include_following_stack and cte_level == nesting_level + 1
+ )
+ if not (cte.nesting and is_rendered_level):
+ continue
+
+ ctes[cte] = self.ctes[cte]
+
+ del self.ctes[cte]
+ del self.level_by_ctes[cte]
+
+ cte_name = cte.name
+ if isinstance(cte_name, elements._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte_name)
+
+ del self.ctes_by_name[(cte_level, cte_name)]
+ else:
+ ctes = self.ctes
+
+ if not ctes:
+ return ""
+
+ ctes_recursive = any([cte.recursive for cte in ctes])
+
if self.positional:
self.positiontup = (
- sum([self.cte_positional[cte] for cte in self.ctes], [])
+ sum([self.cte_positional[cte] for cte in ctes], [])
+ self.positiontup
)
- cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
- cte_text += ", \n".join([txt for txt in self.ctes.values()])
+ cte_text = self.get_cte_preamble(ctes_recursive) + " "
+ cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
return cte_text
@@ -3689,11 +3755,18 @@ class SQLCompiler(Compiled):
if insert_stmt.select is not None:
# placed here by crud.py
select_text = self.process(
- self.stack[-1]["insert_from_select"], **kw
+ self.stack[-1]["insert_from_select"], insert_into=True, **kw
)
- if self.ctes and toplevel and self.dialect.cte_follows_insert:
- text += " %s%s" % (self._render_cte_clause(), select_text)
+ if self.ctes and self.dialect.cte_follows_insert:
+ nesting_level = len(self.stack) if not toplevel else None
+ text += " %s%s" % (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ ),
+ select_text,
+ )
else:
text += " %s" % select_text
elif not crud_params and supports_default_values:
@@ -3731,8 +3804,14 @@ class SQLCompiler(Compiled):
if returning_clause and not self.returning_precedes_values:
text += " " + returning_clause
- if self.ctes and toplevel and not self.dialect.cte_follows_insert:
- text = self._render_cte_clause() + text
+ if self.ctes and not self.dialect.cte_follows_insert:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
self.stack.pop(-1)
@@ -3865,8 +3944,9 @@ class SQLCompiler(Compiled):
update_stmt, self.returning or update_stmt._returning
)
- if self.ctes and toplevel:
- text = self._render_cte_clause() + text
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)
@@ -3968,8 +4048,9 @@ class SQLCompiler(Compiled):
delete_stmt, delete_stmt._returning
)
- if self.ctes and toplevel:
- text = self._render_cte_clause() + text
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
self.stack.pop(-1)