diff options
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 11 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/__init__.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_cte.py | 193 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 2 |
8 files changed, 217 insertions, 3 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index c8a3d3322..62753e1a5 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1684,6 +1684,8 @@ class MySQLDialect(default.DefaultDialect): default_paramstyle = 'format' colspecs = colspecs + cte_follows_insert = True + statement_compiler = MySQLCompiler ddl_compiler = MySQLDDLCompiler type_compiler = MySQLTypeCompiler diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 39acbf28d..356c2a2bf 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1030,6 +1030,7 @@ class OracleDialect(default.DefaultDialect): max_identifier_length = 30 supports_simple_order_by_label = False + cte_follows_insert = True supports_sequences = True sequences_optional = False diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4d5f338bf..54fb25c16 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -60,6 +60,7 @@ class DefaultDialect(interfaces.Dialect): implicit_returning = False supports_right_nested_joins = True + cte_follows_insert = False supports_native_enum = False supports_native_boolean = False diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a442c65fd..0b98dc51c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2105,7 +2105,12 @@ class SQLCompiler(Compiled): returning_clause = None if insert_stmt.select is not None: - text += " %s" % self.process(self._insert_from_select, **kw) + select_text = self.process(self._insert_from_select, **kw) + + if self.ctes and toplevel and self.dialect.cte_follows_insert: + text += " %s%s" % (self._render_cte_clause(), select_text) + else: + text += " %s" % select_text elif not crud_params and supports_default_values: text += " DEFAULT VALUES" elif insert_stmt._has_multi_parameters: @@ -2130,7 +2135,7 @@ class SQLCompiler(Compiled): if returning_clause and not self.returning_precedes_values: text += " " + returning_clause - if self.ctes and toplevel: + if self.ctes and toplevel and not self.dialect.cte_follows_insert: text = self._render_cte_clause() + text self.stack.pop(-1) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index b509c94d6..19d80e028 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -180,9 +180,18 @@ class SuiteRequirements(Requirements): return exclusions.closed() @property + def ctes_with_update_delete(self): + """target database supports CTES that ride on top of a normal UPDATE + or DELETE statement which refers to the CTE in a correlated subquery. + + """ + + return exclusions.closed() + + @property def ctes_on_dml(self): """target database supports CTES which consist of INSERT, UPDATE - or DELETE""" + or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)""" return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py index 9eeffd4cb..748d9722d 100644 --- a/lib/sqlalchemy/testing/suite/__init__.py +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -1,4 +1,5 @@ +from sqlalchemy.testing.suite.test_cte import * from sqlalchemy.testing.suite.test_dialect import * from sqlalchemy.testing.suite.test_ddl import * from sqlalchemy.testing.suite.test_insert import * diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py new file mode 100644 index 000000000..cc72278e6 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -0,0 +1,193 @@ +from .. import fixtures, config +from ..assertions import eq_ + +from sqlalchemy import Integer, String, select +from sqlalchemy import ForeignKey +from sqlalchemy import testing + +from ..schema import Table, Column + + +class CTETest(fixtures.TablesTest): + __backend__ = True + __requires__ = 'ctes', + + run_inserts = 'each' + run_deletes = 'each' + + @classmethod + def define_tables(cls, metadata): + Table("some_table", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column("parent_id", ForeignKey("some_table.id"))) + + Table("some_other_table", metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column("parent_id", Integer)) + + @classmethod + def insert_data(cls): + config.db.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "d1", "parent_id": None}, + {"id": 2, "data": "d2", "parent_id": 1}, + {"id": 3, "data": "d3", "parent_id": 1}, + {"id": 4, "data": "d4", "parent_id": 3}, + {"id": 5, "data": "d5", "parent_id": 3} + ] + ) + + def test_select_nonrecursive_round_trip(self): + some_table = self.tables.some_table + + with config.db.connect() as conn: + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte") + result = conn.execute( + select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4", )]) + + def test_select_recursive_round_trip(self): + some_table = self.tables.some_table + + with config.db.connect() as conn: + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"])).cte( + "some_cte", recursive=True) + + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select([st1]).where(st1.c.id == cte_alias.c.parent_id) + ) + result = conn.execute( + select([cte.c.data]).where( + cte.c.data != "d2").order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)] + ) + + def test_insert_from_select_round_trip(self): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], + select([cte]) + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)] + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.update_from + def test_update_from_round_trip(self): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + conn.execute( + some_other_table.insert().from_select( + ['id', 'data', 'parent_id'], + select([some_table]) + ) + ) + + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.update().values(parent_id=5).where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), (2, "d2", 5), + (3, "d3", 5), (4, "d4", 5), (5, "d5", 3) + ] + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.delete_from + def test_delete_from_round_trip(self): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + conn.execute( + some_other_table.insert().from_select( + ['id', 'data', 'parent_id'], + select([some_table]) + ) + ) + + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), (5, "d5", 3) + ] + ) + + @testing.requires.ctes_with_update_delete + def test_delete_scalar_subq_round_trip(self): + + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + with config.db.connect() as conn: + conn.execute( + some_other_table.insert().from_select( + ['id', 'data', 'parent_id'], + select([some_table]) + ) + ) + + cte = select([some_table]).where( + some_table.c.data.in_(["d2", "d3", "d4"]) + ).cte("some_cte") + conn.execute( + some_other_table.delete().where( + some_other_table.c.data == + select([cte.c.data]).where( + cte.c.id == some_other_table.c.id) + ) + ) + eq_( + conn.execute( + select([some_other_table]).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), (5, "d5", 3) + ] + ) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index d9755c8f9..05b9162de 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -511,3 +511,5 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + |