summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-01-09 18:59:11 -0500
committerMatthew Booth <mbooth@redhat.com>2015-03-05 14:06:59 +0000
commite0baed656edf470a22d7a5f1610b70017002b611 (patch)
tree111731863e36761be1036545e0f5a71d6e2afd15
parent7bfdb6a704855984ae35a1c6ef063782a4f7bf1d (diff)
downloadoslo-db-e0baed656edf470a22d7a5f1610b70017002b611.tar.gz
Implement generic update-on-match feature
This feature provides the query.update_on_match() and query.update_returning_pk() methods, as well as the manufacture_persistent_object(), manufacture_entity_criteria(), and manufacture_criteria() utility functions. query.update_on_match() is used to UPDATE a row based on a variety of criteria, and to then return a fully persistent object state representing the row that was matched. It essentially intends to provide an UPDATE that is guaranteed to have matched a specific row in the presence of potential race conditions without using any locking, and to then return a record of that row as if it had been SELECTed. query.update_returning_pk() is a public method that also serves as part of the implementation of query.update_on_match(); this method delivers an UPDATE statement such that the primary key of the single row matched is returned; if zero or multiple rows are matched, and error is raised. To handle this, several backend-specific strategies are provided, which are automatically selected based on the best available. The lowest strategy performs a re-SELECT, but still assumes there's a simple unique column to be queried on, as is currently the use case in Nova (uuid is present). On Postgresql, MySQL and other databases besides SQLite and possibly DB2, more atomic strategies are used. Change-Id: I059f4ae6e72cfa6681a179314144214639f283ef
-rw-r--r--oslo_db/sqlalchemy/session.py22
-rw-r--r--oslo_db/sqlalchemy/update_match.py508
-rw-r--r--oslo_db/sqlalchemy/utils.py12
-rw-r--r--oslo_db/tests/sqlalchemy/test_update_match.py445
4 files changed, 987 insertions, 0 deletions
diff --git a/oslo_db/sqlalchemy/session.py b/oslo_db/sqlalchemy/session.py
index 7e33075..7c83d8d 100644
--- a/oslo_db/sqlalchemy/session.py
+++ b/oslo_db/sqlalchemy/session.py
@@ -295,6 +295,7 @@ from oslo_db import exception
from oslo_db import options
from oslo_db.sqlalchemy import compat
from oslo_db.sqlalchemy import exc_filters
+from oslo_db.sqlalchemy import update_match
from oslo_db.sqlalchemy import utils
LOG = logging.getLogger(__name__)
@@ -595,6 +596,27 @@ class Query(sqlalchemy.orm.query.Query):
'deleted_at': timeutils.utcnow()},
synchronize_session=synchronize_session)
+ def update_returning_pk(self, values, surrogate_key):
+ """Perform an UPDATE, returning the primary key of the matched row.
+
+ This is a method-version of
+ oslo_db.sqlalchemy.update_match.update_returning_pk(); see that
+ function for usage details.
+
+ """
+ return update_match.update_returning_pk(self, values, surrogate_key)
+
+ def update_on_match(self, specimen, surrogate_key, values, **kw):
+ """Emit an UPDATE statement matching the given specimen.
+
+ This is a method-version of
+ oslo_db.sqlalchemy.update_match.update_on_match(); see that function
+ for usage details.
+
+ """
+ return update_match.update_on_match(
+ self, specimen, surrogate_key, values, **kw)
+
class Session(sqlalchemy.orm.session.Session):
"""Custom Session class to avoid SqlAlchemy Session monkey patching."""
diff --git a/oslo_db/sqlalchemy/update_match.py b/oslo_db/sqlalchemy/update_match.py
new file mode 100644
index 0000000..692a72c
--- /dev/null
+++ b/oslo_db/sqlalchemy/update_match.py
@@ -0,0 +1,508 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import copy
+
+from sqlalchemy import inspect
+from sqlalchemy import orm
+from sqlalchemy import sql
+from sqlalchemy import types as sqltypes
+
+from oslo_db.sqlalchemy import utils
+
+
+def update_on_match(
+ query,
+ specimen,
+ surrogate_key,
+ values=None,
+ attempts=3,
+ include_only=None,
+ process_query=None,
+ handle_failure=None
+):
+ """Emit an UPDATE statement matching the given specimen.
+
+ E.g.::
+
+ with enginefacade.writer() as session:
+ specimen = MyInstance(
+ uuid='ccea54f',
+ interface_id='ad33fea',
+ vm_state='SOME_VM_STATE',
+ )
+
+ values = {
+ 'vm_state': 'SOME_NEW_VM_STATE'
+ }
+
+ base_query = model_query(
+ context, models.Instance,
+ project_only=True, session=session)
+
+ hostname_query = model_query(
+ context, models.Instance, session=session,
+ read_deleted='no').
+ filter(func.lower(models.Instance.hostname) == 'SOMEHOSTNAME')
+
+ surrogate_key = ('uuid', )
+
+ def process_query(query):
+ return query.where(~exists(hostname_query))
+
+ def handle_failure(query):
+ try:
+ instance = base_query.one()
+ except NoResultFound:
+ raise exception.InstanceNotFound(instance_id=instance_uuid)
+
+ if session.query(hostname_query.exists()).scalar():
+ raise exception.InstanceExists(
+ name=values['hostname'].lower())
+
+ # try again
+ return False
+
+ peristent_instance = base_query.update_on_match(
+ specimen,
+ surrogate_key,
+ values=values,
+ process_query=process_query,
+ handle_failure=handle_failure
+ )
+
+ The UPDATE statement is constructed against the given specimen
+ using those values which are present to construct a WHERE clause.
+ If the specimen contains additional values to be ignored, the
+ ``include_only`` parameter may be passed which indicates a sequence
+ of attributes to use when constructing the WHERE.
+
+ The UPDATE is performed against an ORM Query, which is created from
+ the given ``Session``, or alternatively by passing the ```query``
+ parameter referring to an existing query.
+
+ Before the query is invoked, it is also passed through the callable
+ sent as ``process_query``, if present. This hook allows additional
+ criteria to be added to the query after it is created but before
+ invocation.
+
+ The function will then invoke the UPDATE statement and check for
+ "success" one or more times, up to a maximum of that passed as
+ ``attempts``.
+
+ The initial check for "success" from the UPDATE statement is that the
+ number of rows returned matches 1. If zero rows are matched, then
+ the UPDATE statement is assumed to have "failed", and the failure handling
+ phase begins.
+
+ The failure handling phase involves invoking the given ``handle_failure``
+ function, if any. This handler can perform additional queries to attempt
+ to figure out why the UPDATE didn't match any rows. The handler,
+ upon detection of the exact failure condition, should throw an exception
+ to exit; if it doesn't, it has the option of returning True or False,
+ where False means the error was not handled, and True means that there
+ was not in fact an error, and the function should return successfully.
+
+ If the failure handler is not present, or returns False after ``attempts``
+ number of attempts, then the function overall raises CantUpdateException.
+ If the handler returns True, then the function returns with no error.
+
+ The return value of the function is a persistent version of the given
+ specimen; this may be the specimen itself, if no matching object were
+ already present in the session; otherwise, the existing object is
+ returned, with the state of the specimen merged into it. The returned
+ persistent object will have the given values populated into the object.
+
+ The object is is returned as "persistent", meaning that it is
+ associated with the given
+ Session and has an identity key (that is, a real primary key
+ value).
+
+ In order to produce this identity key, a strategy must be used to
+ determine it as efficiently and safely as possible:
+
+ 1. If the given specimen already contained its primary key attributes
+ fully populated, then these attributes were used as criteria in the
+ UPDATE, so we have the primary key value; it is populated directly.
+
+ 2. If the target backend supports RETURNING, then when the update() query
+ is performed with a RETURNING clause so that the matching primary key
+ is returned atomically. This currently includes Postgresql, Oracle
+ and others (notably not MySQL or SQLite).
+
+ 3. If the target backend is MySQL, and the given model uses a
+ single-column, AUTO_INCREMENT integer primary key value (as is
+ the case for Nova), MySQL's recommended approach of making use
+ of ``LAST_INSERT_ID(expr)`` is used to atomically acquire the
+ matching primary key value within the scope of the UPDATE
+ statement, then it fetched immediately following by using
+ ``SELECT LAST_INSERT_ID()``.
+ http://dev.mysql.com/doc/refman/5.0/en/information-\
+ functions.html#function_last-insert-id
+
+ 4. Otherwise, for composite keys on MySQL or other backends such
+ as SQLite, the row as UPDATED must be re-fetched in order to
+ acquire the primary key value. The ``surrogate_key``
+ parameter is used for this in order to re-fetch the row; this
+ is a column name with a known, unique value where
+ the object can be fetched.
+
+
+ """
+
+ if values is None:
+ values = {}
+
+ entity = inspect(specimen)
+ mapper = entity.mapper
+ assert \
+ [desc['type'] for desc in query.column_descriptions] == \
+ [mapper.class_], "Query does not match given specimen"
+
+ criteria = manufacture_entity_criteria(
+ specimen, include_only=include_only, exclude=[surrogate_key])
+
+ query = query.filter(criteria)
+
+ if process_query:
+ query = process_query(query)
+
+ surrogate_key_arg = (
+ surrogate_key, entity.attrs[surrogate_key].loaded_value)
+ pk_value = None
+
+ for attempt in range(attempts):
+ try:
+ pk_value = query.update_returning_pk(values, surrogate_key_arg)
+ except MultiRowsMatched:
+ raise
+ except NoRowsMatched:
+ if handle_failure and handle_failure(query):
+ break
+ else:
+ break
+ else:
+ raise NoRowsMatched("Zero rows matched for %d attempts" % attempts)
+
+ if pk_value is None:
+ pk_value = entity.mapper.primary_key_from_instance(specimen)
+
+ # NOTE(mdbooth): Can't pass the original specimen object here as it might
+ # have lists of multiple potential values rather than actual values.
+ values = copy.copy(values)
+ values[surrogate_key] = surrogate_key_arg[1]
+ persistent_obj = manufacture_persistent_object(
+ query.session, specimen.__class__(), values, pk_value)
+
+ return persistent_obj
+
+
+def manufacture_persistent_object(
+ session, specimen, values=None, primary_key=None):
+ """Make an ORM-mapped object persistent in a Session without SQL.
+
+ The persistent object is returned.
+
+ If a matching object is already present in the given session, the specimen
+ is merged into it and the persistent object returned. Otherwise, the
+ specimen itself is made persistent and is returned.
+
+ The object must contain a full primary key, or provide it via the values or
+ primary_key parameters. The object is peristed to the Session in a "clean"
+ state with no pending changes.
+
+ :param session: A Session object.
+
+ :param specimen: a mapped object which is typically transient.
+
+ :param values: a dictionary of values to be applied to the specimen,
+ in addition to the state that's already on it. The attributes will be
+ set such that no history is created; the object remains clean.
+
+ :param primary_key: optional tuple-based primary key. This will also
+ be applied to the instance if present.
+
+
+ """
+ state = inspect(specimen)
+ mapper = state.mapper
+
+ for k, v in values.items():
+ orm.attributes.set_committed_value(specimen, k, v)
+
+ pk_attrs = [
+ mapper.get_property_by_column(col).key
+ for col in mapper.primary_key
+ ]
+
+ if primary_key is not None:
+ for key, value in zip(pk_attrs, primary_key):
+ orm.attributes.set_committed_value(
+ specimen,
+ key,
+ value
+ )
+
+ for key in pk_attrs:
+ if state.attrs[key].loaded_value is orm.attributes.NO_VALUE:
+ raise ValueError("full primary key must be present")
+
+ orm.make_transient_to_detached(specimen)
+
+ if state.key not in session.identity_map:
+ session.add(specimen)
+ return specimen
+ else:
+ return session.merge(specimen, load=False)
+
+
+def manufacture_entity_criteria(entity, include_only=None, exclude=None):
+ """Given a mapped instance, produce a WHERE clause.
+
+ The attributes set upon the instance will be combined to produce
+ a SQL expression using the mapped SQL expressions as the base
+ of comparison.
+
+ Values on the instance may be set as tuples in which case the
+ criteria will produce an IN clause. None is also acceptable as a
+ scalar or tuple entry, which will produce IS NULL that is properly
+ joined with an OR against an IN expression if appropriate.
+
+ :param entity: a mapped entity.
+
+ :param include_only: optional sequence of keys to limit which
+ keys are included.
+
+ :param exclude: sequence of keys to exclude
+
+ """
+
+ state = inspect(entity)
+ exclude = set(exclude) if exclude is not None else set()
+
+ existing = dict(
+ (attr.key, attr.loaded_value)
+ for attr in state.attrs
+ if attr.loaded_value is not orm.attributes.NO_VALUE
+ and attr.key not in exclude
+ )
+ if include_only:
+ existing = dict(
+ (k, existing[k])
+ for k in set(existing).intersection(include_only)
+ )
+
+ return manufacture_criteria(state.mapper, existing)
+
+
+def manufacture_criteria(mapped, values):
+ """Given a mapper/class and a namespace of values, produce a WHERE clause.
+
+ The class should be a mapped class and the entries in the dictionary
+ correspond to mapped attribute names on the class.
+
+ A value may also be a tuple in which case that particular attribute
+ will be compared to a tuple using IN. The scalar value or
+ tuple can also contain None which translates to an IS NULL, that is
+ properly joined with OR against an IN expression if appropriate.
+
+ :param cls: a mapped class, or actual :class:`.Mapper` object.
+
+ :param values: dictionary of values.
+
+ """
+
+ mapper = inspect(mapped)
+
+ # organize keys using mapped attribute ordering, which is deterministic
+ value_keys = set(values)
+ keys = [k for k in mapper.column_attrs.keys() if k in value_keys]
+ return sql.and_(*[
+ _sql_crit(mapper.column_attrs[key].expression, values[key])
+ for key in keys
+ ])
+
+
+def _sql_crit(expression, value):
+ """Produce an equality expression against the given value.
+
+ This takes into account a value that is actually a collection
+ of values, as well as a value of None or collection that contains
+ None.
+
+ """
+
+ values = utils.to_list(value, default=(None, ))
+ if len(values) == 1:
+ if values[0] is None:
+ return expression == sql.null()
+ else:
+ return expression == values[0]
+ elif _none_set.intersection(values):
+ return sql.or_(
+ expression == sql.null(),
+ _sql_crit(expression, set(values).difference(_none_set))
+ )
+ else:
+ return expression.in_(values)
+
+
+def update_returning_pk(query, values, surrogate_key):
+ """Perform an UPDATE, returning the primary key of the matched row.
+
+ The primary key is returned using a selection of strategies:
+
+ * if the database supports RETURNING, RETURNING is used to retrieve
+ the primary key values inline.
+
+ * If the database is MySQL and the entity is mapped to a single integer
+ primary key column, MySQL's last_insert_id() function is used
+ inline within the UPDATE and then upon a second SELECT to get the
+ value.
+
+ * Otherwise, a "refetch" strategy is used, where a given "surrogate"
+ key value (typically a UUID column on the entity) is used to run
+ a new SELECT against that UUID. This UUID is also placed into
+ the UPDATE query to ensure the row matches.
+
+ :param query: a Query object with existing criterion, against a single
+ entity.
+
+ :param values: a dictionary of values to be updated on the row.
+
+ :param surrogate_key: a tuple of (attrname, value), referring to a
+ UNIQUE attribute that will also match the row. This attribute is used
+ to retrieve the row via a SELECT when no optimized strategy exists.
+
+ :return: the primary key, returned as a tuple.
+ Is only returned if rows matched is one. Otherwise, CantUpdateException
+ is raised.
+
+ """
+
+ entity = query.column_descriptions[0]['type']
+ mapper = inspect(entity).mapper
+ session = query.session
+
+ bind = session.connection(mapper=mapper)
+ if bind.dialect.implicit_returning:
+ pk_strategy = _pk_strategy_returning
+ elif bind.dialect.name == 'mysql' and \
+ len(mapper.primary_key) == 1 and \
+ isinstance(
+ mapper.primary_key[0].type, sqltypes.Integer):
+ pk_strategy = _pk_strategy_mysql_last_insert_id
+ else:
+ pk_strategy = _pk_strategy_refetch
+
+ return pk_strategy(query, mapper, values, surrogate_key)
+
+
+def _assert_single_row(rows_updated):
+ if rows_updated == 1:
+ return rows_updated
+ elif rows_updated > 1:
+ raise MultiRowsMatched("%d rows matched; expected one" % rows_updated)
+ else:
+ raise NoRowsMatched("No rows matched the UPDATE")
+
+
+def _pk_strategy_refetch(query, mapper, values, surrogate_key):
+
+ surrogate_key_name, surrogate_key_value = surrogate_key
+ surrogate_key_col = mapper.attrs[surrogate_key_name].expression
+
+ rowcount = query.\
+ filter(surrogate_key_col == surrogate_key_value).\
+ update(values, synchronize_session=False)
+
+ _assert_single_row(rowcount)
+ # SELECT my_table.id AS my_table_id FROM my_table
+ # WHERE my_table.y = ? AND my_table.z = ?
+ # LIMIT ? OFFSET ?
+ fetch_query = query.session.query(
+ *mapper.primary_key).filter(
+ surrogate_key_col == surrogate_key_value)
+
+ primary_key = fetch_query.one()
+
+ return primary_key
+
+
+def _pk_strategy_returning(query, mapper, values, surrogate_key):
+ surrogate_key_name, surrogate_key_value = surrogate_key
+ surrogate_key_col = mapper.attrs[surrogate_key_name].expression
+
+ update_stmt = _update_stmt_from_query(mapper, query, values)
+ update_stmt = update_stmt.where(surrogate_key_col == surrogate_key_value)
+ update_stmt = update_stmt.returning(*mapper.primary_key)
+
+ # UPDATE my_table SET x=%(x)s, z=%(z)s WHERE my_table.y = %(y_1)s
+ # AND my_table.z = %(z_1)s RETURNING my_table.id
+ result = query.session.execute(update_stmt)
+ rowcount = result.rowcount
+ _assert_single_row(rowcount)
+ primary_key = tuple(result.first())
+
+ return primary_key
+
+
+def _pk_strategy_mysql_last_insert_id(query, mapper, values, surrogate_key):
+
+ surrogate_key_name, surrogate_key_value = surrogate_key
+ surrogate_key_col = mapper.attrs[surrogate_key_name].expression
+
+ surrogate_pk_col = mapper.primary_key[0]
+ update_stmt = _update_stmt_from_query(mapper, query, values)
+ update_stmt = update_stmt.where(surrogate_key_col == surrogate_key_value)
+ update_stmt = update_stmt.values(
+ {surrogate_pk_col: sql.func.last_insert_id(surrogate_pk_col)})
+
+ # UPDATE my_table SET id=last_insert_id(my_table.id),
+ # x=%s, z=%s WHERE my_table.y = %s AND my_table.z = %s
+ result = query.session.execute(update_stmt)
+ rowcount = result.rowcount
+ _assert_single_row(rowcount)
+ # SELECT last_insert_id() AS last_insert_id_1
+ primary_key = query.session.scalar(sql.func.last_insert_id()),
+
+ return primary_key
+
+
+def _update_stmt_from_query(mapper, query, values):
+ upd_values = dict(
+ (
+ mapper.column_attrs[key], value
+ ) for key, value in values.items()
+ )
+ query = query.enable_eagerloads(False)
+ context = query._compile_context()
+ primary_table = context.statement.froms[0]
+ update_stmt = sql.update(primary_table,
+ context.whereclause,
+ upd_values)
+ return update_stmt
+
+
+_none_set = frozenset([None])
+
+
+class CantUpdateException(Exception):
+ pass
+
+
+class NoRowsMatched(CantUpdateException):
+ pass
+
+
+class MultiRowsMatched(CantUpdateException):
+ pass
diff --git a/oslo_db/sqlalchemy/utils.py b/oslo_db/sqlalchemy/utils.py
index 932d409..e7c5b4a 100644
--- a/oslo_db/sqlalchemy/utils.py
+++ b/oslo_db/sqlalchemy/utils.py
@@ -188,6 +188,18 @@ def paginate_query(query, model, limit, sort_keys, marker=None,
return query
+def to_list(x, default=None):
+ if x is None:
+ return default
+ if not isinstance(x, collections.Iterable) or \
+ isinstance(x, six.string_types):
+ return [x]
+ elif isinstance(x, list):
+ return x
+ else:
+ return list(x)
+
+
def _read_deleted_filter(query, db_model, deleted):
if 'deleted' not in db_model.__table__.columns:
raise ValueError(_("There is no `deleted` column in `%s` table. "
diff --git a/oslo_db/tests/sqlalchemy/test_update_match.py b/oslo_db/tests/sqlalchemy/test_update_match.py
new file mode 100644
index 0000000..ecc7af7
--- /dev/null
+++ b/oslo_db/tests/sqlalchemy/test_update_match.py
@@ -0,0 +1,445 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+from oslotest import base as oslo_test_base
+from sqlalchemy.ext import declarative
+from sqlalchemy import schema
+from sqlalchemy import sql
+from sqlalchemy import types as sqltypes
+
+from oslo_db.sqlalchemy import test_base
+from oslo_db.sqlalchemy import update_match
+
+Base = declarative.declarative_base()
+
+
+class MyModel(Base):
+ __tablename__ = 'my_table'
+
+ id = schema.Column(sqltypes.Integer, primary_key=True)
+ uuid = schema.Column(sqltypes.String(36), nullable=False, unique=True)
+ x = schema.Column(sqltypes.Integer)
+ y = schema.Column(sqltypes.String(40))
+ z = schema.Column(sqltypes.String(40))
+
+
+class ManufactureCriteriaTest(oslo_test_base.BaseTestCase):
+ def test_instance_criteria_basic(self):
+ specimen = MyModel(
+ y='y1', z='z3',
+ uuid='136254d5-3869-408f-9da7-190e0072641a'
+ )
+ self.assertEqual(
+ "my_table.uuid = :uuid_1 AND my_table.y = :y_1 "
+ "AND my_table.z = :z_1",
+ str(update_match.manufacture_entity_criteria(specimen).compile())
+ )
+
+ def test_instance_criteria_basic_wnone(self):
+ specimen = MyModel(
+ y='y1', z=None,
+ uuid='136254d5-3869-408f-9da7-190e0072641a'
+ )
+ self.assertEqual(
+ "my_table.uuid = :uuid_1 AND my_table.y = :y_1 "
+ "AND my_table.z IS NULL",
+ str(update_match.manufacture_entity_criteria(specimen).compile())
+ )
+
+ def test_instance_criteria_tuples(self):
+ specimen = MyModel(
+ y='y1', z=('z1', 'z2'),
+ )
+ self.assertEqual(
+ "my_table.y = :y_1 AND my_table.z IN (:z_1, :z_2)",
+ str(update_match.manufacture_entity_criteria(specimen).compile())
+ )
+
+ def test_instance_criteria_tuples_wnone(self):
+ specimen = MyModel(
+ y='y1', z=('z1', 'z2', None),
+ )
+ self.assertEqual(
+ "my_table.y = :y_1 AND (my_table.z IS NULL OR "
+ "my_table.z IN (:z_1, :z_2))",
+ str(update_match.manufacture_entity_criteria(specimen).compile())
+ )
+
+ def test_instance_criteria_none_list(self):
+ specimen = MyModel(
+ y='y1', z=[None],
+ )
+ self.assertEqual(
+ "my_table.y = :y_1 AND my_table.z IS NULL",
+ str(update_match.manufacture_entity_criteria(specimen).compile())
+ )
+
+
+class UpdateMatchTest(test_base.DbTestCase):
+ def setUp(self):
+ super(UpdateMatchTest, self).setUp()
+ Base.metadata.create_all(self.engine)
+ self.addCleanup(Base.metadata.drop_all, self.engine)
+ # self.engine.echo = 'debug'
+ self.session = self.sessionmaker(autocommit=False)
+ self.addCleanup(self.session.close)
+ self.session.add_all([
+ MyModel(
+ id=1,
+ uuid='23cb9224-9f8e-40fe-bd3c-e7577b7af37d',
+ x=5, y='y1', z='z1'),
+ MyModel(
+ id=2,
+ uuid='136254d5-3869-408f-9da7-190e0072641a',
+ x=6, y='y1', z='z2'),
+ MyModel(
+ id=3,
+ uuid='094eb162-d5df-494b-a458-a91a1b2d2c65',
+ x=7, y='y1', z='z1'),
+ MyModel(
+ id=4,
+ uuid='94659b3f-ea1f-4ffd-998d-93b28f7f5b70',
+ x=8, y='y2', z='z2'),
+ MyModel(
+ id=5,
+ uuid='bdf3893c-ee3c-40a0-bc79-960adb6cd1d4',
+ x=8, y='y2', z=None),
+ ])
+
+ self.session.commit()
+
+ def _assert_row(self, pk, values):
+ row = self.session.execute(
+ sql.select([MyModel.__table__]).where(MyModel.__table__.c.id == pk)
+ ).first()
+ values['id'] = pk
+ self.assertEqual(values, dict(row))
+
+ def test_update_specimen_successful(self):
+ uuid = '136254d5-3869-408f-9da7-190e0072641a'
+
+ specimen = MyModel(
+ y='y1', z='z2', uuid=uuid
+ )
+
+ result = self.session.query(MyModel).update_on_match(
+ specimen,
+ 'uuid',
+ values={'x': 9, 'z': 'z3'}
+ )
+
+ self.assertEqual(uuid, result.uuid)
+ self.assertEqual(2, result.id)
+ self.assertEqual('z3', result.z)
+ self.assertIn(result, self.session)
+
+ self._assert_row(
+ 2,
+ {
+ 'uuid': '136254d5-3869-408f-9da7-190e0072641a',
+ 'x': 9, 'y': 'y1', 'z': 'z3'
+ }
+ )
+
+ def test_update_specimen_include_only(self):
+ uuid = '136254d5-3869-408f-9da7-190e0072641a'
+
+ specimen = MyModel(
+ y='y9', z='z5', x=6, uuid=uuid
+ )
+
+ # Query the object first to test that we merge when the object is
+ # already cached in the session.
+ self.session.query(MyModel).filter(MyModel.uuid == uuid).one()
+
+ result = self.session.query(MyModel).update_on_match(
+ specimen,
+ 'uuid',
+ values={'x': 9, 'z': 'z3'},
+ include_only=('x', )
+ )
+
+ self.assertEqual(uuid, result.uuid)
+ self.assertEqual(2, result.id)
+ self.assertEqual('z3', result.z)
+ self.assertIn(result, self.session)
+ self.assertNotIn(result, self.session.dirty)
+
+ self._assert_row(
+ 2,
+ {
+ 'uuid': '136254d5-3869-408f-9da7-190e0072641a',
+ 'x': 9, 'y': 'y1', 'z': 'z3'
+ }
+ )
+
+ def test_update_specimen_no_rows(self):
+ specimen = MyModel(
+ y='y1', z='z3',
+ uuid='136254d5-3869-408f-9da7-190e0072641a'
+ )
+
+ exc = self.assertRaises(
+ update_match.NoRowsMatched,
+ self.session.query(MyModel).update_on_match,
+ specimen, 'uuid', values={'x': 9, 'z': 'z3'}
+ )
+
+ self.assertEqual("Zero rows matched for 3 attempts", exc.args[0])
+
+ def test_update_specimen_process_query_no_rows(self):
+ specimen = MyModel(
+ y='y1', z='z2',
+ uuid='136254d5-3869-408f-9da7-190e0072641a'
+ )
+
+ def process_query(query):
+ return query.filter_by(x=10)
+
+ exc = self.assertRaises(
+ update_match.NoRowsMatched,
+ self.session.query(MyModel).update_on_match,
+ specimen, 'uuid', values={'x': 9, 'z': 'z3'},
+ process_query=process_query
+ )
+
+ self.assertEqual("Zero rows matched for 3 attempts", exc.args[0])
+
+ def test_update_specimen_given_query_no_rows(self):
+ specimen = MyModel(
+ y='y1', z='z2',
+ uuid='136254d5-3869-408f-9da7-190e0072641a'
+ )
+
+ query = self.session.query(MyModel).filter_by(x=10)
+
+ exc = self.assertRaises(
+ update_match.NoRowsMatched,
+ query.update_on_match,
+ specimen, 'uuid', values={'x': 9, 'z': 'z3'},
+ )
+
+ self.assertEqual("Zero rows matched for 3 attempts", exc.args[0])
+
+ def test_update_specimen_multi_rows(self):
+ specimen = MyModel(
+ y='y1', z='z1',
+ )
+
+ exc = self.assertRaises(
+ update_match.MultiRowsMatched,
+ self.session.query(MyModel).update_on_match,
+ specimen, 'y', values={'x': 9, 'z': 'z3'}
+ )
+
+ self.assertEqual("2 rows matched; expected one", exc.args[0])
+
+ def test_update_specimen_query_mismatch_error(self):
+ specimen = MyModel(
+ y='y1'
+ )
+ q = self.session.query(MyModel.x, MyModel.y)
+ exc = self.assertRaises(
+ AssertionError,
+ q.update_on_match,
+ specimen, 'y', values={'x': 9, 'z': 'z3'},
+ )
+
+ self.assertEqual("Query does not match given specimen", exc.args[0])
+
+ def test_custom_handle_failure_raise_new(self):
+ class MyException(Exception):
+ pass
+
+ def handle_failure(query):
+ # ensure the query is usable
+ result = query.count()
+ self.assertEqual(0, result)
+
+ raise MyException("test: %d" % result)
+
+ specimen = MyModel(
+ y='y1', z='z3',
+ uuid='136254d5-3869-408f-9da7-190e0072641a'
+ )
+
+ exc = self.assertRaises(
+ MyException,
+ self.session.query(MyModel).update_on_match,
+ specimen, 'uuid', values={'x': 9, 'z': 'z3'},
+ handle_failure=handle_failure
+ )
+
+ self.assertEqual("test: 0", exc.args[0])
+
+ def test_custom_handle_failure_cancel_raise(self):
+ uuid = '136254d5-3869-408f-9da7-190e0072641a'
+
+ class MyException(Exception):
+ pass
+
+ def handle_failure(query):
+ # ensure the query is usable
+ result = query.count()
+ self.assertEqual(0, result)
+
+ return True
+
+ specimen = MyModel(
+ id=2, y='y1', z='z3', uuid=uuid
+ )
+
+ result = self.session.query(MyModel).update_on_match(
+ specimen, 'uuid', values={'x': 9, 'z': 'z3'},
+ handle_failure=handle_failure
+ )
+ self.assertEqual(uuid, result.uuid)
+ self.assertEqual(2, result.id)
+ self.assertEqual('z3', result.z)
+ self.assertEqual(9, result.x)
+ self.assertIn(result, self.session)
+
+ def test_update_specimen_on_none_successful(self):
+ uuid = 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4'
+
+ specimen = MyModel(
+ y='y2', z=None, uuid=uuid
+ )
+
+ result = self.session.query(MyModel).update_on_match(
+ specimen,
+ 'uuid',
+ values={'x': 9, 'z': 'z3'},
+ )
+
+ self.assertIn(result, self.session)
+ self.assertEqual(uuid, result.uuid)
+ self.assertEqual(5, result.id)
+ self.assertEqual('z3', result.z)
+ self._assert_row(
+ 5,
+ {
+ 'uuid': 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4',
+ 'x': 9, 'y': 'y2', 'z': 'z3'
+ }
+ )
+
+ def test_update_specimen_on_multiple_nonnone_successful(self):
+ uuid = '094eb162-d5df-494b-a458-a91a1b2d2c65'
+
+ specimen = MyModel(
+ y=('y1', 'y2'), x=(5, 7), uuid=uuid
+ )
+
+ result = self.session.query(MyModel).update_on_match(
+ specimen,
+ 'uuid',
+ values={'x': 9, 'z': 'z3'},
+ )
+
+ self.assertIn(result, self.session)
+ self.assertEqual(uuid, result.uuid)
+ self.assertEqual(3, result.id)
+ self.assertEqual('z3', result.z)
+ self._assert_row(
+ 3,
+ {
+ 'uuid': '094eb162-d5df-494b-a458-a91a1b2d2c65',
+ 'x': 9, 'y': 'y1', 'z': 'z3'
+ }
+ )
+
+ def test_update_specimen_on_multiple_wnone_successful(self):
+ uuid = 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4'
+ specimen = MyModel(
+ y=('y1', 'y2'), x=(8, 7), z=('z1', 'z2', None), uuid=uuid
+ )
+
+ result = self.session.query(MyModel).update_on_match(
+ specimen,
+ 'uuid',
+ values={'x': 9, 'z': 'z3'},
+ )
+
+ self.assertIn(result, self.session)
+ self.assertEqual(uuid, result.uuid)
+ self.assertEqual(5, result.id)
+ self.assertEqual('z3', result.z)
+ self._assert_row(
+ 5,
+ {
+ 'uuid': 'bdf3893c-ee3c-40a0-bc79-960adb6cd1d4',
+ 'x': 9, 'y': 'y2', 'z': 'z3'
+ }
+ )
+
+ def test_update_returning_pk_matched(self):
+ pk = self.session.query(MyModel).\
+ filter_by(y='y1', z='z2').update_returning_pk(
+ {'x': 9, 'z': 'z3'},
+ ('uuid', '136254d5-3869-408f-9da7-190e0072641a')
+ )
+
+ self.assertEqual((2,), pk)
+ self._assert_row(
+ 2,
+ {
+ 'uuid': '136254d5-3869-408f-9da7-190e0072641a',
+ 'x': 9, 'y': 'y1', 'z': 'z3'
+ }
+ )
+
+ def test_update_returning_wrong_uuid(self):
+ exc = self.assertRaises(
+ update_match.NoRowsMatched,
+ self.session.query(MyModel).
+ filter_by(y='y1', z='z2').update_returning_pk,
+ {'x': 9, 'z': 'z3'},
+ ('uuid', '23cb9224-9f8e-40fe-bd3c-e7577b7af37d')
+ )
+
+ self.assertEqual("No rows matched the UPDATE", exc.args[0])
+
+ def test_update_returning_no_rows(self):
+ exc = self.assertRaises(
+ update_match.NoRowsMatched,
+ self.session.query(MyModel).
+ filter_by(y='y1', z='z3').update_returning_pk,
+ {'x': 9, 'z': 'z3'},
+ ('uuid', '136254d5-3869-408f-9da7-190e0072641a')
+ )
+
+ self.assertEqual("No rows matched the UPDATE", exc.args[0])
+
+ def test_update_multiple_rows(self):
+ exc = self.assertRaises(
+ update_match.MultiRowsMatched,
+ self.session.query(MyModel).
+ filter_by(y='y1', z='z1').update_returning_pk,
+ {'x': 9, 'z': 'z3'},
+ ('y', 'y1')
+ )
+
+ self.assertEqual("2 rows matched; expected one", exc.args[0])
+
+
+class PGUpdateMatchTest(
+ UpdateMatchTest,
+ test_base.PostgreSQLOpportunisticTestCase):
+ pass
+
+
+class MySQLUpdateMatchTest(
+ UpdateMatchTest,
+ test_base.MySQLOpportunisticTestCase):
+ pass