diff options
Diffstat (limited to 'lib/sqlalchemy/engine/default.py')
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 59 |
1 files changed, 37 insertions, 22 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 3ed2d5ee8..1bb575984 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -593,12 +593,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self._is_implicit_returning = bool( compiled.returning and not compiled.statement._returning) - if not self.isdelete: - if self.compiled.prefetch: - if self.executemany: - self._process_executemany_defaults() - else: - self._process_executesingle_defaults() + if self.compiled.insert_prefetch or self.compiled.update_prefetch: + if self.executemany: + self._process_executemany_defaults() + else: + self._process_executesingle_defaults() processors = compiled._bind_processors @@ -712,7 +711,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext): @util.memoized_property def prefetch_cols(self): - return self.compiled.prefetch + if self.isinsert: + return self.compiled.insert_prefetch + elif self.isupdate: + return self.compiled.update_prefetch + else: + return () @util.memoized_property def returning_cols(self): @@ -1007,46 +1011,57 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _process_executemany_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch scalar_defaults = {} + insert_prefetch = self.compiled.insert_prefetch + update_prefetch = self.compiled.update_prefetch + # pre-determine scalar Python-side defaults # to avoid many calls of get_insert_default()/ # get_update_default() - for c in prefetch: - if self.isinsert and c.default and c.default.is_scalar: + for c in insert_prefetch: + if c.default and c.default.is_scalar: scalar_defaults[c] = c.default.arg - elif self.isupdate and c.onupdate and c.onupdate.is_scalar: + for c in update_prefetch: + if c.onupdate and c.onupdate.is_scalar: scalar_defaults[c] = c.onupdate.arg for param in self.compiled_parameters: self.current_parameters = param - for c in prefetch: + for c in insert_prefetch: if c in scalar_defaults: val = scalar_defaults[c] - elif self.isinsert: + else: val = self.get_insert_default(c) + if val is not None: + param[key_getter(c)] = val + for c in update_prefetch: + if c in scalar_defaults: + val = scalar_defaults[c] else: val = self.get_update_default(c) if val is not None: param[key_getter(c)] = val + del self.current_parameters def _process_executesingle_defaults(self): key_getter = self.compiled._key_getters_for_crud_column[2] - prefetch = self.compiled.prefetch self.current_parameters = compiled_parameters = \ self.compiled_parameters[0] - for c in prefetch: - if self.isinsert: - if c.default and \ - not c.default.is_sequence and c.default.is_scalar: - val = c.default.arg - else: - val = self.get_insert_default(c) + for c in self.compiled.insert_prefetch: + if c.default and \ + not c.default.is_sequence and c.default.is_scalar: + val = c.default.arg else: - val = self.get_update_default(c) + val = self.get_insert_default(c) + + if val is not None: + compiled_parameters[key_getter(c)] = val + + for c in self.compiled.update_prefetch: + val = self.get_update_default(c) if val is not None: compiled_parameters[key_getter(c)] = val |