diff options
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 59 |
1 files changed, 50 insertions, 9 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6370b1227..5d05cbc29 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1761,11 +1761,12 @@ class SQLCompiler(Compiled): '=' + c[1] for c in colparams ) - if update_stmt._returning: - self.returning = update_stmt._returning + if self.returning or update_stmt._returning: + if not self.returning: + self.returning = update_stmt._returning if self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, update_stmt._returning) + update_stmt, self.returning) if extra_froms: extra_from_text = self.update_from_clause( @@ -1785,7 +1786,7 @@ class SQLCompiler(Compiled): if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( - update_stmt, update_stmt._returning) + update_stmt, self.returning) self.stack.pop(-1) @@ -1866,6 +1867,19 @@ class SQLCompiler(Compiled): self.dialect.implicit_returning and \ stmt.table.implicit_returning + if self.isinsert: + implicit_return_defaults = implicit_returning and stmt._return_defaults + elif self.isupdate: + implicit_return_defaults = self.dialect.implicit_returning and \ + stmt.table.implicit_returning and \ + stmt._return_defaults + + if implicit_return_defaults: + if stmt._return_defaults is True: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults) + postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} @@ -1928,6 +1942,10 @@ class SQLCompiler(Compiled): elif c.primary_key and implicit_returning: self.returning.append(c) value = self.process(value.self_group()) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + value = self.process(value.self_group()) else: self.postfetch.append(c) value = self.process(value.self_group()) @@ -1984,14 +2002,20 @@ class SQLCompiler(Compiled): not self.dialect.sequences_optional): proc = self.process(c.default) values.append((c, proc)) - if not c.primary_key: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: self.postfetch.append(c) elif c.default.is_clause_element: values.append( (c, self.process(c.default.arg.self_group())) ) - if not c.primary_key: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: # dont add primary key column to postfetch self.postfetch.append(c) else: @@ -2000,8 +2024,14 @@ class SQLCompiler(Compiled): ) self.prefetch.append(c) elif c.server_default is not None: - if not c.primary_key: + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + elif not c.primary_key: self.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) elif self.isupdate: if c.onupdate is not None and not c.onupdate.is_sequence: @@ -2009,14 +2039,25 @@ class SQLCompiler(Compiled): values.append( (c, self.process(c.onupdate.arg.self_group())) ) - self.postfetch.append(c) + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + else: + self.postfetch.append(c) else: values.append( (c, self._create_crud_bind_param(c, None)) ) self.prefetch.append(c) elif c.server_onupdate is not None: - self.postfetch.append(c) + if implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) + else: + self.postfetch.append(c) + elif implicit_return_defaults and \ + c in implicit_return_defaults: + self.returning.append(c) if parameters and stmt_parameters: check = set(parameters).intersection( |