From a762b3bf8d94ab2973f119d044618012e594392c Mon Sep 17 00:00:00 2001 From: Tony Zhang <33922283+dntczdx@users.noreply.github.com> Date: Tue, 21 May 2019 10:31:07 -0700 Subject: Retry download for metadata scripts (#771) --- .../metadata_scripts/script_retriever.py | 47 ++++++++++++--- .../tests/script_retriever_test.py | 70 +++++++++++++++++++++- 2 files changed, 107 insertions(+), 10 deletions(-) (limited to 'packages') diff --git a/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/script_retriever.py b/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/script_retriever.py index c678f99..5f92de2 100644 --- a/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/script_retriever.py +++ b/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/script_retriever.py @@ -15,11 +15,11 @@ """Retrieve and store user provided metadata scripts.""" -import ast +import functools import re import socket -import subprocess import tempfile +import time from google_compute_engine import metadata_watcher from google_compute_engine.compat import httpclient @@ -28,6 +28,37 @@ from google_compute_engine.compat import urlrequest from google_compute_engine.compat import urlretrieve +def _RetryOnUnavailable(func): + """Function decorator template to retry on a service unavailable exception.""" + + @functools.wraps(func) + def Wrapper(*args, **kwargs): + final_exception = None + for _ in range(3): + try: + response = func(*args, **kwargs) + except (httpclient.HTTPException, socket.error, urlerror.URLError) as e: + final_exception = e + time.sleep(5) + continue + else: + return response + raise final_exception + return Wrapper + + +@_RetryOnUnavailable +def _UrlOpenWithRetry(request): + """Call urlopen with retry.""" + return urlrequest.urlopen(request) + + +@_RetryOnUnavailable +def _UrlRetrieveWithRetry(url, dest): + """Call urlretrieve with retry.""" + return urlretrieve.urlretrieve(url, dest) + + class ScriptRetriever(object): """A class for retrieving and storing user provided metadata scripts.""" token_metadata_key = 'instance/service-accounts/default/token' @@ -81,8 +112,8 @@ class ScriptRetriever(object): request = urlrequest.Request(url) request.add_unredirected_header('Metadata-Flavor', 'Google') request.add_unredirected_header('Authorization', self.token) - content = urlrequest.urlopen(request).read().decode('utf-8') - except (httpclient.HTTPException, socket.error, urlerror.URLError) as e: + content = _UrlOpenWithRetry(request).read().decode('utf-8') + except Exception as e: self.logger.warning('Could not download %s. %s.', url, str(e)) return None @@ -107,7 +138,7 @@ class ScriptRetriever(object): self.logger.info('Downloading url from %s to %s.', url, dest) try: - urlretrieve.urlretrieve(url, dest) + _UrlRetrieveWithRetry(url, dest) return dest except (httpclient.HTTPException, socket.error, urlerror.URLError) as e: self.logger.warning('Could not download %s. %s.', url, str(e)) @@ -192,8 +223,10 @@ class ScriptRetriever(object): metadata_value = attribute_data.get(metadata_key) if metadata_value: self.logger.info('Found %s in metadata.', metadata_key) - script_dict[metadata_key] = self._DownloadScript( - metadata_value, dest_dir) + downloaded_dest = self._DownloadScript(metadata_value, dest_dir) + if downloaded_dest is None: + self.logger.warning('Failed to download metadata script.') + script_dict[metadata_key] = downloaded_dest return script_dict diff --git a/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/tests/script_retriever_test.py b/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/tests/script_retriever_test.py index f3f520a..6048f68 100644 --- a/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/tests/script_retriever_test.py +++ b/packages/python-google-compute-engine/google_compute_engine/metadata_scripts/tests/script_retriever_test.py @@ -19,7 +19,6 @@ import subprocess from google_compute_engine.compat import urlerror from google_compute_engine.metadata_scripts import script_retriever -from google_compute_engine.metadata_watcher import MetadataWatcher from google_compute_engine.test_compat import builtin from google_compute_engine.test_compat import mock from google_compute_engine.test_compat import unittest @@ -145,16 +144,42 @@ class ScriptRetrieverTest(unittest.TestCase): mock_retrieve.assert_called_once_with(url, self.dest) self.mock_logger.warning.assert_not_called() + @mock.patch('google_compute_engine.metadata_scripts.script_retriever.time') @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') - def testDownloadUrlProcessError(self, mock_retrieve, mock_tempfile): + def testDownloadUrlProcessError(self, mock_retrieve, mock_tempfile, mock_time): url = 'http://www.google.com/fake/url' mock_tempfile.return_value = mock_tempfile mock_tempfile.name = self.dest - mock_retrieve.side_effect = script_retriever.socket.timeout() + mock_success = mock.Mock() + mock_success.getcode.return_value = script_retriever.httpclient.OK + # Success after 3 timeout. Since max_retry = 3, the final result is fail. + mock_retrieve.side_effect = [ + script_retriever.socket.timeout(), + script_retriever.socket.timeout(), + script_retriever.socket.timeout(), + mock_success, + ] self.assertIsNone(self.retriever._DownloadUrl(url, self.dest_dir)) self.assertEqual(self.mock_logger.warning.call_count, 1) + @mock.patch('google_compute_engine.metadata_scripts.script_retriever.time') + @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') + @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') + def testDownloadUrlWithRetry(self, mock_retrieve, mock_tempfile, mock_time): + url = 'http://www.google.com/fake/url' + mock_tempfile.return_value = mock_tempfile + mock_tempfile.name = self.dest + mock_success = mock.Mock() + mock_success.getcode.return_value = script_retriever.httpclient.OK + # Success after 2 timeout. Since max_retry = 3, the final result is success. + mock_retrieve.side_effect = [ + script_retriever.socket.timeout(), + script_retriever.socket.timeout(), + mock_success, + ] + self.assertIsNotNone(self.retriever._DownloadUrl(url, self.dest_dir)) + @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') @mock.patch('google_compute_engine.metadata_scripts.script_retriever.urlretrieve.urlretrieve') def testDownloadUrlException(self, mock_retrieve, mock_tempfile): @@ -325,6 +350,7 @@ class ScriptRetrieverTest(unittest.TestCase): self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data) self.assertEqual(self.mock_logger.info.call_count, 2) + self.assertEqual(self.mock_logger.warning.call_count, 0) mock_dest.write.assert_called_once_with('a') mock_download.assert_called_once_with('b', self.dest_dir) @@ -352,6 +378,44 @@ class ScriptRetrieverTest(unittest.TestCase): self.mock_logger.info.assert_not_called() self.assertEqual(self.mock_logger.warning.call_count, 2) + @mock.patch('google_compute_engine.metadata_scripts.script_retriever.tempfile.NamedTemporaryFile') + def testGetScriptsFailed(self, mock_tempfile): + script_dest = '/tmp/script' + script_url_dest = None + metadata = { + 'instance': { + 'attributes': { + '%s-script' % self.script_type: 'a', + '%s-script-url' % self.script_type: 'b', + }, + }, + 'project': { + 'attributes': { + '%s-script' % self.script_type: 'c', + '%s-script-url' % self.script_type: 'd', + }, + }, + } + expected_data = { + '%s-script' % self.script_type: script_dest, + '%s-script-url' % self.script_type: script_url_dest, + } + self.mock_watcher.GetMetadata.return_value = metadata + self.retriever.watcher = self.mock_watcher + # Mock saving a script to a file. + mock_dest = mock.Mock() + mock_dest.name = script_dest + mock_tempfile.__enter__.return_value = mock_dest + mock_tempfile.return_value = mock_tempfile + # Mock downloading a script from a URL. + mock_download = mock.Mock() + mock_download.return_value = None + self.retriever._DownloadScript = mock_download + + self.assertEqual(self.retriever.GetScripts(self.dest_dir), expected_data) + self.assertEqual(self.mock_logger.info.call_count, 2) + self.assertEqual(self.mock_logger.warning.call_count, 1) + if __name__ == '__main__': unittest.main() -- cgit v1.2.1