diff options
-rw-r--r-- | heat/engine/clients/os/sahara.py | 23 | ||||
-rw-r--r-- | heat/tests/clients/test_sahara_client.py | 23 | ||||
-rw-r--r-- | heat/tests/openstack/sahara/test_job.py | 2 |
3 files changed, 43 insertions, 5 deletions
diff --git a/heat/engine/clients/os/sahara.py b/heat/engine/clients/os/sahara.py index 7f82b2e29..390f00f7f 100644 --- a/heat/engine/clients/os/sahara.py +++ b/heat/engine/clients/os/sahara.py @@ -139,6 +139,20 @@ class SaharaClientPlugin(client_plugin.ClientPlugin): raise exception.EntityNotFound(entity='Plugin', name=plugin_name) + def get_job_type(self, job_type): + """Find the job type + + :param job_type: the name of sahara job type to find + :returns: the name of :job_type: + :raises: exception.EntityNotFound + """ + try: + filters = {'name': job_type} + return self.client().job_types.find_unique(**filters) + except sahara_base.APIException: + raise exception.EntityNotFound(entity='Job Type', + name=job_type) + class SaharaBaseConstraint(constraints.BaseCustomConstraint): expected_exceptions = (exception.EntityNotFound, @@ -157,6 +171,11 @@ class PluginConstraint(constraints.BaseCustomConstraint): resource_getter_name = 'get_plugin_id' +class JobTypeConstraint(constraints.BaseCustomConstraint): + resource_client_name = CLIENT_NAME + resource_getter_name = 'get_job_type' + + class ImageConstraint(SaharaBaseConstraint): resource_name = 'images' @@ -175,7 +194,3 @@ class DataSourceConstraint(SaharaBaseConstraint): class ClusterTemplateConstraint(SaharaBaseConstraint): resource_name = 'cluster_templates' - - -class JobTypeConstraint(SaharaBaseConstraint): - resource_name = 'job_types' diff --git a/heat/tests/clients/test_sahara_client.py b/heat/tests/clients/test_sahara_client.py index 33189a459..a69450b7c 100644 --- a/heat/tests/clients/test_sahara_client.py +++ b/heat/tests/clients/test_sahara_client.py @@ -35,6 +35,7 @@ class SaharaUtilsTest(common.HeatTestCase): self.sahara_plugin.client = lambda: self.sahara_client self.my_image = mock.MagicMock() self.my_plugin = mock.MagicMock() + self.my_jobtype = mock.MagicMock() def test_get_image_id(self): """Tests the get_image_id function.""" @@ -155,12 +156,31 @@ class SaharaUtilsTest(common.HeatTestCase): calls = [mock.call(plugin_name), mock.call(plugin_name)] self.sahara_client.plugins.get.assert_has_calls(calls) + def test_get_job_type(self): + """Tests the get_job_type function.""" + job_type = 'myfakejobtype' + self.my_jobtype = job_type + + def side_effect(name): + if name == job_type: + return self.my_jobtype + else: + raise sahara_base.APIException(error_code=404, + error_name='NOT_FOUND') + + self.sahara_client.job_types.find_unique.side_effect = side_effect + self.assertEqual(self.sahara_plugin.get_job_type(job_type), job_type) + self.assertRaises(exception.EntityNotFound, + self.sahara_plugin.get_job_type, 'nojobtype') + calls = [mock.call(name=job_type), mock.call(name='nojobtype')] + self.sahara_client.job_types.find_unique.assert_has_calls(calls) + class SaharaConstraintsTest(common.HeatTestCase): scenarios = [ ('JobType', dict( constraint=sahara.JobTypeConstraint(), - resource_name='job_types' + resource_name=None )), ('ClusterTemplate', dict( constraint=sahara.ClusterTemplateConstraint(), @@ -196,6 +216,7 @@ class SaharaConstraintsTest(common.HeatTestCase): cl_plgn.find_resource_by_name_or_id = self.mock_get cl_plgn.get_image_id = self.mock_get cl_plgn.get_plugin_id = self.mock_get + cl_plgn.get_job_type = self.mock_get def test_validation(self): self.mock_get.return_value = "fake_val" diff --git a/heat/tests/openstack/sahara/test_job.py b/heat/tests/openstack/sahara/test_job.py index 5cdc193a4..0566b8f45 100644 --- a/heat/tests/openstack/sahara/test_job.py +++ b/heat/tests/openstack/sahara/test_job.py @@ -70,7 +70,9 @@ class SaharaJobTest(common.HeatTestCase): value = mock.MagicMock(id='fake-resource-id') self.client.jobs.create.return_value = value mock_get_res = mock.Mock(return_value='some res id') + mock_get_type = mock.Mock(return_value='MapReduce') jb.client_plugin().find_resource_by_name_or_id = mock_get_res + jb.client_plugin().get_job_type = mock_get_type scheduler.TaskRunner(jb.create)() return jb |