diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2016-04-12 15:57:20 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2017-06-15 18:58:29 -0400 |
commit | 64b0760faa45a26c727a76b9fda97f2b4ea15417 (patch) | |
tree | 595986b7b08eee231c18b7252b06adbcb0333789 /lib | |
parent | 7af05fcc9387cea4172cc35eb6a198776488f90d (diff) | |
download | sqlalchemy-64b0760faa45a26c727a76b9fda97f2b4ea15417.tar.gz |
Add all versioning logic to _post_update()
An UPDATE emitted as a result of the
:paramref:`.relationship.post_update` feature will now integrate with
the versioning feature to both bump the version id of the row as well
as assert that the existing version number was matched.
Fixes: #3496
Change-Id: I865405dd6069f1c1e3b0d27a4980e9374e059f97
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/orm/persistence.py | 110 |
1 files changed, 93 insertions, 17 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0de64011a..924b9e1c9 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -212,15 +212,22 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): continue update = ( - (state, state_dict, sub_mapper, connection) + ( + state, state_dict, sub_mapper, connection, + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) if mapper.version_id_col is not None else None + ) for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table ) - update = _collect_post_update_commands(base_mapper, uowtransaction, - table, update, - post_update_cols) + update = _collect_post_update_commands( + base_mapper, uowtransaction, + table, update, + post_update_cols + ) _emit_post_update_statements(base_mapper, uowtransaction, cached_connections, @@ -576,7 +583,8 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, """ - for state, state_dict, mapper, connection in states_to_update: + for state, state_dict, mapper, connection, \ + update_version_id in states_to_update: # assert table in mapper._pks_by_table @@ -601,6 +609,16 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, params[col.key] = value hasdata = True if hasdata: + if update_version_id is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + + col = mapper.version_id_col + params[col._label] = update_version_id + + if bool(state.key) and col.key not in params and \ + mapper.version_id_generator is not False: + val = mapper.version_id_generator(update_version_id) + params[col.key] = val yield state, state_dict, mapper, connection, params @@ -870,6 +888,9 @@ def _emit_post_update_statements(base_mapper, uowtransaction, """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" + needs_version_id = mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table] + def update_stmt(): clause = sql.and_() @@ -877,7 +898,18 @@ def _emit_post_update_statements(base_mapper, uowtransaction, clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) - return table.update(clause) + if needs_version_id: + clause.clauses.append( + mapper.version_id_col == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type)) + + stmt = table.update(clause) + + if mapper.version_id_col is not None: + stmt = stmt.return_defaults(mapper.version_id_col) + + return stmt statement = base_mapper._memo(('post_update', table), update_stmt) @@ -885,23 +917,63 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # list of states to guarantee row access order, but # also group them into common (connection, cols) sets # to support executemany(). - for key, grouper in groupby( + for key, records in groupby( update, lambda rec: ( rec[3], # connection set(rec[4]), # parameter keys ) ): - grouper = list(grouper) + rows = 0 + + records = list(records) connection = key[0] - multiparams = [ - params for state, state_dict, mapper_rec, conn, params in grouper] - c = cached_connections[connection].\ - execute(statement, multiparams) - for state, state_dict, mapper_rec, connection, params in grouper: - _postfetch_post_update( - mapper, uowtransaction, state, state_dict, - c, c.context.compiled_parameters[0]) + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = not needs_version_id or assert_multirow + + if not allow_multirow: + check_rowcount = assert_singlerow + for state, state_dict, mapper_rec, \ + connection, params in records: + c = cached_connections[connection].\ + execute(statement, params) + _postfetch_post_update( + mapper_rec, uowtransaction, table, state, state_dict, + c, c.context.compiled_parameters[0]) + rows += c.rowcount + else: + multiparams = [ + params for + state, state_dict, mapper_rec, conn, params in records] + + check_rowcount = assert_multirow or ( + assert_singlerow and + len(multiparams) == 1 + ) + + c = cached_connections[connection].\ + execute(statement, multiparams) + + rows += c.rowcount + for state, state_dict, mapper_rec, \ + connection, params in records: + _postfetch_post_update( + mapper_rec, uowtransaction, table, state, state_dict, + c, c.context.compiled_parameters[0]) + + if check_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." % + (table.description, len(records), rows)) + + elif needs_version_id: + util.warn("Dialect %s does not support updated rowcount " + "- versioning cannot be verified." % + c.dialect.dialect_description) def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, @@ -1045,11 +1117,15 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): "Instance does not contain a non-NULL version value") -def _postfetch_post_update(mapper, uowtransaction, +def _postfetch_post_update(mapper, uowtransaction, table, state, dict_, result, params): prefetch_cols = result.context.compiled.prefetch postfetch_cols = result.context.compiled.postfetch + if mapper.version_id_col is not None and \ + mapper.version_id_col in mapper._cols_by_table[table]: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) if refresh_flush: load_evt_attrs = [] |