summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py2
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py1
-rw-r--r--lib/sqlalchemy/engine/default.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py9
-rw-r--r--lib/sqlalchemy/testing/requirements.py11
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py1
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py193
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py2
8 files changed, 217 insertions, 3 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index c8a3d3322..62753e1a5 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1684,6 +1684,8 @@ class MySQLDialect(default.DefaultDialect):
default_paramstyle = 'format'
colspecs = colspecs
+ cte_follows_insert = True
+
statement_compiler = MySQLCompiler
ddl_compiler = MySQLDDLCompiler
type_compiler = MySQLTypeCompiler
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 39acbf28d..356c2a2bf 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -1030,6 +1030,7 @@ class OracleDialect(default.DefaultDialect):
max_identifier_length = 30
supports_simple_order_by_label = False
+ cte_follows_insert = True
supports_sequences = True
sequences_optional = False
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 4d5f338bf..54fb25c16 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -60,6 +60,7 @@ class DefaultDialect(interfaces.Dialect):
implicit_returning = False
supports_right_nested_joins = True
+ cte_follows_insert = False
supports_native_enum = False
supports_native_boolean = False
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index a442c65fd..0b98dc51c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2105,7 +2105,12 @@ class SQLCompiler(Compiled):
returning_clause = None
if insert_stmt.select is not None:
- text += " %s" % self.process(self._insert_from_select, **kw)
+ select_text = self.process(self._insert_from_select, **kw)
+
+ if self.ctes and toplevel and self.dialect.cte_follows_insert:
+ text += " %s%s" % (self._render_cte_clause(), select_text)
+ else:
+ text += " %s" % select_text
elif not crud_params and supports_default_values:
text += " DEFAULT VALUES"
elif insert_stmt._has_multi_parameters:
@@ -2130,7 +2135,7 @@ class SQLCompiler(Compiled):
if returning_clause and not self.returning_precedes_values:
text += " " + returning_clause
- if self.ctes and toplevel:
+ if self.ctes and toplevel and not self.dialect.cte_follows_insert:
text = self._render_cte_clause() + text
self.stack.pop(-1)
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index b509c94d6..19d80e028 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -180,9 +180,18 @@ class SuiteRequirements(Requirements):
return exclusions.closed()
@property
+ def ctes_with_update_delete(self):
+ """target database supports CTES that ride on top of a normal UPDATE
+ or DELETE statement which refers to the CTE in a correlated subquery.
+
+ """
+
+ return exclusions.closed()
+
+ @property
def ctes_on_dml(self):
"""target database supports CTES which consist of INSERT, UPDATE
- or DELETE"""
+ or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
return exclusions.closed()
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
index 9eeffd4cb..748d9722d 100644
--- a/lib/sqlalchemy/testing/suite/__init__.py
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -1,4 +1,5 @@
+from sqlalchemy.testing.suite.test_cte import *
from sqlalchemy.testing.suite.test_dialect import *
from sqlalchemy.testing.suite.test_ddl import *
from sqlalchemy.testing.suite.test_insert import *
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644
index 000000000..cc72278e6
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -0,0 +1,193 @@
+from .. import fixtures, config
+from ..assertions import eq_
+
+from sqlalchemy import Integer, String, select
+from sqlalchemy import ForeignKey
+from sqlalchemy import testing
+
+from ..schema import Table, Column
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = 'ctes',
+
+ run_inserts = 'each'
+ run_deletes = 'each'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("some_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column("parent_id", ForeignKey("some_table.id")))
+
+ Table("some_other_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column("parent_id", Integer))
+
+ @classmethod
+ def insert_data(cls):
+ config.db.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3}
+ ]
+ )
+
+ def test_select_nonrecursive_round_trip(self):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+ result = conn.execute(
+ select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4", )])
+
+ def test_select_recursive_round_trip(self):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])).cte(
+ "some_cte", recursive=True)
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select([st1]).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = conn.execute(
+ select([cte.c.data]).where(
+ cte.c.data != "d2").order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+ )
+
+ def test_insert_from_select_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"],
+ select([cte])
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.update().values(parent_id=5).where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (2, "d2", 5),
+ (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
+ ]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (5, "d5", 3)
+ ]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data ==
+ select([cte.c.data]).where(
+ cte.c.id == some_other_table.c.id)
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (5, "d5", 3)
+ ]
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
index d9755c8f9..05b9162de 100644
--- a/lib/sqlalchemy/testing/suite/test_select.py
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -511,3 +511,5 @@ class LikeFunctionsTest(fixtures.TablesTest):
col = self.tables.some_table.c.data
self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+