summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py94
1 files changed, 54 insertions, 40 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index cc9a49a91..a2fc0fe68 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -418,6 +418,11 @@ class SQLCompiler(Compiled):
self.truncated_names = {}
Compiled.__init__(self, dialect, statement, **kwargs)
+ if (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and statement._returning:
+ self.returning = statement._returning
+
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
@@ -1659,7 +1664,7 @@ class SQLCompiler(Compiled):
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes and self._is_toplevel_select(select):
+ if self.ctes and toplevel:
text = self._render_cte_clause() + text
if select._suffixes:
@@ -1673,20 +1678,6 @@ class SQLCompiler(Compiled):
else:
return text
- def _is_toplevel_select(self, select):
- """Return True if the stack is placed at the given select, and
- is also the outermost SELECT, meaning there is either no stack
- before this one, or the enclosing stack is a topmost INSERT.
-
- """
- return (
- self.stack[-1]['selectable'] is select and
- (
- len(self.stack) == 1 or self.isinsert and len(self.stack) == 2
- and self.statement is self.stack[0]['selectable']
- )
- )
-
def _setup_select_hints(self, select):
byfrom = dict([
(from_, hinttext % {
@@ -1876,14 +1867,16 @@ class SQLCompiler(Compiled):
)
return dialect_hints, table_text
- def visit_insert(self, insert_stmt, **kw):
+ def visit_insert(self, insert_stmt, asfrom=False, **kw):
+ toplevel = not self.stack
+
self.stack.append(
{'correlate_froms': set(),
"asfrom_froms": set(),
"selectable": insert_stmt})
- self.isinsert = True
- crud_params = crud._get_crud_params(self, insert_stmt, **kw)
+ crud_params = crud._setup_crud_params(
+ self, insert_stmt, crud.ISINSERT, **kw)
if not crud_params and \
not self.dialect.supports_default_values and \
@@ -1929,12 +1922,13 @@ class SQLCompiler(Compiled):
for c in crud_params_single])
if self.returning or insert_stmt._returning:
- self.returning = self.returning or insert_stmt._returning
returning_clause = self.returning_clause(
- insert_stmt, self.returning)
+ insert_stmt, self.returning or insert_stmt._returning)
if self.returning_precedes_values:
text += " " + returning_clause
+ else:
+ returning_clause = None
if insert_stmt.select is not None:
text += " %s" % self.process(self._insert_from_select, **kw)
@@ -1953,12 +1947,18 @@ class SQLCompiler(Compiled):
text += " VALUES (%s)" % \
', '.join([c[1] for c in crud_params])
- if self.returning and not self.returning_precedes_values:
+ if returning_clause and not self.returning_precedes_values:
text += " " + returning_clause
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
- return text
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
def update_limit_clause(self, update_stmt):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
@@ -1972,8 +1972,8 @@ class SQLCompiler(Compiled):
MySQL overrides this.
"""
- return from_table._compiler_dispatch(self, asfrom=True,
- iscrud=True, **kw)
+ kw['asfrom'] = True
+ return from_table._compiler_dispatch(self, iscrud=True, **kw)
def update_from_clause(self, update_stmt,
from_table, extra_froms,
@@ -1990,14 +1990,14 @@ class SQLCompiler(Compiled):
fromhints=from_hints, **kw)
for t in extra_froms)
- def visit_update(self, update_stmt, **kw):
+ def visit_update(self, update_stmt, asfrom=False, **kw):
+ toplevel = not self.stack
+
self.stack.append(
{'correlate_froms': set([update_stmt.table]),
"asfrom_froms": set([update_stmt.table]),
"selectable": update_stmt})
- self.isupdate = True
-
extra_froms = update_stmt._extra_froms
text = "UPDATE "
@@ -2009,7 +2009,8 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- crud_params = crud._get_crud_params(self, update_stmt, **kw)
+ crud_params = crud._setup_crud_params(
+ self, update_stmt, crud.ISUPDATE, **kw)
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
@@ -2029,11 +2030,9 @@ class SQLCompiler(Compiled):
)
if self.returning or update_stmt._returning:
- if not self.returning:
- self.returning = update_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, self.returning or update_stmt._returning)
if extra_froms:
extra_from_text = self.update_from_clause(
@@ -2053,23 +2052,33 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if self.returning and not self.returning_precedes_values:
+ if (self.returning or update_stmt._returning) and \
+ not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning)
+ update_stmt, self.returning or update_stmt._returning)
+
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
self.stack.pop(-1)
- return text
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
@util.memoized_property
def _key_getters_for_crud_column(self):
- return crud._key_getters_for_crud_column(self)
+ return crud._key_getters_for_crud_column(self, self.statement)
+
+ def visit_delete(self, delete_stmt, asfrom=False, **kw):
+ toplevel = not self.stack
- def visit_delete(self, delete_stmt, **kw):
self.stack.append({'correlate_froms': set([delete_stmt.table]),
"asfrom_froms": set([delete_stmt.table]),
"selectable": delete_stmt})
- self.isdelete = True
+
+ crud._setup_crud_params(self, delete_stmt, crud.ISDELETE, **kw)
text = "DELETE "
@@ -2088,7 +2097,6 @@ class SQLCompiler(Compiled):
text += table_text
if delete_stmt._returning:
- self.returning = delete_stmt._returning
if self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
@@ -2098,13 +2106,19 @@ class SQLCompiler(Compiled):
if t:
text += " WHERE " + t
- if self.returning and not self.returning_precedes_values:
+ if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt, delete_stmt._returning)
+ if self.ctes and toplevel:
+ text = self._render_cte_clause() + text
+
self.stack.pop(-1)
- return text
+ if asfrom:
+ return "(" + text + ")"
+ else:
+ return text
def visit_savepoint(self, savepoint_stmt):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)