diff options
Diffstat (limited to 'oslo_db/sqlalchemy/update_match.py')
-rw-r--r-- | oslo_db/sqlalchemy/update_match.py | 508 |
1 files changed, 508 insertions, 0 deletions
diff --git a/oslo_db/sqlalchemy/update_match.py b/oslo_db/sqlalchemy/update_match.py new file mode 100644 index 0000000..692a72c --- /dev/null +++ b/oslo_db/sqlalchemy/update_match.py @@ -0,0 +1,508 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy + +from sqlalchemy import inspect +from sqlalchemy import orm +from sqlalchemy import sql +from sqlalchemy import types as sqltypes + +from oslo_db.sqlalchemy import utils + + +def update_on_match( + query, + specimen, + surrogate_key, + values=None, + attempts=3, + include_only=None, + process_query=None, + handle_failure=None +): + """Emit an UPDATE statement matching the given specimen. + + E.g.:: + + with enginefacade.writer() as session: + specimen = MyInstance( + uuid='ccea54f', + interface_id='ad33fea', + vm_state='SOME_VM_STATE', + ) + + values = { + 'vm_state': 'SOME_NEW_VM_STATE' + } + + base_query = model_query( + context, models.Instance, + project_only=True, session=session) + + hostname_query = model_query( + context, models.Instance, session=session, + read_deleted='no'). + filter(func.lower(models.Instance.hostname) == 'SOMEHOSTNAME') + + surrogate_key = ('uuid', ) + + def process_query(query): + return query.where(~exists(hostname_query)) + + def handle_failure(query): + try: + instance = base_query.one() + except NoResultFound: + raise exception.InstanceNotFound(instance_id=instance_uuid) + + if session.query(hostname_query.exists()).scalar(): + raise exception.InstanceExists( + name=values['hostname'].lower()) + + # try again + return False + + peristent_instance = base_query.update_on_match( + specimen, + surrogate_key, + values=values, + process_query=process_query, + handle_failure=handle_failure + ) + + The UPDATE statement is constructed against the given specimen + using those values which are present to construct a WHERE clause. + If the specimen contains additional values to be ignored, the + ``include_only`` parameter may be passed which indicates a sequence + of attributes to use when constructing the WHERE. + + The UPDATE is performed against an ORM Query, which is created from + the given ``Session``, or alternatively by passing the ```query`` + parameter referring to an existing query. + + Before the query is invoked, it is also passed through the callable + sent as ``process_query``, if present. This hook allows additional + criteria to be added to the query after it is created but before + invocation. + + The function will then invoke the UPDATE statement and check for + "success" one or more times, up to a maximum of that passed as + ``attempts``. + + The initial check for "success" from the UPDATE statement is that the + number of rows returned matches 1. If zero rows are matched, then + the UPDATE statement is assumed to have "failed", and the failure handling + phase begins. + + The failure handling phase involves invoking the given ``handle_failure`` + function, if any. This handler can perform additional queries to attempt + to figure out why the UPDATE didn't match any rows. The handler, + upon detection of the exact failure condition, should throw an exception + to exit; if it doesn't, it has the option of returning True or False, + where False means the error was not handled, and True means that there + was not in fact an error, and the function should return successfully. + + If the failure handler is not present, or returns False after ``attempts`` + number of attempts, then the function overall raises CantUpdateException. + If the handler returns True, then the function returns with no error. + + The return value of the function is a persistent version of the given + specimen; this may be the specimen itself, if no matching object were + already present in the session; otherwise, the existing object is + returned, with the state of the specimen merged into it. The returned + persistent object will have the given values populated into the object. + + The object is is returned as "persistent", meaning that it is + associated with the given + Session and has an identity key (that is, a real primary key + value). + + In order to produce this identity key, a strategy must be used to + determine it as efficiently and safely as possible: + + 1. If the given specimen already contained its primary key attributes + fully populated, then these attributes were used as criteria in the + UPDATE, so we have the primary key value; it is populated directly. + + 2. If the target backend supports RETURNING, then when the update() query + is performed with a RETURNING clause so that the matching primary key + is returned atomically. This currently includes Postgresql, Oracle + and others (notably not MySQL or SQLite). + + 3. If the target backend is MySQL, and the given model uses a + single-column, AUTO_INCREMENT integer primary key value (as is + the case for Nova), MySQL's recommended approach of making use + of ``LAST_INSERT_ID(expr)`` is used to atomically acquire the + matching primary key value within the scope of the UPDATE + statement, then it fetched immediately following by using + ``SELECT LAST_INSERT_ID()``. + http://dev.mysql.com/doc/refman/5.0/en/information-\ + functions.html#function_last-insert-id + + 4. Otherwise, for composite keys on MySQL or other backends such + as SQLite, the row as UPDATED must be re-fetched in order to + acquire the primary key value. The ``surrogate_key`` + parameter is used for this in order to re-fetch the row; this + is a column name with a known, unique value where + the object can be fetched. + + + """ + + if values is None: + values = {} + + entity = inspect(specimen) + mapper = entity.mapper + assert \ + [desc['type'] for desc in query.column_descriptions] == \ + [mapper.class_], "Query does not match given specimen" + + criteria = manufacture_entity_criteria( + specimen, include_only=include_only, exclude=[surrogate_key]) + + query = query.filter(criteria) + + if process_query: + query = process_query(query) + + surrogate_key_arg = ( + surrogate_key, entity.attrs[surrogate_key].loaded_value) + pk_value = None + + for attempt in range(attempts): + try: + pk_value = query.update_returning_pk(values, surrogate_key_arg) + except MultiRowsMatched: + raise + except NoRowsMatched: + if handle_failure and handle_failure(query): + break + else: + break + else: + raise NoRowsMatched("Zero rows matched for %d attempts" % attempts) + + if pk_value is None: + pk_value = entity.mapper.primary_key_from_instance(specimen) + + # NOTE(mdbooth): Can't pass the original specimen object here as it might + # have lists of multiple potential values rather than actual values. + values = copy.copy(values) + values[surrogate_key] = surrogate_key_arg[1] + persistent_obj = manufacture_persistent_object( + query.session, specimen.__class__(), values, pk_value) + + return persistent_obj + + +def manufacture_persistent_object( + session, specimen, values=None, primary_key=None): + """Make an ORM-mapped object persistent in a Session without SQL. + + The persistent object is returned. + + If a matching object is already present in the given session, the specimen + is merged into it and the persistent object returned. Otherwise, the + specimen itself is made persistent and is returned. + + The object must contain a full primary key, or provide it via the values or + primary_key parameters. The object is peristed to the Session in a "clean" + state with no pending changes. + + :param session: A Session object. + + :param specimen: a mapped object which is typically transient. + + :param values: a dictionary of values to be applied to the specimen, + in addition to the state that's already on it. The attributes will be + set such that no history is created; the object remains clean. + + :param primary_key: optional tuple-based primary key. This will also + be applied to the instance if present. + + + """ + state = inspect(specimen) + mapper = state.mapper + + for k, v in values.items(): + orm.attributes.set_committed_value(specimen, k, v) + + pk_attrs = [ + mapper.get_property_by_column(col).key + for col in mapper.primary_key + ] + + if primary_key is not None: + for key, value in zip(pk_attrs, primary_key): + orm.attributes.set_committed_value( + specimen, + key, + value + ) + + for key in pk_attrs: + if state.attrs[key].loaded_value is orm.attributes.NO_VALUE: + raise ValueError("full primary key must be present") + + orm.make_transient_to_detached(specimen) + + if state.key not in session.identity_map: + session.add(specimen) + return specimen + else: + return session.merge(specimen, load=False) + + +def manufacture_entity_criteria(entity, include_only=None, exclude=None): + """Given a mapped instance, produce a WHERE clause. + + The attributes set upon the instance will be combined to produce + a SQL expression using the mapped SQL expressions as the base + of comparison. + + Values on the instance may be set as tuples in which case the + criteria will produce an IN clause. None is also acceptable as a + scalar or tuple entry, which will produce IS NULL that is properly + joined with an OR against an IN expression if appropriate. + + :param entity: a mapped entity. + + :param include_only: optional sequence of keys to limit which + keys are included. + + :param exclude: sequence of keys to exclude + + """ + + state = inspect(entity) + exclude = set(exclude) if exclude is not None else set() + + existing = dict( + (attr.key, attr.loaded_value) + for attr in state.attrs + if attr.loaded_value is not orm.attributes.NO_VALUE + and attr.key not in exclude + ) + if include_only: + existing = dict( + (k, existing[k]) + for k in set(existing).intersection(include_only) + ) + + return manufacture_criteria(state.mapper, existing) + + +def manufacture_criteria(mapped, values): + """Given a mapper/class and a namespace of values, produce a WHERE clause. + + The class should be a mapped class and the entries in the dictionary + correspond to mapped attribute names on the class. + + A value may also be a tuple in which case that particular attribute + will be compared to a tuple using IN. The scalar value or + tuple can also contain None which translates to an IS NULL, that is + properly joined with OR against an IN expression if appropriate. + + :param cls: a mapped class, or actual :class:`.Mapper` object. + + :param values: dictionary of values. + + """ + + mapper = inspect(mapped) + + # organize keys using mapped attribute ordering, which is deterministic + value_keys = set(values) + keys = [k for k in mapper.column_attrs.keys() if k in value_keys] + return sql.and_(*[ + _sql_crit(mapper.column_attrs[key].expression, values[key]) + for key in keys + ]) + + +def _sql_crit(expression, value): + """Produce an equality expression against the given value. + + This takes into account a value that is actually a collection + of values, as well as a value of None or collection that contains + None. + + """ + + values = utils.to_list(value, default=(None, )) + if len(values) == 1: + if values[0] is None: + return expression == sql.null() + else: + return expression == values[0] + elif _none_set.intersection(values): + return sql.or_( + expression == sql.null(), + _sql_crit(expression, set(values).difference(_none_set)) + ) + else: + return expression.in_(values) + + +def update_returning_pk(query, values, surrogate_key): + """Perform an UPDATE, returning the primary key of the matched row. + + The primary key is returned using a selection of strategies: + + * if the database supports RETURNING, RETURNING is used to retrieve + the primary key values inline. + + * If the database is MySQL and the entity is mapped to a single integer + primary key column, MySQL's last_insert_id() function is used + inline within the UPDATE and then upon a second SELECT to get the + value. + + * Otherwise, a "refetch" strategy is used, where a given "surrogate" + key value (typically a UUID column on the entity) is used to run + a new SELECT against that UUID. This UUID is also placed into + the UPDATE query to ensure the row matches. + + :param query: a Query object with existing criterion, against a single + entity. + + :param values: a dictionary of values to be updated on the row. + + :param surrogate_key: a tuple of (attrname, value), referring to a + UNIQUE attribute that will also match the row. This attribute is used + to retrieve the row via a SELECT when no optimized strategy exists. + + :return: the primary key, returned as a tuple. + Is only returned if rows matched is one. Otherwise, CantUpdateException + is raised. + + """ + + entity = query.column_descriptions[0]['type'] + mapper = inspect(entity).mapper + session = query.session + + bind = session.connection(mapper=mapper) + if bind.dialect.implicit_returning: + pk_strategy = _pk_strategy_returning + elif bind.dialect.name == 'mysql' and \ + len(mapper.primary_key) == 1 and \ + isinstance( + mapper.primary_key[0].type, sqltypes.Integer): + pk_strategy = _pk_strategy_mysql_last_insert_id + else: + pk_strategy = _pk_strategy_refetch + + return pk_strategy(query, mapper, values, surrogate_key) + + +def _assert_single_row(rows_updated): + if rows_updated == 1: + return rows_updated + elif rows_updated > 1: + raise MultiRowsMatched("%d rows matched; expected one" % rows_updated) + else: + raise NoRowsMatched("No rows matched the UPDATE") + + +def _pk_strategy_refetch(query, mapper, values, surrogate_key): + + surrogate_key_name, surrogate_key_value = surrogate_key + surrogate_key_col = mapper.attrs[surrogate_key_name].expression + + rowcount = query.\ + filter(surrogate_key_col == surrogate_key_value).\ + update(values, synchronize_session=False) + + _assert_single_row(rowcount) + # SELECT my_table.id AS my_table_id FROM my_table + # WHERE my_table.y = ? AND my_table.z = ? + # LIMIT ? OFFSET ? + fetch_query = query.session.query( + *mapper.primary_key).filter( + surrogate_key_col == surrogate_key_value) + + primary_key = fetch_query.one() + + return primary_key + + +def _pk_strategy_returning(query, mapper, values, surrogate_key): + surrogate_key_name, surrogate_key_value = surrogate_key + surrogate_key_col = mapper.attrs[surrogate_key_name].expression + + update_stmt = _update_stmt_from_query(mapper, query, values) + update_stmt = update_stmt.where(surrogate_key_col == surrogate_key_value) + update_stmt = update_stmt.returning(*mapper.primary_key) + + # UPDATE my_table SET x=%(x)s, z=%(z)s WHERE my_table.y = %(y_1)s + # AND my_table.z = %(z_1)s RETURNING my_table.id + result = query.session.execute(update_stmt) + rowcount = result.rowcount + _assert_single_row(rowcount) + primary_key = tuple(result.first()) + + return primary_key + + +def _pk_strategy_mysql_last_insert_id(query, mapper, values, surrogate_key): + + surrogate_key_name, surrogate_key_value = surrogate_key + surrogate_key_col = mapper.attrs[surrogate_key_name].expression + + surrogate_pk_col = mapper.primary_key[0] + update_stmt = _update_stmt_from_query(mapper, query, values) + update_stmt = update_stmt.where(surrogate_key_col == surrogate_key_value) + update_stmt = update_stmt.values( + {surrogate_pk_col: sql.func.last_insert_id(surrogate_pk_col)}) + + # UPDATE my_table SET id=last_insert_id(my_table.id), + # x=%s, z=%s WHERE my_table.y = %s AND my_table.z = %s + result = query.session.execute(update_stmt) + rowcount = result.rowcount + _assert_single_row(rowcount) + # SELECT last_insert_id() AS last_insert_id_1 + primary_key = query.session.scalar(sql.func.last_insert_id()), + + return primary_key + + +def _update_stmt_from_query(mapper, query, values): + upd_values = dict( + ( + mapper.column_attrs[key], value + ) for key, value in values.items() + ) + query = query.enable_eagerloads(False) + context = query._compile_context() + primary_table = context.statement.froms[0] + update_stmt = sql.update(primary_table, + context.whereclause, + upd_values) + return update_stmt + + +_none_set = frozenset([None]) + + +class CantUpdateException(Exception): + pass + + +class NoRowsMatched(CantUpdateException): + pass + + +class MultiRowsMatched(CantUpdateException): + pass |