summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2016-04-12 15:57:20 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2017-06-15 18:58:29 -0400
commit64b0760faa45a26c727a76b9fda97f2b4ea15417 (patch)
tree595986b7b08eee231c18b7252b06adbcb0333789 /lib
parent7af05fcc9387cea4172cc35eb6a198776488f90d (diff)
downloadsqlalchemy-64b0760faa45a26c727a76b9fda97f2b4ea15417.tar.gz
Add all versioning logic to _post_update()
An UPDATE emitted as a result of the :paramref:`.relationship.post_update` feature will now integrate with the versioning feature to both bump the version id of the row as well as assert that the existing version number was matched. Fixes: #3496 Change-Id: I865405dd6069f1c1e3b0d27a4980e9374e059f97
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/orm/persistence.py110
1 files changed, 93 insertions, 17 deletions
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 0de64011a..924b9e1c9 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -212,15 +212,22 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
continue
update = (
- (state, state_dict, sub_mapper, connection)
+ (
+ state, state_dict, sub_mapper, connection,
+ mapper._get_committed_state_attr_by_column(
+ state, state_dict, mapper.version_id_col
+ ) if mapper.version_id_col is not None else None
+ )
for
state, state_dict, sub_mapper, connection in states_to_update
if table in sub_mapper._pks_by_table
)
- update = _collect_post_update_commands(base_mapper, uowtransaction,
- table, update,
- post_update_cols)
+ update = _collect_post_update_commands(
+ base_mapper, uowtransaction,
+ table, update,
+ post_update_cols
+ )
_emit_post_update_statements(base_mapper, uowtransaction,
cached_connections,
@@ -576,7 +583,8 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
"""
- for state, state_dict, mapper, connection in states_to_update:
+ for state, state_dict, mapper, connection, \
+ update_version_id in states_to_update:
# assert table in mapper._pks_by_table
@@ -601,6 +609,16 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table,
params[col.key] = value
hasdata = True
if hasdata:
+ if update_version_id is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
+
+ col = mapper.version_id_col
+ params[col._label] = update_version_id
+
+ if bool(state.key) and col.key not in params and \
+ mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(update_version_id)
+ params[col.key] = val
yield state, state_dict, mapper, connection, params
@@ -870,6 +888,9 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
+ needs_version_id = mapper.version_id_col is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]
+
def update_stmt():
clause = sql.and_()
@@ -877,7 +898,18 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
clause.clauses.append(col == sql.bindparam(col._label,
type_=col.type))
- return table.update(clause)
+ if needs_version_id:
+ clause.clauses.append(
+ mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type))
+
+ stmt = table.update(clause)
+
+ if mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
statement = base_mapper._memo(('post_update', table), update_stmt)
@@ -885,23 +917,63 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
# list of states to guarantee row access order, but
# also group them into common (connection, cols) sets
# to support executemany().
- for key, grouper in groupby(
+ for key, records in groupby(
update, lambda rec: (
rec[3], # connection
set(rec[4]), # parameter keys
)
):
- grouper = list(grouper)
+ rows = 0
+
+ records = list(records)
connection = key[0]
- multiparams = [
- params for state, state_dict, mapper_rec, conn, params in grouper]
- c = cached_connections[connection].\
- execute(statement, multiparams)
- for state, state_dict, mapper_rec, connection, params in grouper:
- _postfetch_post_update(
- mapper, uowtransaction, state, state_dict,
- c, c.context.compiled_parameters[0])
+ assert_singlerow = connection.dialect.supports_sane_rowcount
+ assert_multirow = assert_singlerow and \
+ connection.dialect.supports_sane_multi_rowcount
+ allow_multirow = not needs_version_id or assert_multirow
+
+ if not allow_multirow:
+ check_rowcount = assert_singlerow
+ for state, state_dict, mapper_rec, \
+ connection, params in records:
+ c = cached_connections[connection].\
+ execute(statement, params)
+ _postfetch_post_update(
+ mapper_rec, uowtransaction, table, state, state_dict,
+ c, c.context.compiled_parameters[0])
+ rows += c.rowcount
+ else:
+ multiparams = [
+ params for
+ state, state_dict, mapper_rec, conn, params in records]
+
+ check_rowcount = assert_multirow or (
+ assert_singlerow and
+ len(multiparams) == 1
+ )
+
+ c = cached_connections[connection].\
+ execute(statement, multiparams)
+
+ rows += c.rowcount
+ for state, state_dict, mapper_rec, \
+ connection, params in records:
+ _postfetch_post_update(
+ mapper_rec, uowtransaction, table, state, state_dict,
+ c, c.context.compiled_parameters[0])
+
+ if check_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched." %
+ (table.description, len(records), rows))
+
+ elif needs_version_id:
+ util.warn("Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description)
def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
@@ -1045,11 +1117,15 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
"Instance does not contain a non-NULL version value")
-def _postfetch_post_update(mapper, uowtransaction,
+def _postfetch_post_update(mapper, uowtransaction, table,
state, dict_, result, params):
prefetch_cols = result.context.compiled.prefetch
postfetch_cols = result.context.compiled.postfetch
+ if mapper.version_id_col is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
+ prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
if refresh_flush:
load_evt_attrs = []