summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZuul <zuul@review.opendev.org>2022-08-04 00:08:54 +0000
committerGerrit Code Review <review@openstack.org>2022-08-04 00:08:54 +0000
commitef187ddec8599c2833909049f77e53d90cf4eea4 (patch)
treed266f44b332648d2c743c250856285f4992d341c
parent9876b7b9357765e430833b149fe7480ddb005550 (diff)
parent857b4c4e63d9ff8b17174f429b7fd285fabc903a (diff)
downloaddesignate-ef187ddec8599c2833909049f77e53d90cf4eea4.tar.gz
Merge "Re-factored central and rpc decorators"
-rw-r--r--designate/central/service.py223
-rw-r--r--designate/common/decorators/__init__.py0
-rw-r--r--designate/common/decorators/lock.py107
-rw-r--r--designate/common/decorators/notification.py90
-rw-r--r--designate/common/decorators/rpc.py49
-rw-r--r--designate/context.py16
-rw-r--r--designate/rpc.py27
-rw-r--r--designate/service.py2
-rw-r--r--designate/tests/test_central/test_decorator.py75
-rw-r--r--designate/tests/unit/test_central/test_basic.py25
-rw-r--r--designate/tests/unit/test_central/test_lock_decorator.py111
-rw-r--r--designate/worker/service.py2
12 files changed, 419 insertions, 308 deletions
diff --git a/designate/central/service.py b/designate/central/service.py
index 05173539..34e39338 100644
--- a/designate/central/service.py
+++ b/designate/central/service.py
@@ -14,16 +14,12 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
-import collections
import copy
-import functools
-import itertools
import random
from random import SystemRandom
import re
import signal
import string
-import threading
import time
from dns import exception as dnsexception
@@ -33,16 +29,16 @@ from oslo_log import log as logging
import oslo_messaging as messaging
from designate.common import constants
-from designate import context as dcontext
+from designate.common.decorators import lock
+from designate.common.decorators import notification
+from designate.common.decorators import rpc
from designate import coordination
from designate import dnsutils
from designate import exceptions
from designate import network_api
-from designate import notifications
from designate import objects
from designate import policy
from designate import quota
-from designate import rpc
from designate import scheduler
from designate import service
from designate import storage
@@ -51,135 +47,7 @@ from designate.storage import transaction_shallow_copy
from designate import utils
from designate.worker import rpcapi as worker_rpcapi
-
LOG = logging.getLogger(__name__)
-ZONE_LOCKS = threading.local()
-NOTIFICATION_BUFFER = threading.local()
-
-
-def synchronized_zone(zone_arg=1, new_zone=False):
- """Ensures only a single operation is in progress for each zone
-
- A Decorator which ensures only a single operation can be happening
- on a single zone at once, within the current designate-central instance
- """
- def outer(f):
- @functools.wraps(f)
- def sync_wrapper(self, *args, **kwargs):
- if not hasattr(ZONE_LOCKS, 'held'):
- # Create the held set if necessary
- ZONE_LOCKS.held = set()
-
- zone_id = None
-
- if 'zone_id' in kwargs:
- zone_id = kwargs['zone_id']
- elif 'zone' in kwargs:
- zone_id = kwargs['zone'].id
- elif 'recordset' in kwargs:
- zone_id = kwargs['recordset'].zone_id
- elif 'record' in kwargs:
- zone_id = kwargs['record'].zone_id
-
- # The various objects won't always have an ID set, we should
- # attempt to locate an Object containing the ID.
- if zone_id is None:
- for arg in itertools.chain(kwargs.values(), args):
- if isinstance(arg, objects.Zone):
- zone_id = arg.id
- if zone_id:
- break
- elif (isinstance(arg, objects.RecordSet) or
- isinstance(arg, objects.Record) or
- isinstance(arg, objects.ZoneTransferRequest) or
- isinstance(arg, objects.ZoneTransferAccept)):
- zone_id = arg.zone_id
- if zone_id:
- break
-
- # If we still don't have an ID, find the Nth argument as
- # defined by the zone_arg decorator option.
- if not zone_id and len(args) > zone_arg:
- zone_id = args[zone_arg]
- if isinstance(zone_id, objects.Zone):
- # If the value is a Zone object, extract it's ID.
- zone_id = zone_id.id
-
- if new_zone and not zone_id:
- lock_name = 'create-new-zone'
- elif not new_zone and zone_id:
- lock_name = 'zone-%s' % zone_id
- else:
- raise Exception('Failed to determine zone id for '
- 'synchronized operation')
-
- if zone_id in ZONE_LOCKS.held:
- return f(self, *args, **kwargs)
-
- with self.coordination.get_lock(lock_name):
- try:
- ZONE_LOCKS.held.add(zone_id)
- return f(self, *args, **kwargs)
- finally:
- ZONE_LOCKS.held.remove(zone_id)
-
- sync_wrapper.__wrapped_function = f
- sync_wrapper.__wrapper_name = 'synchronized_zone'
- return sync_wrapper
-
- return outer
-
-
-def notification(notification_type):
- def outer(f):
- @functools.wraps(f)
- def notification_wrapper(self, *args, **kwargs):
- if not hasattr(NOTIFICATION_BUFFER, 'queue'):
- # Create the notifications queue if necessary
- NOTIFICATION_BUFFER.stack = 0
- NOTIFICATION_BUFFER.queue = collections.deque()
-
- NOTIFICATION_BUFFER.stack += 1
-
- try:
- # Find the context argument
- context = dcontext.DesignateContext.\
- get_context_from_function_and_args(f, args, kwargs)
-
- # Call the wrapped function
- result = f(self, *args, **kwargs)
-
- # Feed the args/result to a notification plugin
- # to determine what is emitted
- payloads = notifications.get_plugin().emit(
- notification_type, context, result, args, kwargs)
-
- # Enqueue the notification
- for payload in payloads:
- LOG.debug('Queueing notification for %(type)s ',
- {'type': notification_type})
- NOTIFICATION_BUFFER.queue.appendleft(
- (context, notification_type, payload,))
-
- return result
-
- finally:
- NOTIFICATION_BUFFER.stack -= 1
-
- if NOTIFICATION_BUFFER.stack == 0:
- LOG.debug('Emitting %(count)d notifications',
- {'count': len(NOTIFICATION_BUFFER.queue)})
- # Send the queued notifications, in order.
- for value in NOTIFICATION_BUFFER.queue:
- LOG.debug('Emitting %(type)s notification',
- {'type': value[1]})
- self.notifier.info(value[0], value[1], value[2])
-
- # Reset the queue
- NOTIFICATION_BUFFER.queue.clear()
-
- return notification_wrapper
- return outer
class Service(service.RPCService):
@@ -188,6 +56,9 @@ class Service(service.RPCService):
target = messaging.Target(version=RPC_API_VERSION)
def __init__(self):
+ self.zone_lock_local = lock.ZoneLockLocal()
+ self.notification_thread_local = notification.NotificationThreadLocal()
+
self._scheduler = None
self._storage = None
self._quota = None
@@ -196,11 +67,9 @@ class Service(service.RPCService):
self.service_name, cfg.CONF['service:central'].topic,
threads=cfg.CONF['service:central'].threads,
)
-
self.coordination = coordination.Coordination(
self.service_name, self.tg, grouping_enabled=False
)
-
self.network_api = network_api.get_network_api(cfg.CONF.network_api)
@property
@@ -713,7 +582,7 @@ class Service(service.RPCService):
# TLD Methods
@rpc.expected_exceptions()
- @notification('dns.tld.create')
+ @notification.notify_type('dns.tld.create')
@transaction
def create_tld(self, context, tld):
policy.check('create_tld', context)
@@ -738,7 +607,7 @@ class Service(service.RPCService):
return self.storage.get_tld(context, tld_id)
@rpc.expected_exceptions()
- @notification('dns.tld.update')
+ @notification.notify_type('dns.tld.update')
@transaction
def update_tld(self, context, tld):
target = {
@@ -751,7 +620,7 @@ class Service(service.RPCService):
return tld
@rpc.expected_exceptions()
- @notification('dns.tld.delete')
+ @notification.notify_type('dns.tld.delete')
@transaction
def delete_tld(self, context, tld_id):
policy.check('delete_tld', context, {'tld_id': tld_id})
@@ -762,7 +631,7 @@ class Service(service.RPCService):
# TSIG Key Methods
@rpc.expected_exceptions()
- @notification('dns.tsigkey.create')
+ @notification.notify_type('dns.tsigkey.create')
@transaction
def create_tsigkey(self, context, tsigkey):
policy.check('create_tsigkey', context)
@@ -788,7 +657,7 @@ class Service(service.RPCService):
return self.storage.get_tsigkey(context, tsigkey_id)
@rpc.expected_exceptions()
- @notification('dns.tsigkey.update')
+ @notification.notify_type('dns.tsigkey.update')
@transaction
def update_tsigkey(self, context, tsigkey):
target = {
@@ -803,7 +672,7 @@ class Service(service.RPCService):
return tsigkey
@rpc.expected_exceptions()
- @notification('dns.tsigkey.delete')
+ @notification.notify_type('dns.tsigkey.delete')
@transaction
def delete_tsigkey(self, context, tsigkey_id):
policy.check('delete_tsigkey', context, {'tsigkey_id': tsigkey_id})
@@ -862,9 +731,9 @@ class Service(service.RPCService):
return pool.ns_records
@rpc.expected_exceptions()
- @notification('dns.domain.create')
- @notification('dns.zone.create')
- @synchronized_zone(new_zone=True)
+ @notification.notify_type('dns.domain.create')
+ @notification.notify_type('dns.zone.create')
+ @lock.synchronized_zone(new_zone=True)
def create_zone(self, context, zone):
"""Create zone: perform checks and then call _create_zone()
"""
@@ -1060,9 +929,9 @@ class Service(service.RPCService):
sort_key, sort_dir)
@rpc.expected_exceptions()
- @notification('dns.domain.update')
- @notification('dns.zone.update')
- @synchronized_zone()
+ @notification.notify_type('dns.domain.update')
+ @notification.notify_type('dns.zone.update')
+ @lock.synchronized_zone()
def update_zone(self, context, zone, increment_serial=True):
"""Update zone. Perform checks and then call _update_zone()
@@ -1134,9 +1003,9 @@ class Service(service.RPCService):
return zone
@rpc.expected_exceptions()
- @notification('dns.domain.delete')
- @notification('dns.zone.delete')
- @synchronized_zone()
+ @notification.notify_type('dns.domain.delete')
+ @notification.notify_type('dns.zone.delete')
+ @lock.synchronized_zone()
def delete_zone(self, context, zone_id):
"""Delete or abandon a zone
On abandon, delete the zone from the DB immediately.
@@ -1294,8 +1163,8 @@ class Service(service.RPCService):
# RecordSet Methods
@rpc.expected_exceptions()
- @notification('dns.recordset.create')
- @synchronized_zone()
+ @notification.notify_type('dns.recordset.create')
+ @lock.synchronized_zone()
def create_recordset(self, context, zone_id, recordset,
increment_serial=True):
zone = self.storage.get_zone(context, zone_id)
@@ -1467,8 +1336,8 @@ class Service(service.RPCService):
recordsets=recordsets)
@rpc.expected_exceptions()
- @notification('dns.recordset.update')
- @synchronized_zone()
+ @notification.notify_type('dns.recordset.update')
+ @lock.synchronized_zone()
def update_recordset(self, context, recordset, increment_serial=True):
zone_id = recordset.obj_get_original_value('zone_id')
zone = self.storage.get_zone(context, zone_id)
@@ -1550,8 +1419,8 @@ class Service(service.RPCService):
return recordset, zone
@rpc.expected_exceptions()
- @notification('dns.recordset.delete')
- @synchronized_zone()
+ @notification.notify_type('dns.recordset.delete')
+ @lock.synchronized_zone()
def delete_recordset(self, context, zone_id, recordset_id,
increment_serial=True):
zone = self.storage.get_zone(context, zone_id)
@@ -2049,7 +1918,7 @@ class Service(service.RPCService):
# Blacklisted zones
@rpc.expected_exceptions()
- @notification('dns.blacklist.create')
+ @notification.notify_type('dns.blacklist.create')
@transaction
def create_blacklist(self, context, blacklist):
policy.check('create_blacklist', context)
@@ -2078,7 +1947,7 @@ class Service(service.RPCService):
return blacklists
@rpc.expected_exceptions()
- @notification('dns.blacklist.update')
+ @notification.notify_type('dns.blacklist.update')
@transaction
def update_blacklist(self, context, blacklist):
target = {
@@ -2091,7 +1960,7 @@ class Service(service.RPCService):
return blacklist
@rpc.expected_exceptions()
- @notification('dns.blacklist.delete')
+ @notification.notify_type('dns.blacklist.delete')
@transaction
def delete_blacklist(self, context, blacklist_id):
policy.check('delete_blacklist', context)
@@ -2102,7 +1971,7 @@ class Service(service.RPCService):
# Server Pools
@rpc.expected_exceptions()
- @notification('dns.pool.create')
+ @notification.notify_type('dns.pool.create')
@transaction
def create_pool(self, context, pool):
# Verify that there is a tenant_id
@@ -2141,7 +2010,7 @@ class Service(service.RPCService):
return self.storage.get_pool(context, pool_id)
@rpc.expected_exceptions()
- @notification('dns.pool.update')
+ @notification.notify_type('dns.pool.update')
@transaction
def update_pool(self, context, pool):
policy.check('update_pool', context)
@@ -2202,7 +2071,7 @@ class Service(service.RPCService):
return updated_pool
@rpc.expected_exceptions()
- @notification('dns.pool.delete')
+ @notification.notify_type('dns.pool.delete')
@transaction
def delete_pool(self, context, pool_id):
@@ -2225,10 +2094,10 @@ class Service(service.RPCService):
# Pool Manager Integration
@rpc.expected_exceptions()
- @notification('dns.domain.update')
- @notification('dns.zone.update')
+ @notification.notify_type('dns.domain.update')
+ @notification.notify_type('dns.zone.update')
@transaction
- @synchronized_zone()
+ @lock.synchronized_zone()
def update_status(self, context, zone_id, status, serial, action=None):
"""
:param context: Security context information.
@@ -2356,7 +2225,7 @@ class Service(service.RPCService):
return ''.join(sysrand.choice(chars) for _ in range(size))
@rpc.expected_exceptions()
- @notification('dns.zone_transfer_request.create')
+ @notification.notify_type('dns.zone_transfer_request.create')
@transaction
def create_zone_transfer_request(self, context, zone_transfer_request):
@@ -2427,7 +2296,7 @@ class Service(service.RPCService):
return requests
@rpc.expected_exceptions()
- @notification('dns.zone_transfer_request.update')
+ @notification.notify_type('dns.zone_transfer_request.update')
@transaction
def update_zone_transfer_request(self, context, zone_transfer_request):
@@ -2449,7 +2318,7 @@ class Service(service.RPCService):
return request
@rpc.expected_exceptions()
- @notification('dns.zone_transfer_request.delete')
+ @notification.notify_type('dns.zone_transfer_request.delete')
@transaction
def delete_zone_transfer_request(self, context, zone_transfer_request_id):
# Get zone transfer request
@@ -2469,7 +2338,7 @@ class Service(service.RPCService):
zone_transfer_request_id)
@rpc.expected_exceptions()
- @notification('dns.zone_transfer_accept.create')
+ @notification.notify_type('dns.zone_transfer_accept.create')
@transaction
def create_zone_transfer_accept(self, context, zone_transfer_accept):
elevated_context = context.elevated(all_tenants=True)
@@ -2571,7 +2440,7 @@ class Service(service.RPCService):
# Zone Import Methods
@rpc.expected_exceptions()
- @notification('dns.zone_import.create')
+ @notification.notify_type('dns.zone_import.create')
def create_zone_import(self, context, request_body):
if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: context.project_id}
@@ -2667,7 +2536,7 @@ class Service(service.RPCService):
self.update_zone_import(context, zone_import)
@rpc.expected_exceptions()
- @notification('dns.zone_import.update')
+ @notification.notify_type('dns.zone_import.update')
def update_zone_import(self, context, zone_import):
if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: zone_import.tenant_id}
@@ -2710,7 +2579,7 @@ class Service(service.RPCService):
return self.storage.get_zone_import(context, zone_import_id)
@rpc.expected_exceptions()
- @notification('dns.zone_import.delete')
+ @notification.notify_type('dns.zone_import.delete')
@transaction
def delete_zone_import(self, context, zone_import_id):
@@ -2733,7 +2602,7 @@ class Service(service.RPCService):
# Zone Export Methods
@rpc.expected_exceptions()
- @notification('dns.zone_export.create')
+ @notification.notify_type('dns.zone_export.create')
def create_zone_export(self, context, zone_id):
# Try getting the zone to ensure it exists
zone = self.storage.get_zone(context, zone_id)
@@ -2797,7 +2666,7 @@ class Service(service.RPCService):
return self.storage.get_zone_export(context, zone_export_id)
@rpc.expected_exceptions()
- @notification('dns.zone_export.update')
+ @notification.notify_type('dns.zone_export.update')
def update_zone_export(self, context, zone_export):
if policy.enforce_new_defaults():
@@ -2810,7 +2679,7 @@ class Service(service.RPCService):
return self.storage.update_zone_export(context, zone_export)
@rpc.expected_exceptions()
- @notification('dns.zone_export.delete')
+ @notification.notify_type('dns.zone_export.delete')
@transaction
def delete_zone_export(self, context, zone_export_id):
diff --git a/designate/common/decorators/__init__.py b/designate/common/decorators/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/designate/common/decorators/__init__.py
diff --git a/designate/common/decorators/lock.py b/designate/common/decorators/lock.py
new file mode 100644
index 00000000..f633fa4d
--- /dev/null
+++ b/designate/common/decorators/lock.py
@@ -0,0 +1,107 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import itertools
+import threading
+
+from oslo_log import log as logging
+
+from designate import objects
+
+LOG = logging.getLogger(__name__)
+
+
+class ZoneLockLocal(threading.local):
+ def __init__(self):
+ super(ZoneLockLocal, self).__init__()
+ self._held = set()
+
+ def hold(self, name):
+ self._held.add(name)
+
+ def release(self, name):
+ self._held.remove(name)
+
+ def has_lock(self, name):
+ return name in self._held
+
+
+def extract_zone_id(args, kwargs):
+ zone_id = None
+
+ if 'zone_id' in kwargs:
+ zone_id = kwargs['zone_id']
+ elif 'zone' in kwargs:
+ zone_id = kwargs['zone'].id
+ elif 'recordset' in kwargs:
+ zone_id = kwargs['recordset'].zone_id
+ elif 'record' in kwargs:
+ zone_id = kwargs['record'].zone_id
+
+ if not zone_id:
+ for arg in itertools.chain(args, kwargs.values()):
+ if not isinstance(arg, objects.DesignateObject):
+ continue
+ if isinstance(arg, objects.Zone):
+ zone_id = arg.id
+ if zone_id:
+ break
+ elif isinstance(arg, (objects.RecordSet,
+ objects.Record,
+ objects.ZoneTransferRequest,
+ objects.ZoneTransferAccept)):
+ zone_id = arg.zone_id
+ if zone_id:
+ break
+
+ if not zone_id and len(args) > 1:
+ arg = args[1]
+ if isinstance(arg, str):
+ zone_id = arg
+ elif isinstance(zone_id, objects.Zone):
+ zone_id = arg.id
+
+ return zone_id
+
+
+def synchronized_zone(new_zone=False):
+ """Ensures only a single operation is in progress for each zone
+
+ A Decorator which ensures only a single operation can be happening
+ on a single zone at once, within the current designate-central instance
+ """
+ def outer(f):
+ @functools.wraps(f)
+ def sync_wrapper(cls, *args, **kwargs):
+ if new_zone is True:
+ lock_name = 'create-new-zone'
+ else:
+ zone_id = extract_zone_id(args, kwargs)
+ if zone_id:
+ lock_name = 'zone-%s' % zone_id
+ else:
+ raise Exception('Failed to determine zone id for '
+ 'synchronized operation')
+
+ if cls.zone_lock_local.has_lock(lock_name):
+ return f(cls, *args, **kwargs)
+
+ with cls.coordination.get_lock(lock_name):
+ try:
+ cls.zone_lock_local.hold(lock_name)
+ return f(cls, *args, **kwargs)
+ finally:
+ cls.zone_lock_local.release(lock_name)
+
+ return sync_wrapper
+ return outer
diff --git a/designate/common/decorators/notification.py b/designate/common/decorators/notification.py
new file mode 100644
index 00000000..c43a92bd
--- /dev/null
+++ b/designate/common/decorators/notification.py
@@ -0,0 +1,90 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import collections
+import functools
+import itertools
+import threading
+
+from oslo_log import log as logging
+
+from designate import context as designate_context
+from designate import notifications
+
+LOG = logging.getLogger(__name__)
+
+
+class NotificationThreadLocal(threading.local):
+ def __init__(self):
+ super(NotificationThreadLocal, self).__init__()
+ self.stack = 0
+ self.queue = collections.deque()
+
+ def reset_queue(self):
+ self.queue.clear()
+
+
+def notify_type(notification_type):
+ def outer(f):
+ @functools.wraps(f)
+ def notification_wrapper(cls, *args, **kwargs):
+ cls.notification_thread_local.stack += 1
+
+ context = None
+ for arg in itertools.chain(args, kwargs.values()):
+ if isinstance(arg, designate_context.DesignateContext):
+ context = arg
+ break
+
+ try:
+ result = f(cls, *args, **kwargs)
+
+ payloads = notifications.get_plugin().emit(
+ notification_type, context, result, args, kwargs
+ )
+ for payload in payloads:
+ LOG.debug(
+ 'Queueing notification for %(type)s',
+ {
+ 'type': notification_type
+ }
+ )
+ cls.notification_thread_local.queue.appendleft(
+ (context, notification_type, payload,)
+ )
+
+ return result
+
+ finally:
+ cls.notification_thread_local.stack -= 1
+
+ if cls.notification_thread_local.stack == 0:
+ LOG.debug(
+ 'Emitting %(count)d notifications',
+ {
+ 'count': len(cls.notification_thread_local.queue)
+ }
+ )
+
+ for message in cls.notification_thread_local.queue:
+ LOG.debug(
+ 'Emitting %(type)s notification',
+ {
+ 'type': message[1]
+ }
+ )
+ cls.notifier.info(message[0], message[1], message[2])
+
+ cls.notification_thread_local.reset_queue()
+
+ return notification_wrapper
+ return outer
diff --git a/designate/common/decorators/rpc.py b/designate/common/decorators/rpc.py
new file mode 100644
index 00000000..69cad608
--- /dev/null
+++ b/designate/common/decorators/rpc.py
@@ -0,0 +1,49 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import threading
+
+from oslo_messaging.rpc import dispatcher as rpc_dispatcher
+
+import designate.exceptions
+
+
+class ExceptionThreadLocal(threading.local):
+ def __init__(self):
+ super(ExceptionThreadLocal, self).__init__()
+ self.depth = 0
+
+ def reset_depth(self):
+ self.depth = 0
+
+
+def expected_exceptions():
+ def outer(f):
+ @functools.wraps(f)
+ def exception_wrapper(cls, *args, **kwargs):
+ cls.exception_thread_local.depth += 1
+
+ # We only want to wrap the first function wrapped.
+ if cls.exception_thread_local.depth > 1:
+ return f(cls, *args, **kwargs)
+
+ try:
+ return f(cls, *args, **kwargs)
+ except designate.exceptions.DesignateException as e:
+ if e.expected:
+ raise rpc_dispatcher.ExpectedException()
+ raise
+ finally:
+ cls.exception_thread_local.reset_depth()
+ return exception_wrapper
+ return outer
diff --git a/designate/context.py b/designate/context.py
index 01ea0ce8..5e033446 100644
--- a/designate/context.py
+++ b/designate/context.py
@@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import copy
-import itertools
from keystoneauth1.access import service_catalog as ksa_service_catalog
from keystoneauth1 import plugin
@@ -145,21 +144,6 @@ class DesignateContext(context.RequestContext):
return cls(None, **kwargs)
- @classmethod
- def get_context_from_function_and_args(cls, function, args, kwargs):
- """
- Find an arg of type DesignateContext and return it.
-
- This is useful in a couple of decorators where we don't
- know much about the function we're wrapping.
- """
-
- for arg in itertools.chain(kwargs.values(), args):
- if isinstance(arg, cls):
- return arg
-
- return None
-
@property
def all_tenants(self):
return self._all_tenants
diff --git a/designate/rpc.py b/designate/rpc.py
index 51efeb71..48636ab9 100644
--- a/designate/rpc.py
+++ b/designate/rpc.py
@@ -11,8 +11,6 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
-import functools
-import threading
from oslo_config import cfg
import oslo_messaging as messaging
@@ -40,7 +38,6 @@ __all__ = [
]
CONF = cfg.CONF
-EXPECTED_EXCEPTION = threading.local()
NOTIFICATION_TRANSPORT = None
NOTIFIER = None
TRANSPORT = None
@@ -237,27 +234,3 @@ def create_transport(url):
return messaging.get_rpc_transport(CONF,
url=url,
allowed_remote_exmods=exmods)
-
-
-def expected_exceptions():
- def outer(f):
- @functools.wraps(f)
- def exception_wrapper(self, *args, **kwargs):
- if not hasattr(EXPECTED_EXCEPTION, 'depth'):
- EXPECTED_EXCEPTION.depth = 0
- EXPECTED_EXCEPTION.depth += 1
-
- # We only want to wrap the first function wrapped.
- if EXPECTED_EXCEPTION.depth > 1:
- return f(self, *args, **kwargs)
-
- try:
- return f(self, *args, **kwargs)
- except designate.exceptions.DesignateException as e:
- if e.expected:
- raise rpc_dispatcher.ExpectedException()
- raise
- finally:
- EXPECTED_EXCEPTION.depth = 0
- return exception_wrapper
- return outer
diff --git a/designate/service.py b/designate/service.py
index ce84cbdb..0fec3e02 100644
--- a/designate/service.py
+++ b/designate/service.py
@@ -30,6 +30,7 @@ from oslo_service import sslutils
from oslo_service import wsgi
from oslo_utils import netutils
+from designate.common.decorators import rpc as rpc_decorator
from designate.common import profiler
import designate.conf
from designate.i18n import _
@@ -77,6 +78,7 @@ class RPCService(Service):
rpc_topic, self.name)
self.endpoints = [self]
+ self.exception_thread_local = rpc_decorator.ExceptionThreadLocal()
self.notifier = None
self.rpc_server = None
self.rpc_topic = rpc_topic
diff --git a/designate/tests/test_central/test_decorator.py b/designate/tests/test_central/test_decorator.py
deleted file mode 100644
index 66472cca..00000000
--- a/designate/tests/test_central/test_decorator.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-from unittest import mock
-
-from oslo_concurrency import lockutils
-from oslo_log import log as logging
-
-from designate.central import service
-from designate import exceptions
-from designate.objects import record
-from designate.objects import zone
-from designate.tests.test_central import CentralTestCase
-from designate import utils
-
-LOG = logging.getLogger(__name__)
-
-
-class FakeCoordination(object):
- def get_lock(self, name):
- return lockutils.lock(name)
-
-
-class CentralDecoratorTests(CentralTestCase):
- def test_synchronized_zone_exception_raised(self):
- @service.synchronized_zone()
- def mock_get_zone(cls, index, zone):
- self.assertEqual(service.ZONE_LOCKS.held, {zone.id})
- if index % 3 == 0:
- raise exceptions.ZoneNotFound()
-
- for index in range(9):
- try:
- mock_get_zone(mock.Mock(coordination=FakeCoordination()),
- index,
- zone.Zone(id=utils.generate_uuid()))
- except exceptions.ZoneNotFound:
- pass
-
- def test_synchronized_zone_recursive_decorator_call(self):
- @service.synchronized_zone()
- def mock_create_record(cls, context, record):
- self.assertEqual(service.ZONE_LOCKS.held, {record.zone_id})
- mock_get_zone(cls, context, zone.Zone(id=record.zone_id))
-
- @service.synchronized_zone()
- def mock_get_zone(cls, context, zone):
- self.assertEqual(service.ZONE_LOCKS.held, {zone.id})
-
- mock_create_record(mock.Mock(coordination=FakeCoordination()),
- self.get_context(),
- record=record.Record(zone_id=utils.generate_uuid()))
- mock_get_zone(mock.Mock(coordination=FakeCoordination()),
- self.get_context(),
- zone=zone.Zone(id=utils.generate_uuid()))
-
- def test_synchronized_zone_raises_exception_when_no_zone_provided(self):
- @service.synchronized_zone(new_zone=False)
- def mock_not_creating_new_zone(cls, context, record):
- pass
-
- self.assertRaisesRegex(
- Exception,
- 'Failed to determine zone id for '
- 'synchronized operation',
- mock_not_creating_new_zone, self.get_context(), None
- )
diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py
index e4ec18ba..e9e7a6b9 100644
--- a/designate/tests/unit/test_central/test_basic.py
+++ b/designate/tests/unit/test_central/test_basic.py
@@ -392,7 +392,7 @@ class CentralServiceTestCase(CentralBasic):
def test_create_recordset_in_storage(self):
self.service._enforce_recordset_quota = mock.Mock()
- self.service._validate_recordset = mock.Mock()
+ self.service._validate_recordset = mock.Mock(spec=objects.RecordSet)
self.service.storage.create_recordset = mock.Mock(return_value='rs')
self.service._update_zone_in_storage = mock.Mock()
@@ -416,7 +416,7 @@ class CentralServiceTestCase(CentralBasic):
central_service.storage.create_recordset = mock.Mock(return_value='rs')
central_service._update_zone_in_storage = mock.Mock()
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()]
@@ -441,7 +441,7 @@ class CentralServiceTestCase(CentralBasic):
# NOTE(thirose): Since this is a race condition we assume that
# we will hit it if we try to do the operations in a loop 100 times.
for num in range(100):
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.name = "b{}".format(num)
recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()]
@@ -1148,7 +1148,7 @@ class CentralZoneTestCase(CentralBasic):
def test_update_recordset_fail_on_changes(self):
self.service.storage.get_zone.return_value = RoObject()
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_original_value.return_value = '1'
recordset.obj_get_changes.return_value = ['tenant_id', 'foo']
@@ -1179,7 +1179,7 @@ class CentralZoneTestCase(CentralBasic):
self.service.storage.get_zone.return_value = RoObject(
action='DELETE',
)
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo']
exc = self.assertRaises(rpc_dispatcher.ExpectedException,
@@ -1196,7 +1196,7 @@ class CentralZoneTestCase(CentralBasic):
tenant_id='2',
action='bogus',
)
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo']
recordset.managed = True
self.context = mock.Mock()
@@ -1216,10 +1216,11 @@ class CentralZoneTestCase(CentralBasic):
tenant_id='2',
action='bogus',
)
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo']
- recordset.obj_get_original_value.return_value =\
+ recordset.obj_get_original_value.return_value = (
'9c85d9b0-1e9d-4e99-aede-a06664f1af2e'
+ )
recordset.managed = False
self.service._update_recordset_in_storage = mock.Mock(
return_value=('x', 'y')
@@ -1239,7 +1240,7 @@ class CentralZoneTestCase(CentralBasic):
'recordset_id': '9c85d9b0-1e9d-4e99-aede-a06664f1af2e',
'project_id': '2'}, target)
- def test__update_recordset_in_storage(self):
+ def test_update_recordset_in_storage(self):
recordset = mock.Mock()
recordset.name = 'n'
recordset.type = 't'
@@ -1426,7 +1427,7 @@ class CentralZoneTestCase(CentralBasic):
self.assertTrue(
self.service._delete_recordset_in_storage.called)
- def test__delete_recordset_in_storage(self):
+ def test_delete_recordset_in_storage(self):
def mock_uds(c, zone, inc):
return zone
self.service._update_zone_in_storage = mock_uds
@@ -1730,7 +1731,7 @@ class CentralQuotaTest(unittest.TestCase):
service = Service()
service.storage.count_records.return_value = 10
- recordset = mock.Mock()
+ recordset = mock.Mock(spec=objects.RecordSet)
recordset.managed = False
recordset.records = ['1.1.1.%i' % (i + 1) for i in range(5)]
@@ -1801,7 +1802,7 @@ class CentralQuotaTest(unittest.TestCase):
1, 1,
]
- managed_recordset = mock.Mock()
+ managed_recordset = mock.Mock(spec=objects.RecordSet)
managed_recordset.managed = True
recordset_one_record = mock.Mock()
diff --git a/designate/tests/unit/test_central/test_lock_decorator.py b/designate/tests/unit/test_central/test_lock_decorator.py
new file mode 100644
index 00000000..c8d8058d
--- /dev/null
+++ b/designate/tests/unit/test_central/test_lock_decorator.py
@@ -0,0 +1,111 @@
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from unittest import mock
+
+from oslo_concurrency import lockutils
+from oslo_log import log as logging
+import oslotest.base
+
+from designate.common.decorators import lock
+from designate import exceptions
+from designate.objects import record
+from designate.objects import zone
+from designate import utils
+
+LOG = logging.getLogger(__name__)
+
+
+class FakeCoordination:
+ def get_lock(self, name):
+ return lockutils.lock(name)
+
+
+class FakeService:
+ def __init__(self):
+ self.zone_lock_local = lock.ZoneLockLocal()
+ self.coordination = FakeCoordination()
+
+
+class CentralDecoratorTests(oslotest.base.BaseTestCase):
+ def setUp(self):
+ super().setUp()
+ self.context = mock.Mock()
+ self.service = FakeService()
+
+ def test_synchronized_zone_exception_raised(self):
+ @lock.synchronized_zone()
+ def mock_get_zone(cls, current_index, zone_obj):
+ self.assertEqual(
+ {'zone-%s' % zone_obj.id}, cls.zone_lock_local._held
+ )
+ if current_index % 3 == 0:
+ raise exceptions.ZoneNotFound()
+
+ for index in range(9):
+ try:
+ mock_get_zone(
+ self.service, index, zone.Zone(id=utils.generate_uuid())
+ )
+ except exceptions.ZoneNotFound:
+ pass
+
+ def test_synchronized_new_zone_with_recursion(self):
+ @lock.synchronized_zone(new_zone=True)
+ def mock_create_zone(cls, context):
+ self.assertEqual({'create-new-zone'}, cls.zone_lock_local._held)
+ mock_create_record(
+ cls, context, zone.Zone(id=utils.generate_uuid())
+ )
+
+ @lock.synchronized_zone()
+ def mock_create_record(cls, context, zone_obj):
+ self.assertIn('zone-%s' % zone_obj.id, cls.zone_lock_local._held)
+ self.assertIn('create-new-zone', cls.zone_lock_local._held)
+
+ mock_create_zone(
+ self.service, self.context
+ )
+
+ def test_synchronized_zone_recursive_decorator_call(self):
+ @lock.synchronized_zone()
+ def mock_create_record(cls, context, record_obj):
+ self.assertEqual(
+ {'zone-%s' % record_obj.zone_id}, cls.zone_lock_local._held
+ )
+ mock_get_zone(cls, context, zone.Zone(id=record_obj.zone_id))
+
+ @lock.synchronized_zone()
+ def mock_get_zone(cls, context, zone_obj):
+ self.assertEqual(
+ {'zone-%s' % zone_obj.id}, cls.zone_lock_local._held
+ )
+
+ mock_create_record(
+ self.service, self.context,
+ record_obj=record.Record(zone_id=utils.generate_uuid())
+ )
+ mock_get_zone(
+ self.service, self.context,
+ zone_obj=zone.Zone(id=utils.generate_uuid())
+ )
+
+ def test_synchronized_zone_raises_exception_when_no_zone_provided(self):
+ @lock.synchronized_zone(new_zone=False)
+ def mock_not_creating_new_zone(cls, context, record_obj):
+ pass
+
+ self.assertRaisesRegex(
+ Exception,
+ 'Failed to determine zone id for synchronized operation',
+ mock_not_creating_new_zone, self.service, mock.Mock(), None
+ )
diff --git a/designate/worker/service.py b/designate/worker/service.py
index b5d8a622..cf5a1a5d 100644
--- a/designate/worker/service.py
+++ b/designate/worker/service.py
@@ -21,9 +21,9 @@ import oslo_messaging as messaging
from designate import backend
from designate.central import rpcapi as central_api
+from designate.common.decorators import rpc
from designate.context import DesignateContext
from designate import exceptions
-from designate import rpc
from designate import service
from designate import storage
from designate.worker import processing