summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/crud.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r--lib/sqlalchemy/sql/crud.py440
1 files changed, 262 insertions, 178 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 999d48a55..602b91a25 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -15,7 +15,9 @@ from . import dml
from . import elements
import operator
-REQUIRED = util.symbol('REQUIRED', """
+REQUIRED = util.symbol(
+ "REQUIRED",
+ """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
to :meth:`.Connection.execute`.
@@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert`
or :func:`.expression.update` statement is compiled without parameter
values present.
-""")
+""",
+)
-ISINSERT = util.symbol('ISINSERT')
-ISUPDATE = util.symbol('ISUPDATE')
-ISDELETE = util.symbol('ISDELETE')
+ISINSERT = util.symbol("ISINSERT")
+ISUPDATE = util.symbol("ISUPDATE")
+ISDELETE = util.symbol("ISDELETE")
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
@@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw):
# compiled params - return binds for all columns
if compiler.column_keys is None and stmt.parameters is None:
return [
- (c, _create_bind_param(
- compiler, c, None, required=True))
+ (c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
@@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw):
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- _key_getters_for_crud_column(compiler, stmt)
+ _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column(
+ compiler, stmt
+ )
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict(
+ (_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or key not in stmt_parameters
+ )
# create a list of column assignment clauses as tuples
values = []
if stmt_parameters is not None:
_get_stmt_parameters_params(
- compiler,
- parameters, stmt_parameters, _column_as_key, values, kw)
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+ )
check_columns = {}
@@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw):
# statements
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
_get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw)
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+ )
if compiler.isinsert and stmt.select_names:
_scan_insert_from_select_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
else:
_scan_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt_parameters
- ).difference(check_columns)
+ check = (
+ set(parameters)
+ .intersection(_column_as_key(k) for k in stmt_parameters)
+ .difference(check_columns)
+ )
if check:
raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
+ "Unconsumed column names: %s"
+ % (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
@@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw):
def _create_bind_param(
- compiler, col, value, process=True,
- required=False, name=None, **kw):
+ compiler, col, value, process=True, required=False, name=None, **kw
+):
if name is None:
name = col.key
bindparam = elements.BindParameter(
- name, value, type_=col.type, required=required)
+ name, value, type_=col.type, required=required
+ )
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
@@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _column_as_key(key):
str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
+ if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
@@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
- cols = [stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
compiler._insert_from_select = stmt.select
@@ -228,32 +263,39 @@ def _scan_insert_from_select_cols(
values.append((c, None))
else:
_append_param_insert_select_hasdefault(
- compiler, stmt, c, add_select_cols, kw)
+ compiler, stmt, c, add_select_cols, kw
+ )
if add_select_cols:
values.extend(add_select_cols)
compiler._insert_from_select = compiler._insert_from_select._generate()
- compiler._insert_from_select._raw_columns = \
- tuple(compiler._insert_from_select._raw_columns) + tuple(
- expr for col, expr in add_select_cols)
+ compiler._insert_from_select._raw_columns = tuple(
+ compiler._insert_from_select._raw_columns
+ ) + tuple(expr for col, expr in add_select_cols)
def _scan_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
if stmt._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in stmt._parameter_ordering
]
ordered_keys = set(parameter_ordering)
- cols = [
- stmt.table.c[key] for key in parameter_ordering
- ] + [
+ cols = [stmt.table.c[key] for key in parameter_ordering] + [
c for c in stmt.table.c if c.key not in ordered_keys
]
else:
@@ -265,72 +307,95 @@ def _scan_cols(
if col_key in parameters and col_key not in check_columns:
_append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw)
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
elif compiler.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
+ if (
+ c.primary_key
+ and need_pks
+ and (
+ implicit_returning
+ or not postfetch_lastrowid
+ or c is not stmt.table._autoincrement_column
+ )
+ ):
if implicit_returning:
_append_param_insert_pk_returning(
- compiler, stmt, c, values, kw)
+ compiler, stmt, c, values, kw
+ )
else:
_append_param_insert_pk(compiler, stmt, c, values, kw)
elif c.default is not None:
_append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults,
- values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
- elif c.primary_key and \
- c is not stmt.table._autoincrement_column and \
- not c.nullable:
+ elif (
+ c.primary_key
+ and c is not stmt.table._autoincrement_column
+ and not c.nullable
+ ):
_warn_pk_with_no_anticipated_value(c)
elif compiler.isupdate:
_append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
def _append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw):
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
name=_col_bind_name(c)
if not stmt._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
+ if isinstance(value, elements.BindParameter) and value.type._isnull:
value = value._clone()
value.type = c.type
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
@@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
"""
if c.default is not None:
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
- (c, compiler.process(
- c.default.arg.self_group(), **kw))
+ (c, compiler.process(c.default.arg.self_group(), **kw))
)
compiler.returning.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
@@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement):
self.type = original.type
def __eq__(self, other):
- return isinstance(other, _multiparam_column) and \
- other.key == self.key and \
- other.original == self.original
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
@@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
- "a Python-side value or SQL expression is required" % c)
+ "a Python-side value or SQL expression is required" % c
+ )
elif c.default.is_clause_element:
return compiler.process(c.default.arg.self_group(), **kw)
else:
@@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
"""
if (
- (
- # column has a Python-side default
- c.default is not None and
- (
- # and it won't be a Sequence
- not c.default.is_sequence or
- compiler.dialect.supports_sequences
- )
- )
- or
- (
- # column is the "autoincrement column"
- c is stmt.table._autoincrement_column and
- (
- # and it's either a "sequence" or a
- # pre-executable "autoincrement" sequence
- compiler.dialect.supports_sequences or
- compiler.dialect.preexecute_autoincrement_sequences
- )
- )
- ):
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
+ # column has a Python-side default
+ c.default is not None
+ and (
+ # and it won't be a Sequence
+ not c.default.is_sequence
+ or compiler.dialect.supports_sequences
)
+ ) or (
+ # column is the "autoincrement column"
+ c is stmt.table._autoincrement_column
+ and (
+ # and it's either a "sequence" or a
+ # pre-executable "autoincrement" sequence
+ compiler.dialect.supports_sequences
+ or compiler.dialect.preexecute_autoincrement_sequences
+ )
+ ):
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
@@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
def _append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
@@ -488,25 +548,21 @@ def _append_param_insert_hasdefault(
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
-def _append_param_insert_select_hasdefault(
- compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = c.default
values.append((c, proc.next_value()))
elif c.default.is_clause_element:
@@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault(
def _append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(), **kw))
+ (c, compiler.process(c.onupdate.arg.self_group(), **kw))
)
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_update_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_update_prefetch_bind_param(compiler, c)))
elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- stmt._return_defaults is not True and \
- c in implicit_return_defaults:
+ elif (
+ implicit_return_defaults
+ and stmt._return_defaults is not True
+ and c in implicit_return_defaults
+ ):
compiler.returning.append(c)
def _get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw):
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+):
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
@@ -565,8 +626,12 @@ def _get_multitable_params(
value = normalized_params[c]
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c),
+ )
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -577,20 +642,25 @@ def _get_multitable_params(
for c in t.c:
if c in normalized_params:
continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
+ (
+ c,
+ compiler.process(
+ c.onupdate.arg.self_group(), **kw
+ ),
+ )
)
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_update_prefetch_bind_param(
- compiler, c, name=_col_bind_name(c)))
+ (
+ c,
+ _create_update_prefetch_bind_param(
+ compiler, c, name=_col_bind_name(c)
+ ),
+ )
)
elif c.server_onupdate is not None:
compiler.postfetch.append(c)
@@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
if elements._is_literal(row[key]):
new_param = _create_bind_param(
- compiler, col, row[key],
- name="%s_m%d" % (col.key, i + 1), **kw
+ compiler,
+ col,
+ row[key],
+ name="%s_m%d" % (col.key, i + 1),
+ **kw
)
else:
new_param = compiler.process(row[key].self_group(), **kw)
@@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
def _get_stmt_parameters_params(
- compiler, parameters, stmt_parameters, _column_as_key, values, kw):
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
@@ -637,8 +711,8 @@ def _get_stmt_parameters_params(
# coercing right side to bound param
if elements._is_literal(v):
v = compiler.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
+ elements.BindParameter(None, v, type_=k.type), **kw
+ )
else:
v = compiler.process(v.self_group(), **kw)
@@ -646,22 +720,27 @@ def _get_stmt_parameters_params(
def _get_returning_modifiers(compiler, stmt):
- need_pks = compiler.isinsert and \
- not compiler.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
+ need_pks = (
+ compiler.isinsert
+ and not compiler.inline
+ and not stmt._returning
+ and not stmt._has_multi_parameters
+ )
- implicit_returning = need_pks and \
- compiler.dialect.implicit_returning and \
- stmt.table.implicit_returning
+ implicit_returning = (
+ need_pks
+ and compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ )
if compiler.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compiler.isupdate:
- implicit_return_defaults = (compiler.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = (
+ compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ and stmt._return_defaults
+ )
else:
# this line is unused, currently we are always
# isinsert or isupdate
@@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt):
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
- return need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid
+ return (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ )
def _warn_pk_with_no_anticipated_value(c):
@@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c):
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
"and no explicit value is passed. "
"Primary key columns typically may not store NULL."
- %
- (c.table.fullname, c.name, c.table.fullname))
+ % (c.table.fullname, c.name, c.table.fullname)
+ )
if len(c.table.primary_key) > 1:
msg += (
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
@@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c):
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
"behavior is expected for one of the columns in the primary key. "
"CREATE TABLE statements are impacted by this change as well on "
- "most backends.")
+ "most backends."
+ )
util.warn(msg)