diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-02-11 12:12:19 -0500 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-02-11 12:27:28 -0500 |
commit | e5f1a3fb7dc1888ed187fdeae8171e4ff322dab6 (patch) | |
tree | 320ef9285c4a4477ab90d838c216cba979bc4fc9 /lib/sqlalchemy/sql/compiler.py | |
parent | 287aaa9d416b4f72179da320af0624b9ebc43846 (diff) | |
download | sqlalchemy-e5f1a3fb7dc1888ed187fdeae8171e4ff322dab6.tar.gz |
- CTE functionality has been expanded to support all DML, allowing
INSERT, UPDATE, and DELETE statements to both specify their own
WITH clause, as well as for these statements themselves to be
CTE expressions when they include a RETURNING clause.
fixes #2551
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 94 |
1 files changed, 54 insertions, 40 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index cc9a49a91..a2fc0fe68 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -418,6 +418,11 @@ class SQLCompiler(Compiled): self.truncated_names = {} Compiled.__init__(self, dialect, statement, **kwargs) + if ( + self.isinsert or self.isupdate or self.isdelete + ) and statement._returning: + self.returning = statement._returning + if self.positional and dialect.paramstyle == 'numeric': self._apply_numbered_params() @@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes and self._is_toplevel_select(select): + if self.ctes and toplevel: text = self._render_cte_clause() + text if select._suffixes: @@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled): else: return text - def _is_toplevel_select(self, select): - """Return True if the stack is placed at the given select, and - is also the outermost SELECT, meaning there is either no stack - before this one, or the enclosing stack is a topmost INSERT. - - """ - return ( - self.stack[-1]['selectable'] is select and - ( - len(self.stack) == 1 or self.isinsert and len(self.stack) == 2 - and self.statement is self.stack[0]['selectable'] - ) - ) - def _setup_select_hints(self, select): byfrom = dict([ (from_, hinttext % { @@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled): ) return dialect_hints, table_text - def visit_insert(self, insert_stmt, **kw): + def visit_insert(self, insert_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set(), "asfrom_froms": set(), "selectable": insert_stmt}) - self.isinsert = True - crud_params = crud._get_crud_params(self, insert_stmt, **kw) + crud_params = crud._setup_crud_params( + self, insert_stmt, crud.ISINSERT, **kw) if not crud_params and \ not self.dialect.supports_default_values and \ @@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled): for c in crud_params_single]) if self.returning or insert_stmt._returning: - self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause( - insert_stmt, self.returning) + insert_stmt, self.returning or insert_stmt._returning) if self.returning_precedes_values: text += " " + returning_clause + else: + returning_clause = None if insert_stmt.select is not None: text += " %s" % self.process(self._insert_from_select, **kw) @@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in crud_params]) - if self.returning and not self.returning_precedes_values: + if returning_clause and not self.returning_precedes_values: text += " " + returning_clause + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def update_limit_clause(self, update_stmt): """Provide a hook for MySQL to add LIMIT to the UPDATE""" @@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled): MySQL overrides this. """ - return from_table._compiler_dispatch(self, asfrom=True, - iscrud=True, **kw) + kw['asfrom'] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) def update_from_clause(self, update_stmt, from_table, extra_froms, @@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled): fromhints=from_hints, **kw) for t in extra_froms) - def visit_update(self, update_stmt, **kw): + def visit_update(self, update_stmt, asfrom=False, **kw): + toplevel = not self.stack + self.stack.append( {'correlate_froms': set([update_stmt.table]), "asfrom_froms": set([update_stmt.table]), "selectable": update_stmt}) - self.isupdate = True - extra_froms = update_stmt._extra_froms text = "UPDATE " @@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - crud_params = crud._get_crud_params(self, update_stmt, **kw) + crud_params = crud._setup_crud_params( + self, update_stmt, crud.ISUPDATE, **kw) if update_stmt._hints: dialect_hints, table_text = self._setup_crud_hints( @@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled): ) if self.returning or update_stmt._returning: - if not self.returning: - self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled): if limit_clause: text += " " + limit_clause - if self.returning and not self.returning_precedes_values: + if (self.returning or update_stmt._returning) and \ + not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, self.returning) + update_stmt, self.returning or update_stmt._returning) + + if self.ctes and toplevel: + text = self._render_cte_clause() + text self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text @util.memoized_property def _key_getters_for_crud_column(self): - return crud._key_getters_for_crud_column(self) + return crud._key_getters_for_crud_column(self, self.statement) + + def visit_delete(self, delete_stmt, asfrom=False, **kw): + toplevel = not self.stack - def visit_delete(self, delete_stmt, **kw): self.stack.append({'correlate_froms': set([delete_stmt.table]), "asfrom_froms": set([delete_stmt.table]), "selectable": delete_stmt}) - self.isdelete = True + + crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw) text = "DELETE " @@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled): text += table_text if delete_stmt._returning: - self.returning = delete_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) @@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled): if t: text += " WHERE " + t - if self.returning and not self.returning_precedes_values: + if delete_stmt._returning and not self.returning_precedes_values: text += " " + self.returning_clause( delete_stmt, delete_stmt._returning) + if self.ctes and toplevel: + text = self._render_cte_clause() + text + self.stack.pop(-1) - return text + if asfrom: + return "(" + text + ")" + else: + return text def visit_savepoint(self, savepoint_stmt): return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) |