from sqlalchemy.testing import fixtures from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message from sqlalchemy.sql import table, column, select, func, literal from sqlalchemy.dialects import mssql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError class CTETest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' def test_nonrecursive(self): orders = table('orders', column('region'), column('amount'), column('product'), column('quantity') ) regional_sales = select([ orders.c.region, func.sum(orders.c.amount).label('total_sales') ]).group_by(orders.c.region).cte("regional_sales") top_regions = select([regional_sales.c.region]).\ where( regional_sales.c.total_sales > select([ func.sum(regional_sales.c.total_sales) / 10 ]) ).cte("top_regions") s = select([ orders.c.region, orders.c.product, func.sum(orders.c.quantity).label("product_units"), func.sum(orders.c.amount).label("product_sales") ]).where(orders.c.region.in_( select([top_regions.c.region]) )).group_by(orders.c.region, orders.c.product) # needs to render regional_sales first as top_regions # refers to it self.assert_compile( s, "WITH regional_sales AS (SELECT orders.region AS region, " "sum(orders.amount) AS total_sales FROM orders " "GROUP BY orders.region), " "top_regions AS (SELECT " "regional_sales.region AS region FROM regional_sales " "WHERE regional_sales.total_sales > " "(SELECT sum(regional_sales.total_sales) / :sum_1 AS " "anon_1 FROM regional_sales)) " "SELECT orders.region, orders.product, " "sum(orders.quantity) AS product_units, " "sum(orders.amount) AS product_sales " "FROM orders WHERE orders.region " "IN (SELECT top_regions.region FROM top_regions) " "GROUP BY orders.region, orders.product" ) def test_recursive(self): parts = table('parts', column('part'), column('sub_part'), column('quantity'), ) included_parts = select([ parts.c.sub_part, parts.c.part, parts.c.quantity]).\ where(parts.c.part == 'our part').\ cte(recursive=True) incl_alias = included_parts.alias() parts_alias = parts.alias() included_parts = included_parts.union( select([ parts_alias.c.sub_part, parts_alias.c.part, parts_alias.c.quantity]). where(parts_alias.c.part == incl_alias.c.sub_part) ) s = select([ included_parts.c.sub_part, func.sum(included_parts.c.quantity).label('total_quantity')]).\ select_from(included_parts.join( parts, included_parts.c.part == parts.c.part)).\ group_by(included_parts.c.sub_part) self.assert_compile( s, "WITH RECURSIVE anon_1(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " "WHERE parts.part = :part_1 UNION " "SELECT parts_1.sub_part AS sub_part, " "parts_1.part AS part, parts_1.quantity " "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " "WHERE parts_1.part = anon_2.sub_part) " "SELECT anon_1.sub_part, " "sum(anon_1.quantity) AS total_quantity FROM anon_1 " "JOIN parts ON anon_1.part = parts.part " "GROUP BY anon_1.sub_part") # quick check that the "WITH RECURSIVE" varies per # dialect self.assert_compile( s, "WITH anon_1(sub_part, part, quantity) " "AS (SELECT parts.sub_part AS sub_part, parts.part " "AS part, parts.quantity AS quantity FROM parts " "WHERE parts.part = :part_1 UNION " "SELECT parts_1.sub_part AS sub_part, " "parts_1.part AS part, parts_1.quantity " "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " "WHERE parts_1.part = anon_2.sub_part) " "SELECT anon_1.sub_part, " "sum(anon_1.quantity) AS total_quantity FROM anon_1 " "JOIN parts ON anon_1.part = parts.part " "GROUP BY anon_1.sub_part", dialect=mssql.dialect()) def test_recursive_union_no_alias_one(self): s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) cte = cte.union_all( select([cte.c.x + 1]).where(cte.c.x < 10) ) s2 = select([cte]) self.assert_compile(s2, "WITH RECURSIVE cte(x) AS " "(SELECT :param_1 AS x UNION ALL " "SELECT cte.x + :x_1 AS anon_1 " "FROM cte WHERE cte.x < :x_2) " "SELECT cte.x FROM cte" ) def test_recursive_union_no_alias_two(self): """ pg's example: WITH RECURSIVE t(n) AS ( VALUES (1) UNION ALL SELECT n+1 FROM t WHERE n < 100 ) SELECT sum(n) FROM t; """ # I know, this is the PG VALUES keyword, # we're cheating here. also yes we need the SELECT, # sorry PG. t = select([func.values(1).label("n")]).cte("t", recursive=True) t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)) s = select([func.sum(t.c.n)]) self.assert_compile(s, "WITH RECURSIVE t(n) AS " "(SELECT values(:values_1) AS n " "UNION ALL SELECT t.n + :n_1 AS anon_1 " "FROM t " "WHERE t.n < :n_2) " "SELECT sum(t.n) AS sum_1 FROM t" ) def test_recursive_union_no_alias_three(self): # like test one, but let's refer to the CTE # in a sibling CTE. s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) # can't do it here... #bar = select([cte]).cte('bar') cte = cte.union_all( select([cte.c.x + 1]).where(cte.c.x < 10) ) bar = select([cte]).cte('bar') s2 = select([cte, bar]) self.assert_compile(s2, "WITH RECURSIVE cte(x) AS " "(SELECT :param_1 AS x UNION ALL " "SELECT cte.x + :x_1 AS anon_1 " "FROM cte WHERE cte.x < :x_2), " "bar AS (SELECT cte.x AS x FROM cte) " "SELECT cte.x, bar.x FROM cte, bar" ) def test_recursive_union_no_alias_four(self): # like test one and three, but let's refer # previous version of "cte". here we test # how the compiler resolves multiple instances # of "cte". s1 = select([literal(0).label("x")]) cte = s1.cte(name="cte", recursive=True) bar = select([cte]).cte('bar') cte = cte.union_all( select([cte.c.x + 1]).where(cte.c.x < 10) ) # outer cte rendered first, then bar, which # includes "inner" cte s2 = select([cte, bar]) self.assert_compile(s2, "WITH RECURSIVE cte(x) AS " "(SELECT :param_1 AS x UNION ALL " "SELECT cte.x + :x_1 AS anon_1 " "FROM cte WHERE cte.x < :x_2), " "bar AS (SELECT cte.x AS x FROM cte) " "SELECT cte.x, bar.x FROM cte, bar" ) # bar rendered, only includes "inner" cte, # "outer" cte isn't present s2 = select([bar]) self.assert_compile(s2, "WITH RECURSIVE cte(x) AS " "(SELECT :param_1 AS x), " "bar AS (SELECT cte.x AS x FROM cte) " "SELECT bar.x FROM bar" ) # bar rendered, but then the "outer" # cte is rendered. s2 = select([bar, cte]) self.assert_compile( s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), " "cte(x) AS " "(SELECT :param_1 AS x UNION ALL " "SELECT cte.x + :x_1 AS anon_1 " "FROM cte WHERE cte.x < :x_2) " "SELECT bar.x, cte.x FROM bar, cte") def test_conflicting_names(self): """test a flat out name conflict.""" s1 = select([1]) c1 = s1.cte(name='cte1', recursive=True) s2 = select([1]) c2 = s2.cte(name='cte1', recursive=True) s = select([c1, c2]) assert_raises_message( CompileError, "Multiple, unrelated CTEs found " "with the same name: 'cte1'", s.compile ) def test_union(self): orders = table('orders', column('region'), column('amount'), ) regional_sales = select([ orders.c.region, orders.c.amount ]).cte("regional_sales") s = select( [regional_sales.c.region]).where( regional_sales.c.amount > 500 ) self.assert_compile(s, "WITH regional_sales AS " "(SELECT orders.region AS region, " "orders.amount AS amount FROM orders) " "SELECT regional_sales.region " "FROM regional_sales WHERE " "regional_sales.amount > :amount_1") s = s.union_all( select([regional_sales.c.region]). where( regional_sales.c.amount < 300 ) ) self.assert_compile(s, "WITH regional_sales AS " "(SELECT orders.region AS region, " "orders.amount AS amount FROM orders) " "SELECT regional_sales.region FROM regional_sales " "WHERE regional_sales.amount > :amount_1 " "UNION ALL SELECT regional_sales.region " "FROM regional_sales WHERE " "regional_sales.amount < :amount_2") def test_reserved_quote(self): orders = table('orders', column('order'), ) s = select([orders.c.order]).cte("regional_sales", recursive=True) s = select([s.c.order]) self.assert_compile(s, 'WITH RECURSIVE regional_sales("order") AS ' '(SELECT orders."order" AS "order" ' "FROM orders)" ' SELECT regional_sales."order" ' "FROM regional_sales" ) def test_multi_subq_quote(self): cte = select([literal(1).label("id")]).cte(name='CTE') s1 = select([cte.c.id]).alias() s2 = select([cte.c.id]).alias() s = select([s1, s2]) self.assert_compile( s, 'WITH "CTE" AS (SELECT :param_1 AS id) ' 'SELECT anon_1.id, anon_2.id FROM ' '(SELECT "CTE".id AS id FROM "CTE") AS anon_1, ' '(SELECT "CTE".id AS id FROM "CTE") AS anon_2' ) def test_positional_binds(self): orders = table('orders', column('order'), ) s = select([orders.c.order, literal("x")]).cte("regional_sales") s = select([s.c.order, literal("y")]) dialect = default.DefaultDialect() dialect.positional = True dialect.paramstyle = 'numeric' self.assert_compile( s, 'WITH regional_sales AS (SELECT orders."order" ' 'AS "order", :1 AS anon_2 FROM orders) SELECT ' 'regional_sales."order", :2 AS anon_1 FROM regional_sales', checkpositional=( 'x', 'y'), dialect=dialect) self.assert_compile( s.union(s), 'WITH regional_sales AS (SELECT orders."order" ' 'AS "order", :1 AS anon_2 FROM orders) SELECT ' 'regional_sales."order", :2 AS anon_1 FROM regional_sales ' 'UNION SELECT regional_sales."order", :3 AS anon_1 ' 'FROM regional_sales', checkpositional=( 'x', 'y', 'y'), dialect=dialect) s = select([orders.c.order]).\ where(orders.c.order == 'x').cte("regional_sales") s = select([s.c.order]).where(s.c.order == "y") self.assert_compile( s, 'WITH regional_sales AS (SELECT orders."order" AS ' '"order" FROM orders WHERE orders."order" = :1) ' 'SELECT regional_sales."order" FROM regional_sales ' 'WHERE regional_sales."order" = :2', checkpositional=( 'x', 'y'), dialect=dialect) def test_positional_binds_2(self): orders = table('orders', column('order'), ) s = select([orders.c.order, literal("x")]).cte("regional_sales") s = select([s.c.order, literal("y")]) dialect = default.DefaultDialect() dialect.positional = True dialect.paramstyle = 'numeric' s1 = select([orders.c.order]).where(orders.c.order == 'x').\ cte("regional_sales_1") s1a = s1.alias() s2 = select([orders.c.order == 'y', s1a.c.order, orders.c.order, s1.c.order]).\ where(orders.c.order == 'z').\ cte("regional_sales_2") s3 = select([s2]) self.assert_compile( s3, 'WITH regional_sales_1 AS (SELECT orders."order" AS "order" ' 'FROM orders WHERE orders."order" = :1), regional_sales_2 AS ' '(SELECT orders."order" = :2 AS anon_1, ' 'anon_2."order" AS "order", ' 'orders."order" AS "order", ' 'regional_sales_1."order" AS "order" FROM orders, ' 'regional_sales_1 ' 'AS anon_2, regional_sales_1 ' 'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, ' 'regional_sales_2."order" FROM regional_sales_2', checkpositional=('x', 'y', 'z'), dialect=dialect) def test_positional_binds_2_asliteral(self): orders = table('orders', column('order'), ) s = select([orders.c.order, literal("x")]).cte("regional_sales") s = select([s.c.order, literal("y")]) dialect = default.DefaultDialect() dialect.positional = True dialect.paramstyle = 'numeric' s1 = select([orders.c.order]).where(orders.c.order == 'x').\ cte("regional_sales_1") s1a = s1.alias() s2 = select([orders.c.order == 'y', s1a.c.order, orders.c.order, s1.c.order]).\ where(orders.c.order == 'z').\ cte("regional_sales_2") s3 = select([s2]) self.assert_compile( s3, 'WITH regional_sales_1 AS ' '(SELECT orders."order" AS "order" ' 'FROM orders ' 'WHERE orders."order" = \'x\'), ' 'regional_sales_2 AS ' '(SELECT orders."order" = \'y\' AS anon_1, ' 'anon_2."order" AS "order", orders."order" AS "order", ' 'regional_sales_1."order" AS "order" ' 'FROM orders, regional_sales_1 AS anon_2, regional_sales_1 ' 'WHERE orders."order" = \'z\') ' 'SELECT regional_sales_2.anon_1, regional_sales_2."order" ' 'FROM regional_sales_2', checkpositional=(), dialect=dialect, literal_binds=True) def test_all_aliases(self): orders = table('order', column('order')) s = select([orders.c.order]).cte("regional_sales") r1 = s.alias() r2 = s.alias() s2 = select([r1, r2]).where(r1.c.order > r2.c.order) self.assert_compile( s2, 'WITH regional_sales AS (SELECT "order"."order" ' 'AS "order" FROM "order") ' 'SELECT anon_1."order", anon_2."order" ' 'FROM regional_sales AS anon_1, ' 'regional_sales AS anon_2 WHERE anon_1."order" > anon_2."order"' ) s3 = select( [orders]).select_from( orders.join( r1, r1.c.order == orders.c.order)) self.assert_compile( s3, 'WITH regional_sales AS ' '(SELECT "order"."order" AS "order" ' 'FROM "order")' ' SELECT "order"."order" ' 'FROM "order" JOIN regional_sales AS anon_1 ' 'ON anon_1."order" = "order"."order"' ) def test_suffixes(self): orders = table('order', column('order')) s = select([orders.c.order]).cte("regional_sales") s = s.suffix_with("pg suffix", dialect='postgresql') s = s.suffix_with('oracle suffix', dialect='oracle') stmt = select([orders]).where(orders.c.order > s.c.order) self.assert_compile( stmt, 'WITH regional_sales AS (SELECT "order"."order" AS "order" ' 'FROM "order") SELECT "order"."order" FROM "order", ' 'regional_sales WHERE "order"."order" > regional_sales."order"' ) self.assert_compile( stmt, 'WITH regional_sales AS (SELECT "order"."order" AS "order" ' 'FROM "order") oracle suffix SELECT "order"."order" FROM "order", ' 'regional_sales WHERE "order"."order" > regional_sales."order"', dialect='oracle' ) self.assert_compile( stmt, 'WITH regional_sales AS (SELECT "order"."order" AS "order" ' 'FROM "order") pg suffix SELECT "order"."order" FROM "order", ' 'regional_sales WHERE "order"."order" > regional_sales."order"', dialect='postgresql' )