From 857b4c4e63d9ff8b17174f429b7fd285fabc903a Mon Sep 17 00:00:00 2001 From: Erik Olof Gunnar Andersson Date: Mon, 25 Jul 2022 23:23:46 -0700 Subject: Re-factored central and rpc decorators - Moved central and rpc decorators to common location. - Cleaned up decorator code. Change-Id: I79d21df7d17a2f706b8747e600e79a1ef1762e2b --- designate/central/service.py | 223 +++++---------------- designate/common/decorators/__init__.py | 0 designate/common/decorators/lock.py | 107 ++++++++++ designate/common/decorators/notification.py | 90 +++++++++ designate/common/decorators/rpc.py | 49 +++++ designate/context.py | 16 -- designate/rpc.py | 27 --- designate/service.py | 2 + designate/tests/test_central/test_decorator.py | 75 ------- designate/tests/unit/test_central/test_basic.py | 25 +-- .../tests/unit/test_central/test_lock_decorator.py | 111 ++++++++++ designate/worker/service.py | 2 +- 12 files changed, 419 insertions(+), 308 deletions(-) create mode 100644 designate/common/decorators/__init__.py create mode 100644 designate/common/decorators/lock.py create mode 100644 designate/common/decorators/notification.py create mode 100644 designate/common/decorators/rpc.py delete mode 100644 designate/tests/test_central/test_decorator.py create mode 100644 designate/tests/unit/test_central/test_lock_decorator.py diff --git a/designate/central/service.py b/designate/central/service.py index 05173539..34e39338 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -14,16 +14,12 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -import collections import copy -import functools -import itertools import random from random import SystemRandom import re import signal import string -import threading import time from dns import exception as dnsexception @@ -33,16 +29,16 @@ from oslo_log import log as logging import oslo_messaging as messaging from designate.common import constants -from designate import context as dcontext +from designate.common.decorators import lock +from designate.common.decorators import notification +from designate.common.decorators import rpc from designate import coordination from designate import dnsutils from designate import exceptions from designate import network_api -from designate import notifications from designate import objects from designate import policy from designate import quota -from designate import rpc from designate import scheduler from designate import service from designate import storage @@ -51,135 +47,7 @@ from designate.storage import transaction_shallow_copy from designate import utils from designate.worker import rpcapi as worker_rpcapi - LOG = logging.getLogger(__name__) -ZONE_LOCKS = threading.local() -NOTIFICATION_BUFFER = threading.local() - - -def synchronized_zone(zone_arg=1, new_zone=False): - """Ensures only a single operation is in progress for each zone - - A Decorator which ensures only a single operation can be happening - on a single zone at once, within the current designate-central instance - """ - def outer(f): - @functools.wraps(f) - def sync_wrapper(self, *args, **kwargs): - if not hasattr(ZONE_LOCKS, 'held'): - # Create the held set if necessary - ZONE_LOCKS.held = set() - - zone_id = None - - if 'zone_id' in kwargs: - zone_id = kwargs['zone_id'] - elif 'zone' in kwargs: - zone_id = kwargs['zone'].id - elif 'recordset' in kwargs: - zone_id = kwargs['recordset'].zone_id - elif 'record' in kwargs: - zone_id = kwargs['record'].zone_id - - # The various objects won't always have an ID set, we should - # attempt to locate an Object containing the ID. - if zone_id is None: - for arg in itertools.chain(kwargs.values(), args): - if isinstance(arg, objects.Zone): - zone_id = arg.id - if zone_id: - break - elif (isinstance(arg, objects.RecordSet) or - isinstance(arg, objects.Record) or - isinstance(arg, objects.ZoneTransferRequest) or - isinstance(arg, objects.ZoneTransferAccept)): - zone_id = arg.zone_id - if zone_id: - break - - # If we still don't have an ID, find the Nth argument as - # defined by the zone_arg decorator option. - if not zone_id and len(args) > zone_arg: - zone_id = args[zone_arg] - if isinstance(zone_id, objects.Zone): - # If the value is a Zone object, extract it's ID. - zone_id = zone_id.id - - if new_zone and not zone_id: - lock_name = 'create-new-zone' - elif not new_zone and zone_id: - lock_name = 'zone-%s' % zone_id - else: - raise Exception('Failed to determine zone id for ' - 'synchronized operation') - - if zone_id in ZONE_LOCKS.held: - return f(self, *args, **kwargs) - - with self.coordination.get_lock(lock_name): - try: - ZONE_LOCKS.held.add(zone_id) - return f(self, *args, **kwargs) - finally: - ZONE_LOCKS.held.remove(zone_id) - - sync_wrapper.__wrapped_function = f - sync_wrapper.__wrapper_name = 'synchronized_zone' - return sync_wrapper - - return outer - - -def notification(notification_type): - def outer(f): - @functools.wraps(f) - def notification_wrapper(self, *args, **kwargs): - if not hasattr(NOTIFICATION_BUFFER, 'queue'): - # Create the notifications queue if necessary - NOTIFICATION_BUFFER.stack = 0 - NOTIFICATION_BUFFER.queue = collections.deque() - - NOTIFICATION_BUFFER.stack += 1 - - try: - # Find the context argument - context = dcontext.DesignateContext.\ - get_context_from_function_and_args(f, args, kwargs) - - # Call the wrapped function - result = f(self, *args, **kwargs) - - # Feed the args/result to a notification plugin - # to determine what is emitted - payloads = notifications.get_plugin().emit( - notification_type, context, result, args, kwargs) - - # Enqueue the notification - for payload in payloads: - LOG.debug('Queueing notification for %(type)s ', - {'type': notification_type}) - NOTIFICATION_BUFFER.queue.appendleft( - (context, notification_type, payload,)) - - return result - - finally: - NOTIFICATION_BUFFER.stack -= 1 - - if NOTIFICATION_BUFFER.stack == 0: - LOG.debug('Emitting %(count)d notifications', - {'count': len(NOTIFICATION_BUFFER.queue)}) - # Send the queued notifications, in order. - for value in NOTIFICATION_BUFFER.queue: - LOG.debug('Emitting %(type)s notification', - {'type': value[1]}) - self.notifier.info(value[0], value[1], value[2]) - - # Reset the queue - NOTIFICATION_BUFFER.queue.clear() - - return notification_wrapper - return outer class Service(service.RPCService): @@ -188,6 +56,9 @@ class Service(service.RPCService): target = messaging.Target(version=RPC_API_VERSION) def __init__(self): + self.zone_lock_local = lock.ZoneLockLocal() + self.notification_thread_local = notification.NotificationThreadLocal() + self._scheduler = None self._storage = None self._quota = None @@ -196,11 +67,9 @@ class Service(service.RPCService): self.service_name, cfg.CONF['service:central'].topic, threads=cfg.CONF['service:central'].threads, ) - self.coordination = coordination.Coordination( self.service_name, self.tg, grouping_enabled=False ) - self.network_api = network_api.get_network_api(cfg.CONF.network_api) @property @@ -713,7 +582,7 @@ class Service(service.RPCService): # TLD Methods @rpc.expected_exceptions() - @notification('dns.tld.create') + @notification.notify_type('dns.tld.create') @transaction def create_tld(self, context, tld): policy.check('create_tld', context) @@ -738,7 +607,7 @@ class Service(service.RPCService): return self.storage.get_tld(context, tld_id) @rpc.expected_exceptions() - @notification('dns.tld.update') + @notification.notify_type('dns.tld.update') @transaction def update_tld(self, context, tld): target = { @@ -751,7 +620,7 @@ class Service(service.RPCService): return tld @rpc.expected_exceptions() - @notification('dns.tld.delete') + @notification.notify_type('dns.tld.delete') @transaction def delete_tld(self, context, tld_id): policy.check('delete_tld', context, {'tld_id': tld_id}) @@ -762,7 +631,7 @@ class Service(service.RPCService): # TSIG Key Methods @rpc.expected_exceptions() - @notification('dns.tsigkey.create') + @notification.notify_type('dns.tsigkey.create') @transaction def create_tsigkey(self, context, tsigkey): policy.check('create_tsigkey', context) @@ -788,7 +657,7 @@ class Service(service.RPCService): return self.storage.get_tsigkey(context, tsigkey_id) @rpc.expected_exceptions() - @notification('dns.tsigkey.update') + @notification.notify_type('dns.tsigkey.update') @transaction def update_tsigkey(self, context, tsigkey): target = { @@ -803,7 +672,7 @@ class Service(service.RPCService): return tsigkey @rpc.expected_exceptions() - @notification('dns.tsigkey.delete') + @notification.notify_type('dns.tsigkey.delete') @transaction def delete_tsigkey(self, context, tsigkey_id): policy.check('delete_tsigkey', context, {'tsigkey_id': tsigkey_id}) @@ -862,9 +731,9 @@ class Service(service.RPCService): return pool.ns_records @rpc.expected_exceptions() - @notification('dns.domain.create') - @notification('dns.zone.create') - @synchronized_zone(new_zone=True) + @notification.notify_type('dns.domain.create') + @notification.notify_type('dns.zone.create') + @lock.synchronized_zone(new_zone=True) def create_zone(self, context, zone): """Create zone: perform checks and then call _create_zone() """ @@ -1060,9 +929,9 @@ class Service(service.RPCService): sort_key, sort_dir) @rpc.expected_exceptions() - @notification('dns.domain.update') - @notification('dns.zone.update') - @synchronized_zone() + @notification.notify_type('dns.domain.update') + @notification.notify_type('dns.zone.update') + @lock.synchronized_zone() def update_zone(self, context, zone, increment_serial=True): """Update zone. Perform checks and then call _update_zone() @@ -1134,9 +1003,9 @@ class Service(service.RPCService): return zone @rpc.expected_exceptions() - @notification('dns.domain.delete') - @notification('dns.zone.delete') - @synchronized_zone() + @notification.notify_type('dns.domain.delete') + @notification.notify_type('dns.zone.delete') + @lock.synchronized_zone() def delete_zone(self, context, zone_id): """Delete or abandon a zone On abandon, delete the zone from the DB immediately. @@ -1294,8 +1163,8 @@ class Service(service.RPCService): # RecordSet Methods @rpc.expected_exceptions() - @notification('dns.recordset.create') - @synchronized_zone() + @notification.notify_type('dns.recordset.create') + @lock.synchronized_zone() def create_recordset(self, context, zone_id, recordset, increment_serial=True): zone = self.storage.get_zone(context, zone_id) @@ -1467,8 +1336,8 @@ class Service(service.RPCService): recordsets=recordsets) @rpc.expected_exceptions() - @notification('dns.recordset.update') - @synchronized_zone() + @notification.notify_type('dns.recordset.update') + @lock.synchronized_zone() def update_recordset(self, context, recordset, increment_serial=True): zone_id = recordset.obj_get_original_value('zone_id') zone = self.storage.get_zone(context, zone_id) @@ -1550,8 +1419,8 @@ class Service(service.RPCService): return recordset, zone @rpc.expected_exceptions() - @notification('dns.recordset.delete') - @synchronized_zone() + @notification.notify_type('dns.recordset.delete') + @lock.synchronized_zone() def delete_recordset(self, context, zone_id, recordset_id, increment_serial=True): zone = self.storage.get_zone(context, zone_id) @@ -2049,7 +1918,7 @@ class Service(service.RPCService): # Blacklisted zones @rpc.expected_exceptions() - @notification('dns.blacklist.create') + @notification.notify_type('dns.blacklist.create') @transaction def create_blacklist(self, context, blacklist): policy.check('create_blacklist', context) @@ -2078,7 +1947,7 @@ class Service(service.RPCService): return blacklists @rpc.expected_exceptions() - @notification('dns.blacklist.update') + @notification.notify_type('dns.blacklist.update') @transaction def update_blacklist(self, context, blacklist): target = { @@ -2091,7 +1960,7 @@ class Service(service.RPCService): return blacklist @rpc.expected_exceptions() - @notification('dns.blacklist.delete') + @notification.notify_type('dns.blacklist.delete') @transaction def delete_blacklist(self, context, blacklist_id): policy.check('delete_blacklist', context) @@ -2102,7 +1971,7 @@ class Service(service.RPCService): # Server Pools @rpc.expected_exceptions() - @notification('dns.pool.create') + @notification.notify_type('dns.pool.create') @transaction def create_pool(self, context, pool): # Verify that there is a tenant_id @@ -2141,7 +2010,7 @@ class Service(service.RPCService): return self.storage.get_pool(context, pool_id) @rpc.expected_exceptions() - @notification('dns.pool.update') + @notification.notify_type('dns.pool.update') @transaction def update_pool(self, context, pool): policy.check('update_pool', context) @@ -2202,7 +2071,7 @@ class Service(service.RPCService): return updated_pool @rpc.expected_exceptions() - @notification('dns.pool.delete') + @notification.notify_type('dns.pool.delete') @transaction def delete_pool(self, context, pool_id): @@ -2225,10 +2094,10 @@ class Service(service.RPCService): # Pool Manager Integration @rpc.expected_exceptions() - @notification('dns.domain.update') - @notification('dns.zone.update') + @notification.notify_type('dns.domain.update') + @notification.notify_type('dns.zone.update') @transaction - @synchronized_zone() + @lock.synchronized_zone() def update_status(self, context, zone_id, status, serial, action=None): """ :param context: Security context information. @@ -2356,7 +2225,7 @@ class Service(service.RPCService): return ''.join(sysrand.choice(chars) for _ in range(size)) @rpc.expected_exceptions() - @notification('dns.zone_transfer_request.create') + @notification.notify_type('dns.zone_transfer_request.create') @transaction def create_zone_transfer_request(self, context, zone_transfer_request): @@ -2427,7 +2296,7 @@ class Service(service.RPCService): return requests @rpc.expected_exceptions() - @notification('dns.zone_transfer_request.update') + @notification.notify_type('dns.zone_transfer_request.update') @transaction def update_zone_transfer_request(self, context, zone_transfer_request): @@ -2449,7 +2318,7 @@ class Service(service.RPCService): return request @rpc.expected_exceptions() - @notification('dns.zone_transfer_request.delete') + @notification.notify_type('dns.zone_transfer_request.delete') @transaction def delete_zone_transfer_request(self, context, zone_transfer_request_id): # Get zone transfer request @@ -2469,7 +2338,7 @@ class Service(service.RPCService): zone_transfer_request_id) @rpc.expected_exceptions() - @notification('dns.zone_transfer_accept.create') + @notification.notify_type('dns.zone_transfer_accept.create') @transaction def create_zone_transfer_accept(self, context, zone_transfer_accept): elevated_context = context.elevated(all_tenants=True) @@ -2571,7 +2440,7 @@ class Service(service.RPCService): # Zone Import Methods @rpc.expected_exceptions() - @notification('dns.zone_import.create') + @notification.notify_type('dns.zone_import.create') def create_zone_import(self, context, request_body): if policy.enforce_new_defaults(): target = {constants.RBAC_PROJECT_ID: context.project_id} @@ -2667,7 +2536,7 @@ class Service(service.RPCService): self.update_zone_import(context, zone_import) @rpc.expected_exceptions() - @notification('dns.zone_import.update') + @notification.notify_type('dns.zone_import.update') def update_zone_import(self, context, zone_import): if policy.enforce_new_defaults(): target = {constants.RBAC_PROJECT_ID: zone_import.tenant_id} @@ -2710,7 +2579,7 @@ class Service(service.RPCService): return self.storage.get_zone_import(context, zone_import_id) @rpc.expected_exceptions() - @notification('dns.zone_import.delete') + @notification.notify_type('dns.zone_import.delete') @transaction def delete_zone_import(self, context, zone_import_id): @@ -2733,7 +2602,7 @@ class Service(service.RPCService): # Zone Export Methods @rpc.expected_exceptions() - @notification('dns.zone_export.create') + @notification.notify_type('dns.zone_export.create') def create_zone_export(self, context, zone_id): # Try getting the zone to ensure it exists zone = self.storage.get_zone(context, zone_id) @@ -2797,7 +2666,7 @@ class Service(service.RPCService): return self.storage.get_zone_export(context, zone_export_id) @rpc.expected_exceptions() - @notification('dns.zone_export.update') + @notification.notify_type('dns.zone_export.update') def update_zone_export(self, context, zone_export): if policy.enforce_new_defaults(): @@ -2810,7 +2679,7 @@ class Service(service.RPCService): return self.storage.update_zone_export(context, zone_export) @rpc.expected_exceptions() - @notification('dns.zone_export.delete') + @notification.notify_type('dns.zone_export.delete') @transaction def delete_zone_export(self, context, zone_export_id): diff --git a/designate/common/decorators/__init__.py b/designate/common/decorators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/designate/common/decorators/lock.py b/designate/common/decorators/lock.py new file mode 100644 index 00000000..f633fa4d --- /dev/null +++ b/designate/common/decorators/lock.py @@ -0,0 +1,107 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import threading + +from oslo_log import log as logging + +from designate import objects + +LOG = logging.getLogger(__name__) + + +class ZoneLockLocal(threading.local): + def __init__(self): + super(ZoneLockLocal, self).__init__() + self._held = set() + + def hold(self, name): + self._held.add(name) + + def release(self, name): + self._held.remove(name) + + def has_lock(self, name): + return name in self._held + + +def extract_zone_id(args, kwargs): + zone_id = None + + if 'zone_id' in kwargs: + zone_id = kwargs['zone_id'] + elif 'zone' in kwargs: + zone_id = kwargs['zone'].id + elif 'recordset' in kwargs: + zone_id = kwargs['recordset'].zone_id + elif 'record' in kwargs: + zone_id = kwargs['record'].zone_id + + if not zone_id: + for arg in itertools.chain(args, kwargs.values()): + if not isinstance(arg, objects.DesignateObject): + continue + if isinstance(arg, objects.Zone): + zone_id = arg.id + if zone_id: + break + elif isinstance(arg, (objects.RecordSet, + objects.Record, + objects.ZoneTransferRequest, + objects.ZoneTransferAccept)): + zone_id = arg.zone_id + if zone_id: + break + + if not zone_id and len(args) > 1: + arg = args[1] + if isinstance(arg, str): + zone_id = arg + elif isinstance(zone_id, objects.Zone): + zone_id = arg.id + + return zone_id + + +def synchronized_zone(new_zone=False): + """Ensures only a single operation is in progress for each zone + + A Decorator which ensures only a single operation can be happening + on a single zone at once, within the current designate-central instance + """ + def outer(f): + @functools.wraps(f) + def sync_wrapper(cls, *args, **kwargs): + if new_zone is True: + lock_name = 'create-new-zone' + else: + zone_id = extract_zone_id(args, kwargs) + if zone_id: + lock_name = 'zone-%s' % zone_id + else: + raise Exception('Failed to determine zone id for ' + 'synchronized operation') + + if cls.zone_lock_local.has_lock(lock_name): + return f(cls, *args, **kwargs) + + with cls.coordination.get_lock(lock_name): + try: + cls.zone_lock_local.hold(lock_name) + return f(cls, *args, **kwargs) + finally: + cls.zone_lock_local.release(lock_name) + + return sync_wrapper + return outer diff --git a/designate/common/decorators/notification.py b/designate/common/decorators/notification.py new file mode 100644 index 00000000..c43a92bd --- /dev/null +++ b/designate/common/decorators/notification.py @@ -0,0 +1,90 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import functools +import itertools +import threading + +from oslo_log import log as logging + +from designate import context as designate_context +from designate import notifications + +LOG = logging.getLogger(__name__) + + +class NotificationThreadLocal(threading.local): + def __init__(self): + super(NotificationThreadLocal, self).__init__() + self.stack = 0 + self.queue = collections.deque() + + def reset_queue(self): + self.queue.clear() + + +def notify_type(notification_type): + def outer(f): + @functools.wraps(f) + def notification_wrapper(cls, *args, **kwargs): + cls.notification_thread_local.stack += 1 + + context = None + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, designate_context.DesignateContext): + context = arg + break + + try: + result = f(cls, *args, **kwargs) + + payloads = notifications.get_plugin().emit( + notification_type, context, result, args, kwargs + ) + for payload in payloads: + LOG.debug( + 'Queueing notification for %(type)s', + { + 'type': notification_type + } + ) + cls.notification_thread_local.queue.appendleft( + (context, notification_type, payload,) + ) + + return result + + finally: + cls.notification_thread_local.stack -= 1 + + if cls.notification_thread_local.stack == 0: + LOG.debug( + 'Emitting %(count)d notifications', + { + 'count': len(cls.notification_thread_local.queue) + } + ) + + for message in cls.notification_thread_local.queue: + LOG.debug( + 'Emitting %(type)s notification', + { + 'type': message[1] + } + ) + cls.notifier.info(message[0], message[1], message[2]) + + cls.notification_thread_local.reset_queue() + + return notification_wrapper + return outer diff --git a/designate/common/decorators/rpc.py b/designate/common/decorators/rpc.py new file mode 100644 index 00000000..69cad608 --- /dev/null +++ b/designate/common/decorators/rpc.py @@ -0,0 +1,49 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import threading + +from oslo_messaging.rpc import dispatcher as rpc_dispatcher + +import designate.exceptions + + +class ExceptionThreadLocal(threading.local): + def __init__(self): + super(ExceptionThreadLocal, self).__init__() + self.depth = 0 + + def reset_depth(self): + self.depth = 0 + + +def expected_exceptions(): + def outer(f): + @functools.wraps(f) + def exception_wrapper(cls, *args, **kwargs): + cls.exception_thread_local.depth += 1 + + # We only want to wrap the first function wrapped. + if cls.exception_thread_local.depth > 1: + return f(cls, *args, **kwargs) + + try: + return f(cls, *args, **kwargs) + except designate.exceptions.DesignateException as e: + if e.expected: + raise rpc_dispatcher.ExpectedException() + raise + finally: + cls.exception_thread_local.reset_depth() + return exception_wrapper + return outer diff --git a/designate/context.py b/designate/context.py index 01ea0ce8..5e033446 100644 --- a/designate/context.py +++ b/designate/context.py @@ -14,7 +14,6 @@ # License for the specific language governing permissions and limitations # under the License. import copy -import itertools from keystoneauth1.access import service_catalog as ksa_service_catalog from keystoneauth1 import plugin @@ -145,21 +144,6 @@ class DesignateContext(context.RequestContext): return cls(None, **kwargs) - @classmethod - def get_context_from_function_and_args(cls, function, args, kwargs): - """ - Find an arg of type DesignateContext and return it. - - This is useful in a couple of decorators where we don't - know much about the function we're wrapping. - """ - - for arg in itertools.chain(kwargs.values(), args): - if isinstance(arg, cls): - return arg - - return None - @property def all_tenants(self): return self._all_tenants diff --git a/designate/rpc.py b/designate/rpc.py index 51efeb71..48636ab9 100644 --- a/designate/rpc.py +++ b/designate/rpc.py @@ -11,8 +11,6 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. -import functools -import threading from oslo_config import cfg import oslo_messaging as messaging @@ -40,7 +38,6 @@ __all__ = [ ] CONF = cfg.CONF -EXPECTED_EXCEPTION = threading.local() NOTIFICATION_TRANSPORT = None NOTIFIER = None TRANSPORT = None @@ -237,27 +234,3 @@ def create_transport(url): return messaging.get_rpc_transport(CONF, url=url, allowed_remote_exmods=exmods) - - -def expected_exceptions(): - def outer(f): - @functools.wraps(f) - def exception_wrapper(self, *args, **kwargs): - if not hasattr(EXPECTED_EXCEPTION, 'depth'): - EXPECTED_EXCEPTION.depth = 0 - EXPECTED_EXCEPTION.depth += 1 - - # We only want to wrap the first function wrapped. - if EXPECTED_EXCEPTION.depth > 1: - return f(self, *args, **kwargs) - - try: - return f(self, *args, **kwargs) - except designate.exceptions.DesignateException as e: - if e.expected: - raise rpc_dispatcher.ExpectedException() - raise - finally: - EXPECTED_EXCEPTION.depth = 0 - return exception_wrapper - return outer diff --git a/designate/service.py b/designate/service.py index ce84cbdb..0fec3e02 100644 --- a/designate/service.py +++ b/designate/service.py @@ -30,6 +30,7 @@ from oslo_service import sslutils from oslo_service import wsgi from oslo_utils import netutils +from designate.common.decorators import rpc as rpc_decorator from designate.common import profiler import designate.conf from designate.i18n import _ @@ -77,6 +78,7 @@ class RPCService(Service): rpc_topic, self.name) self.endpoints = [self] + self.exception_thread_local = rpc_decorator.ExceptionThreadLocal() self.notifier = None self.rpc_server = None self.rpc_topic = rpc_topic diff --git a/designate/tests/test_central/test_decorator.py b/designate/tests/test_central/test_decorator.py deleted file mode 100644 index 66472cca..00000000 --- a/designate/tests/test_central/test_decorator.py +++ /dev/null @@ -1,75 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -from unittest import mock - -from oslo_concurrency import lockutils -from oslo_log import log as logging - -from designate.central import service -from designate import exceptions -from designate.objects import record -from designate.objects import zone -from designate.tests.test_central import CentralTestCase -from designate import utils - -LOG = logging.getLogger(__name__) - - -class FakeCoordination(object): - def get_lock(self, name): - return lockutils.lock(name) - - -class CentralDecoratorTests(CentralTestCase): - def test_synchronized_zone_exception_raised(self): - @service.synchronized_zone() - def mock_get_zone(cls, index, zone): - self.assertEqual(service.ZONE_LOCKS.held, {zone.id}) - if index % 3 == 0: - raise exceptions.ZoneNotFound() - - for index in range(9): - try: - mock_get_zone(mock.Mock(coordination=FakeCoordination()), - index, - zone.Zone(id=utils.generate_uuid())) - except exceptions.ZoneNotFound: - pass - - def test_synchronized_zone_recursive_decorator_call(self): - @service.synchronized_zone() - def mock_create_record(cls, context, record): - self.assertEqual(service.ZONE_LOCKS.held, {record.zone_id}) - mock_get_zone(cls, context, zone.Zone(id=record.zone_id)) - - @service.synchronized_zone() - def mock_get_zone(cls, context, zone): - self.assertEqual(service.ZONE_LOCKS.held, {zone.id}) - - mock_create_record(mock.Mock(coordination=FakeCoordination()), - self.get_context(), - record=record.Record(zone_id=utils.generate_uuid())) - mock_get_zone(mock.Mock(coordination=FakeCoordination()), - self.get_context(), - zone=zone.Zone(id=utils.generate_uuid())) - - def test_synchronized_zone_raises_exception_when_no_zone_provided(self): - @service.synchronized_zone(new_zone=False) - def mock_not_creating_new_zone(cls, context, record): - pass - - self.assertRaisesRegex( - Exception, - 'Failed to determine zone id for ' - 'synchronized operation', - mock_not_creating_new_zone, self.get_context(), None - ) diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py index e4ec18ba..e9e7a6b9 100644 --- a/designate/tests/unit/test_central/test_basic.py +++ b/designate/tests/unit/test_central/test_basic.py @@ -392,7 +392,7 @@ class CentralServiceTestCase(CentralBasic): def test_create_recordset_in_storage(self): self.service._enforce_recordset_quota = mock.Mock() - self.service._validate_recordset = mock.Mock() + self.service._validate_recordset = mock.Mock(spec=objects.RecordSet) self.service.storage.create_recordset = mock.Mock(return_value='rs') self.service._update_zone_in_storage = mock.Mock() @@ -416,7 +416,7 @@ class CentralServiceTestCase(CentralBasic): central_service.storage.create_recordset = mock.Mock(return_value='rs') central_service._update_zone_in_storage = mock.Mock() - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.obj_attr_is_set.return_value = True recordset.records = [MockRecord()] @@ -441,7 +441,7 @@ class CentralServiceTestCase(CentralBasic): # NOTE(thirose): Since this is a race condition we assume that # we will hit it if we try to do the operations in a loop 100 times. for num in range(100): - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.name = "b{}".format(num) recordset.obj_attr_is_set.return_value = True recordset.records = [MockRecord()] @@ -1148,7 +1148,7 @@ class CentralZoneTestCase(CentralBasic): def test_update_recordset_fail_on_changes(self): self.service.storage.get_zone.return_value = RoObject() - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.obj_get_original_value.return_value = '1' recordset.obj_get_changes.return_value = ['tenant_id', 'foo'] @@ -1179,7 +1179,7 @@ class CentralZoneTestCase(CentralBasic): self.service.storage.get_zone.return_value = RoObject( action='DELETE', ) - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.obj_get_changes.return_value = ['foo'] exc = self.assertRaises(rpc_dispatcher.ExpectedException, @@ -1196,7 +1196,7 @@ class CentralZoneTestCase(CentralBasic): tenant_id='2', action='bogus', ) - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.obj_get_changes.return_value = ['foo'] recordset.managed = True self.context = mock.Mock() @@ -1216,10 +1216,11 @@ class CentralZoneTestCase(CentralBasic): tenant_id='2', action='bogus', ) - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.obj_get_changes.return_value = ['foo'] - recordset.obj_get_original_value.return_value =\ + recordset.obj_get_original_value.return_value = ( '9c85d9b0-1e9d-4e99-aede-a06664f1af2e' + ) recordset.managed = False self.service._update_recordset_in_storage = mock.Mock( return_value=('x', 'y') @@ -1239,7 +1240,7 @@ class CentralZoneTestCase(CentralBasic): 'recordset_id': '9c85d9b0-1e9d-4e99-aede-a06664f1af2e', 'project_id': '2'}, target) - def test__update_recordset_in_storage(self): + def test_update_recordset_in_storage(self): recordset = mock.Mock() recordset.name = 'n' recordset.type = 't' @@ -1426,7 +1427,7 @@ class CentralZoneTestCase(CentralBasic): self.assertTrue( self.service._delete_recordset_in_storage.called) - def test__delete_recordset_in_storage(self): + def test_delete_recordset_in_storage(self): def mock_uds(c, zone, inc): return zone self.service._update_zone_in_storage = mock_uds @@ -1730,7 +1731,7 @@ class CentralQuotaTest(unittest.TestCase): service = Service() service.storage.count_records.return_value = 10 - recordset = mock.Mock() + recordset = mock.Mock(spec=objects.RecordSet) recordset.managed = False recordset.records = ['1.1.1.%i' % (i + 1) for i in range(5)] @@ -1801,7 +1802,7 @@ class CentralQuotaTest(unittest.TestCase): 1, 1, ] - managed_recordset = mock.Mock() + managed_recordset = mock.Mock(spec=objects.RecordSet) managed_recordset.managed = True recordset_one_record = mock.Mock() diff --git a/designate/tests/unit/test_central/test_lock_decorator.py b/designate/tests/unit/test_central/test_lock_decorator.py new file mode 100644 index 00000000..c8d8058d --- /dev/null +++ b/designate/tests/unit/test_central/test_lock_decorator.py @@ -0,0 +1,111 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from unittest import mock + +from oslo_concurrency import lockutils +from oslo_log import log as logging +import oslotest.base + +from designate.common.decorators import lock +from designate import exceptions +from designate.objects import record +from designate.objects import zone +from designate import utils + +LOG = logging.getLogger(__name__) + + +class FakeCoordination: + def get_lock(self, name): + return lockutils.lock(name) + + +class FakeService: + def __init__(self): + self.zone_lock_local = lock.ZoneLockLocal() + self.coordination = FakeCoordination() + + +class CentralDecoratorTests(oslotest.base.BaseTestCase): + def setUp(self): + super().setUp() + self.context = mock.Mock() + self.service = FakeService() + + def test_synchronized_zone_exception_raised(self): + @lock.synchronized_zone() + def mock_get_zone(cls, current_index, zone_obj): + self.assertEqual( + {'zone-%s' % zone_obj.id}, cls.zone_lock_local._held + ) + if current_index % 3 == 0: + raise exceptions.ZoneNotFound() + + for index in range(9): + try: + mock_get_zone( + self.service, index, zone.Zone(id=utils.generate_uuid()) + ) + except exceptions.ZoneNotFound: + pass + + def test_synchronized_new_zone_with_recursion(self): + @lock.synchronized_zone(new_zone=True) + def mock_create_zone(cls, context): + self.assertEqual({'create-new-zone'}, cls.zone_lock_local._held) + mock_create_record( + cls, context, zone.Zone(id=utils.generate_uuid()) + ) + + @lock.synchronized_zone() + def mock_create_record(cls, context, zone_obj): + self.assertIn('zone-%s' % zone_obj.id, cls.zone_lock_local._held) + self.assertIn('create-new-zone', cls.zone_lock_local._held) + + mock_create_zone( + self.service, self.context + ) + + def test_synchronized_zone_recursive_decorator_call(self): + @lock.synchronized_zone() + def mock_create_record(cls, context, record_obj): + self.assertEqual( + {'zone-%s' % record_obj.zone_id}, cls.zone_lock_local._held + ) + mock_get_zone(cls, context, zone.Zone(id=record_obj.zone_id)) + + @lock.synchronized_zone() + def mock_get_zone(cls, context, zone_obj): + self.assertEqual( + {'zone-%s' % zone_obj.id}, cls.zone_lock_local._held + ) + + mock_create_record( + self.service, self.context, + record_obj=record.Record(zone_id=utils.generate_uuid()) + ) + mock_get_zone( + self.service, self.context, + zone_obj=zone.Zone(id=utils.generate_uuid()) + ) + + def test_synchronized_zone_raises_exception_when_no_zone_provided(self): + @lock.synchronized_zone(new_zone=False) + def mock_not_creating_new_zone(cls, context, record_obj): + pass + + self.assertRaisesRegex( + Exception, + 'Failed to determine zone id for synchronized operation', + mock_not_creating_new_zone, self.service, mock.Mock(), None + ) diff --git a/designate/worker/service.py b/designate/worker/service.py index b5d8a622..cf5a1a5d 100644 --- a/designate/worker/service.py +++ b/designate/worker/service.py @@ -21,9 +21,9 @@ import oslo_messaging as messaging from designate import backend from designate.central import rpcapi as central_api +from designate.common.decorators import rpc from designate.context import DesignateContext from designate import exceptions -from designate import rpc from designate import service from designate import storage from designate.worker import processing -- cgit v1.2.1