diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 60 |
1 files changed, 43 insertions, 17 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 0717a8fef..63b9e44b3 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -635,13 +635,7 @@ class DefaultExecutionContext(base.ExecutionContext): ipk.append(row[c]) self._inserted_primary_key = ipk - - def last_inserted_params(self): - return self._last_inserted_params - - def last_updated_params(self): - return self._last_updated_params - + def lastrow_has_defaults(self): return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols) @@ -714,7 +708,32 @@ class DefaultExecutionContext(base.ExecutionContext): return None else: return self._exec_default(column.onupdate) - + + @util.memoized_property + def _inserted_primary_key(self): + + if not self.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() expression construct.") + elif self._is_explicit_returning: + raise exc.InvalidRequestError( + "Can't call inserted_primary_key when returning() " + "is used.") + + + # lazyily evaluate inserted_primary_key for executemany. + # for execute(), its already in __dict__. + if self.executemany: + return [ + [compiled_parameters.get(c.key, None) + for c in self.compiled.\ + statement.table.primary_key + ] for compiled_parameters in self.compiled_parameters + ] + else: + # _inserted_primary_key should be calced here + assert False + def __process_defaults(self): """Generate default values for compiled insert/update statements, and generate inserted_primary_key collection. @@ -746,6 +765,11 @@ class DefaultExecutionContext(base.ExecutionContext): param[c.key] = val del self.current_parameters + if self.isinsert: + self.last_inserted_params = self.compiled_parameters + else: + self.last_updated_params = self.compiled_parameters + else: self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] @@ -759,18 +783,20 @@ class DefaultExecutionContext(base.ExecutionContext): if val is not None: compiled_parameters[c.key] = val del self.current_parameters - - if self.isinsert: + + if self.isinsert and not self._is_explicit_returning: self._inserted_primary_key = [ - compiled_parameters.get(c.key, None) - for c in self.compiled.\ + self.compiled_parameters[0].get(c.key, None) + for c in self.compiled.\ statement.table.primary_key - ] - self._last_inserted_params = compiled_parameters + ] + + if self.isinsert: + self.last_inserted_params = compiled_parameters else: - self._last_updated_params = compiled_parameters + self.last_updated_params = compiled_parameters - self.postfetch_cols = self.compiled.postfetch - self.prefetch_cols = self.compiled.prefetch + self.postfetch_cols = self.compiled.postfetch + self.prefetch_cols = self.compiled.prefetch DefaultDialect.execution_ctx_cls = DefaultExecutionContext |