diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-03-03 13:00:44 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-03-03 13:00:44 -0500 |
commit | 1607b74f8527905ecdc6133b4b4166a9ed675e09 (patch) | |
tree | cd752b16ab90c4864a071689c57f3ff946f8b241 /lib/sqlalchemy/sql/compiler.py | |
parent | 4d43079e34a66c3718127266bc5eaa3041c69447 (diff) | |
download | sqlalchemy-1607b74f8527905ecdc6133b4b4166a9ed675e09.tar.gz |
- [feature] Added cte() method to Query,
invokes common table expression support
from the Core (see below). [ticket:1859]
- [feature] Added support for SQL standard
common table expressions (CTE), allowing
SELECT objects as the CTE source (DML
not yet supported). This is invoked via
the cte() method on any select() construct.
[ticket:1859]
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b955c5608..e8f86634d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + # true if the paramstyle is positional self.positional = dialect.positional if self.positional: @@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled): else: return self.bindtemplate % {'name':name} + def visit_cte(self, cte, asfrom=False, ashint=False, + fromhints=None, **kwargs): + if isinstance(cte.name, sql._truncated_label): + cte_name = self._truncated_identifier("alias", cte.name) + else: + cte_name = cte.name + if cte.cte_alias: + if isinstance(cte.cte_alias, sql._truncated_label): + cte_alias = self._truncated_identifier("alias", cte.cte_alias) + else: + cte_alias = cte.cte_alias + if not cte.cte_alias and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.original, sql.Select): + col_source = cte.original + elif isinstance(cte.original, sql.CompoundSelect): + col_source = cte.original.selects[0] + else: + assert False + recur_cols = [c.key for c in util.unique_list(col_source.inner_columns) + if c is not None] + + text += "(%s)" % (", ".join(recur_cols)) + text += " AS \n" + \ + cte.original._compiler_dispatch( + self, asfrom=True, **kwargs + ) + self.ctes[cte] = text + if asfrom: + if cte.cte_alias: + text = self.preparer.format_alias(cte, cte_alias) + text += " AS " + cte_name + else: + return self.preparer.format_alias(cte, cte_name) + return text + def visit_alias(self, alias, asfrom=False, ashint=False, fromhints=None, **kwargs): if asfrom or ashint: @@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled): if select.for_update: text += self.for_update_clause(select) + if self.ctes and \ + compound_index==1 and not entry: + cte_text = self.get_cte_preamble(self.ctes_recursive) + " " + cte_text += ", \n".join( + [txt for txt in self.ctes.values()] + ) + cte_text += "\n " + text = cte_text + text + self.stack.pop(-1) if asfrom and parens: @@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled): else: return text + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + def get_select_precolumns(self, select): """Called when building a ``SELECT`` statement, position is just before column list. |