summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-03-03 13:00:44 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2012-03-03 13:00:44 -0500
commit1607b74f8527905ecdc6133b4b4166a9ed675e09 (patch)
treecd752b16ab90c4864a071689c57f3ff946f8b241 /lib/sqlalchemy/sql/compiler.py
parent4d43079e34a66c3718127266bc5eaa3041c69447 (diff)
downloadsqlalchemy-1607b74f8527905ecdc6133b4b4166a9ed675e09.tar.gz
- [feature] Added cte() method to Query,
invokes common table expression support from the Core (see below). [ticket:1859] - [feature] Added support for SQL standard common table expressions (CTE), allowing SELECT objects as the CTE source (DML not yet supported). This is invoked via the cte() method on any select() construct. [ticket:1859]
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b955c5608..e8f86634d 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -252,6 +252,10 @@ class SQLCompiler(engine.Compiled):
# column targeting
self.result_map = {}
+ # collect CTEs to tack on top of a SELECT
+ self.ctes = util.OrderedDict()
+ self.ctes_recursive = False
+
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
@@ -749,6 +753,45 @@ class SQLCompiler(engine.Compiled):
else:
return self.bindtemplate % {'name':name}
+ def visit_cte(self, cte, asfrom=False, ashint=False,
+ fromhints=None, **kwargs):
+ if isinstance(cte.name, sql._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte.name)
+ else:
+ cte_name = cte.name
+ if cte.cte_alias:
+ if isinstance(cte.cte_alias, sql._truncated_label):
+ cte_alias = self._truncated_identifier("alias", cte.cte_alias)
+ else:
+ cte_alias = cte.cte_alias
+ if not cte.cte_alias and cte not in self.ctes:
+ if cte.recursive:
+ self.ctes_recursive = True
+ text = self.preparer.format_alias(cte, cte_name)
+ if cte.recursive:
+ if isinstance(cte.original, sql.Select):
+ col_source = cte.original
+ elif isinstance(cte.original, sql.CompoundSelect):
+ col_source = cte.original.selects[0]
+ else:
+ assert False
+ recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
+ if c is not None]
+
+ text += "(%s)" % (", ".join(recur_cols))
+ text += " AS \n" + \
+ cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
+ self.ctes[cte] = text
+ if asfrom:
+ if cte.cte_alias:
+ text = self.preparer.format_alias(cte, cte_alias)
+ text += " AS " + cte_name
+ else:
+ return self.preparer.format_alias(cte, cte_name)
+ return text
+
def visit_alias(self, alias, asfrom=False, ashint=False,
fromhints=None, **kwargs):
if asfrom or ashint:
@@ -909,6 +952,15 @@ class SQLCompiler(engine.Compiled):
if select.for_update:
text += self.for_update_clause(select)
+ if self.ctes and \
+ compound_index==1 and not entry:
+ cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+ cte_text += ", \n".join(
+ [txt for txt in self.ctes.values()]
+ )
+ cte_text += "\n "
+ text = cte_text + text
+
self.stack.pop(-1)
if asfrom and parens:
@@ -916,6 +968,12 @@ class SQLCompiler(engine.Compiled):
else:
return text
+ def get_cte_preamble(self, recursive):
+ if recursive:
+ return "WITH RECURSIVE"
+ else:
+ return "WITH"
+
def get_select_precolumns(self, select):
"""Called when building a ``SELECT`` statement, position is just
before column list.