From 1607b74f8527905ecdc6133b4b4166a9ed675e09 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 3 Mar 2012 13:00:44 -0500 Subject: - [feature] Added cte() method to Query, invokes common table expression support from the Core (see below). [ticket:1859] - [feature] Added support for SQL standard common table expressions (CTE), allowing SELECT objects as the CTE source (DML not yet supported). This is invoked via the cte() method on any select() construct. [ticket:1859] --- lib/sqlalchemy/sql/compiler.py | 58 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) (limited to 'lib/sqlalchemy/sql/compiler.py') diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b955c5608..e8f86634d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + # true if the paramstyle is positional self.positional = dialect.positional if self.positional: @@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, **kwargs): + if isinstance(cte.name, sql._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + if cte.cte_alias: + if isinstance(cte.cte_alias, sql._truncated_label): + cte_alias = self._truncated_identifier("alias", cte.cte_alias) + else: + cte_alias = cte.cte_alias + if not cte.cte_alias and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, sql.Select): + col_source = cte.original + elif isinstance(cte.original, sql.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c.key for c in util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join(recur_cols)) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.ctes[cte] = text + if asfrom: + if cte.cte_alias: + text = self.preparer.format_alias(cte, cte_alias) + text += " AS " + cte_name + else: + return self.preparer.format_alias(cte, cte_name) + return text + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: @@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled): if select.for_update: text += self.for_update_clause(select) + if self.ctes and \ + compound_index==1 and not entry: + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + text = cte_text + text + self.stack.pop(-1) if asfrom and parens: @@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled): else: return text + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list. -- cgit v1.2.1