summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--heat/engine/clients/os/sahara.py23
-rw-r--r--heat/tests/clients/test_sahara_client.py23
-rw-r--r--heat/tests/openstack/sahara/test_job.py2
3 files changed, 43 insertions, 5 deletions
diff --git a/heat/engine/clients/os/sahara.py b/heat/engine/clients/os/sahara.py
index 7f82b2e29..390f00f7f 100644
--- a/heat/engine/clients/os/sahara.py
+++ b/heat/engine/clients/os/sahara.py
@@ -139,6 +139,20 @@ class SaharaClientPlugin(client_plugin.ClientPlugin):
raise exception.EntityNotFound(entity='Plugin',
name=plugin_name)
+ def get_job_type(self, job_type):
+ """Find the job type
+
+ :param job_type: the name of sahara job type to find
+ :returns: the name of :job_type:
+ :raises: exception.EntityNotFound
+ """
+ try:
+ filters = {'name': job_type}
+ return self.client().job_types.find_unique(**filters)
+ except sahara_base.APIException:
+ raise exception.EntityNotFound(entity='Job Type',
+ name=job_type)
+
class SaharaBaseConstraint(constraints.BaseCustomConstraint):
expected_exceptions = (exception.EntityNotFound,
@@ -157,6 +171,11 @@ class PluginConstraint(constraints.BaseCustomConstraint):
resource_getter_name = 'get_plugin_id'
+class JobTypeConstraint(constraints.BaseCustomConstraint):
+ resource_client_name = CLIENT_NAME
+ resource_getter_name = 'get_job_type'
+
+
class ImageConstraint(SaharaBaseConstraint):
resource_name = 'images'
@@ -175,7 +194,3 @@ class DataSourceConstraint(SaharaBaseConstraint):
class ClusterTemplateConstraint(SaharaBaseConstraint):
resource_name = 'cluster_templates'
-
-
-class JobTypeConstraint(SaharaBaseConstraint):
- resource_name = 'job_types'
diff --git a/heat/tests/clients/test_sahara_client.py b/heat/tests/clients/test_sahara_client.py
index 33189a459..a69450b7c 100644
--- a/heat/tests/clients/test_sahara_client.py
+++ b/heat/tests/clients/test_sahara_client.py
@@ -35,6 +35,7 @@ class SaharaUtilsTest(common.HeatTestCase):
self.sahara_plugin.client = lambda: self.sahara_client
self.my_image = mock.MagicMock()
self.my_plugin = mock.MagicMock()
+ self.my_jobtype = mock.MagicMock()
def test_get_image_id(self):
"""Tests the get_image_id function."""
@@ -155,12 +156,31 @@ class SaharaUtilsTest(common.HeatTestCase):
calls = [mock.call(plugin_name), mock.call(plugin_name)]
self.sahara_client.plugins.get.assert_has_calls(calls)
+ def test_get_job_type(self):
+ """Tests the get_job_type function."""
+ job_type = 'myfakejobtype'
+ self.my_jobtype = job_type
+
+ def side_effect(name):
+ if name == job_type:
+ return self.my_jobtype
+ else:
+ raise sahara_base.APIException(error_code=404,
+ error_name='NOT_FOUND')
+
+ self.sahara_client.job_types.find_unique.side_effect = side_effect
+ self.assertEqual(self.sahara_plugin.get_job_type(job_type), job_type)
+ self.assertRaises(exception.EntityNotFound,
+ self.sahara_plugin.get_job_type, 'nojobtype')
+ calls = [mock.call(name=job_type), mock.call(name='nojobtype')]
+ self.sahara_client.job_types.find_unique.assert_has_calls(calls)
+
class SaharaConstraintsTest(common.HeatTestCase):
scenarios = [
('JobType', dict(
constraint=sahara.JobTypeConstraint(),
- resource_name='job_types'
+ resource_name=None
)),
('ClusterTemplate', dict(
constraint=sahara.ClusterTemplateConstraint(),
@@ -196,6 +216,7 @@ class SaharaConstraintsTest(common.HeatTestCase):
cl_plgn.find_resource_by_name_or_id = self.mock_get
cl_plgn.get_image_id = self.mock_get
cl_plgn.get_plugin_id = self.mock_get
+ cl_plgn.get_job_type = self.mock_get
def test_validation(self):
self.mock_get.return_value = "fake_val"
diff --git a/heat/tests/openstack/sahara/test_job.py b/heat/tests/openstack/sahara/test_job.py
index 5cdc193a4..0566b8f45 100644
--- a/heat/tests/openstack/sahara/test_job.py
+++ b/heat/tests/openstack/sahara/test_job.py
@@ -70,7 +70,9 @@ class SaharaJobTest(common.HeatTestCase):
value = mock.MagicMock(id='fake-resource-id')
self.client.jobs.create.return_value = value
mock_get_res = mock.Mock(return_value='some res id')
+ mock_get_type = mock.Mock(return_value='MapReduce')
jb.client_plugin().find_resource_by_name_or_id = mock_get_res
+ jb.client_plugin().get_job_type = mock_get_type
scheduler.TaskRunner(jb.create)()
return jb