summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-02-11 12:12:19 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2016-02-11 12:27:28 -0500
commite5f1a3fb7dc1888ed187fdeae8171e4ff322dab6 (patch)
tree320ef9285c4a4477ab90d838c216cba979bc4fc9 /lib/sqlalchemy/sql/compiler.py
parent287aaa9d416b4f72179da320af0624b9ebc43846 (diff)
downloadsqlalchemy-e5f1a3fb7dc1888ed187fdeae8171e4ff322dab6.tar.gz
- CTE functionality has been expanded to support all DML, allowing
INSERT, UPDATE, and DELETE statements to both specify their own WITH clause, as well as for these statements themselves to be CTE expressions when they include a RETURNING clause. fixes #2551
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py94
1 files changed, 54 insertions, 40 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index cc9a49a91..a2fc0fe68 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -418,6 +418,11 @@ class SQLCompiler(Compiled):
self.truncated_names = {}
Compiled.__init__(self, dialect, statement, **kwargs)
+ if (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and statement._returning:
+ self.returning = statement._returning
+
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes and self._is_toplevel_select(select):
+ if self.ctes and toplevel:
text = self._render_cte_clause() + text
if select._suffixes:
@@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled):
else:
return text
- def _is_toplevel_select(self, select):
- """Return True if the stack is placed at the given select, and
- is also the outermost SELECT, meaning there is either no stack
- before this one, or the enclosing stack is a topmost INSERT.
-
- """
- return (
- self.stack[-1]['selectable'] is select and
- (
- len(self.stack) == 1 or self.isinsert and len(self.stack) == 2
- and self.statement is self.stack[0]['selectable']
- )
- )
-
def _setup_select_hints(self, select):
byfrom = dict([
(from_, hinttext % {
@@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled):
)
return dialect_hints, table_text
- def visit_insert(self, insert_stmt, **kw):
+ def visit_insert(self, insert_stmt, asfrom=False, **kw):
+ toplevel = not self.stack
+
self.stack.append(
{'correlate_froms': set(),
"asfrom_froms": set(),
"selectable": insert_stmt})
- self.isinsert = True
- crud_params = crud._get_crud_params(self, insert_stmt, **kw)
+ crud_params = crud._setup_crud_params(
+ self, insert_stmt, crud.ISINSERT, **kw)
if not crud_params and \
not self.dialect.supports_default_values and \
@@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled):
for c in crud_params_single])
if self.returning or insert_stmt._returning:
- self.returning = self.returning or insert_stmt._returning
returning_clause = self.returning_clause(
- insert_stmt, self.returning)
+ insert_stmt, self.returning or insert_stmt._returning)
if self.returning_precedes_values:
text += " " + returning_clause
+ else:
+ returning_clause = None
if insert_stmt.select is not None:
text += " %s" % self.process(self._insert_from_select, **kw)
@@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled):
text += " VALUES (%s)" % \
', '.join([c[1] for c in crud_params])
- if self.returning and not self.returning_precedes_values:
+ if returning_clause and not self.returning_precedes_values:
text += " " + returning_clause
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
- return text
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
def update_limit_clause(self, update_stmt):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
@@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled):
MySQL overrides this.
"""
- return from_table._compiler_dispatch(self, asfrom=True,
- iscrud=True, **kw)
+ kw['asfrom'] = True
+ return from_table._compiler_dispatch(self, iscrud=True, **kw)
def update_from_clause(self, update_stmt,
from_table, extra_froms,
@@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled):
fromhints=from_hints, **kw)
for t in extra_froms)
- def visit_update(self, update_stmt, **kw):
+ def visit_update(self, update_stmt, asfrom=False, **kw):
+ toplevel = not self.stack
+
self.stack.append(
{'correlate_froms': set([update_stmt.table]),
"asfrom_froms": set([update_stmt.table]),
"selectable": update_stmt})
- self.isupdate = True
-
extra_froms = update_stmt._extra_froms
text = "UPDATE "
@@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- crud_params = crud._get_crud_params(self, update_stmt, **kw)
+ crud_params = crud._setup_crud_params(
+ self, update_stmt, crud.ISUPDATE, **kw)
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
@@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled):
)
if self.returning or update_stmt._returning:
- if not self.returning:
- self.returning = update_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, self.returning or update_stmt._returning)
if extra_froms:
extra_from_text = self.update_from_clause(
@@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if self.returning and not self.returning_precedes_values:
+ if (self.returning or update_stmt._returning) and \
+ not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, self.returning or update_stmt._returning)
+
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
self.stack.pop(-1)
- return text
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
@util.memoized_property
def _key_getters_for_crud_column(self):
- return crud._key_getters_for_crud_column(self)
+ return crud._key_getters_for_crud_column(self, self.statement)
+
+ def visit_delete(self, delete_stmt, asfrom=False, **kw):
+ toplevel = not self.stack
- def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
"asfrom_froms": set([delete_stmt.table]),
"selectable": delete_stmt})
- self.isdelete = True
+
+ crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
text = "DELETE "
@@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled):
text += table_text
if delete_stmt._returning:
- self.returning = delete_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
@@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled):
if t:
text += " WHERE " + t
- if self.returning and not self.returning_precedes_values:
+ if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
- return text
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
def visit_savepoint(self, savepoint_stmt):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)