diff options
author | Jenkins <jenkins@review.openstack.org> | 2015-03-23 15:19:22 +0000 |
---|---|---|
committer | Gerrit Code Review <review@openstack.org> | 2015-03-23 15:19:22 +0000 |
commit | 3cf218d0961cf1e0dc2c6f58602c307f43ac6ad9 (patch) | |
tree | 188e99af080ba541387b20bada25917841e845ad | |
parent | adfb47387655d13e1c2fb7da0b63781500c7c69c (diff) | |
parent | e0baed656edf470a22d7a5f1610b70017002b611 (diff) | |
download | oslo-db-3cf218d0961cf1e0dc2c6f58602c307f43ac6ad9.tar.gz |
Merge "Implement generic update-on-match feature"
-rw-r--r-- | oslo_db/sqlalchemy/session.py | 22 | ||||
-rw-r--r-- | oslo_db/sqlalchemy/update_match.py | 508 | ||||
-rw-r--r-- | oslo_db/sqlalchemy/utils.py | 12 | ||||
-rw-r--r-- | oslo_db/tests/sqlalchemy/test_update_match.py | 445 |
4 files changed, 987 insertions, 0 deletions
diff --git a/oslo_db/sqlalchemy/session.py b/oslo_db/sqlalchemy/session.py index 6a0355d..ce347ce 100644 --- a/oslo_db/sqlalchemy/session.py +++ b/oslo_db/sqlalchemy/session.py @@ -297,6 +297,7 @@ from oslo_db import exception from oslo_db import options from oslo_db.sqlalchemy import compat from oslo_db.sqlalchemy import exc_filters +from oslo_db.sqlalchemy import update_match from oslo_db.sqlalchemy import utils LOG = logging.getLogger(__name__) @@ -599,6 +600,27 @@ class Query(sqlalchemy.orm.query.Query): 'deleted_at': timeutils.utcnow()}, synchronize_session=synchronize_session) + def update_returning_pk(self, values, surrogate_key): + """Perform an UPDATE, returning the primary key of the matched row. + + This is a method-version of + oslo_db.sqlalchemy.update_match.update_returning_pk(); see that + function for usage details. + + """ + return update_match.update_returning_pk(self, values, surrogate_key) + + def update_on_match(self, specimen, surrogate_key, values, **kw): + """Emit an UPDATE statement matching the given specimen. + + This is a method-version of + oslo_db.sqlalchemy.update_match.update_on_match(); see that function + for usage details. + + """ + return update_match.update_on_match( + self, specimen, surrogate_key, values, **kw) + class Session(sqlalchemy.orm.session.Session): """Custom Session class to avoid SqlAlchemy Session monkey patching.""" diff --git a/oslo_db/sqlalchemy/update_match.py b/oslo_db/sqlalchemy/update_match.py new file mode 100644 index 0000000..692a72c --- /dev/null +++ b/oslo_db/sqlalchemy/update_match.py @@ -0,0 +1,508 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy + +from sqlalchemy import inspect +from sqlalchemy import orm +from sqlalchemy import sql +from sqlalchemy import types as sqltypes + +from oslo_db.sqlalchemy import utils + + +def update_on_match( + query, + specimen, + surrogate_key, + values=None, + attempts=3, + include_only=None, + process_query=None, + handle_failure=None +): + """Emit an UPDATE statement matching the given specimen. + + E.g.:: + + with enginefacade.writer() as session: + specimen = MyInstance( + uuid='ccea54f', + interface_id='ad33fea', + vm_state='SOME_VM_STATE', + ) + + values = { + 'vm_state': 'SOME_NEW_VM_STATE' + } + + base_query = model_query( + context, models.Instance, + project_only=True, session=session) + + hostname_query = model_query( + context, models.Instance, session=session, + read_deleted='no'). + filter(func.lower(models.Instance.hostname) == 'SOMEHOSTNAME') + + surrogate_key = ('uuid', ) + + def process_query(query): + return query.where(~exists(hostname_query)) + + def handle_failure(query): + try: + instance = base_query.one() + except NoResultFound: + raise exception.InstanceNotFound(instance_id=instance_uuid) + + if session.query(hostname_query.exists()).scalar(): + raise exception.InstanceExists( + name=values['hostname'].lower()) + + # try again + return False + + peristent_instance = base_query.update_on_match( + specimen, + surrogate_key, + values=values, + process_query=process_query, + handle_failure=handle_failure + ) + + The UPDATE statement is constructed against the given specimen + using those values which are present to construct a WHERE clause. + If the specimen contains additional values to be ignored, the + ``include_only`` parameter may be passed which indicates a sequence + of attributes to use when constructing the WHERE. + + The UPDATE is performed against an ORM Query, which is created from + the given ``Session``, or alternatively by passing the ```query`` + parameter referring to an existing query. + + Before the query is invoked, it is also passed through the callable + sent as ``process_query``, if present. This hook allows additional + criteria to be added to the query after it is created but before + invocation. + + The function will then invoke the UPDATE statement and check for + "success" one or more times, up to a maximum of that passed as + ``attempts``. + + The initial check for "success" from the UPDATE statement is that the + number of rows returned matches 1. If zero rows are matched, then + the UPDATE statement is assumed to have "failed", and the failure handling + phase begins. + + The failure handling phase involves invoking the given ``handle_failure`` + function, if any. This handler can perform additional queries to attempt + to figure out why the UPDATE didn't match any rows. The handler, + upon detection of the exact failure condition, should throw an exception + to exit; if it doesn't, it has the option of returning True or False, + where False means the error was not handled, and True means that there + was not in fact an error, and the function should return successfully. + + If the failure handler is not present, or returns False after ``attempts`` + number of attempts, then the function overall raises CantUpdateException. + If the handler returns True, then the function returns with no error. + + The return value of the function is a persistent version of the given + specimen; this may be the specimen itself, if no matching object were + already present in the session; otherwise, the existing object is + returned, with the state of the specimen merged into it. The returned + persistent object will have the given values populated into the object. + + The object is is returned as "persistent", meaning that it is + associated with the given + Session and has an identity key (that is, a real primary key + value). + + In order to produce this identity key, a strategy must be used to + determine it as efficiently and safely as possible: + + 1. If the given specimen already contained its primary key attributes + fully populated, then these attributes were used as criteria in the + UPDATE, so we have the primary key value; it is populated directly. + + 2. If the target backend supports RETURNING, then when the update() query + is performed with a RETURNING clause so that the matching primary key + is returned atomically. This currently includes Postgresql, Oracle + and others (notably not MySQL or SQLite). + + 3. If the target backend is MySQL, and the given model uses a + single-column, AUTO_INCREMENT integer primary key value (as is + the case for Nova), MySQL's recommended approach of making use + of ``LAST_INSERT_ID(expr)`` is used to atomically acquire the + matching primary key value within the scope of the UPDATE + statement, then it fetched immediately following by using + ``SELECT LAST_INSERT_ID()``. + http://dev.mysql.com/doc/refman/5.0/en/information-\ + functions.html#function_last-insert-id + + 4. Otherwise, for composite keys on MySQL or other backends such + as SQLite, the row as UPDATED must be re-fetched in order to + acquire the primary key value. The ``surrogate_key`` + parameter is used for this in order to re-fetch the row; this + is a column name with a known, unique value where + the object can be fetched. + + + """ + + if values is None: + values = {} + + entity = inspect(specimen) + mapper = entity.mapper + assert \ + [desc['type'] for desc in query.column_descriptions] == \ + [mapper.class_], "Query does not match given specimen" + + criteria = manufacture_entity_criteria( + specimen, include_only=include_only, exclude=[surrogate_key]) + + query = query.filter(criteria) + + if process_query: + query = process_query(query) + + surrogate_key_arg = ( + surrogate_key, entity.attrs[surrogate_key].loaded_value) + pk_value = None + + for attempt in range(attempts): + try: + pk_value = query.update_returning_pk(values, surrogate_key_arg) + except MultiRowsMatched: + raise + except NoRowsMatched: + if handle_failure and handle_failure(query): + break + else: + break + else: + raise NoRowsMatched("Zero rows matched for %d attempts" % attempts) + + if pk_value is None: + pk_value = entity.mapper.primary_key_from_instance(specimen) + + # NOTE(mdbooth): Can't pass the original specimen object here as it might + # have lists of multiple potential values rather than actual values. + values = copy.copy(values) + values[surrogate_key] = surrogate_key_arg[1] + persistent_obj = manufacture_persistent_object( + query.session, specimen.__class__(), values, pk_value) + + return persistent_obj + + +def manufacture_persistent_object( + session, specimen, values=None, primary_key=None): + """Make an ORM-mapped object persistent in a Session without SQL. + + The persistent object is returned. + + If a matching object is already present in the given session, the specimen + is merged into it and the persistent object returned. Otherwise, the + specimen itself is made persistent and is returned. + + The object must contain a full primary key, or provide it via the values or + primary_key parameters. The object is peristed to the Session in a "clean" + state with no pending changes. + + :param session: A Session object. + + :param specimen: a mapped object which is typically transient. + + :param values: a dictionary of values to be applied to the specimen, + in addition to the state that's already on it. The attributes will be + set such that no history is created; the object remains clean. + + :param primary_key: optional tuple-based primary key. This will also + be applied to the instance if present. + + + """ + state = inspect(specimen) + mapper = state.mapper + + for k, v in values.items(): + orm.attributes.set_committed_value(specimen, k, v) + + pk_attrs = [ + mapper.get_property_by_column(col).key + for col in mapper.primary_key + ] + + if primary_key is not None: + for key, value in zip(pk_attrs, primary_key): + orm.attributes.set_committed_value( + specimen, + key, + value + ) + + for key in pk_attrs: + if state.attrs[key].loaded_value is orm.attributes.NO_VALUE: + raise ValueError("full primary key must be present") + + orm.make_transient_to_detached(specimen) + + if state.key not in session.identity_map: + session.add(specimen) + return specimen + else: + return session.merge(specimen, load=False) + + +def manufacture_entity_criteria(entity, include_only=None, exclude=None): + """Given a mapped instance, produce a WHERE clause. + + The attributes set upon the instance will be combined to produce + a SQL expression using the mapped SQL expressions as the base + of comparison. + + Values on the instance may be set as tuples in which case the + criteria will produce an IN clause. None is also acceptable as a + scalar or tuple entry, which will produce IS NULL that is properly + joined with an OR against an IN expression if appropriate. + + :param entity: a mapped entity. + + :param include_only: optional sequence of keys to limit which + keys are included. + + :param exclude: sequence of keys to exclude + + """ + + state = inspect(entity) + exclude = set(exclude) if exclude is not None else set() + + existing = dict( + (attr.key, attr.loaded_value) + for attr in state.attrs + if attr.loaded_value is not orm.attributes.NO_VALUE + and attr.key not in exclude + ) + if include_only: + existing = dict( + (k, existing[k]) + for k in set(existing).intersection(include_only) + ) + + return manufacture_criteria(state.mapper, existing) + + +def manufacture_criteria(mapped, values): + """Given a mapper/class and a namespace of values, produce a WHERE clause. + + The class should be a mapped class and the entries in the dictionary + correspond to mapped attribute names on the class. + + A value may also be a tuple in which case that particular attribute + will be compared to a tuple using IN. The scalar value or + tuple can also contain None which translates to an IS NULL, that is + properly joined with OR against an IN expression if appropriate. + + :param cls: a mapped class, or actual :class:`.Mapper` object. + + :param values: dictionary of values. + + """ + + mapper = inspect(mapped) + + # organize keys using mapped attribute ordering, which is deterministic + value_keys = set(values) + keys = [k for k in mapper.column_attrs.keys() if k in value_keys] + return sql.and_(*[ + _sql_crit(mapper.column_attrs[key].expression, values[key]) + for key in keys + ]) + + +def _sql_crit(expression, value): + """Produce an equality expression against the given value. + + This takes into account a value that is actually a collection + of values, as well as a value of None or collection that contains + None. + + """ + + values = utils.to_list(value, default=(None, )) + if len(values) == 1: + if values[0] is None: + return expression == sql.null() + else: + return expression == values[0] + elif _none_set.intersection(values): + return sql.or_( + expression == sql.null(), + _sql_crit(expression, set(values).difference(_none_set)) + ) + else: + return expression.in_(values) + + +def update_returning_pk(query, values, surrogate_key): + """Perform an UPDATE, returning the primary key of the matched row. + + The primary key is returned using a selection of strategies: + + * if the database supports RETURNING, RETURNING is used to retrieve + the primary key values inline. + + * If the database is MySQL and the entity is mapped to a single integer + primary key column, MySQL's last_insert_id() function is used + inline within the UPDATE and then upon a second SELECT to get the + value. + + * Otherwise, a "refetch" strategy is used, where a given "surrogate" + key value (typically a UUID column on the entity) is used to run + a new SELECT against that UUID. This UUID is also placed into + the UPDATE query to ensure the row matches. + + :param query: a Query object with existing criterion, against a single + entity. + + :param values: a dictionary of values to be updated on the row. + + :param surrogate_key: a tuple of (attrname, value), referring to a + UNIQUE attribute that will also match the row. This attribute is used + to retrieve the row via a SELECT when no optimized strategy exists. + + :return: the primary key, returned as a tuple. + Is only returned if rows matched is one. Otherwise, CantUpdateException + is raised. + + """ + + entity = query.column_descriptions[0]['type'] + mapper = inspect(entity).mapper + session = query.session + + bind = session.connection(mapper=mapper) + if bind.dialect.implicit_returning: + pk_strategy = _pk_strategy_returning + elif bind.dialect.name == 'mysql' and \ + len(mapper.primary_key) == 1 and \ + isinstance( + mapper.primary_key[0].type, sqltypes.Integer): + pk_strategy = _pk_strategy_mysql_last_insert_id + else: + pk_strategy = _pk_strategy_refetch + + return pk_strategy(query, mapper, values, surrogate_key) + + +def _assert_single_row(rows_updated): + if rows_updated == 1: + return rows_updated + elif rows_updated > 1: + raise MultiRowsMatched("%d rows matched; expected one" % rows_updated) + else: + raise NoRowsMatched("No rows matched the UPDATE") + + +def _pk_strategy_refetch(query, mapper, values, surrogate_key): + + surrogate_key_name, surrogate_key_value = surrogate_key + surrogate_key_col = mapper.attrs[surrogate_key_name].expression + + rowcount = query.\ + filter(surrogate_key_col == surrogate_key_value).\ + update(values, synchronize_session=False) + + _assert_single_row(rowcount) + # SELECT my_table.id AS my_table_id FROM my_table + # WHERE my_table.y = ? AND my_table.z = ? + # LIMIT ? OFFSET ? + fetch_query = query.session.query( + *mapper.primary_key).filter( + surrogate_key_col == surrogate_key_value) + + primary_key = fetch_query.one() + + return primary_key + + +def _pk_strategy_returning(query, mapper, values, surrogate_key): + surrogate_key_name, surrogate_key_value = surrogate_key + surrogate_key_col = mapper.attrs[surrogate_key_name].expression + + update_stmt = _update_stmt_from_query(mapper, query, values) + update_stmt = update_stmt.where(surrogate_key_col == surrogate_key_value) + update_stmt = update_stmt.returning(*mapper.primary_key) + + # UPDATE my_table SET x=%(x)s, z=%(z)s WHERE my_table.y = %(y_1)s + # AND my_table.z = %(z_1)s RETURNING my_table.id + result = query.session.execute(update_stmt) + rowcount = result.rowcount + _assert_single_row(rowcount) + primary_key = tuple(result.first()) + + return primary_key + + +def _pk_strategy_mysql_last_insert_id(query, mapper, values, surrogate_key): + + surrogate_key_name, surrogate_key_value = surrogate_key + surrogate_key_col = mapper.attrs[surrogate_key_name].expression + + surrogate_pk_col = mapper.primary_key[0] + update_stmt = _update_stmt_from_query(mapper, query, values) + update_stmt = update_stmt.where(surrogate_key_col == surrogate_key_value) + update_stmt = update_stmt.values( + {surrogate_pk_col: sql.func.last_insert_id(surrogate_pk_col)}) + + # UPDATE my_table SET id=last_insert_id(my_table.id), + # x=%s, z=%s WHERE my_table.y = %s AND my_table.z = %s + result = query.session.execute(update_stmt) + rowcount = result.rowcount + _assert_single_row(rowcount) + # SELECT last_insert_id() AS last_insert_id_1 + primary_key = query.session.scalar(sql.func.last_insert_id()), + + return primary_key + + +def _update_stmt_from_query(mapper, query, values): + upd_values = dict( + ( + mapper.column_attrs[key], value + ) for key, value in values.items() + ) + query = query.enable_eagerloads(False) + context = query._compile_context() + primary_table = context.statement.froms[0] + update_stmt = sql.update(primary_table, + context.whereclause, + upd_values) + return update_stmt + + +_none_set = frozenset([None]) + + +class CantUpdateException(Exception): + pass + + +class NoRowsMatched(CantUpdateException): + pass + + +class MultiRowsMatched(CantUpdateException): + pass diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py index 932d409..e7c5b4a 100644 --- a/oslo_db/sqlalchemy/utils.py +++ b/oslo_db/sqlalchemy/utils.py @@ -188,6 +188,18 @@ def paginate_query(query, model, limit, sort_keys, marker=None, return query +def to_list(x, default=None): + if x is None: + return default + if not isinstance(x, collections.Iterable) or \ + isinstance(x, six.string_types): + return [x] + elif isinstance(x, list): + return x + else: + return list(x) + + def _read_deleted_filter(query, db_model, deleted): if 'deleted' not in db_model.__table__.columns: raise ValueError(_("There is no `deleted` column in `%s` table. " diff --git a/oslo_db/tests/sqlalchemy/test_update_match.py b/oslo_db/tests/sqlalchemy/test_update_match.py new file mode 100644 index 0000000..ecc7af7 --- /dev/null +++ b/oslo_db/tests/sqlalchemy/test_update_match.py @@ -0,0 +1,445 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from oslotest import base as oslo_test_base +from sqlalchemy.ext import declarative +from sqlalchemy import schema +from sqlalchemy import sql +from sqlalchemy import types as sqltypes + +from oslo_db.sqlalchemy import test_base +from oslo_db.sqlalchemy import update_match + +Base = declarative.declarative_base() + + +class MyModel(Base): + __tablename__ = 'my_table' + + id = schema.Column(sqltypes.Integer, primary_key=True) + uuid = schema.Column(sqltypes.String(36), nullable=False, unique=True) + x = schema.Column(sqltypes.Integer) + y = schema.Column(sqltypes.String(40)) + z = schema.Column(sqltypes.String(40)) + + +class ManufactureCriteriaTest(oslo_test_base.BaseTestCase): + def test_instance_criteria_basic(self): + specimen = MyModel( + y='y1', z='z3', + uuid='136254d5-3869-408f-9da7-190e0072641a' + ) + self.assertEqual( + "my_table.uuid = :uuid_1 AND my_table.y = :y_1 " + "AND my_table.z = :z_1", + str(update_match.manufacture_entity_criteria(specimen).compile()) + ) + + def test_instance_criteria_basic_wnone(self): + specimen = MyModel( + y='y1', z=None, + uuid='136254d5-3869-408f-9da7-190e0072641a' + ) + self.assertEqual( + "my_table.uuid = :uuid_1 AND my_table.y = :y_1 " + "AND my_table.z IS NULL", + str(update_match.manufacture_entity_criteria(specimen).compile()) + ) + + def test_instance_criteria_tuples(self): + specimen = MyModel( + y='y1', z=('z1', 'z2'), + ) + self.assertEqual( + "my_table.y = :y_1 AND my_table.z IN (:z_1, :z_2)", + str(update_match.manufacture_entity_criteria(specimen).compile()) + ) + + def test_instance_criteria_tuples_wnone(self): + specimen = MyModel( + y='y1', z=('z1', 'z2', None), + ) + self.assertEqual( + "my_table.y = :y_1 AND (my_table.z IS NULL OR " + "my_table.z IN (:z_1, :z_2))", + str(update_match.manufacture_entity_criteria(specimen).compile()) + ) + + def test_instance_criteria_none_list(self): + specimen = MyModel( + y='y1', z=[None], + ) + self.assertEqual( + "my_table.y = :y_1 AND my_table.z IS NULL", + str(update_match.manufacture_entity_criteria(specimen).compile()) + ) + + +class UpdateMatchTest(test_base.DbTestCase): + def setUp(self): + super(UpdateMatchTest, self).setUp() + Base.metadata.create_all(self.engine) + self.addCleanup(Base.metadata.drop_all, self.engine) + # self.engine.echo = 'debug' + self.session = self.sessionmaker(autocommit=False) + self.addCleanup(self.session.close) + self.session.add_all([ + MyModel( + id=1, + uuid='23cb9224-9f8e-40fe-bd3c-e7577b7af37d', + x=5, y='y1', z='z1'), + MyModel( + id=2, + uuid='136254d5-3869-408f-9da7-190e0072641a', + x=6, y='y1', z='z2'), + MyModel( + id=3, + uuid='094eb162-d5df-494b-a458-a91a1b2d2c65', + x=7, y='y1', z='z1'), + MyModel( + id=4, + uuid='94659b3f-ea1f-4ffd-998d-93b28f7f5b70', + x=8, y='y2', z='z2'), + MyModel( + id=5, + uuid='bdf3893c-ee3c-40a0-bc79-960adb6cd1d4', + x=8, y='y2', z=None), + ]) + + self.session.commit() + + def _assert_row(self, pk, values): + row = self.session.execute( + sql.select([MyModel.__table__]).where(MyModel.__table__.c.id == pk) + ).first() + values['id'] = pk + self.assertEqual(values, dict(row)) + + def test_update_specimen_successful(self): + uuid = '136254d5-3869-408f-9da7-190e0072641a' + + specimen = MyModel( + y='y1', z='z2', uuid=uuid + ) + + result = self.session.query(MyModel).update_on_match( + specimen, + 'uuid', + values={'x': 9, 'z': 'z3'} + ) + + self.assertEqual(uuid, result.uuid) + self.assertEqual(2, result.id) + self.assertEqual('z3', result.z) + self.assertIn(result, self.session) + + self._assert_row( + 2, + { + 'uuid': '136254d5-3869-408f-9da7-190e0072641a', + 'x': 9, 'y': 'y1', 'z': 'z3' + } + ) + + def test_update_specimen_include_only(self): + uuid = '136254d5-3869-408f-9da7-190e0072641a' + + specimen = MyModel( + y='y9', z='z5', x=6, uuid=uuid + ) + + # Query the object first to test that we merge when the object is + # already cached in the session. + self.session.query(MyModel).filter(MyModel.uuid == uuid).one() + + result = self.session.query(MyModel).update_on_match( + specimen, + 'uuid', + values={'x': 9, 'z': 'z3'}, + include_only=('x', ) + ) + + self.assertEqual(uuid, result.uuid) + self.assertEqual(2, result.id) + self.assertEqual('z3', result.z) + self.assertIn(result, self.session) + self.assertNotIn(result, self.session.dirty) + + self._assert_row( + 2, + { + 'uuid': '136254d5-3869-408f-9da7-190e0072641a', + 'x': 9, 'y': 'y1', 'z': 'z3' + } + ) + + def test_update_specimen_no_rows(self): + specimen = MyModel( + y='y1', z='z3', + uuid='136254d5-3869-408f-9da7-190e0072641a' + ) + + exc = self.assertRaises( + update_match.NoRowsMatched, + self.session.query(MyModel).update_on_match, + specimen, 'uuid', values={'x': 9, 'z': 'z3'} + ) + + self.assertEqual("Zero rows matched for 3 attempts", exc.args[0]) + + def test_update_specimen_process_query_no_rows(self): + specimen = MyModel( + y='y1', z='z2', + uuid='136254d5-3869-408f-9da7-190e0072641a' + ) + + def process_query(query): + return query.filter_by(x=10) + + exc = self.assertRaises( + update_match.NoRowsMatched, + self.session.query(MyModel).update_on_match, + specimen, 'uuid', values={'x': 9, 'z': 'z3'}, + process_query=process_query + ) + + self.assertEqual("Zero rows matched for 3 attempts", exc.args[0]) + + def test_update_specimen_given_query_no_rows(self): + specimen = MyModel( + y='y1', z='z2', + uuid='136254d5-3869-408f-9da7-190e0072641a' + ) + + query = self.session.query(MyModel).filter_by(x=10) + + exc = self.assertRaises( + update_match.NoRowsMatched, + query.update_on_match, + specimen, 'uuid', values={'x': 9, 'z': 'z3'}, + ) + + self.assertEqual("Zero rows matched for 3 attempts", exc.args[0]) + + def test_update_specimen_multi_rows(self): + specimen = MyModel( + y='y1', z='z1', + ) + + exc = self.assertRaises( + update_match.MultiRowsMatched, + self.session.query(MyModel).update_on_match, + specimen, 'y', values={'x': 9, 'z': 'z3'} + ) + + self.assertEqual("2 rows matched; expected one", exc.args[0]) + + def test_update_specimen_query_mismatch_error(self): + specimen = MyModel( + y='y1' + ) + q = self.session.query(MyModel.x, MyModel.y) + exc = self.assertRaises( + AssertionError, + q.update_on_match, + specimen, 'y', values={'x': 9, 'z': 'z3'}, + ) + + self.assertEqual("Query does not match given specimen", exc.args[0]) + + def test_custom_handle_failure_raise_new(self): + class MyException(Exception): + pass + + def handle_failure(query): + # ensure the query is usable + result = query.count() + self.assertEqual(0, result) + + raise MyException("test: %d" % result) + + specimen = MyModel( + y='y1', z='z3', + uuid='136254d5-3869-408f-9da7-190e0072641a' + ) + + exc = self.assertRaises( + MyException, + self.session.query(MyModel).update_on_match, + specimen, 'uuid', values={'x': 9, 'z': 'z3'}, + handle_failure=handle_failure + ) + + self.assertEqual("test: 0", exc.args[0]) + + def test_custom_handle_failure_cancel_raise(self): + uuid = '136254d5-3869-408f-9da7-190e0072641a' + + class MyException(Exception): + pass + + def handle_failure(query): + # ensure the query is usable + result = query.count() + self.assertEqual(0, result) + + return True + + specimen = MyModel( + id=2, y='y1', z='z3', uuid=uuid + ) + + result = self.session.query(MyModel).update_on_match( + specimen, 'uuid', values={'x': 9, 'z': 'z3'}, + handle_failure=handle_failure + ) + self.assertEqual(uuid, result.uuid) + self.assertEqual(2, result.id) + self.assertEqual('z3', result.z) + self.assertEqual(9, result.x) + self.assertIn(result, self.session) + + def test_update_specimen_on_none_successful(self): + uuid = 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4' + + specimen = MyModel( + y='y2', z=None, uuid=uuid + ) + + result = self.session.query(MyModel).update_on_match( + specimen, + 'uuid', + values={'x': 9, 'z': 'z3'}, + ) + + self.assertIn(result, self.session) + self.assertEqual(uuid, result.uuid) + self.assertEqual(5, result.id) + self.assertEqual('z3', result.z) + self._assert_row( + 5, + { + 'uuid': 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4', + 'x': 9, 'y': 'y2', 'z': 'z3' + } + ) + + def test_update_specimen_on_multiple_nonnone_successful(self): + uuid = '094eb162-d5df-494b-a458-a91a1b2d2c65' + + specimen = MyModel( + y=('y1', 'y2'), x=(5, 7), uuid=uuid + ) + + result = self.session.query(MyModel).update_on_match( + specimen, + 'uuid', + values={'x': 9, 'z': 'z3'}, + ) + + self.assertIn(result, self.session) + self.assertEqual(uuid, result.uuid) + self.assertEqual(3, result.id) + self.assertEqual('z3', result.z) + self._assert_row( + 3, + { + 'uuid': '094eb162-d5df-494b-a458-a91a1b2d2c65', + 'x': 9, 'y': 'y1', 'z': 'z3' + } + ) + + def test_update_specimen_on_multiple_wnone_successful(self): + uuid = 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4' + specimen = MyModel( + y=('y1', 'y2'), x=(8, 7), z=('z1', 'z2', None), uuid=uuid + ) + + result = self.session.query(MyModel).update_on_match( + specimen, + 'uuid', + values={'x': 9, 'z': 'z3'}, + ) + + self.assertIn(result, self.session) + self.assertEqual(uuid, result.uuid) + self.assertEqual(5, result.id) + self.assertEqual('z3', result.z) + self._assert_row( + 5, + { + 'uuid': 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4', + 'x': 9, 'y': 'y2', 'z': 'z3' + } + ) + + def test_update_returning_pk_matched(self): + pk = self.session.query(MyModel).\ + filter_by(y='y1', z='z2').update_returning_pk( + {'x': 9, 'z': 'z3'}, + ('uuid', '136254d5-3869-408f-9da7-190e0072641a') + ) + + self.assertEqual((2,), pk) + self._assert_row( + 2, + { + 'uuid': '136254d5-3869-408f-9da7-190e0072641a', + 'x': 9, 'y': 'y1', 'z': 'z3' + } + ) + + def test_update_returning_wrong_uuid(self): + exc = self.assertRaises( + update_match.NoRowsMatched, + self.session.query(MyModel). + filter_by(y='y1', z='z2').update_returning_pk, + {'x': 9, 'z': 'z3'}, + ('uuid', '23cb9224-9f8e-40fe-bd3c-e7577b7af37d') + ) + + self.assertEqual("No rows matched the UPDATE", exc.args[0]) + + def test_update_returning_no_rows(self): + exc = self.assertRaises( + update_match.NoRowsMatched, + self.session.query(MyModel). + filter_by(y='y1', z='z3').update_returning_pk, + {'x': 9, 'z': 'z3'}, + ('uuid', '136254d5-3869-408f-9da7-190e0072641a') + ) + + self.assertEqual("No rows matched the UPDATE", exc.args[0]) + + def test_update_multiple_rows(self): + exc = self.assertRaises( + update_match.MultiRowsMatched, + self.session.query(MyModel). + filter_by(y='y1', z='z1').update_returning_pk, + {'x': 9, 'z': 'z3'}, + ('y', 'y1') + ) + + self.assertEqual("2 rows matched; expected one", exc.args[0]) + + +class PGUpdateMatchTest( + UpdateMatchTest, + test_base.PostgreSQLOpportunisticTestCase): + pass + + +class MySQLUpdateMatchTest( + UpdateMatchTest, + test_base.MySQLOpportunisticTestCase): + pass |