summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/engine/default.py59
-rw-r--r--lib/sqlalchemy/sql/compiler.py6
-rw-r--r--lib/sqlalchemy/sql/crud.py35
3 files changed, 66 insertions, 34 deletions
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 3ed2d5ee8..1bb575984 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -593,12 +593,11 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self._is_implicit_returning = bool(
compiled.returning and not compiled.statement._returning)
- if not self.isdelete:
- if self.compiled.prefetch:
- if self.executemany:
- self._process_executemany_defaults()
- else:
- self._process_executesingle_defaults()
+ if self.compiled.insert_prefetch or self.compiled.update_prefetch:
+ if self.executemany:
+ self._process_executemany_defaults()
+ else:
+ self._process_executesingle_defaults()
processors = compiled._bind_processors
@@ -712,7 +711,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
@util.memoized_property
def prefetch_cols(self):
- return self.compiled.prefetch
+ if self.isinsert:
+ return self.compiled.insert_prefetch
+ elif self.isupdate:
+ return self.compiled.update_prefetch
+ else:
+ return ()
@util.memoized_property
def returning_cols(self):
@@ -1007,46 +1011,57 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
def _process_executemany_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- prefetch = self.compiled.prefetch
scalar_defaults = {}
+ insert_prefetch = self.compiled.insert_prefetch
+ update_prefetch = self.compiled.update_prefetch
+
# pre-determine scalar Python-side defaults
# to avoid many calls of get_insert_default()/
# get_update_default()
- for c in prefetch:
- if self.isinsert and c.default and c.default.is_scalar:
+ for c in insert_prefetch:
+ if c.default and c.default.is_scalar:
scalar_defaults[c] = c.default.arg
- elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
+ for c in update_prefetch:
+ if c.onupdate and c.onupdate.is_scalar:
scalar_defaults[c] = c.onupdate.arg
for param in self.compiled_parameters:
self.current_parameters = param
- for c in prefetch:
+ for c in insert_prefetch:
if c in scalar_defaults:
val = scalar_defaults[c]
- elif self.isinsert:
+ else:
val = self.get_insert_default(c)
+ if val is not None:
+ param[key_getter(c)] = val
+ for c in update_prefetch:
+ if c in scalar_defaults:
+ val = scalar_defaults[c]
else:
val = self.get_update_default(c)
if val is not None:
param[key_getter(c)] = val
+
del self.current_parameters
def _process_executesingle_defaults(self):
key_getter = self.compiled._key_getters_for_crud_column[2]
- prefetch = self.compiled.prefetch
self.current_parameters = compiled_parameters = \
self.compiled_parameters[0]
- for c in prefetch:
- if self.isinsert:
- if c.default and \
- not c.default.is_sequence and c.default.is_scalar:
- val = c.default.arg
- else:
- val = self.get_insert_default(c)
+ for c in self.compiled.insert_prefetch:
+ if c.default and \
+ not c.default.is_sequence and c.default.is_scalar:
+ val = c.default.arg
else:
- val = self.get_update_default(c)
+ val = self.get_insert_default(c)
+
+ if val is not None:
+ compiled_parameters[key_getter(c)] = val
+
+ for c in self.compiled.update_prefetch:
+ val = self.get_update_default(c)
if val is not None:
compiled_parameters[key_getter(c)] = val
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 16ca7f959..095c84f03 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -359,6 +359,8 @@ class SQLCompiler(Compiled):
True unless using an unordered TextAsFrom.
"""
+ insert_prefetch = update_prefetch = ()
+
def __init__(self, dialect, statement, column_keys=None,
inline=False, **kwargs):
"""Construct a new :class:`.SQLCompiler` object.
@@ -428,6 +430,10 @@ class SQLCompiler(Compiled):
if self.positional and dialect.paramstyle == 'numeric':
self._apply_numbered_params()
+ @property
+ def prefetch(self):
+ return list(self.insert_prefetch + self.update_prefetch)
+
@util.memoized_instancemethod
def _init_cte_state(self):
"""Initialize collections related to CTEs only if
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 70e03d220..f770fc513 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -11,6 +11,7 @@ within INSERT and UPDATE statements.
"""
from .. import util
from .. import exc
+from . import dml
from . import elements
import operator
@@ -73,7 +74,8 @@ def _get_crud_params(compiler, stmt, **kw):
"""
compiler.postfetch = []
- compiler.prefetch = []
+ compiler.insert_prefetch = []
+ compiler.update_prefetch = []
compiler.returning = []
# no parameters in the statement, no parameters in the
@@ -370,7 +372,7 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
compiler.returning.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_insert_prefetch_bind_param(compiler, c))
)
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
@@ -380,9 +382,15 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
_raise_pk_with_no_anticipated_value(c)
-def _create_prefetch_bind_param(compiler, c, process=True, name=None):
+def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
- compiler.prefetch.append(c)
+ compiler.insert_prefetch.append(c)
+ return param
+
+
+def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
+ param = _create_bind_param(compiler, c, None, process=process, name=name)
+ compiler.update_prefetch.append(c)
return param
@@ -399,7 +407,7 @@ class _multiparam_column(elements.ColumnElement):
other.original == self.original
-def _process_multiparam_default_bind(compiler, c, index, kw):
+def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
if not c.default:
raise exc.CompileError(
@@ -410,7 +418,10 @@ def _process_multiparam_default_bind(compiler, c, index, kw):
return compiler.process(c.default.arg.self_group(), **kw)
else:
col = _multiparam_column(c, index)
- return _create_prefetch_bind_param(compiler, col)
+ if isinstance(stmt, dml.Insert):
+ return _create_insert_prefetch_bind_param(compiler, col)
+ else:
+ return _create_update_prefetch_bind_param(compiler, col)
def _append_param_insert_pk(compiler, stmt, c, values, kw):
@@ -448,7 +459,7 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
)
):
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_insert_prefetch_bind_param(compiler, c))
)
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
@@ -482,7 +493,7 @@ def _append_param_insert_hasdefault(
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_insert_prefetch_bind_param(compiler, c))
)
@@ -500,7 +511,7 @@ def _append_param_insert_select_hasdefault(
values.append((c, proc))
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c, process=False))
+ (c, _create_insert_prefetch_bind_param(compiler, c, process=False))
)
@@ -520,7 +531,7 @@ def _append_param_update(
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(compiler, c))
+ (c, _create_update_prefetch_bind_param(compiler, c))
)
elif c.server_onupdate is not None:
if implicit_return_defaults and \
@@ -575,7 +586,7 @@ def _get_multitable_params(
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_prefetch_bind_param(
+ (c, _create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c)))
)
elif c.server_onupdate is not None:
@@ -597,7 +608,7 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
else compiler.process(
row[c.key].self_group(), **kw))
if c.key in row else
- _process_multiparam_default_bind(compiler, c, i, kw)
+ _process_multiparam_default_bind(compiler, stmt, c, i, kw)
)
for (c, param) in values_0
]