diff options
-rw-r--r-- | designate/central/service.py | 49 | ||||
-rw-r--r-- | designate/tests/unit/test_central/test_basic.py | 42 |
2 files changed, 65 insertions, 26 deletions
diff --git a/designate/central/service.py b/designate/central/service.py index 53c9828f..94c2fb2f 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -1301,26 +1301,44 @@ class Service(service.RPCService, service.Service): return recordset - @transaction - def _create_recordset_in_storage(self, context, zone, recordset, - increment_serial=True): + def _validate_recordset(self, context, zone, recordset): - # Ensure the tenant has enough quota to continue - self._enforce_recordset_quota(context, zone) + # See if we're validating an existing or new recordset + recordset_id = None + if hasattr(recordset, 'id'): + recordset_id = recordset.id # Ensure TTL is above the minimum - ttl = getattr(recordset, 'ttl', None) + if not recordset_id: + ttl = getattr(recordset, 'ttl', None) + else: + changes = recordset.obj_get_changes() + ttl = changes.get('ttl', None) + if ttl is not None: self._is_valid_ttl(context, ttl) # Ensure the recordset name and placement is valid self._is_valid_recordset_name(context, zone, recordset.name) - self._is_valid_recordset_placement(context, zone, recordset.name, - recordset.type) + + self._is_valid_recordset_placement( + context, zone, recordset.name, recordset.type, recordset_id) + self._is_valid_recordset_placement_subzone( context, zone, recordset.name) + + # Validate the records self._is_valid_recordset_records(recordset) + @transaction + def _create_recordset_in_storage(self, context, zone, recordset, + increment_serial=True): + + # Ensure the tenant has enough quota to continue + self._enforce_recordset_quota(context, zone) + + self._validate_recordset(context, zone, recordset) + if recordset.obj_attr_is_set('records') and len(recordset.records) > 0: # Ensure the tenant has enough zone record quotas to @@ -1441,20 +1459,7 @@ class Service(service.RPCService, service.Service): def _update_recordset_in_storage(self, context, zone, recordset, increment_serial=True, set_delayed_notify=False): - changes = recordset.obj_get_changes() - - # Ensure the record name is valid - self._is_valid_recordset_name(context, zone, recordset.name) - self._is_valid_recordset_placement(context, zone, recordset.name, - recordset.type, recordset.id) - self._is_valid_recordset_placement_subzone( - context, zone, recordset.name) - self._is_valid_recordset_records(recordset) - - # Ensure TTL is above the minimum - ttl = changes.get('ttl', None) - if ttl is not None: - self._is_valid_ttl(context, ttl) + self._validate_recordset(context, zone, recordset) if increment_serial: # update the zone's status and increment the serial diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py index 58c13f4b..97806c1a 100644 --- a/designate/tests/unit/test_central/test_basic.py +++ b/designate/tests/unit/test_central/test_basic.py @@ -384,14 +384,46 @@ class CentralServiceTestCase(CentralBasic): self.context, {'a': 1} ) - def test_create_recordset_in_storage(self): - self.service._enforce_recordset_quota = mock.Mock() + def test_validate_new_recordset(self): + self.service._is_valid_recordset_name = mock.Mock() + self.service._is_valid_recordset_placement = mock.Mock() + self.service._is_valid_recordset_placement_subzone = mock.Mock() self.service._is_valid_ttl = mock.Mock() + + MockRecordSet.id = None + + self.service._validate_recordset( + self.context, Mockzone, MockRecordSet + ) + + assert self.service._is_valid_recordset_name.called + assert self.service._is_valid_recordset_placement.called + assert self.service._is_valid_recordset_placement_subzone.called + assert self.service._is_valid_ttl.called + + def test_validate_existing_recordset(self): self.service._is_valid_recordset_name = mock.Mock() self.service._is_valid_recordset_placement = mock.Mock() self.service._is_valid_recordset_placement_subzone = mock.Mock() - self.service.storage.create_recordset = mock.Mock(return_value='rs') - self.service._update_zone_in_storage = mock.Mock() + self.service._is_valid_ttl = mock.Mock() + + MockRecordSet.obj_get_changes = Mock(return_value={'ttl': 3600}) + + self.service._validate_recordset( + self.context, Mockzone, MockRecordSet + ) + + assert self.service._is_valid_recordset_name.called + assert self.service._is_valid_recordset_placement.called + assert self.service._is_valid_recordset_placement_subzone.called + assert self.service._is_valid_ttl.called + + def test_create_recordset_in_storage(self): + self.service._enforce_recordset_quota = Mock() + self.service._validate_recordset = mock.Mock() + + self.service.storage.create_recordset = Mock(return_value='rs') + self.service._update_zone_in_storage = Mock() rs, zone = self.service._create_recordset_in_storage( self.context, Mockzone(), MockRecordSet() @@ -1166,6 +1198,7 @@ class CentralzoneTestCase(CentralBasic): recordset.type = 't' recordset.id = 'i' recordset.obj_get_changes.return_value = {'ttl': 90} + recordset.ttl = 90 recordset.records = [] self.service._is_valid_recordset_name = Mock() self.service._is_valid_recordset_placement = Mock() @@ -1204,6 +1237,7 @@ class CentralzoneTestCase(CentralBasic): recordset.name = 'n' recordset.type = 't' recordset.id = 'i' + recordset.ttl = None recordset.obj_get_changes.return_value = {'ttl': None} recordset.records = [RwObject( action='a', |