summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py181
1 files changed, 103 insertions, 78 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index e474952ce..2827a5817 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -16,6 +16,7 @@ from . import coercions
from . import dml
from . import elements
from . import roles
+from .elements import ClauseElement
from .. import exc
from .. import util
@@ -33,45 +34,8 @@ values present.
""",
)
-ISINSERT = util.symbol("ISINSERT")
-ISUPDATE = util.symbol("ISUPDATE")
-ISDELETE = util.symbol("ISDELETE")
-
-def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
- restore_isinsert = compiler.isinsert
- restore_isupdate = compiler.isupdate
- restore_isdelete = compiler.isdelete
-
- should_restore = (
- (restore_isinsert or restore_isupdate or restore_isdelete)
- or len(compiler.stack) > 1
- or "visiting_cte" in kw
- )
-
- if local_stmt_type is ISINSERT:
- compiler.isupdate = False
- compiler.isinsert = True
- elif local_stmt_type is ISUPDATE:
- compiler.isupdate = True
- compiler.isinsert = False
- elif local_stmt_type is ISDELETE:
- if not should_restore:
- compiler.isdelete = True
- else:
- assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
-
- try:
- if local_stmt_type in (ISINSERT, ISUPDATE):
- return _get_crud_params(compiler, stmt, **kw)
- finally:
- if should_restore:
- compiler.isinsert = restore_isinsert
- compiler.isupdate = restore_isupdate
- compiler.isdelete = restore_isdelete
-
-
-def _get_crud_params(compiler, stmt, **kw):
+def _get_crud_params(compiler, stmt, compile_state, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -87,27 +51,29 @@ def _get_crud_params(compiler, stmt, **kw):
compiler.update_prefetch = []
compiler.returning = []
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ (
+ _column_as_key,
+ _getattr_col_key,
+ _col_bind_name,
+ ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+
+ compiler._key_getters_for_crud_column = getters
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
- if compiler.column_keys is None and stmt.parameters is None:
+ if compiler.column_keys is None and compile_state._no_parameters:
return [
(c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
- if stmt._has_multi_parameters:
- stmt_parameters = stmt.parameters[0]
+ if compile_state._has_multi_parameters:
+ stmt_parameters = compile_state._multi_parameters[0]
else:
- stmt_parameters = stmt.parameters
-
- # getters - these are normally just column.key,
- # but in the case of mysql multi-table update, the rules for
- # .key must conditionally take tablename into account
- (
- _column_as_key,
- _getattr_col_key,
- _col_bind_name,
- ) = _key_getters_for_crud_column(compiler, stmt)
+ stmt_parameters = compile_state._dict_parameters
# if we have statement parameters - set defaults in the
# compiled params
@@ -132,10 +98,15 @@ def _get_crud_params(compiler, stmt, **kw):
# special logic that only occurs for multi-table UPDATE
# statements
- if compiler.isupdate and stmt._extra_froms and stmt_parameters:
+ if (
+ compile_state.isupdate
+ and compile_state._extra_froms
+ and stmt_parameters
+ ):
_get_multitable_params(
compiler,
stmt,
+ compile_state,
stmt_parameters,
check_columns,
_col_bind_name,
@@ -144,10 +115,11 @@ def _get_crud_params(compiler, stmt, **kw):
kw,
)
- if compiler.isinsert and stmt.select_names:
+ if compile_state.isinsert and stmt._select_names:
_scan_insert_from_select_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -160,6 +132,7 @@ def _get_crud_params(compiler, stmt, **kw):
_scan_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -181,8 +154,10 @@ def _get_crud_params(compiler, stmt, **kw):
% (", ".join("%s" % (c,) for c in check))
)
- if stmt._has_multi_parameters:
- values = _extend_values_for_multiparams(compiler, stmt, values, kw)
+ if compile_state._has_multi_parameters:
+ values = _extend_values_for_multiparams(
+ compiler, stmt, compile_state, values, kw
+ )
return values
@@ -201,15 +176,46 @@ def _create_bind_param(
return bindparam
-def _key_getters_for_crud_column(compiler, stmt):
- if compiler.isupdate and stmt._extra_froms:
+def _handle_values_anonymous_param(compiler, col, value, name, **kw):
+ # the insert() and update() constructs as of 1.4 will now produce anonymous
+ # bindparam() objects in the values() collections up front when given plain
+ # literal values. This is so that cache key behaviors, which need to
+ # produce bound parameters in deterministic order without invoking any
+ # compilation here, can be applied to these constructs when they include
+ # values() (but not yet multi-values, which are not included in caching
+ # right now).
+ #
+ # in order to produce the desired "crud" style name for these parameters,
+ # which will also be targetable in engine/default.py through the usual
+ # conventions, apply our desired name to these unique parameters by
+ # populating the compiler truncated names cache with the desired name,
+ # rather than having
+ # compiler.visit_bindparam()->compiler._truncated_identifier make up a
+ # name. Saves on call counts also.
+ if value.unique and isinstance(value.key, elements._truncated_label):
+ compiler.truncated_names[("bindparam", value.key)] = name
+
+ if value.type._isnull:
+ # either unique parameter, or other bound parameters that were
+ # passed in directly
+ # clone using base ClauseElement to retain unique key
+ value = ClauseElement._clone(value)
+
+ # set type to that of the column unconditionally
+ value.type = col.type
+
+ return value._compiler_dispatch(compiler, **kw)
+
+
+def _key_getters_for_crud_column(compiler, stmt, compile_state):
+ if compile_state.isupdate and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
- _et = set(stmt._extra_froms)
+ _et = set(compile_state._extra_froms)
c_key_role = functools.partial(
coercions.expect_as_key, roles.DMLColumnRole
@@ -246,6 +252,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -260,9 +267,9 @@ def _scan_insert_from_select_cols(
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt)
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
- cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
compiler._insert_from_select = stmt.select
@@ -294,6 +301,7 @@ def _scan_insert_from_select_cols(
def _scan_cols(
compiler,
stmt,
+ compile_state,
parameters,
_getattr_col_key,
_column_as_key,
@@ -308,11 +316,11 @@ def _scan_cols(
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt)
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
- if stmt._parameter_ordering:
+ if compile_state._parameter_ordering:
parameter_ordering = [
- _column_as_key(key) for key in stmt._parameter_ordering
+ _column_as_key(key) for key in compile_state._parameter_ordering
]
ordered_keys = set(parameter_ordering)
cols = [stmt.table.c[key] for key in parameter_ordering] + [
@@ -329,6 +337,7 @@ def _scan_cols(
_append_param_parameter(
compiler,
stmt,
+ compile_state,
c,
col_key,
parameters,
@@ -339,7 +348,7 @@ def _scan_cols(
kw,
)
- elif compiler.isinsert:
+ elif compile_state.isinsert:
if (
c.primary_key
and need_pks
@@ -377,7 +386,7 @@ def _scan_cols(
):
_warn_pk_with_no_anticipated_value(c)
- elif compiler.isupdate:
+ elif compile_state.isupdate:
_append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
)
@@ -386,6 +395,7 @@ def _scan_cols(
def _append_param_parameter(
compiler,
stmt,
+ compile_state,
c,
col_key,
parameters,
@@ -395,7 +405,9 @@ def _append_param_parameter(
values,
kw,
):
+
value = parameters.pop(col_key)
+
if coercions._is_literal(value):
value = _create_bind_param(
compiler,
@@ -403,15 +415,21 @@ def _append_param_parameter(
value,
required=value is REQUIRED,
name=_col_bind_name(c)
- if not stmt._has_multi_parameters
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler,
+ c,
+ value,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and value.type._isnull:
- value = value._clone()
- value.type = c.type
-
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -644,6 +662,7 @@ def _append_param_update(
def _get_multitable_params(
compiler,
stmt,
+ compile_state,
stmt_parameters,
check_columns,
_col_bind_name,
@@ -656,7 +675,7 @@ def _get_multitable_params(
for c, param in stmt_parameters.items()
)
affected_tables = set()
- for t in stmt._extra_froms:
+ for t in compile_state._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
@@ -669,6 +688,11 @@ def _get_multitable_params(
value,
required=value is REQUIRED,
name=_col_bind_name(c),
+ **kw # TODO: no test coverage for literal binds here
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler, c, value, name=_col_bind_name(c), **kw
)
else:
compiler.postfetch.append(c)
@@ -704,11 +728,11 @@ def _get_multitable_params(
compiler.postfetch.append(c)
-def _extend_values_for_multiparams(compiler, stmt, values, kw):
+def _extend_values_for_multiparams(compiler, stmt, compile_state, values, kw):
values_0 = values
values = [values]
- for i, row in enumerate(stmt.parameters[1:]):
+ for i, row in enumerate(compile_state._multi_parameters[1:]):
extension = []
for (col, param) in values_0:
if col in row or col.key in row:
@@ -757,12 +781,13 @@ def _get_stmt_parameters_params(
values.append((k, v))
-def _get_returning_modifiers(compiler, stmt):
+def _get_returning_modifiers(compiler, stmt, compile_state):
+
need_pks = (
- compiler.isinsert
+ compile_state.isinsert
and not compiler.inline
and not stmt._returning
- and not stmt._has_multi_parameters
+ and not compile_state._has_multi_parameters
)
implicit_returning = (
@@ -771,9 +796,9 @@ def _get_returning_modifiers(compiler, stmt):
and stmt.table.implicit_returning
)
- if compiler.isinsert:
+ if compile_state.isinsert:
implicit_return_defaults = implicit_returning and stmt._return_defaults
- elif compiler.isupdate:
+ elif compile_state.isupdate:
implicit_return_defaults = (
compiler.dialect.implicit_returning
and stmt.table.implicit_returning