summaryrefslogtreecommitdiff
path: root/tests/unittests/sources/azure/test_imds.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittests/sources/azure/test_imds.py')
-rw-r--r--tests/unittests/sources/azure/test_imds.py491
1 files changed, 491 insertions, 0 deletions
diff --git a/tests/unittests/sources/azure/test_imds.py b/tests/unittests/sources/azure/test_imds.py
new file mode 100644
index 00000000..b5a72645
--- /dev/null
+++ b/tests/unittests/sources/azure/test_imds.py
@@ -0,0 +1,491 @@
+# This file is part of cloud-init. See LICENSE file for license information.
+
+import json
+import logging
+import math
+from unittest import mock
+
+import pytest
+import requests
+
+from cloudinit.sources.azure import imds
+from cloudinit.url_helper import UrlError
+
+MOCKPATH = "cloudinit.sources.azure.imds."
+
+
+@pytest.fixture
+def mock_readurl():
+ with mock.patch(MOCKPATH + "readurl", autospec=True) as m:
+ yield m
+
+
+@pytest.fixture
+def mock_requests_session_request():
+ with mock.patch("requests.Session.request", autospec=True) as m:
+ yield m
+
+
+@pytest.fixture
+def mock_url_helper_time_sleep():
+ with mock.patch("cloudinit.url_helper.time.sleep", autospec=True) as m:
+ yield m
+
+
+def fake_http_error_for_code(status_code: int):
+ response_failure = requests.Response()
+ response_failure.status_code = status_code
+ return requests.exceptions.HTTPError(
+ "fake error",
+ response=response_failure,
+ )
+
+
+class TestFetchMetadataWithApiFallback:
+ default_url = (
+ "http://169.254.169.254/metadata/instance?"
+ "api-version=2021-08-01&extended=true"
+ )
+ fallback_url = (
+ "http://169.254.169.254/metadata/instance?api-version=2019-06-01"
+ )
+ headers = {"Metadata": "true"}
+ retries = 10
+ timeout = 2
+
+ def test_basic(
+ self,
+ caplog,
+ mock_readurl,
+ ):
+ fake_md = {"foo": {"bar": []}}
+ mock_readurl.side_effect = [
+ mock.Mock(contents=json.dumps(fake_md).encode()),
+ ]
+
+ md = imds.fetch_metadata_with_api_fallback()
+
+ assert md == fake_md
+ assert mock_readurl.mock_calls == [
+ mock.call(
+ self.default_url,
+ timeout=self.timeout,
+ headers=self.headers,
+ retries=self.retries,
+ exception_cb=imds._readurl_exception_callback,
+ infinite=False,
+ log_req_resp=True,
+ ),
+ ]
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == []
+
+ def test_basic_fallback(
+ self,
+ caplog,
+ mock_readurl,
+ ):
+ fake_md = {"foo": {"bar": []}}
+ mock_readurl.side_effect = [
+ UrlError("No IMDS version", code=400),
+ mock.Mock(contents=json.dumps(fake_md).encode()),
+ ]
+
+ md = imds.fetch_metadata_with_api_fallback()
+
+ assert md == fake_md
+ assert mock_readurl.mock_calls == [
+ mock.call(
+ self.default_url,
+ timeout=self.timeout,
+ headers=self.headers,
+ retries=self.retries,
+ exception_cb=imds._readurl_exception_callback,
+ infinite=False,
+ log_req_resp=True,
+ ),
+ mock.call(
+ self.fallback_url,
+ timeout=self.timeout,
+ headers=self.headers,
+ retries=self.retries,
+ exception_cb=imds._readurl_exception_callback,
+ infinite=False,
+ log_req_resp=True,
+ ),
+ ]
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == [
+ "Failed to fetch metadata from IMDS: No IMDS version",
+ "Falling back to IMDS api-version: 2019-06-01",
+ ]
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ fake_http_error_for_code(404),
+ fake_http_error_for_code(410),
+ fake_http_error_for_code(429),
+ fake_http_error_for_code(500),
+ requests.ConnectionError("Fake connection error"),
+ requests.Timeout("Fake connection timeout"),
+ ],
+ )
+ def test_will_retry_errors(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ error,
+ ):
+ fake_md = {"foo": {"bar": []}}
+ mock_requests_session_request.side_effect = [
+ error,
+ mock.Mock(content=json.dumps(fake_md)),
+ ]
+
+ md = imds.fetch_metadata_with_api_fallback()
+
+ assert md == fake_md
+ assert len(mock_requests_session_request.mock_calls) == 2
+ assert mock_url_helper_time_sleep.mock_calls == [mock.call(1)]
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == []
+
+ def test_will_retry_errors_on_fallback(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ ):
+ error = fake_http_error_for_code(400)
+ fake_md = {"foo": {"bar": []}}
+ mock_requests_session_request.side_effect = [
+ error,
+ fake_http_error_for_code(429),
+ mock.Mock(content=json.dumps(fake_md)),
+ ]
+
+ md = imds.fetch_metadata_with_api_fallback()
+
+ assert md == fake_md
+ assert len(mock_requests_session_request.mock_calls) == 3
+ assert mock_url_helper_time_sleep.mock_calls == [mock.call(1)]
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == [
+ "Failed to fetch metadata from IMDS: fake error",
+ "Falling back to IMDS api-version: 2019-06-01",
+ ]
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ fake_http_error_for_code(404),
+ fake_http_error_for_code(410),
+ fake_http_error_for_code(429),
+ fake_http_error_for_code(500),
+ requests.ConnectionError("Fake connection error"),
+ requests.Timeout("Fake connection timeout"),
+ ],
+ )
+ def test_retry_until_failure(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ error,
+ ):
+ mock_requests_session_request.side_effect = [error] * (11)
+
+ with pytest.raises(UrlError) as exc_info:
+ imds.fetch_metadata_with_api_fallback()
+
+ assert exc_info.value.cause == error
+ assert len(mock_requests_session_request.mock_calls) == (
+ self.retries + 1
+ )
+ assert (
+ mock_url_helper_time_sleep.mock_calls
+ == [mock.call(1)] * self.retries
+ )
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == [f"Failed to fetch metadata from IMDS: {error!s}"]
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ fake_http_error_for_code(403),
+ fake_http_error_for_code(501),
+ ],
+ )
+ def test_will_not_retry_errors(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ error,
+ ):
+ fake_md = {"foo": {"bar": []}}
+ mock_requests_session_request.side_effect = [
+ error,
+ mock.Mock(content=json.dumps(fake_md)),
+ ]
+
+ with pytest.raises(UrlError) as exc_info:
+ imds.fetch_metadata_with_api_fallback()
+
+ assert exc_info.value.cause == error
+ assert len(mock_requests_session_request.mock_calls) == 1
+ assert mock_url_helper_time_sleep.mock_calls == []
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == [f"Failed to fetch metadata from IMDS: {error!s}"]
+
+ def test_non_json_repsonse(
+ self,
+ caplog,
+ mock_readurl,
+ ):
+ mock_readurl.side_effect = [
+ mock.Mock(contents=b"bad data"),
+ ]
+
+ with pytest.raises(ValueError):
+ imds.fetch_metadata_with_api_fallback()
+
+ assert mock_readurl.mock_calls == [
+ mock.call(
+ self.default_url,
+ timeout=self.timeout,
+ headers=self.headers,
+ retries=self.retries,
+ exception_cb=imds._readurl_exception_callback,
+ infinite=False,
+ log_req_resp=True,
+ ),
+ ]
+
+ warnings = [
+ x.message for x in caplog.records if x.levelno == logging.WARNING
+ ]
+ assert warnings == [
+ (
+ "Failed to parse metadata from IMDS: "
+ "Expecting value: line 1 column 1 (char 0)"
+ )
+ ]
+
+
+class TestFetchReprovisionData:
+ url = (
+ "http://169.254.169.254/metadata/"
+ "reprovisiondata?api-version=2019-06-01"
+ )
+ headers = {"Metadata": "true"}
+ timeout = 2
+
+ def test_basic(
+ self,
+ caplog,
+ mock_readurl,
+ ):
+ content = b"ovf content"
+ mock_readurl.side_effect = [
+ mock.Mock(contents=content),
+ ]
+
+ ovf = imds.fetch_reprovision_data()
+
+ assert ovf == content
+ assert mock_readurl.mock_calls == [
+ mock.call(
+ self.url,
+ timeout=self.timeout,
+ headers=self.headers,
+ exception_cb=mock.ANY,
+ infinite=True,
+ log_req_resp=False,
+ ),
+ ]
+
+ assert caplog.record_tuples == [
+ (
+ "cloudinit.sources.azure.imds",
+ logging.DEBUG,
+ "Polled IMDS 1 time(s)",
+ )
+ ]
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ fake_http_error_for_code(404),
+ fake_http_error_for_code(410),
+ ],
+ )
+ @pytest.mark.parametrize("failures", [1, 5, 100, 1000])
+ def test_will_retry_errors(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ error,
+ failures,
+ ):
+ content = b"ovf content"
+ mock_requests_session_request.side_effect = [error] * failures + [
+ mock.Mock(content=content),
+ ]
+
+ ovf = imds.fetch_reprovision_data()
+
+ assert ovf == content
+ assert len(mock_requests_session_request.mock_calls) == failures + 1
+ assert (
+ mock_url_helper_time_sleep.mock_calls == [mock.call(1)] * failures
+ )
+
+ wrapped_error = UrlError(
+ error,
+ code=error.response.status_code,
+ headers=error.response.headers,
+ url=self.url,
+ )
+ backoff_logs = [
+ (
+ "cloudinit.sources.azure.imds",
+ logging.INFO,
+ "Polling IMDS failed with exception: "
+ f"{wrapped_error!r} count: {i}",
+ )
+ for i in range(1, failures + 1)
+ if i == 1 or math.log2(i).is_integer()
+ ]
+ assert caplog.record_tuples == backoff_logs + [
+ (
+ "cloudinit.url_helper",
+ logging.DEBUG,
+ mock.ANY,
+ ),
+ (
+ "cloudinit.sources.azure.imds",
+ logging.DEBUG,
+ f"Polled IMDS {failures+1} time(s)",
+ ),
+ ]
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ fake_http_error_for_code(404),
+ fake_http_error_for_code(410),
+ ],
+ )
+ @pytest.mark.parametrize("failures", [1, 5, 100, 1000])
+ @pytest.mark.parametrize(
+ "terminal_error",
+ [
+ requests.ConnectionError("Fake connection error"),
+ requests.Timeout("Fake connection timeout"),
+ ],
+ )
+ def test_retry_until_failure(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ error,
+ failures,
+ terminal_error,
+ ):
+ mock_requests_session_request.side_effect = [error] * failures + [
+ terminal_error
+ ]
+
+ with pytest.raises(UrlError) as exc_info:
+ imds.fetch_reprovision_data()
+
+ assert exc_info.value.cause == terminal_error
+ assert len(mock_requests_session_request.mock_calls) == (failures + 1)
+ assert (
+ mock_url_helper_time_sleep.mock_calls == [mock.call(1)] * failures
+ )
+
+ wrapped_error = UrlError(
+ error,
+ code=error.response.status_code,
+ headers=error.response.headers,
+ url=self.url,
+ )
+
+ backoff_logs = [
+ (
+ "cloudinit.sources.azure.imds",
+ logging.INFO,
+ "Polling IMDS failed with exception: "
+ f"{wrapped_error!r} count: {i}",
+ )
+ for i in range(1, failures + 1)
+ if i == 1 or math.log2(i).is_integer()
+ ]
+ assert caplog.record_tuples == backoff_logs + [
+ (
+ "cloudinit.sources.azure.imds",
+ logging.INFO,
+ "Polling IMDS failed with exception: "
+ f"{exc_info.value!r} count: {failures+1}",
+ ),
+ ]
+
+ @pytest.mark.parametrize(
+ "error",
+ [
+ fake_http_error_for_code(403),
+ fake_http_error_for_code(501),
+ ],
+ )
+ def test_will_not_retry_errors(
+ self,
+ caplog,
+ mock_requests_session_request,
+ mock_url_helper_time_sleep,
+ error,
+ ):
+ fake_md = {"foo": {"bar": []}}
+ mock_requests_session_request.side_effect = [
+ error,
+ mock.Mock(content=json.dumps(fake_md)),
+ ]
+
+ with pytest.raises(UrlError) as exc_info:
+ imds.fetch_reprovision_data()
+
+ assert exc_info.value.cause == error
+ assert len(mock_requests_session_request.mock_calls) == 1
+ assert mock_url_helper_time_sleep.mock_calls == []
+
+ assert caplog.record_tuples == [
+ (
+ "cloudinit.sources.azure.imds",
+ logging.INFO,
+ "Polling IMDS failed with exception: "
+ f"{exc_info.value!r} count: 1",
+ ),
+ ]