diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-02-11 12:12:19 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-02-11 12:27:28 -0500 |
commit | e5f1a3fb7dc1888ed187fdeae8171e4ff322dab6 (patch) | |
tree | 320ef9285c4a4477ab90d838c216cba979bc4fc9 /lib/sqlalchemy | |
parent | 287aaa9d416b4f72179da320af0624b9ebc43846 (diff) | |
download | sqlalchemy-e5f1a3fb7dc1888ed187fdeae8171e4ff322dab6.tar.gz |
- CTE functionality has been expanded to support all DML, allowing
INSERT, UPDATE, and DELETE statements to both specify their own
WITH clause, as well as for these statements themselves to be
CTE expressions when they include a RETURNING clause.
fixes #2551
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/orm/query.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 94 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 47 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 9 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 288 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 2 |
7 files changed, 275 insertions, 173 deletions
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8a25f570a..ad7b9130b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -478,8 +478,6 @@ class Query(object): """Return the full SELECT statement represented by this :class:`.Query` represented as a common table expression (CTE). - .. versionadded:: 0.7.6 - Parameters and usage are the same as those of the :meth:`.SelectBase.cte` method; see that method for further details. @@ -528,7 +526,7 @@ class Query(object): .. seealso:: - :meth:`.SelectBase.cte` + :meth:`.HasCTE.cte` """ return self.enable_eagerloads(False).\ diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cc9a49a91..a2fc0fe68 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -418,6 +418,11 @@ class SQLCompiler(Compiled): self.truncated_names = {} Compiled.__init__(self, dialect, statement, **kwargs) + if ( + self.isinsert or self.isupdate or self.isdelete + ) and statement._returning: + self.returning = statement._returning + if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and self._is_toplevel_select(select): + if self.ctes and toplevel: text = self._render_cte_clause() + text if select._suffixes: @@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled): else: return text - def _is_toplevel_select(self, select): - """Return True if the stack is placed at the given select, and - is also the outermost SELECT, meaning there is either no stack - before this one, or the enclosing stack is a topmost INSERT. - - """ - return ( - self.stack[-1]['selectable'] is select and - ( - len(self.stack) == 1 or self.isinsert and len(self.stack) == 2 - and self.statement is self.stack[0]['selectable'] - ) - ) - def _setup_select_hints(self, select): byfrom = dict([ (from_, hinttext % { @@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled): ) return dialect_hints, table_text - def visit_insert(self, insert_stmt, **kw): + def visit_insert(self, insert_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set(), "asfrom_froms": set(), "selectable": insert_stmt}) - self.isinsert = True - crud_params = crud._get_crud_params(self, insert_stmt, **kw) + crud_params = crud._setup_crud_params( + self, insert_stmt, crud.ISINSERT, **kw) if not crud_params and \ not self.dialect.supports_default_values and \ @@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled): for c in crud_params_single]) if self.returning or insert_stmt._returning: - self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause( - insert_stmt, self.returning) + insert_stmt, self.returning or insert_stmt._returning) if self.returning_precedes_values: text += " " + returning_clause + else: + returning_clause = None if insert_stmt.select is not None: text += " %s" % self.process(self._insert_from_select, **kw) @@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in crud_params]) - if self.returning and not self.returning_precedes_values: + if returning_clause and not self.returning_precedes_values: text += " " + returning_clause + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def update_limit_clause(self, update_stmt): """Provide a hook for MySQL to add LIMIT to the UPDATE""" @@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled): MySQL overrides this. """ - return from_table._compiler_dispatch(self, asfrom=True, - iscrud=True, **kw) + kw['asfrom'] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) def update_from_clause(self, update_stmt, from_table, extra_froms, @@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled): fromhints=from_hints, **kw) for t in extra_froms) - def visit_update(self, update_stmt, **kw): + def visit_update(self, update_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set([update_stmt.table]), "asfrom_froms": set([update_stmt.table]), "selectable": update_stmt}) - self.isupdate = True - extra_froms = update_stmt._extra_froms text = "UPDATE " @@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, **kw) + crud_params = crud._setup_crud_params( + self, update_stmt, crud.ISUPDATE, **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( @@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled): ) if self.returning or update_stmt._returning: - if not self.returning: - self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if self.returning and not self.returning_precedes_values: + if (self.returning or update_stmt._returning) and \ + not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text @util.memoized_property def _key_getters_for_crud_column(self): - return crud._key_getters_for_crud_column(self) + return crud._key_getters_for_crud_column(self, self.statement) + + def visit_delete(self, delete_stmt, asfrom=False, **kw): + toplevel = not self.stack - def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), "asfrom_froms": set([delete_stmt.table]), "selectable": delete_stmt}) - self.isdelete = True + + crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) text = "DELETE " @@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled): text += table_text if delete_stmt._returning: - self.returning = delete_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) @@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if self.returning and not self.returning_precedes_values: + if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index a01b72e61..58cd80995 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -25,6 +25,41 @@ values present. """) +ISINSERT = util.symbol('ISINSERT') +ISUPDATE = util.symbol('ISUPDATE') +ISDELETE = util.symbol('ISDELETE') + + +def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): + restore_isinsert = compiler.isinsert + restore_isupdate = compiler.isupdate + restore_isdelete = compiler.isdelete + + should_restore = ( + restore_isinsert or restore_isupdate or restore_isdelete + ) or len(compiler.stack) > 1 + + if local_stmt_type is ISINSERT: + compiler.isupdate = False + compiler.isinsert = True + elif local_stmt_type is ISUPDATE: + compiler.isupdate = True + compiler.isinsert = False + elif local_stmt_type is ISDELETE: + if not should_restore: + compiler.isdelete = True + else: + assert False, "ISINSERT, ISUPDATE, or ISDELETE expected" + + try: + if local_stmt_type in (ISINSERT, ISUPDATE): + return _get_crud_params(compiler, stmt, **kw) + finally: + if should_restore: + compiler.isinsert = restore_isinsert + compiler.isupdate = restore_isupdate + compiler.isdelete = restore_isdelete + def _get_crud_params(compiler, stmt, **kw): """create a set of tuples representing column/string pairs for use @@ -59,7 +94,7 @@ def _get_crud_params(compiler, stmt, **kw): # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler) + _key_getters_for_crud_column(compiler, stmt) # if we have statement parameters - set defaults in the # compiled params @@ -128,15 +163,15 @@ def _create_bind_param( return bindparam -def _key_getters_for_crud_column(compiler): - if compiler.isupdate and compiler.statement._extra_froms: +def _key_getters_for_crud_column(compiler, stmt): + if compiler.isupdate and stmt._extra_froms: # when extra tables are present, refer to the columns # in those extra tables as table-qualified, including in # dictionaries and when rendering bind param names. # the "main" table of the statement remains unqualified, # allowing the most compatibility with a non-multi-table # statement. - _et = set(compiler.statement._extra_froms) + _et = set(stmt._extra_froms) def _column_as_key(key): str_key = elements._column_as_key(key) @@ -609,7 +644,9 @@ def _get_returning_modifiers(compiler, stmt): stmt.table.implicit_returning and stmt._return_defaults) else: - implicit_return_defaults = False + # this line is unused, currently we are always + # isinsert or isupdate + implicit_return_defaults = False # pragma: no cover if implicit_return_defaults: if stmt._return_defaults is True: diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 7b506f9db..8f368dcdb 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -9,15 +9,18 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`. """ -from .base import Executable, _generative, _from_objects, DialectKWArgs +from .base import Executable, _generative, _from_objects, DialectKWArgs, \ + ColumnCollection from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \ _column_as_key -from .selectable import _interpret_as_from, _interpret_as_select, HasPrefixes +from .selectable import _interpret_as_from, _interpret_as_select, \ + HasPrefixes, HasCTE from .. import util from .. import exc -class UpdateBase(DialectKWArgs, HasPrefixes, Executable, ClauseElement): +class UpdateBase( + HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements. """ diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index a63eac2f8..36f7f7fe1 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -46,8 +46,8 @@ from .base import ColumnCollection, Generative, Executable, \ from .selectable import Alias, Join, Select, Selectable, TableClause, \ CompoundSelect, CTE, FromClause, FromGrouping, SelectBase, \ - alias, GenerativeSelect, \ - subquery, HasPrefixes, HasSuffixes, Exists, ScalarSelect, TextAsFrom + alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \ + Exists, ScalarSelect, TextAsFrom from .dml import Insert, Update, Delete, UpdateBase, ValuesBase diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c3906c2f2..fcd22a786 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1195,6 +1195,15 @@ class CTE(Generative, HasSuffixes, Alias): self._suffixes = _suffixes super(CTE, self).__init__(selectable, name=name) + @util.dependencies("sqlalchemy.sql.dml") + def _populate_column_collection(self, dml): + if isinstance(self.element, dml.UpdateBase): + for col in self.element._returning: + col._make_proxy(self) + else: + for col in self.element.columns._all_columns: + col._make_proxy(self) + def alias(self, name=None, flat=False): return CTE( self.original, @@ -1223,6 +1232,164 @@ class CTE(Generative, HasSuffixes, Alias): ) +class HasCTE(object): + """Mixin that declares a class to include CTE support. + + .. versionadded:: 1.1 + + """ + + def cte(self, name=None, recursive=False): + """Return a new :class:`.CTE`, or Common Table Expression instance. + + Common table expressions are a SQL standard whereby SELECT + statements can draw upon secondary statements specified along + with the primary statement, using a clause called "WITH". + Special semantics regarding UNION can also be employed to + allow "recursive" queries, where a SELECT statement can draw + upon the set of rows that have previously been selected. + + CTEs can also be applied to DML constructs UPDATE, INSERT + and DELETE on some databases, both as a source of CTE rows + when combined with RETURNING, as well as a consumer of + CTE rows. + + SQLAlchemy detects :class:`.CTE` objects, which are treated + similarly to :class:`.Alias` objects, as special elements + to be delivered to the FROM clause of the statement as well + as to a WITH clause at the top of the statement. + + .. versionchanged:: 1.1 Added support for UPDATE/INSERT/DELETE as + CTE, CTEs added to UPDATE/INSERT/DELETE. + + :param name: name given to the common table expression. Like + :meth:`._FromClause.alias`, the name can be left as ``None`` + in which case an anonymous symbol will be used at query + compile time. + :param recursive: if ``True``, will render ``WITH RECURSIVE``. + A recursive common table expression is intended to be used in + conjunction with UNION ALL in order to derive rows + from those already selected. + + The following examples include two from Postgresql's documentation at + http://www.postgresql.org/docs/current/static/queries-with.html, + as well as additional examples. + + Example 1, non recursive:: + + from sqlalchemy import (Table, Column, String, Integer, + MetaData, select, func) + + metadata = MetaData() + + orders = Table('orders', metadata, + Column('region', String), + Column('amount', Integer), + Column('product', String), + Column('quantity', Integer) + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + + top_regions = select([regional_sales.c.region]).\\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + statement = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + result = conn.execute(statement).fetchall() + + Example 2, WITH RECURSIVE:: + + from sqlalchemy import (Table, Column, String, Integer, + MetaData, select, func) + + metadata = MetaData() + + parts = Table('parts', metadata, + Column('part', String), + Column('sub_part', String), + Column('quantity', Integer), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\\ + where(parts.c.part=='our part').\\ + cte(recursive=True) + + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union_all( + select([ + parts_alias.c.sub_part, + parts_alias.c.part, + parts_alias.c.quantity + ]). + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + statement = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ]).\\ + group_by(included_parts.c.sub_part) + + result = conn.execute(statement).fetchall() + + Example 3, an upsert using UPDATE and INSERT with CTEs:: + + orders = table( + 'orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + upsert = ( + orders.update() + .where(orders.c.region == 'Region1') + .values(amount=1.0, product='Product1', quantity=1) + .returning(*(orders.c._all_columns)).cte('upsert')) + + insert = orders.insert().from_select( + orders.c.keys(), + select([ + literal('Region1'), literal(1.0), + literal('Product1'), literal(1) + ).where(exists(upsert.select())) + ) + + connection.execute(insert) + + .. seealso:: + + :meth:`.orm.query.Query.cte` - ORM version of + :meth:`.HasCTE.cte`. + + """ + return CTE(self, name=name, recursive=recursive) + + class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' @@ -1497,7 +1664,7 @@ class ForUpdateArg(ClauseElement): self.of = None -class SelectBase(Executable, FromClause): +class SelectBase(HasCTE, Executable, FromClause): """Base class for SELECT statements. @@ -1531,125 +1698,6 @@ class SelectBase(Executable, FromClause): """ return self.as_scalar().label(name) - def cte(self, name=None, recursive=False): - """Return a new :class:`.CTE`, or Common Table Expression instance. - - Common table expressions are a SQL standard whereby SELECT - statements can draw upon secondary statements specified along - with the primary statement, using a clause called "WITH". - Special semantics regarding UNION can also be employed to - allow "recursive" queries, where a SELECT statement can draw - upon the set of rows that have previously been selected. - - SQLAlchemy detects :class:`.CTE` objects, which are treated - similarly to :class:`.Alias` objects, as special elements - to be delivered to the FROM clause of the statement as well - as to a WITH clause at the top of the statement. - - .. versionadded:: 0.7.6 - - :param name: name given to the common table expression. Like - :meth:`._FromClause.alias`, the name can be left as ``None`` - in which case an anonymous symbol will be used at query - compile time. - :param recursive: if ``True``, will render ``WITH RECURSIVE``. - A recursive common table expression is intended to be used in - conjunction with UNION ALL in order to derive rows - from those already selected. - - The following examples illustrate two examples from - Postgresql's documentation at - http://www.postgresql.org/docs/8.4/static/queries-with.html. - - Example 1, non recursive:: - - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) - - metadata = MetaData() - - orders = Table('orders', metadata, - Column('region', String), - Column('amount', Integer), - Column('product', String), - Column('quantity', Integer) - ) - - regional_sales = select([ - orders.c.region, - func.sum(orders.c.amount).label('total_sales') - ]).group_by(orders.c.region).cte("regional_sales") - - - top_regions = select([regional_sales.c.region]).\\ - where( - regional_sales.c.total_sales > - select([ - func.sum(regional_sales.c.total_sales)/10 - ]) - ).cte("top_regions") - - statement = select([ - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), - func.sum(orders.c.amount).label("product_sales") - ]).where(orders.c.region.in_( - select([top_regions.c.region]) - )).group_by(orders.c.region, orders.c.product) - - result = conn.execute(statement).fetchall() - - Example 2, WITH RECURSIVE:: - - from sqlalchemy import (Table, Column, String, Integer, - MetaData, select, func) - - metadata = MetaData() - - parts = Table('parts', metadata, - Column('part', String), - Column('sub_part', String), - Column('quantity', Integer), - ) - - included_parts = select([ - parts.c.sub_part, - parts.c.part, - parts.c.quantity]).\\ - where(parts.c.part=='our part').\\ - cte(recursive=True) - - - incl_alias = included_parts.alias() - parts_alias = parts.alias() - included_parts = included_parts.union_all( - select([ - parts_alias.c.sub_part, - parts_alias.c.part, - parts_alias.c.quantity - ]). - where(parts_alias.c.part==incl_alias.c.sub_part) - ) - - statement = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity). - label('total_quantity') - ]).\\ - group_by(included_parts.c.sub_part) - - result = conn.execute(statement).fetchall() - - - .. seealso:: - - :meth:`.orm.query.Query.cte` - ORM version of - :meth:`.SelectBase.cte`. - - """ - return CTE(self, name=name, recursive=recursive) - @_generative @util.deprecated('0.6', message="``autocommit()`` is deprecated. Use " diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index bb5a96256..21f9f68fb 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -296,6 +296,8 @@ class AssertsCompiledSQL(object): dialect = config.db.dialect elif dialect == 'default': dialect = default.DefaultDialect() + elif dialect == 'default_enhanced': + dialect = default.StrCompileDialect() elif isinstance(dialect, util.string_types): dialect = url.URL(dialect).get_dialect()() |