summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b955c5608..e8f86634d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled):
# column targeting
self.result_map = {}
+ # collect CTEs to tack on top of a SELECT
+ self.ctes = util.OrderedDict()
+ self.ctes_recursive = False
+
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
@@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled):
else:
return self.bindtemplate % {'name':name}
+ def visit_cte(self, cte, asfrom=False, ashint=False,
+ fromhints=None, **kwargs):
+ if isinstance(cte.name, sql._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte.name)
+ else:
+ cte_name = cte.name
+ if cte.cte_alias:
+ if isinstance(cte.cte_alias, sql._truncated_label):
+ cte_alias = self._truncated_identifier("alias", cte.cte_alias)
+ else:
+ cte_alias = cte.cte_alias
+ if not cte.cte_alias and cte not in self.ctes:
+ if cte.recursive:
+ self.ctes_recursive = True
+ text = self.preparer.format_alias(cte, cte_name)
+ if cte.recursive:
+ if isinstance(cte.original, sql.Select):
+ col_source = cte.original
+ elif isinstance(cte.original, sql.CompoundSelect):
+ col_source = cte.original.selects[0]
+ else:
+ assert False
+ recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
+ if c is not None]
+
+ text += "(%s)" % (", ".join(recur_cols))
+ text += " AS \n" + \
+ cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
+ self.ctes[cte] = text
+ if asfrom:
+ if cte.cte_alias:
+ text = self.preparer.format_alias(cte, cte_alias)
+ text += " AS " + cte_name
+ else:
+ return self.preparer.format_alias(cte, cte_name)
+ return text
+
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
@@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled):
if select.for_update:
text += self.for_update_clause(select)
+ if self.ctes and \
+ compound_index==1 and not entry:
+ cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+ cte_text += ", \n".join(
+ [txt for txt in self.ctes.values()]
+ )
+ cte_text += "\n "
+ text = cte_text + text
+
self.stack.pop(-1)
if asfrom and parens:
@@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled):
else:
return text
+ def get_cte_preamble(self, recursive):
+ if recursive:
+ return "WITH RECURSIVE"
+ else:
+ return "WITH"
+
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list.