summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--designate/central/service.py215
-rw-r--r--designate/tests/__init__.py26
-rw-r--r--designate/tests/test_api/test_v2/test_import_export.py6
-rw-r--r--designate/tests/test_central/test_service.py24
-rw-r--r--designate/tests/unit/test_central/test_basic.py40
5 files changed, 157 insertions, 154 deletions
diff --git a/designate/central/service.py b/designate/central/service.py
index fe3fb04c..99654e30 100644
--- a/designate/central/service.py
+++ b/designate/central/service.py
@@ -26,9 +26,8 @@ import random
from random import SystemRandom
import time
-from eventlet import tpool
-from dns import zone as dnszone
from dns import exception as dnsexception
+from dns import zone as dnszone
from oslo_config import cfg
import oslo_messaging as messaging
from oslo_log import log as logging
@@ -561,51 +560,55 @@ class Service(service.RPCService):
objects.Record(data=r, managed=True) for r in ns_records])
values = {
'name': zone['name'],
- 'type': "NS",
+ 'type': 'NS',
'records': recordlist
}
ns, zone = self._create_recordset_in_storage(
context, zone, objects.RecordSet(**values),
- increment_serial=False)
+ increment_serial=False
+ )
return ns
def _add_ns(self, context, zone, ns_record):
# Get NS recordset
# If the zone doesn't have an NS recordset yet, create one
- recordsets = self.find_recordsets(
- context, criterion={'zone_id': zone['id'], 'type': "NS"}
- )
-
- managed = []
- for rs in recordsets:
- if rs.managed:
- managed.append(rs)
-
- if len(managed) == 0:
+ try:
+ recordset = self.find_recordset(
+ context,
+ criterion={
+ 'zone_id': zone['id'],
+ 'name': zone['name'],
+ 'type': 'NS'
+ }
+ )
+ except exceptions.RecordSetNotFound:
self._create_ns(context, zone, [ns_record])
return
- elif len(managed) != 1:
- raise exceptions.RecordSetNotFound("No valid recordset found")
-
- ns_recordset = managed[0]
# Add new record to recordset based on the new nameserver
- ns_recordset.records.append(
- objects.Record(data=ns_record, managed=True))
+ recordset.records.append(
+ objects.Record(data=ns_record, managed=True)
+ )
- self._update_recordset_in_storage(context, zone, ns_recordset,
+ self._update_recordset_in_storage(context, zone, recordset,
set_delayed_notify=True)
def _delete_ns(self, context, zone, ns_record):
- ns_recordset = self.find_recordset(
- context, criterion={'zone_id': zone['id'], 'type': "NS"})
+ recordset = self.find_recordset(
+ context,
+ criterion={
+ 'zone_id': zone['id'],
+ 'name': zone['name'],
+ 'type': 'NS'
+ }
+ )
- for record in copy.deepcopy(ns_recordset.records):
+ for record in list(recordset.records):
if record.data == ns_record:
- ns_recordset.records.remove(record)
+ recordset.records.remove(record)
- self._update_recordset_in_storage(context, zone, ns_recordset,
+ self._update_recordset_in_storage(context, zone, recordset,
set_delayed_notify=True)
# Quota Enforcement Methods
@@ -2505,46 +2508,49 @@ class Service(service.RPCService):
@notification('dns.pool.update')
@transaction
def update_pool(self, context, pool):
-
policy.check('update_pool', context)
# If there is a nameserver, then additional steps need to be done
# Since these are treated as mutable objects, we're only going to
# be comparing the nameserver.value which is the FQDN
- if pool.obj_attr_is_set('ns_records'):
- elevated_context = context.elevated(all_tenants=True)
+ elevated_context = context.elevated(all_tenants=True)
- # TODO(kiall): ListObjects should be able to give you their
- # original set of values.
- original_pool_ns_records = self._get_pool_ns_records(context,
- pool.id)
- # Find the current NS hostnames
- existing_ns = set([n.hostname for n in original_pool_ns_records])
+ # TODO(kiall): ListObjects should be able to give you their
+ # original set of values.
+ original_pool_ns_records = self._get_pool_ns_records(
+ context, pool.id
+ )
- # Find the desired NS hostnames
- request_ns = set([n.hostname for n in pool.ns_records])
+ updated_pool = self.storage.update_pool(context, pool)
- # Get the NS's to be created and deleted, ignoring the ones that
- # are in both sets, as those haven't changed.
- # TODO(kiall): Factor in priority
- create_ns = request_ns.difference(existing_ns)
- delete_ns = existing_ns.difference(request_ns)
+ if not pool.obj_attr_is_set('ns_records'):
+ return updated_pool
- updated_pool = self.storage.update_pool(context, pool)
+ # Find the current NS hostnames
+ existing_ns = set([n.hostname for n in original_pool_ns_records])
+
+ # Find the desired NS hostnames
+ request_ns = set([n.hostname for n in pool.ns_records])
+
+ # Get the NS's to be created and deleted, ignoring the ones that
+ # are in both sets, as those haven't changed.
+ # TODO(kiall): Factor in priority
+ create_ns = request_ns.difference(existing_ns)
+ delete_ns = existing_ns.difference(request_ns)
# After the update, handle new ns_records
- for ns in create_ns:
+ for ns_record in create_ns:
# Create new NS recordsets for every zone
zones = self.find_zones(
context=elevated_context,
criterion={'pool_id': pool.id, 'action': '!DELETE'})
- for z in zones:
- self._add_ns(elevated_context, z, ns)
+ for zone in zones:
+ self._add_ns(elevated_context, zone, ns_record)
# Then handle the ns_records to delete
- for ns in delete_ns:
+ for ns_record in delete_ns:
# Cannot delete the last nameserver, so verify that first.
- if len(pool.ns_records) == 0:
+ if not pool.ns_records:
raise exceptions.LastServerDeleteNotAllowed(
"Not allowed to delete last of servers"
)
@@ -2552,9 +2558,10 @@ class Service(service.RPCService):
# Delete the NS record for every zone
zones = self.find_zones(
context=elevated_context,
- criterion={'pool_id': pool.id})
- for z in zones:
- self._delete_ns(elevated_context, z, ns)
+ criterion={'pool_id': pool.id}
+ )
+ for zone in zones:
+ self._delete_ns(elevated_context, zone, ns_record)
return updated_pool
@@ -2990,7 +2997,6 @@ class Service(service.RPCService):
@rpc.expected_exceptions()
@notification('dns.zone_import.create')
def create_zone_import(self, context, request_body):
-
if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: context.project_id}
else:
@@ -3010,59 +3016,49 @@ class Service(service.RPCService):
zone_import = objects.ZoneImport(**values)
created_zone_import = self.storage.create_zone_import(context,
- zone_import)
+ zone_import)
self.tg.add_thread(self._import_zone, context, created_zone_import,
- request_body)
+ request_body)
return created_zone_import
def _import_zone(self, context, zone_import, request_body):
-
- def _import(self, context, zone_import, request_body):
- # Dnspython needs a str instead of a unicode object
- zone = None
- try:
- dnspython_zone = dnszone.from_text(
- request_body,
- # Don't relativize, or we end up with '@' record names.
- relativize=False,
- # Don't check origin, we allow missing NS records
- # (missing SOA records are taken care of in _create_zone).
- check_origin=False)
- zone = dnsutils.from_dnspython_zone(dnspython_zone)
- zone.type = 'PRIMARY'
-
- for rrset in list(zone.recordsets):
- if rrset.type == 'SOA':
- zone.recordsets.remove(rrset)
- # subdomain NS records should be kept
- elif rrset.type == 'NS' and rrset.name == zone.name:
- zone.recordsets.remove(rrset)
-
- except dnszone.UnknownOrigin:
- zone_import.message = ('The $ORIGIN statement is required and'
- ' must be the first statement in the'
- ' zonefile.')
- zone_import.status = 'ERROR'
- except dnsexception.SyntaxError:
- zone_import.message = 'Malformed zonefile.'
- zone_import.status = 'ERROR'
- except exceptions.BadRequest:
- zone_import.message = 'An SOA record is required.'
- zone_import.status = 'ERROR'
- except Exception as e:
- LOG.exception('An undefined error occurred during zone import')
- msg = 'An undefined error occurred. %s'\
- % str(e)[:130]
- zone_import.message = msg
- zone_import.status = 'ERROR'
-
- return zone, zone_import
-
- # Execute the import in a real Python thread
- zone, zone_import = tpool.execute(_import, self, context,
- zone_import, request_body)
+ zone = None
+ try:
+ dnspython_zone = dnszone.from_text(
+ request_body,
+ # Don't relativize, or we end up with '@' record names.
+ relativize=False,
+ # Don't check origin, we allow missing NS records
+ # (missing SOA records are taken care of in _create_zone).
+ check_origin=False)
+ zone = dnsutils.from_dnspython_zone(dnspython_zone)
+ zone.type = 'PRIMARY'
+ for rrset in list(zone.recordsets):
+ if rrset.type == 'SOA':
+ zone.recordsets.remove(rrset)
+ # subdomain NS records should be kept
+ elif rrset.type == 'NS' and rrset.name == zone.name:
+ zone.recordsets.remove(rrset)
+ except dnszone.UnknownOrigin:
+ zone_import.message = (
+ 'The $ORIGIN statement is required and must be the first '
+ 'statement in the zonefile.'
+ )
+ zone_import.status = 'ERROR'
+ except dnsexception.SyntaxError:
+ zone_import.message = 'Malformed zonefile.'
+ zone_import.status = 'ERROR'
+ except exceptions.BadRequest:
+ zone_import.message = 'An SOA record is required.'
+ zone_import.status = 'ERROR'
+ except Exception as e:
+ LOG.exception('An undefined error occurred during zone import')
+ zone_import.message = (
+ 'An undefined error occurred. %s' % str(e)[:130]
+ )
+ zone_import.status = 'ERROR'
# If the zone import was valid, create the zone
if zone_import.status != 'ERROR':
@@ -3070,27 +3066,32 @@ class Service(service.RPCService):
zone = self.create_zone(context, zone)
zone_import.status = 'COMPLETE'
zone_import.zone_id = zone.id
- zone_import.message = '%(name)s imported' % {'name':
- zone.name}
+ zone_import.message = (
+ '%(name)s imported' % {'name': zone.name}
+ )
except exceptions.DuplicateZone:
zone_import.status = 'ERROR'
zone_import.message = 'Duplicate zone.'
except exceptions.InvalidTTL as e:
zone_import.status = 'ERROR'
zone_import.message = str(e)
+ except exceptions.OverQuota:
+ zone_import.status = 'ERROR'
+ zone_import.message = 'Quota exceeded during zone import.'
except Exception as e:
- LOG.exception('An undefined error occurred during zone '
- 'import creation')
- msg = 'An undefined error occurred. %s'\
- % str(e)[:130]
- zone_import.message = msg
+ LOG.exception(
+ 'An undefined error occurred during zone import creation'
+ )
+ zone_import.message = (
+ 'An undefined error occurred. %s' % str(e)[:130]
+ )
zone_import.status = 'ERROR'
self.update_zone_import(context, zone_import)
@rpc.expected_exceptions()
def find_zone_imports(self, context, criterion=None, marker=None,
- limit=None, sort_key=None, sort_dir=None):
+ limit=None, sort_key=None, sort_dir=None):
if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: context.project_id}
diff --git a/designate/tests/__init__.py b/designate/tests/__init__.py
index 07bf510d..24c7beaa 100644
--- a/designate/tests/__init__.py
+++ b/designate/tests/__init__.py
@@ -786,34 +786,36 @@ class TestCase(base.BaseTestCase):
return self.storage.create_zone_export(
context, objects.ZoneExport.from_dict(zone_export))
- def wait_for_import(self, zone_import_id, errorok=False):
+ def wait_for_import(self, zone_import_id, error_is_ok=False, max_wait=10):
"""
Zone imports spawn a thread to parse the zone file and
insert the data. This waits for this process before continuing
"""
- attempts = 0
- while attempts < 20:
- # Give the import a half second to complete
- time.sleep(.5)
-
+ start_time = time.time()
+ while True:
# Retrieve it, and ensure it's the same
zone_import = self.central_service.get_zone_import(
- self.admin_context_all_tenants, zone_import_id)
+ self.admin_context_all_tenants, zone_import_id
+ )
# If the import is done, we're done
if zone_import.status == 'COMPLETE':
break
# If errors are allowed, just make sure that something completed
- if errorok:
- if zone_import.status != 'PENDING':
- break
+ if error_is_ok and zone_import.status != 'PENDING':
+ break
- attempts += 1
+ if (time.time() - start_time) > max_wait:
+ break
- if not errorok:
+ time.sleep(0.5)
+
+ if not error_is_ok:
self.assertEqual('COMPLETE', zone_import.status)
+ return zone_import
+
def _ensure_interface(self, interface, implementation):
for name in interface.__abstractmethods__:
in_arginfo = inspect.getfullargspec(getattr(interface, name))
diff --git a/designate/tests/test_api/test_v2/test_import_export.py b/designate/tests/test_api/test_v2/test_import_export.py
index aefc6e1d..b9e491e0 100644
--- a/designate/tests/test_api/test_v2/test_import_export.py
+++ b/designate/tests/test_api/test_v2/test_import_export.py
@@ -53,7 +53,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase):
headers={'Content-type': 'text/dns'})
import_id = response.json_body['id']
- self.wait_for_import(import_id, errorok=True)
+ self.wait_for_import(import_id, error_is_ok=True)
url = '/zones/tasks/imports/%s' % import_id
@@ -70,7 +70,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase):
headers={'Content-type': 'text/dns'})
import_id = response.json_body['id']
- self.wait_for_import(import_id, errorok=True)
+ self.wait_for_import(import_id, error_is_ok=True)
url = '/zones/tasks/imports/%s' % import_id
@@ -86,7 +86,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase):
headers={'Content-type': 'text/dns'})
import_id = response.json_body['id']
- self.wait_for_import(import_id, errorok=True)
+ self.wait_for_import(import_id, error_is_ok=True)
url = '/zones/tasks/imports/%s' % import_id
diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py
index 5db47b4f..7ff36c1f 100644
--- a/designate/tests/test_central/test_service.py
+++ b/designate/tests/test_central/test_service.py
@@ -3548,6 +3548,30 @@ class CentralServiceTest(CentralTestCase):
self.wait_for_import(zone_import.id)
+ def test_create_zone_import_overquota(self):
+ self.config(
+ quota_zone_records=5,
+ quota_zone_recordsets=5,
+ )
+
+ # Create a Zone Import
+ context = self.get_context(project_id=utils.generate_uuid())
+ request_body = self.get_zonefile_fixture()
+ zone_import = self.central_service.create_zone_import(context,
+ request_body)
+
+ # Ensure all values have been set correctly
+ self.assertIsNotNone(zone_import['id'])
+ self.assertEqual('PENDING', zone_import.status)
+ self.assertIsNone(zone_import.message)
+ self.assertIsNone(zone_import.zone_id)
+
+ zone_import = self.wait_for_import(zone_import.id, error_is_ok=True)
+
+ self.assertEqual('Quota exceeded during zone import.',
+ zone_import.message)
+ self.assertEqual('ERROR', zone_import.status)
+
def test_find_zone_imports(self):
context = self.get_context(project_id=utils.generate_uuid())
diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py
index e75922e3..d3bb6b41 100644
--- a/designate/tests/unit/test_central/test_basic.py
+++ b/designate/tests/unit/test_central/test_basic.py
@@ -789,13 +789,13 @@ class CentralZoneTestCase(CentralBasic):
def test_add_ns_creation(self):
self.service._create_ns = mock.Mock()
- self.service.find_recordsets = mock.Mock(
- return_value=[]
+ self.service.find_recordset = mock.Mock(
+ side_effect=exceptions.RecordSetNotFound()
)
self.service._add_ns(
self.context,
- RoObject(id=CentralZoneTestCase.zone__id),
+ RoObject(name='foo', id=CentralZoneTestCase.zone__id),
RoObject(name='bar')
)
ctx, zone, records = self.service._create_ns.call_args[0]
@@ -804,16 +804,15 @@ class CentralZoneTestCase(CentralBasic):
def test_add_ns(self):
self.service._update_recordset_in_storage = mock.Mock()
- recordsets = [
- RoObject(records=objects.RecordList.from_list([]), managed=True)
- ]
- self.service.find_recordsets = mock.Mock(
- return_value=recordsets
+ self.service.find_recordset = mock.Mock(
+ return_value=RoObject(
+ records=objects.RecordList.from_list([]), managed=True
+ )
)
self.service._add_ns(
self.context,
- RoObject(id=CentralZoneTestCase.zone__id),
+ RoObject(name='foo', id=CentralZoneTestCase.zone__id),
RoObject(name='bar')
)
ctx, zone, rset = \
@@ -822,29 +821,6 @@ class CentralZoneTestCase(CentralBasic):
self.assertTrue(rset.records[0].managed)
self.assertEqual('bar', rset.records[0].data.name)
- def test_add_ns_with_other_ns_rs(self):
- self.service._update_recordset_in_storage = mock.Mock()
-
- recordsets = [
- RoObject(records=objects.RecordList.from_list([]), managed=True),
- RoObject(records=objects.RecordList.from_list([]), managed=False)
- ]
-
- self.service.find_recordsets = mock.Mock(
- return_value=recordsets
- )
-
- self.service._add_ns(
- self.context,
- RoObject(id=CentralZoneTestCase.zone__id),
- RoObject(name='bar')
- )
- ctx, zone, rset = \
- self.service._update_recordset_in_storage.call_args[0]
- self.assertEqual(1, len(rset.records))
- self.assertTrue(rset.records[0].managed)
- self.assertEqual('bar', rset.records[0].data.name)
-
def test_create_zone_no_servers(self):
self.service._enforce_zone_quota = mock.Mock()
self.service._is_valid_zone_name = mock.Mock()