summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--designate/central/service.py49
-rw-r--r--designate/tests/unit/test_central/test_basic.py42
2 files changed, 65 insertions, 26 deletions
diff --git a/designate/central/service.py b/designate/central/service.py
index 53c9828f..94c2fb2f 100644
--- a/designate/central/service.py
+++ b/designate/central/service.py
@@ -1301,26 +1301,44 @@ class Service(service.RPCService, service.Service):
return recordset
- @transaction
- def _create_recordset_in_storage(self, context, zone, recordset,
- increment_serial=True):
+ def _validate_recordset(self, context, zone, recordset):
- # Ensure the tenant has enough quota to continue
- self._enforce_recordset_quota(context, zone)
+ # See if we're validating an existing or new recordset
+ recordset_id = None
+ if hasattr(recordset, 'id'):
+ recordset_id = recordset.id
# Ensure TTL is above the minimum
- ttl = getattr(recordset, 'ttl', None)
+ if not recordset_id:
+ ttl = getattr(recordset, 'ttl', None)
+ else:
+ changes = recordset.obj_get_changes()
+ ttl = changes.get('ttl', None)
+
if ttl is not None:
self._is_valid_ttl(context, ttl)
# Ensure the recordset name and placement is valid
self._is_valid_recordset_name(context, zone, recordset.name)
- self._is_valid_recordset_placement(context, zone, recordset.name,
- recordset.type)
+
+ self._is_valid_recordset_placement(
+ context, zone, recordset.name, recordset.type, recordset_id)
+
self._is_valid_recordset_placement_subzone(
context, zone, recordset.name)
+
+ # Validate the records
self._is_valid_recordset_records(recordset)
+ @transaction
+ def _create_recordset_in_storage(self, context, zone, recordset,
+ increment_serial=True):
+
+ # Ensure the tenant has enough quota to continue
+ self._enforce_recordset_quota(context, zone)
+
+ self._validate_recordset(context, zone, recordset)
+
if recordset.obj_attr_is_set('records') and len(recordset.records) > 0:
# Ensure the tenant has enough zone record quotas to
@@ -1441,20 +1459,7 @@ class Service(service.RPCService, service.Service):
def _update_recordset_in_storage(self, context, zone, recordset,
increment_serial=True, set_delayed_notify=False):
- changes = recordset.obj_get_changes()
-
- # Ensure the record name is valid
- self._is_valid_recordset_name(context, zone, recordset.name)
- self._is_valid_recordset_placement(context, zone, recordset.name,
- recordset.type, recordset.id)
- self._is_valid_recordset_placement_subzone(
- context, zone, recordset.name)
- self._is_valid_recordset_records(recordset)
-
- # Ensure TTL is above the minimum
- ttl = changes.get('ttl', None)
- if ttl is not None:
- self._is_valid_ttl(context, ttl)
+ self._validate_recordset(context, zone, recordset)
if increment_serial:
# update the zone's status and increment the serial
diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py
index 58c13f4b..97806c1a 100644
--- a/designate/tests/unit/test_central/test_basic.py
+++ b/designate/tests/unit/test_central/test_basic.py
@@ -384,14 +384,46 @@ class CentralServiceTestCase(CentralBasic):
self.context, {'a': 1}
)
- def test_create_recordset_in_storage(self):
- self.service._enforce_recordset_quota = mock.Mock()
+ def test_validate_new_recordset(self):
+ self.service._is_valid_recordset_name = mock.Mock()
+ self.service._is_valid_recordset_placement = mock.Mock()
+ self.service._is_valid_recordset_placement_subzone = mock.Mock()
self.service._is_valid_ttl = mock.Mock()
+
+ MockRecordSet.id = None
+
+ self.service._validate_recordset(
+ self.context, Mockzone, MockRecordSet
+ )
+
+ assert self.service._is_valid_recordset_name.called
+ assert self.service._is_valid_recordset_placement.called
+ assert self.service._is_valid_recordset_placement_subzone.called
+ assert self.service._is_valid_ttl.called
+
+ def test_validate_existing_recordset(self):
self.service._is_valid_recordset_name = mock.Mock()
self.service._is_valid_recordset_placement = mock.Mock()
self.service._is_valid_recordset_placement_subzone = mock.Mock()
- self.service.storage.create_recordset = mock.Mock(return_value='rs')
- self.service._update_zone_in_storage = mock.Mock()
+ self.service._is_valid_ttl = mock.Mock()
+
+ MockRecordSet.obj_get_changes = Mock(return_value={'ttl': 3600})
+
+ self.service._validate_recordset(
+ self.context, Mockzone, MockRecordSet
+ )
+
+ assert self.service._is_valid_recordset_name.called
+ assert self.service._is_valid_recordset_placement.called
+ assert self.service._is_valid_recordset_placement_subzone.called
+ assert self.service._is_valid_ttl.called
+
+ def test_create_recordset_in_storage(self):
+ self.service._enforce_recordset_quota = Mock()
+ self.service._validate_recordset = mock.Mock()
+
+ self.service.storage.create_recordset = Mock(return_value='rs')
+ self.service._update_zone_in_storage = Mock()
rs, zone = self.service._create_recordset_in_storage(
self.context, Mockzone(), MockRecordSet()
@@ -1166,6 +1198,7 @@ class CentralzoneTestCase(CentralBasic):
recordset.type = 't'
recordset.id = 'i'
recordset.obj_get_changes.return_value = {'ttl': 90}
+ recordset.ttl = 90
recordset.records = []
self.service._is_valid_recordset_name = Mock()
self.service._is_valid_recordset_placement = Mock()
@@ -1204,6 +1237,7 @@ class CentralzoneTestCase(CentralBasic):
recordset.name = 'n'
recordset.type = 't'
recordset.id = 'i'
+ recordset.ttl = None
recordset.obj_get_changes.return_value = {'ttl': None}
recordset.records = [RwObject(
action='a',