summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-06-13 18:21:42 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-06-13 18:21:42 -0400
commit1ac57f0b52e3e89097129931d46ebbbb39ee7927 (patch)
tree548f6e6982fa3501d6b5a5cf7abd1537a5a68a68
parenta29245e247698160172e42e2154180997b81b8ba (diff)
downloadsqlalchemy-1ac57f0b52e3e89097129931d46ebbbb39ee7927.tar.gz
- [bug] Repaired common table expression
rendering to function correctly when the SELECT statement contains UNION or other compound expressions, courtesy btbuilder. [ticket:2490]
-rw-r--r--CHANGES7
-rw-r--r--lib/sqlalchemy/sql/compiler.py19
-rw-r--r--test/sql/test_compiler.py44
3 files changed, 62 insertions, 8 deletions
diff --git a/CHANGES b/CHANGES
index 517f5f553..d16ce9777 100644
--- a/CHANGES
+++ b/CHANGES
@@ -315,6 +315,13 @@ CHANGES
also in 0.7.7, [ticket:2499] also
in 0.7.8.
+ - [bug] Repaired common table expression
+ rendering to function correctly when the
+ SELECT statement contains UNION or other
+ compound expressions, courtesy btbuilder.
+ [ticket:2490] Also in 0.7.8.
+
+
- sqlite
- [feature] the SQLite date and time types
have been overhauled to support a more open
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 05cc70aba..fd7e7d773 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -572,6 +572,10 @@ class SQLCompiler(engine.Compiled):
text += (cs._limit is not None or cs._offset is not None) and \
self.limit_clause(cs) or ""
+ if self.ctes and \
+ compound_index==1 and not entry:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
if asfrom and parens:
return "(" + text + ")"
@@ -968,12 +972,7 @@ class SQLCompiler(engine.Compiled):
if self.ctes and \
compound_index==1 and not entry:
- cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
- cte_text += ", \n".join(
- [txt for txt in self.ctes.values()]
- )
- cte_text += "\n "
- text = cte_text + text
+ text = self._render_cte_clause() + text
self.stack.pop(-1)
@@ -982,6 +981,14 @@ class SQLCompiler(engine.Compiled):
else:
return text
+ def _render_cte_clause(self):
+ cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+ cte_text += ", \n".join(
+ [txt for txt in self.ctes.values()]
+ )
+ cte_text += "\n "
+ return cte_text
+
def get_cte_preamble(self, recursive):
if recursive:
return "WITH RECURSIVE"
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index dfdf8bd87..61dcf61ab 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -345,7 +345,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
eq_(s.positiontup, ['a', 'b', 'c'])
def test_nested_label_targeting(self):
- """test nested anonymous label generation.
+ """test nested anonymous label generation.
"""
s1 = table1.select()
@@ -1203,7 +1203,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT mytable.myid, mytable.name, mytable.description "
"FROM mytable WHERE mytable.myid = %(myid_1)s FOR SHARE",
dialect=postgresql.dialect())
-
+
self.assert_compile(
table1.select(table1.c.myid==7, for_update="read_nowait"),
"SELECT mytable.myid, mytable.name, mytable.description "
@@ -2446,6 +2446,46 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
dialect=mssql.dialect()
)
+ def test_cte_union(self):
+ orders = table('orders',
+ column('region'),
+ column('amount'),
+ )
+
+ regional_sales = select([
+ orders.c.region,
+ orders.c.amount
+ ]).cte("regional_sales")
+
+ s = select([regional_sales.c.region]).\
+ where(
+ regional_sales.c.amount > 500
+ )
+
+ self.assert_compile(s,
+ "WITH regional_sales AS "
+ "(SELECT orders.region AS region, "
+ "orders.amount AS amount FROM orders) "
+ "SELECT regional_sales.region "
+ "FROM regional_sales WHERE "
+ "regional_sales.amount > :amount_1")
+
+ s = s.union_all(
+ select([regional_sales.c.region]).\
+ where(
+ regional_sales.c.amount < 300
+ )
+ )
+ self.assert_compile(s,
+ "WITH regional_sales AS "
+ "(SELECT orders.region AS region, "
+ "orders.amount AS amount FROM orders) "
+ "SELECT regional_sales.region FROM regional_sales "
+ "WHERE regional_sales.amount > :amount_1 "
+ "UNION ALL SELECT regional_sales.region "
+ "FROM regional_sales WHERE "
+ "regional_sales.amount < :amount_2")
+
def test_date_between(self):
import datetime
table = Table('dt', metadata,