diff options
-rw-r--r-- | designate/central/service.py | 215 | ||||
-rw-r--r-- | designate/tests/__init__.py | 26 | ||||
-rw-r--r-- | designate/tests/test_api/test_v2/test_import_export.py | 6 | ||||
-rw-r--r-- | designate/tests/test_central/test_service.py | 24 | ||||
-rw-r--r-- | designate/tests/unit/test_central/test_basic.py | 40 |
5 files changed, 157 insertions, 154 deletions
diff --git a/designate/central/service.py b/designate/central/service.py index fe3fb04c..99654e30 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -26,9 +26,8 @@ import random from random import SystemRandom import time -from eventlet import tpool -from dns import zone as dnszone from dns import exception as dnsexception +from dns import zone as dnszone from oslo_config import cfg import oslo_messaging as messaging from oslo_log import log as logging @@ -561,51 +560,55 @@ class Service(service.RPCService): objects.Record(data=r, managed=True) for r in ns_records]) values = { 'name': zone['name'], - 'type': "NS", + 'type': 'NS', 'records': recordlist } ns, zone = self._create_recordset_in_storage( context, zone, objects.RecordSet(**values), - increment_serial=False) + increment_serial=False + ) return ns def _add_ns(self, context, zone, ns_record): # Get NS recordset # If the zone doesn't have an NS recordset yet, create one - recordsets = self.find_recordsets( - context, criterion={'zone_id': zone['id'], 'type': "NS"} - ) - - managed = [] - for rs in recordsets: - if rs.managed: - managed.append(rs) - - if len(managed) == 0: + try: + recordset = self.find_recordset( + context, + criterion={ + 'zone_id': zone['id'], + 'name': zone['name'], + 'type': 'NS' + } + ) + except exceptions.RecordSetNotFound: self._create_ns(context, zone, [ns_record]) return - elif len(managed) != 1: - raise exceptions.RecordSetNotFound("No valid recordset found") - - ns_recordset = managed[0] # Add new record to recordset based on the new nameserver - ns_recordset.records.append( - objects.Record(data=ns_record, managed=True)) + recordset.records.append( + objects.Record(data=ns_record, managed=True) + ) - self._update_recordset_in_storage(context, zone, ns_recordset, + self._update_recordset_in_storage(context, zone, recordset, set_delayed_notify=True) def _delete_ns(self, context, zone, ns_record): - ns_recordset = self.find_recordset( - context, criterion={'zone_id': zone['id'], 'type': "NS"}) + recordset = self.find_recordset( + context, + criterion={ + 'zone_id': zone['id'], + 'name': zone['name'], + 'type': 'NS' + } + ) - for record in copy.deepcopy(ns_recordset.records): + for record in list(recordset.records): if record.data == ns_record: - ns_recordset.records.remove(record) + recordset.records.remove(record) - self._update_recordset_in_storage(context, zone, ns_recordset, + self._update_recordset_in_storage(context, zone, recordset, set_delayed_notify=True) # Quota Enforcement Methods @@ -2505,46 +2508,49 @@ class Service(service.RPCService): @notification('dns.pool.update') @transaction def update_pool(self, context, pool): - policy.check('update_pool', context) # If there is a nameserver, then additional steps need to be done # Since these are treated as mutable objects, we're only going to # be comparing the nameserver.value which is the FQDN - if pool.obj_attr_is_set('ns_records'): - elevated_context = context.elevated(all_tenants=True) + elevated_context = context.elevated(all_tenants=True) - # TODO(kiall): ListObjects should be able to give you their - # original set of values. - original_pool_ns_records = self._get_pool_ns_records(context, - pool.id) - # Find the current NS hostnames - existing_ns = set([n.hostname for n in original_pool_ns_records]) + # TODO(kiall): ListObjects should be able to give you their + # original set of values. + original_pool_ns_records = self._get_pool_ns_records( + context, pool.id + ) - # Find the desired NS hostnames - request_ns = set([n.hostname for n in pool.ns_records]) + updated_pool = self.storage.update_pool(context, pool) - # Get the NS's to be created and deleted, ignoring the ones that - # are in both sets, as those haven't changed. - # TODO(kiall): Factor in priority - create_ns = request_ns.difference(existing_ns) - delete_ns = existing_ns.difference(request_ns) + if not pool.obj_attr_is_set('ns_records'): + return updated_pool - updated_pool = self.storage.update_pool(context, pool) + # Find the current NS hostnames + existing_ns = set([n.hostname for n in original_pool_ns_records]) + + # Find the desired NS hostnames + request_ns = set([n.hostname for n in pool.ns_records]) + + # Get the NS's to be created and deleted, ignoring the ones that + # are in both sets, as those haven't changed. + # TODO(kiall): Factor in priority + create_ns = request_ns.difference(existing_ns) + delete_ns = existing_ns.difference(request_ns) # After the update, handle new ns_records - for ns in create_ns: + for ns_record in create_ns: # Create new NS recordsets for every zone zones = self.find_zones( context=elevated_context, criterion={'pool_id': pool.id, 'action': '!DELETE'}) - for z in zones: - self._add_ns(elevated_context, z, ns) + for zone in zones: + self._add_ns(elevated_context, zone, ns_record) # Then handle the ns_records to delete - for ns in delete_ns: + for ns_record in delete_ns: # Cannot delete the last nameserver, so verify that first. - if len(pool.ns_records) == 0: + if not pool.ns_records: raise exceptions.LastServerDeleteNotAllowed( "Not allowed to delete last of servers" ) @@ -2552,9 +2558,10 @@ class Service(service.RPCService): # Delete the NS record for every zone zones = self.find_zones( context=elevated_context, - criterion={'pool_id': pool.id}) - for z in zones: - self._delete_ns(elevated_context, z, ns) + criterion={'pool_id': pool.id} + ) + for zone in zones: + self._delete_ns(elevated_context, zone, ns_record) return updated_pool @@ -2990,7 +2997,6 @@ class Service(service.RPCService): @rpc.expected_exceptions() @notification('dns.zone_import.create') def create_zone_import(self, context, request_body): - if policy.enforce_new_defaults(): target = {constants.RBAC_PROJECT_ID: context.project_id} else: @@ -3010,59 +3016,49 @@ class Service(service.RPCService): zone_import = objects.ZoneImport(**values) created_zone_import = self.storage.create_zone_import(context, - zone_import) + zone_import) self.tg.add_thread(self._import_zone, context, created_zone_import, - request_body) + request_body) return created_zone_import def _import_zone(self, context, zone_import, request_body): - - def _import(self, context, zone_import, request_body): - # Dnspython needs a str instead of a unicode object - zone = None - try: - dnspython_zone = dnszone.from_text( - request_body, - # Don't relativize, or we end up with '@' record names. - relativize=False, - # Don't check origin, we allow missing NS records - # (missing SOA records are taken care of in _create_zone). - check_origin=False) - zone = dnsutils.from_dnspython_zone(dnspython_zone) - zone.type = 'PRIMARY' - - for rrset in list(zone.recordsets): - if rrset.type == 'SOA': - zone.recordsets.remove(rrset) - # subdomain NS records should be kept - elif rrset.type == 'NS' and rrset.name == zone.name: - zone.recordsets.remove(rrset) - - except dnszone.UnknownOrigin: - zone_import.message = ('The $ORIGIN statement is required and' - ' must be the first statement in the' - ' zonefile.') - zone_import.status = 'ERROR' - except dnsexception.SyntaxError: - zone_import.message = 'Malformed zonefile.' - zone_import.status = 'ERROR' - except exceptions.BadRequest: - zone_import.message = 'An SOA record is required.' - zone_import.status = 'ERROR' - except Exception as e: - LOG.exception('An undefined error occurred during zone import') - msg = 'An undefined error occurred. %s'\ - % str(e)[:130] - zone_import.message = msg - zone_import.status = 'ERROR' - - return zone, zone_import - - # Execute the import in a real Python thread - zone, zone_import = tpool.execute(_import, self, context, - zone_import, request_body) + zone = None + try: + dnspython_zone = dnszone.from_text( + request_body, + # Don't relativize, or we end up with '@' record names. + relativize=False, + # Don't check origin, we allow missing NS records + # (missing SOA records are taken care of in _create_zone). + check_origin=False) + zone = dnsutils.from_dnspython_zone(dnspython_zone) + zone.type = 'PRIMARY' + for rrset in list(zone.recordsets): + if rrset.type == 'SOA': + zone.recordsets.remove(rrset) + # subdomain NS records should be kept + elif rrset.type == 'NS' and rrset.name == zone.name: + zone.recordsets.remove(rrset) + except dnszone.UnknownOrigin: + zone_import.message = ( + 'The $ORIGIN statement is required and must be the first ' + 'statement in the zonefile.' + ) + zone_import.status = 'ERROR' + except dnsexception.SyntaxError: + zone_import.message = 'Malformed zonefile.' + zone_import.status = 'ERROR' + except exceptions.BadRequest: + zone_import.message = 'An SOA record is required.' + zone_import.status = 'ERROR' + except Exception as e: + LOG.exception('An undefined error occurred during zone import') + zone_import.message = ( + 'An undefined error occurred. %s' % str(e)[:130] + ) + zone_import.status = 'ERROR' # If the zone import was valid, create the zone if zone_import.status != 'ERROR': @@ -3070,27 +3066,32 @@ class Service(service.RPCService): zone = self.create_zone(context, zone) zone_import.status = 'COMPLETE' zone_import.zone_id = zone.id - zone_import.message = '%(name)s imported' % {'name': - zone.name} + zone_import.message = ( + '%(name)s imported' % {'name': zone.name} + ) except exceptions.DuplicateZone: zone_import.status = 'ERROR' zone_import.message = 'Duplicate zone.' except exceptions.InvalidTTL as e: zone_import.status = 'ERROR' zone_import.message = str(e) + except exceptions.OverQuota: + zone_import.status = 'ERROR' + zone_import.message = 'Quota exceeded during zone import.' except Exception as e: - LOG.exception('An undefined error occurred during zone ' - 'import creation') - msg = 'An undefined error occurred. %s'\ - % str(e)[:130] - zone_import.message = msg + LOG.exception( + 'An undefined error occurred during zone import creation' + ) + zone_import.message = ( + 'An undefined error occurred. %s' % str(e)[:130] + ) zone_import.status = 'ERROR' self.update_zone_import(context, zone_import) @rpc.expected_exceptions() def find_zone_imports(self, context, criterion=None, marker=None, - limit=None, sort_key=None, sort_dir=None): + limit=None, sort_key=None, sort_dir=None): if policy.enforce_new_defaults(): target = {constants.RBAC_PROJECT_ID: context.project_id} diff --git a/designate/tests/__init__.py b/designate/tests/__init__.py index 07bf510d..24c7beaa 100644 --- a/designate/tests/__init__.py +++ b/designate/tests/__init__.py @@ -786,34 +786,36 @@ class TestCase(base.BaseTestCase): return self.storage.create_zone_export( context, objects.ZoneExport.from_dict(zone_export)) - def wait_for_import(self, zone_import_id, errorok=False): + def wait_for_import(self, zone_import_id, error_is_ok=False, max_wait=10): """ Zone imports spawn a thread to parse the zone file and insert the data. This waits for this process before continuing """ - attempts = 0 - while attempts < 20: - # Give the import a half second to complete - time.sleep(.5) - + start_time = time.time() + while True: # Retrieve it, and ensure it's the same zone_import = self.central_service.get_zone_import( - self.admin_context_all_tenants, zone_import_id) + self.admin_context_all_tenants, zone_import_id + ) # If the import is done, we're done if zone_import.status == 'COMPLETE': break # If errors are allowed, just make sure that something completed - if errorok: - if zone_import.status != 'PENDING': - break + if error_is_ok and zone_import.status != 'PENDING': + break - attempts += 1 + if (time.time() - start_time) > max_wait: + break - if not errorok: + time.sleep(0.5) + + if not error_is_ok: self.assertEqual('COMPLETE', zone_import.status) + return zone_import + def _ensure_interface(self, interface, implementation): for name in interface.__abstractmethods__: in_arginfo = inspect.getfullargspec(getattr(interface, name)) diff --git a/designate/tests/test_api/test_v2/test_import_export.py b/designate/tests/test_api/test_v2/test_import_export.py index aefc6e1d..b9e491e0 100644 --- a/designate/tests/test_api/test_v2/test_import_export.py +++ b/designate/tests/test_api/test_v2/test_import_export.py @@ -53,7 +53,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase): headers={'Content-type': 'text/dns'}) import_id = response.json_body['id'] - self.wait_for_import(import_id, errorok=True) + self.wait_for_import(import_id, error_is_ok=True) url = '/zones/tasks/imports/%s' % import_id @@ -70,7 +70,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase): headers={'Content-type': 'text/dns'}) import_id = response.json_body['id'] - self.wait_for_import(import_id, errorok=True) + self.wait_for_import(import_id, error_is_ok=True) url = '/zones/tasks/imports/%s' % import_id @@ -86,7 +86,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase): headers={'Content-type': 'text/dns'}) import_id = response.json_body['id'] - self.wait_for_import(import_id, errorok=True) + self.wait_for_import(import_id, error_is_ok=True) url = '/zones/tasks/imports/%s' % import_id diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index 5db47b4f..7ff36c1f 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -3548,6 +3548,30 @@ class CentralServiceTest(CentralTestCase): self.wait_for_import(zone_import.id) + def test_create_zone_import_overquota(self): + self.config( + quota_zone_records=5, + quota_zone_recordsets=5, + ) + + # Create a Zone Import + context = self.get_context(project_id=utils.generate_uuid()) + request_body = self.get_zonefile_fixture() + zone_import = self.central_service.create_zone_import(context, + request_body) + + # Ensure all values have been set correctly + self.assertIsNotNone(zone_import['id']) + self.assertEqual('PENDING', zone_import.status) + self.assertIsNone(zone_import.message) + self.assertIsNone(zone_import.zone_id) + + zone_import = self.wait_for_import(zone_import.id, error_is_ok=True) + + self.assertEqual('Quota exceeded during zone import.', + zone_import.message) + self.assertEqual('ERROR', zone_import.status) + def test_find_zone_imports(self): context = self.get_context(project_id=utils.generate_uuid()) diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py index e75922e3..d3bb6b41 100644 --- a/designate/tests/unit/test_central/test_basic.py +++ b/designate/tests/unit/test_central/test_basic.py @@ -789,13 +789,13 @@ class CentralZoneTestCase(CentralBasic): def test_add_ns_creation(self): self.service._create_ns = mock.Mock() - self.service.find_recordsets = mock.Mock( - return_value=[] + self.service.find_recordset = mock.Mock( + side_effect=exceptions.RecordSetNotFound() ) self.service._add_ns( self.context, - RoObject(id=CentralZoneTestCase.zone__id), + RoObject(name='foo', id=CentralZoneTestCase.zone__id), RoObject(name='bar') ) ctx, zone, records = self.service._create_ns.call_args[0] @@ -804,16 +804,15 @@ class CentralZoneTestCase(CentralBasic): def test_add_ns(self): self.service._update_recordset_in_storage = mock.Mock() - recordsets = [ - RoObject(records=objects.RecordList.from_list([]), managed=True) - ] - self.service.find_recordsets = mock.Mock( - return_value=recordsets + self.service.find_recordset = mock.Mock( + return_value=RoObject( + records=objects.RecordList.from_list([]), managed=True + ) ) self.service._add_ns( self.context, - RoObject(id=CentralZoneTestCase.zone__id), + RoObject(name='foo', id=CentralZoneTestCase.zone__id), RoObject(name='bar') ) ctx, zone, rset = \ @@ -822,29 +821,6 @@ class CentralZoneTestCase(CentralBasic): self.assertTrue(rset.records[0].managed) self.assertEqual('bar', rset.records[0].data.name) - def test_add_ns_with_other_ns_rs(self): - self.service._update_recordset_in_storage = mock.Mock() - - recordsets = [ - RoObject(records=objects.RecordList.from_list([]), managed=True), - RoObject(records=objects.RecordList.from_list([]), managed=False) - ] - - self.service.find_recordsets = mock.Mock( - return_value=recordsets - ) - - self.service._add_ns( - self.context, - RoObject(id=CentralZoneTestCase.zone__id), - RoObject(name='bar') - ) - ctx, zone, rset = \ - self.service._update_recordset_in_storage.call_args[0] - self.assertEqual(1, len(rset.records)) - self.assertTrue(rset.records[0].managed) - self.assertEqual('bar', rset.records[0].data.name) - def test_create_zone_no_servers(self): self.service._enforce_zone_quota = mock.Mock() self.service._is_valid_zone_name = mock.Mock() |