diff options
Diffstat (limited to 'lib/sqlalchemy/sql/crud.py')
-rw-r--r-- | lib/sqlalchemy/sql/crud.py | 440 |
1 files changed, 262 insertions, 178 deletions
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 999d48a55..602b91a25 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -15,7 +15,9 @@ from . import dml from . import elements import operator -REQUIRED = util.symbol('REQUIRED', """ +REQUIRED = util.symbol( + "REQUIRED", + """ Placeholder for the value within a :class:`.BindParameter` which is required to be present when the statement is passed to :meth:`.Connection.execute`. @@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert` or :func:`.expression.update` statement is compiled without parameter values present. -""") +""", +) -ISINSERT = util.symbol('ISINSERT') -ISUPDATE = util.symbol('ISUPDATE') -ISDELETE = util.symbol('ISDELETE') +ISINSERT = util.symbol("ISINSERT") +ISUPDATE = util.symbol("ISUPDATE") +ISDELETE = util.symbol("ISDELETE") def _setup_crud_params(compiler, stmt, local_stmt_type, **kw): @@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw): # compiled params - return binds for all columns if compiler.column_keys is None and stmt.parameters is None: return [ - (c, _create_bind_param( - compiler, c, None, required=True)) + (c, _create_bind_param(compiler, c, None, required=True)) for c in stmt.table.columns ] @@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw): # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account - _column_as_key, _getattr_col_key, _col_bind_name = \ - _key_getters_for_crud_column(compiler, stmt) + _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column( + compiler, stmt + ) # if we have statement parameters - set defaults in the # compiled params if compiler.column_keys is None: parameters = {} else: - parameters = dict((_column_as_key(key), REQUIRED) - for key in compiler.column_keys - if not stmt_parameters or - key not in stmt_parameters) + parameters = dict( + (_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if not stmt_parameters or key not in stmt_parameters + ) # create a list of column assignment clauses as tuples values = [] if stmt_parameters is not None: _get_stmt_parameters_params( - compiler, - parameters, stmt_parameters, _column_as_key, values, kw) + compiler, parameters, stmt_parameters, _column_as_key, values, kw + ) check_columns = {} @@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw): # statements if compiler.isupdate and stmt._extra_froms and stmt_parameters: _get_multitable_params( - compiler, stmt, stmt_parameters, check_columns, - _col_bind_name, _getattr_col_key, values, kw) + compiler, + stmt, + stmt_parameters, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) if compiler.isinsert and stmt.select_names: _scan_insert_from_select_cols( - compiler, stmt, parameters, - _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) else: _scan_cols( - compiler, stmt, parameters, - _getattr_col_key, _column_as_key, - _col_bind_name, check_columns, values, kw) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) if parameters and stmt_parameters: - check = set(parameters).intersection( - _column_as_key(k) for k in stmt_parameters - ).difference(check_columns) + check = ( + set(parameters) + .intersection(_column_as_key(k) for k in stmt_parameters) + .difference(check_columns) + ) if check: raise exc.CompileError( - "Unconsumed column names: %s" % - (", ".join("%s" % c for c in check)) + "Unconsumed column names: %s" + % (", ".join("%s" % c for c in check)) ) if stmt._has_multi_parameters: @@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw): def _create_bind_param( - compiler, col, value, process=True, - required=False, name=None, **kw): + compiler, col, value, process=True, required=False, name=None, **kw +): if name is None: name = col.key bindparam = elements.BindParameter( - name, value, type_=col.type, required=required) + name, value, type_=col.type, required=required + ) bindparam._is_crud = True if process: bindparam = bindparam._compiler_dispatch(compiler, **kw) @@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt): def _column_as_key(key): str_key = elements._column_as_key(key) - if hasattr(key, 'table') and key.table in _et: + if hasattr(key, "table") and key.table in _et: return (key.table.name, str_key) else: return str_key @@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt): def _scan_insert_from_select_cols( - compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): - - need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid = \ - _get_returning_modifiers(compiler, stmt) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers( + compiler, stmt + ) - cols = [stmt.table.c[_column_as_key(name)] - for name in stmt.select_names] + cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names] compiler._insert_from_select = stmt.select @@ -228,32 +263,39 @@ def _scan_insert_from_select_cols( values.append((c, None)) else: _append_param_insert_select_hasdefault( - compiler, stmt, c, add_select_cols, kw) + compiler, stmt, c, add_select_cols, kw + ) if add_select_cols: values.extend(add_select_cols) compiler._insert_from_select = compiler._insert_from_select._generate() - compiler._insert_from_select._raw_columns = \ - tuple(compiler._insert_from_select._raw_columns) + tuple( - expr for col, expr in add_select_cols) + compiler._insert_from_select._raw_columns = tuple( + compiler._insert_from_select._raw_columns + ) + tuple(expr for col, expr in add_select_cols) def _scan_cols( - compiler, stmt, parameters, _getattr_col_key, - _column_as_key, _col_bind_name, check_columns, values, kw): - - need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid = \ - _get_returning_modifiers(compiler, stmt) + compiler, + stmt, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers( + compiler, stmt + ) if stmt._parameter_ordering: parameter_ordering = [ _column_as_key(key) for key in stmt._parameter_ordering ] ordered_keys = set(parameter_ordering) - cols = [ - stmt.table.c[key] for key in parameter_ordering - ] + [ + cols = [stmt.table.c[key] for key in parameter_ordering] + [ c for c in stmt.table.c if c.key not in ordered_keys ] else: @@ -265,72 +307,95 @@ def _scan_cols( if col_key in parameters and col_key not in check_columns: _append_param_parameter( - compiler, stmt, c, col_key, parameters, _col_bind_name, - implicit_returning, implicit_return_defaults, values, kw) + compiler, + stmt, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, + ) elif compiler.isinsert: - if c.primary_key and \ - need_pks and \ - ( - implicit_returning or - not postfetch_lastrowid or - c is not stmt.table._autoincrement_column - ): + if ( + c.primary_key + and need_pks + and ( + implicit_returning + or not postfetch_lastrowid + or c is not stmt.table._autoincrement_column + ) + ): if implicit_returning: _append_param_insert_pk_returning( - compiler, stmt, c, values, kw) + compiler, stmt, c, values, kw + ) else: _append_param_insert_pk(compiler, stmt, c, values, kw) elif c.default is not None: _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, - values, kw) + compiler, stmt, c, implicit_return_defaults, values, kw + ) elif c.server_default is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) - elif implicit_return_defaults and \ - c in implicit_return_defaults: + elif implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) - elif c.primary_key and \ - c is not stmt.table._autoincrement_column and \ - not c.nullable: + elif ( + c.primary_key + and c is not stmt.table._autoincrement_column + and not c.nullable + ): _warn_pk_with_no_anticipated_value(c) elif compiler.isupdate: _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw) + compiler, stmt, c, implicit_return_defaults, values, kw + ) def _append_param_parameter( - compiler, stmt, c, col_key, parameters, _col_bind_name, - implicit_returning, implicit_return_defaults, values, kw): + compiler, + stmt, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, +): value = parameters.pop(col_key) if elements._is_literal(value): value = _create_bind_param( - compiler, c, value, required=value is REQUIRED, + compiler, + c, + value, + required=value is REQUIRED, name=_col_bind_name(c) if not stmt._has_multi_parameters else "%s_m0" % _col_bind_name(c), **kw ) else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: + if isinstance(value, elements.BindParameter) and value.type._isnull: value = value._clone() value.type = c.type if c.primary_key and implicit_returning: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) - elif implicit_return_defaults and \ - c in implicit_return_defaults: + elif implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) else: @@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): """ if c.default is not None: if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ): proc = compiler.process(c.default, **kw) values.append((c, proc)) compiler.returning.append(c) elif c.default.is_clause_element: values.append( - (c, compiler.process( - c.default.arg.self_group(), **kw)) + (c, compiler.process(c.default.arg.self_group(), **kw)) ) compiler.returning.append(c) else: - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) elif c is stmt.table._autoincrement_column or c.server_default is not None: compiler.returning.append(c) elif not c.nullable: @@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement): self.type = original.type def __eq__(self, other): - return isinstance(other, _multiparam_column) and \ - other.key == self.key and \ - other.original == self.original + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) def _process_multiparam_default_bind(compiler, stmt, c, index, kw): @@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw): raise exc.CompileError( "INSERT value for column %s is explicitly rendered as a bound" "parameter in the VALUES clause; " - "a Python-side value or SQL expression is required" % c) + "a Python-side value or SQL expression is required" % c + ) elif c.default.is_clause_element: return compiler.process(c.default.arg.self_group(), **kw) else: @@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): """ if ( - ( - # column has a Python-side default - c.default is not None and - ( - # and it won't be a Sequence - not c.default.is_sequence or - compiler.dialect.supports_sequences - ) - ) - or - ( - # column is the "autoincrement column" - c is stmt.table._autoincrement_column and - ( - # and it's either a "sequence" or a - # pre-executable "autoincrement" sequence - compiler.dialect.supports_sequences or - compiler.dialect.preexecute_autoincrement_sequences - ) - ) - ): - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) + # column has a Python-side default + c.default is not None + and ( + # and it won't be a Sequence + not c.default.is_sequence + or compiler.dialect.supports_sequences ) + ) or ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column + and ( + # and it's either a "sequence" or a + # pre-executable "autoincrement" sequence + compiler.dialect.supports_sequences + or compiler.dialect.preexecute_autoincrement_sequences + ) + ): + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) elif c.default is None and c.server_default is None and not c.nullable: # no .default, no .server_default, not autoincrement, we have # no indication this primary key column will have any value @@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw): def _append_param_insert_hasdefault( - compiler, stmt, c, implicit_return_defaults, values, kw): + compiler, stmt, c, implicit_return_defaults, values, kw +): if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): proc = compiler.process(c.default, **kw) values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: compiler.postfetch.append(c) @@ -488,25 +548,21 @@ def _append_param_insert_hasdefault( proc = compiler.process(c.default.arg.self_group(), **kw) values.append((c, proc)) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) elif not c.primary_key: # don't add primary key column to postfetch compiler.postfetch.append(c) else: - values.append( - (c, _create_insert_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_insert_prefetch_bind_param(compiler, c))) -def _append_param_insert_select_hasdefault( - compiler, stmt, c, values, kw): +def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): if c.default.is_sequence: - if compiler.dialect.supports_sequences and \ - (not c.default.optional or - not compiler.dialect.sequences_optional): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): proc = c.default values.append((c, proc.next_value())) elif c.default.is_clause_element: @@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault( def _append_param_update( - compiler, stmt, c, implicit_return_defaults, values, kw): + compiler, stmt, c, implicit_return_defaults, values, kw +): if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process( - c.onupdate.arg.self_group(), **kw)) + (c, compiler.process(c.onupdate.arg.self_group(), **kw)) ) - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) else: - values.append( - (c, _create_update_prefetch_bind_param(compiler, c)) - ) + values.append((c, _create_update_prefetch_bind_param(compiler, c))) elif c.server_onupdate is not None: - if implicit_return_defaults and \ - c in implicit_return_defaults: + if implicit_return_defaults and c in implicit_return_defaults: compiler.returning.append(c) else: compiler.postfetch.append(c) - elif implicit_return_defaults and \ - stmt._return_defaults is not True and \ - c in implicit_return_defaults: + elif ( + implicit_return_defaults + and stmt._return_defaults is not True + and c in implicit_return_defaults + ): compiler.returning.append(c) def _get_multitable_params( - compiler, stmt, stmt_parameters, check_columns, - _col_bind_name, _getattr_col_key, values, kw): + compiler, + stmt, + stmt_parameters, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, +): normalized_params = dict( (elements._clause_element_as_expr(c), param) @@ -565,8 +626,12 @@ def _get_multitable_params( value = normalized_params[c] if elements._is_literal(value): value = _create_bind_param( - compiler, c, value, required=value is REQUIRED, - name=_col_bind_name(c)) + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c), + ) else: compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) @@ -577,20 +642,25 @@ def _get_multitable_params( for c in t.c: if c in normalized_params: continue - elif (c.onupdate is not None and not - c.onupdate.is_sequence): + elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, compiler.process( - c.onupdate.arg.self_group(), - **kw) - ) + ( + c, + compiler.process( + c.onupdate.arg.self_group(), **kw + ), + ) ) compiler.postfetch.append(c) else: values.append( - (c, _create_update_prefetch_bind_param( - compiler, c, name=_col_bind_name(c))) + ( + c, + _create_update_prefetch_bind_param( + compiler, c, name=_col_bind_name(c) + ), + ) ) elif c.server_onupdate is not None: compiler.postfetch.append(c) @@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): if elements._is_literal(row[key]): new_param = _create_bind_param( - compiler, col, row[key], - name="%s_m%d" % (col.key, i + 1), **kw + compiler, + col, + row[key], + name="%s_m%d" % (col.key, i + 1), + **kw ) else: new_param = compiler.process(row[key].self_group(), **kw) @@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw): def _get_stmt_parameters_params( - compiler, parameters, stmt_parameters, _column_as_key, values, kw): + compiler, parameters, stmt_parameters, _column_as_key, values, kw +): for k, v in stmt_parameters.items(): colkey = _column_as_key(k) if colkey is not None: @@ -637,8 +711,8 @@ def _get_stmt_parameters_params( # coercing right side to bound param if elements._is_literal(v): v = compiler.process( - elements.BindParameter(None, v, type_=k.type), - **kw) + elements.BindParameter(None, v, type_=k.type), **kw + ) else: v = compiler.process(v.self_group(), **kw) @@ -646,22 +720,27 @@ def _get_stmt_parameters_params( def _get_returning_modifiers(compiler, stmt): - need_pks = compiler.isinsert and \ - not compiler.inline and \ - not stmt._returning and \ - not stmt._has_multi_parameters + need_pks = ( + compiler.isinsert + and not compiler.inline + and not stmt._returning + and not stmt._has_multi_parameters + ) - implicit_returning = need_pks and \ - compiler.dialect.implicit_returning and \ - stmt.table.implicit_returning + implicit_returning = ( + need_pks + and compiler.dialect.implicit_returning + and stmt.table.implicit_returning + ) if compiler.isinsert: - implicit_return_defaults = (implicit_returning and - stmt._return_defaults) + implicit_return_defaults = implicit_returning and stmt._return_defaults elif compiler.isupdate: - implicit_return_defaults = (compiler.dialect.implicit_returning and - stmt.table.implicit_returning and - stmt._return_defaults) + implicit_return_defaults = ( + compiler.dialect.implicit_returning + and stmt.table.implicit_returning + and stmt._return_defaults + ) else: # this line is unused, currently we are always # isinsert or isupdate @@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt): postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid - return need_pks, implicit_returning, \ - implicit_return_defaults, postfetch_lastrowid + return ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) def _warn_pk_with_no_anticipated_value(c): @@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c): "nor does it indicate 'autoincrement=True' or 'nullable=True', " "and no explicit value is passed. " "Primary key columns typically may not store NULL." - % - (c.table.fullname, c.name, c.table.fullname)) + % (c.table.fullname, c.name, c.table.fullname) + ) if len(c.table.primary_key) > 1: msg += ( " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " @@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c): "keys if AUTO_INCREMENT/SERIAL/IDENTITY " "behavior is expected for one of the columns in the primary key. " "CREATE TABLE statements are impacted by this change as well on " - "most backends.") + "most backends." + ) util.warn(msg) |