summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--designate/central/service.py106
-rw-r--r--designate/tests/unit/test_central/test_basic.py40
2 files changed, 65 insertions, 81 deletions
diff --git a/designate/central/service.py b/designate/central/service.py
index fe3fb04c..cc1539fa 100644
--- a/designate/central/service.py
+++ b/designate/central/service.py
@@ -561,51 +561,55 @@ class Service(service.RPCService):
objects.Record(data=r, managed=True) for r in ns_records])
values = {
'name': zone['name'],
- 'type': "NS",
+ 'type': 'NS',
'records': recordlist
}
ns, zone = self._create_recordset_in_storage(
context, zone, objects.RecordSet(**values),
- increment_serial=False)
+ increment_serial=False
+ )
return ns
def _add_ns(self, context, zone, ns_record):
# Get NS recordset
# If the zone doesn't have an NS recordset yet, create one
- recordsets = self.find_recordsets(
- context, criterion={'zone_id': zone['id'], 'type': "NS"}
- )
-
- managed = []
- for rs in recordsets:
- if rs.managed:
- managed.append(rs)
-
- if len(managed) == 0:
+ try:
+ recordset = self.find_recordset(
+ context,
+ criterion={
+ 'zone_id': zone['id'],
+ 'name': zone['name'],
+ 'type': 'NS'
+ }
+ )
+ except exceptions.RecordSetNotFound:
self._create_ns(context, zone, [ns_record])
return
- elif len(managed) != 1:
- raise exceptions.RecordSetNotFound("No valid recordset found")
-
- ns_recordset = managed[0]
# Add new record to recordset based on the new nameserver
- ns_recordset.records.append(
- objects.Record(data=ns_record, managed=True))
+ recordset.records.append(
+ objects.Record(data=ns_record, managed=True)
+ )
- self._update_recordset_in_storage(context, zone, ns_recordset,
+ self._update_recordset_in_storage(context, zone, recordset,
set_delayed_notify=True)
def _delete_ns(self, context, zone, ns_record):
- ns_recordset = self.find_recordset(
- context, criterion={'zone_id': zone['id'], 'type': "NS"})
+ recordset = self.find_recordset(
+ context,
+ criterion={
+ 'zone_id': zone['id'],
+ 'name': zone['name'],
+ 'type': 'NS'
+ }
+ )
- for record in copy.deepcopy(ns_recordset.records):
+ for record in list(recordset.records):
if record.data == ns_record:
- ns_recordset.records.remove(record)
+ recordset.records.remove(record)
- self._update_recordset_in_storage(context, zone, ns_recordset,
+ self._update_recordset_in_storage(context, zone, recordset,
set_delayed_notify=True)
# Quota Enforcement Methods
@@ -2505,46 +2509,49 @@ class Service(service.RPCService):
@notification('dns.pool.update')
@transaction
def update_pool(self, context, pool):
-
policy.check('update_pool', context)
# If there is a nameserver, then additional steps need to be done
# Since these are treated as mutable objects, we're only going to
# be comparing the nameserver.value which is the FQDN
- if pool.obj_attr_is_set('ns_records'):
- elevated_context = context.elevated(all_tenants=True)
+ elevated_context = context.elevated(all_tenants=True)
- # TODO(kiall): ListObjects should be able to give you their
- # original set of values.
- original_pool_ns_records = self._get_pool_ns_records(context,
- pool.id)
- # Find the current NS hostnames
- existing_ns = set([n.hostname for n in original_pool_ns_records])
+ # TODO(kiall): ListObjects should be able to give you their
+ # original set of values.
+ original_pool_ns_records = self._get_pool_ns_records(
+ context, pool.id
+ )
- # Find the desired NS hostnames
- request_ns = set([n.hostname for n in pool.ns_records])
+ updated_pool = self.storage.update_pool(context, pool)
- # Get the NS's to be created and deleted, ignoring the ones that
- # are in both sets, as those haven't changed.
- # TODO(kiall): Factor in priority
- create_ns = request_ns.difference(existing_ns)
- delete_ns = existing_ns.difference(request_ns)
+ if not pool.obj_attr_is_set('ns_records'):
+ return updated_pool
- updated_pool = self.storage.update_pool(context, pool)
+ # Find the current NS hostnames
+ existing_ns = set([n.hostname for n in original_pool_ns_records])
+
+ # Find the desired NS hostnames
+ request_ns = set([n.hostname for n in pool.ns_records])
+
+ # Get the NS's to be created and deleted, ignoring the ones that
+ # are in both sets, as those haven't changed.
+ # TODO(kiall): Factor in priority
+ create_ns = request_ns.difference(existing_ns)
+ delete_ns = existing_ns.difference(request_ns)
# After the update, handle new ns_records
- for ns in create_ns:
+ for ns_record in create_ns:
# Create new NS recordsets for every zone
zones = self.find_zones(
context=elevated_context,
criterion={'pool_id': pool.id, 'action': '!DELETE'})
- for z in zones:
- self._add_ns(elevated_context, z, ns)
+ for zone in zones:
+ self._add_ns(elevated_context, zone, ns_record)
# Then handle the ns_records to delete
- for ns in delete_ns:
+ for ns_record in delete_ns:
# Cannot delete the last nameserver, so verify that first.
- if len(pool.ns_records) == 0:
+ if not pool.ns_records:
raise exceptions.LastServerDeleteNotAllowed(
"Not allowed to delete last of servers"
)
@@ -2552,9 +2559,10 @@ class Service(service.RPCService):
# Delete the NS record for every zone
zones = self.find_zones(
context=elevated_context,
- criterion={'pool_id': pool.id})
- for z in zones:
- self._delete_ns(elevated_context, z, ns)
+ criterion={'pool_id': pool.id}
+ )
+ for zone in zones:
+ self._delete_ns(elevated_context, zone, ns_record)
return updated_pool
diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py
index e75922e3..d3bb6b41 100644
--- a/designate/tests/unit/test_central/test_basic.py
+++ b/designate/tests/unit/test_central/test_basic.py
@@ -789,13 +789,13 @@ class CentralZoneTestCase(CentralBasic):
def test_add_ns_creation(self):
self.service._create_ns = mock.Mock()
- self.service.find_recordsets = mock.Mock(
- return_value=[]
+ self.service.find_recordset = mock.Mock(
+ side_effect=exceptions.RecordSetNotFound()
)
self.service._add_ns(
self.context,
- RoObject(id=CentralZoneTestCase.zone__id),
+ RoObject(name='foo', id=CentralZoneTestCase.zone__id),
RoObject(name='bar')
)
ctx, zone, records = self.service._create_ns.call_args[0]
@@ -804,16 +804,15 @@ class CentralZoneTestCase(CentralBasic):
def test_add_ns(self):
self.service._update_recordset_in_storage = mock.Mock()
- recordsets = [
- RoObject(records=objects.RecordList.from_list([]), managed=True)
- ]
- self.service.find_recordsets = mock.Mock(
- return_value=recordsets
+ self.service.find_recordset = mock.Mock(
+ return_value=RoObject(
+ records=objects.RecordList.from_list([]), managed=True
+ )
)
self.service._add_ns(
self.context,
- RoObject(id=CentralZoneTestCase.zone__id),
+ RoObject(name='foo', id=CentralZoneTestCase.zone__id),
RoObject(name='bar')
)
ctx, zone, rset = \
@@ -822,29 +821,6 @@ class CentralZoneTestCase(CentralBasic):
self.assertTrue(rset.records[0].managed)
self.assertEqual('bar', rset.records[0].data.name)
- def test_add_ns_with_other_ns_rs(self):
- self.service._update_recordset_in_storage = mock.Mock()
-
- recordsets = [
- RoObject(records=objects.RecordList.from_list([]), managed=True),
- RoObject(records=objects.RecordList.from_list([]), managed=False)
- ]
-
- self.service.find_recordsets = mock.Mock(
- return_value=recordsets
- )
-
- self.service._add_ns(
- self.context,
- RoObject(id=CentralZoneTestCase.zone__id),
- RoObject(name='bar')
- )
- ctx, zone, rset = \
- self.service._update_recordset_in_storage.call_args[0]
- self.assertEqual(1, len(rset.records))
- self.assertTrue(rset.records[0].managed)
- self.assertEqual('bar', rset.records[0].data.name)
-
def test_create_zone_no_servers(self):
self.service._enforce_zone_quota = mock.Mock()
self.service._is_valid_zone_name = mock.Mock()