From 2af230f28fd1150e144b1524be24c0a04769d4ae Mon Sep 17 00:00:00 2001 From: Chris Patterson Date: Thu, 16 Feb 2023 13:16:31 -0700 Subject: sources/azure: refactor imds handler into own module (#1977) Create new azure package for better organization and move IMDS logic for fetching into it. Future work will clean up the test_azure.py tests a little further thanks to these changes, but wanted to minimize churn here to make changes fairly visible. Signed-off-by: Chris Patterson --- cloudinit/sources/DataSourceAzure.py | 326 ++----------- cloudinit/sources/azure/__init__.py | 0 cloudinit/sources/azure/imds.py | 156 +++++++ tests/unittests/sources/azure/test_imds.py | 491 ++++++++++++++++++++ tests/unittests/sources/test_azure.py | 720 +++++------------------------ 5 files changed, 799 insertions(+), 894 deletions(-) create mode 100644 cloudinit/sources/azure/__init__.py create mode 100644 cloudinit/sources/azure/imds.py create mode 100644 tests/unittests/sources/azure/test_imds.py diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 9dac4c6b..b7d3e5a3 100644 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -6,7 +6,6 @@ import base64 import crypt -import functools import os import os.path import re @@ -16,8 +15,6 @@ from pathlib import Path from time import sleep, time from typing import Any, Dict, List, Optional -import requests - from cloudinit import dmi from cloudinit import log as logging from cloudinit import net, sources, ssh_util, subp, util @@ -29,6 +26,7 @@ from cloudinit.net.dhcp import ( ) from cloudinit.net.ephemeral import EphemeralDHCPv4 from cloudinit.reporting import events +from cloudinit.sources.azure import imds from cloudinit.sources.helpers import netlink from cloudinit.sources.helpers.azure import ( DEFAULT_WIRESERVER_ENDPOINT, @@ -49,7 +47,7 @@ from cloudinit.sources.helpers.azure import ( report_diagnostic_event, report_failure_to_fabric, ) -from cloudinit.url_helper import UrlError, readurl, retry_on_url_exc +from cloudinit.url_helper import UrlError LOG = logging.getLogger(__name__) @@ -63,34 +61,6 @@ DEFAULT_FS = "ext4" AGENT_SEED_DIR = "/var/lib/waagent" DEFAULT_PROVISIONING_ISO_DEV = "/dev/sr0" -# In the event where the IMDS primary server is not -# available, it takes 1s to fallback to the secondary one -IMDS_TIMEOUT_IN_SECONDS = 2 -IMDS_URL = "http://169.254.169.254/metadata" -IMDS_VER_MIN = "2019-06-01" -IMDS_VER_WANT = "2021-08-01" -IMDS_EXTENDED_VER_MIN = "2021-03-01" -IMDS_RETRY_CODES = ( - 404, # not found (yet) - 410, # gone / unavailable (yet) - 429, # rate-limited/throttled - 500, # server error -) -imds_readurl_exception_callback = functools.partial( - retry_on_url_exc, - retry_codes=IMDS_RETRY_CODES, - retry_instances=( - requests.ConnectionError, - requests.Timeout, - ), -) - - -class MetadataType(Enum): - ALL = "{}/instance".format(IMDS_URL) - NETWORK = "{}/instance/network".format(IMDS_URL) - REPROVISION_DATA = "{}/reprovisiondata".format(IMDS_URL) - class PPSType(Enum): NONE = "None" @@ -593,10 +563,9 @@ class DataSourceAzure(sources.DataSource): except NoDHCPLeaseError: pass + imds_md = {} if self._is_ephemeral_networking_up(): - imds_md = self.get_imds_data_with_api_fallback(retries=10) - else: - imds_md = {} + imds_md = self.get_metadata_from_imds() if not imds_md and ovf_source is None: msg = "No OVF or IMDS available" @@ -619,7 +588,7 @@ class DataSourceAzure(sources.DataSource): md, userdata_raw, cfg, files = self._reprovision() # fetch metadata again as it has changed after reprovisioning - imds_md = self.get_imds_data_with_api_fallback(retries=10) + imds_md = self.get_metadata_from_imds() # Report errors if IMDS network configuration is missing data. self.validate_imds_network_metadata(imds_md=imds_md) @@ -710,6 +679,17 @@ class DataSourceAzure(sources.DataSource): return crawled_data + @azure_ds_telemetry_reporter + def get_metadata_from_imds(self) -> Dict: + try: + return imds.fetch_metadata_with_api_fallback() + except (UrlError, ValueError) as error: + report_diagnostic_event( + "Ignoring IMDS metadata due to: %s" % error, + logger_func=LOG.warning, + ) + return {} + def clear_cached_attrs(self, attr_defaults=()): """Reset any cached class attributes to defaults.""" super(DataSourceAzure, self).clear_cached_attrs(attr_defaults) @@ -795,54 +775,6 @@ class DataSourceAzure(sources.DataSource): ) return True - @azure_ds_telemetry_reporter - def get_imds_data_with_api_fallback( - self, - *, - retries: int, - md_type: MetadataType = MetadataType.ALL, - exc_cb=imds_readurl_exception_callback, - infinite: bool = False, - ) -> dict: - """Fetch metadata from IMDS using IMDS_VER_WANT API version. - - Falls back to IMDS_VER_MIN version if IMDS returns a 400 error code, - indicating that IMDS_VER_WANT is unsupported. - - :return: Parsed metadata dictionary or empty dict on error. - """ - LOG.info("Attempting IMDS api-version: %s", IMDS_VER_WANT) - try: - return get_metadata_from_imds( - retries=retries, - md_type=md_type, - api_version=IMDS_VER_WANT, - exc_cb=exc_cb, - infinite=infinite, - ) - except UrlError as error: - LOG.info("UrlError with IMDS api-version: %s", IMDS_VER_WANT) - # Fall back if HTTP code is 400, otherwise return empty dict. - if error.code != 400: - return {} - - log_msg = "Fall back to IMDS api-version: {}".format(IMDS_VER_MIN) - report_diagnostic_event(log_msg, logger_func=LOG.info) - try: - return get_metadata_from_imds( - retries=retries, - md_type=md_type, - api_version=IMDS_VER_MIN, - exc_cb=exc_cb, - infinite=infinite, - ) - except UrlError as error: - report_diagnostic_event( - "Failed to fetch IMDS metadata: %s" % error, - logger_func=LOG.error, - ) - return {} - def get_instance_id(self): if not self.metadata or "instance-id" not in self.metadata: return self._iid() @@ -1052,82 +984,18 @@ class DataSourceAzure(sources.DataSource): primary nic, then we also get the expected total nic count from IMDS. IMDS will process the request and send a response only for primary NIC. """ - is_primary = False - expected_nic_count = -1 - imds_md = None - metadata_poll_count = 0 - metadata_logging_threshold = 1 - expected_errors_count = 0 - # For now, only a VM's primary NIC can contact IMDS and WireServer. If # DHCP fails for a NIC, we have no mechanism to determine if the NIC is # primary or secondary. In this case, retry DHCP until successful. self._setup_ephemeral_networking(iface=ifname, timeout_minutes=20) - # Retry polling network metadata for a limited duration only when the - # calls fail due to network unreachable error or timeout. - # This is because the platform drops packets going towards IMDS - # when it is not a primary nic. If the calls fail due to other issues - # like 410, 503 etc, then it means we are primary but IMDS service - # is unavailable at the moment. Retry indefinitely in those cases - # since we cannot move on without the network metadata. In the future, - # all this will not be necessary, as a new dhcp option would tell - # whether the nic is primary or not. - def network_metadata_exc_cb(msg, exc): - nonlocal expected_errors_count, metadata_poll_count - nonlocal metadata_logging_threshold - - metadata_poll_count = metadata_poll_count + 1 - - # Log when needed but back off exponentially to avoid exploding - # the log file. - if metadata_poll_count >= metadata_logging_threshold: - metadata_logging_threshold *= 2 - report_diagnostic_event( - "Ran into exception when attempting to reach %s " - "after %d polls." % (msg, metadata_poll_count), - logger_func=LOG.error, - ) - - if isinstance(exc, UrlError): - report_diagnostic_event( - "poll IMDS with %s failed. Exception: %s and code: %s" - % (msg, exc.cause, exc.code), - logger_func=LOG.error, - ) - - # Retry up to a certain limit for both timeout and network - # unreachable errors. - if exc.cause and isinstance( - exc.cause, (requests.Timeout, requests.ConnectionError) - ): - expected_errors_count = expected_errors_count + 1 - return expected_errors_count <= 10 - return True - # Primary nic detection will be optimized in the future. The fact that # primary nic is being attached first helps here. Otherwise each nic # could add several seconds of delay. - try: - imds_md = self.get_imds_data_with_api_fallback( - retries=0, - md_type=MetadataType.NETWORK, - exc_cb=network_metadata_exc_cb, - infinite=True, - ) - except Exception as e: - LOG.warning( - "Failed to get network metadata using nic %s. Attempt to " - "contact IMDS failed with error %s. Assuming this is not the " - "primary nic.", - ifname, - e, - ) - + imds_md = self.get_metadata_from_imds() if imds_md: # Only primary NIC will get a response from IMDS. LOG.info("%s is the primary nic", ifname) - is_primary = True # Set the expected nic count based on the response received. expected_nic_count = len(imds_md["interface"]) @@ -1135,11 +1003,16 @@ class DataSourceAzure(sources.DataSource): "Expected nic count: %d" % expected_nic_count, logger_func=LOG.info, ) - else: - # If we are not the primary nic, then clean the dhcp context. - self._teardown_ephemeral_networking() + return True, expected_nic_count - return is_primary, expected_nic_count + # If we are not the primary nic, then clean the dhcp context. + LOG.warning( + "Failed to fetch IMDS metadata using nic %s. " + "Assuming this is not the primary nic.", + ifname, + ) + self._teardown_ephemeral_networking() + return False, -1 @azure_ds_telemetry_reporter def _wait_for_hot_attached_primary_nic(self, nl_sock): @@ -1229,54 +1102,11 @@ class DataSourceAzure(sources.DataSource): def _poll_imds(self): """Poll IMDS for the new provisioning data until we get a valid response. Then return the returned JSON object.""" - url = "{}?api-version={}".format( - MetadataType.REPROVISION_DATA.value, IMDS_VER_MIN - ) - headers = {"Metadata": "true"} nl_sock = None report_ready = bool( not os.path.isfile(self._reported_ready_marker_file) ) - self.imds_logging_threshold = 1 - self.imds_poll_counter = 1 dhcp_attempts = 0 - reprovision_data = None - - def exc_cb(msg, exception): - if isinstance(exception, UrlError): - if exception.code in (404, 410): - if self.imds_poll_counter == self.imds_logging_threshold: - # Reducing the logging frequency as we are polling IMDS - self.imds_logging_threshold *= 2 - LOG.debug( - "Backing off logging threshold for the same " - "exception to %d", - self.imds_logging_threshold, - ) - report_diagnostic_event( - "poll IMDS with %s failed. " - "Exception: %s and code: %s" - % (msg, exception.cause, exception.code), - logger_func=LOG.debug, - ) - self.imds_poll_counter += 1 - return True - else: - # If we get an exception while trying to call IMDS, we call - # DHCP and setup the ephemeral network to acquire a new IP. - report_diagnostic_event( - "poll IMDS with %s failed. Exception: %s and code: %s" - % (msg, exception.cause, exception.code), - logger_func=LOG.warning, - ) - return False - - report_diagnostic_event( - "poll IMDS failed with an unexpected exception: %s" - % exception, - logger_func=LOG.warning, - ) - return False if report_ready: # Networking must be up for netlink to detect @@ -1338,6 +1168,7 @@ class DataSourceAzure(sources.DataSource): # Teardown old network configuration. self._teardown_ephemeral_networking() + reprovision_data = None while not reprovision_data: if not self._is_ephemeral_networking_up(): dhcp_attempts += 1 @@ -1352,14 +1183,7 @@ class DataSourceAzure(sources.DataSource): parent=azure_ds_reporter, ): try: - reprovision_data = readurl( - url, - timeout=IMDS_TIMEOUT_IN_SECONDS, - headers=headers, - exception_cb=exc_cb, - infinite=True, - log_req_resp=False, - ).contents + reprovision_data = imds.fetch_reprovision_data() except UrlError: self._teardown_ephemeral_networking() continue @@ -1368,10 +1192,6 @@ class DataSourceAzure(sources.DataSource): "attempted dhcp %d times after reuse" % dhcp_attempts, logger_func=LOG.debug, ) - report_diagnostic_event( - "polled imds %d times after reuse" % self.imds_poll_counter, - logger_func=LOG.debug, - ) return reprovision_data @@ -2088,96 +1908,6 @@ def _generate_network_config_from_fallback_config() -> dict: return cfg -@azure_ds_telemetry_reporter -def get_metadata_from_imds( - retries, - md_type=MetadataType.ALL, - api_version=IMDS_VER_MIN, - exc_cb=imds_readurl_exception_callback, - infinite=False, -): - """Query Azure's instance metadata service, returning a dictionary. - - For more info on IMDS: - https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service - - @param retries: The number of retries of the IMDS_URL. - @param md_type: Metadata type for IMDS request. - @param api_version: IMDS api-version to use in the request. - - @return: A dict of instance metadata containing compute and network - info. - """ - kwargs = { - "logfunc": LOG.debug, - "msg": "Crawl of Azure Instance Metadata Service (IMDS)", - "func": _get_metadata_from_imds, - "args": (retries, exc_cb, md_type, api_version, infinite), - } - try: - return util.log_time(**kwargs) - except Exception as e: - report_diagnostic_event( - "exception while getting metadata: %s" % e, - logger_func=LOG.warning, - ) - raise - - -@azure_ds_telemetry_reporter -def _get_metadata_from_imds( - retries, - exc_cb, - md_type=MetadataType.ALL, - api_version=IMDS_VER_MIN, - infinite=False, -): - url = "{}?api-version={}".format(md_type.value, api_version) - headers = {"Metadata": "true"} - - # support for extended metadata begins with 2021-03-01 - if api_version >= IMDS_EXTENDED_VER_MIN and md_type == MetadataType.ALL: - url = url + "&extended=true" - - try: - response = readurl( - url, - timeout=IMDS_TIMEOUT_IN_SECONDS, - headers=headers, - retries=retries, - exception_cb=exc_cb, - infinite=infinite, - ) - except Exception as e: - # pylint:disable=no-member - if isinstance(e, UrlError) and e.code == 400: - raise - else: - report_diagnostic_event( - "Ignoring IMDS instance metadata. " - "Get metadata from IMDS failed: %s" % e, - logger_func=LOG.warning, - ) - return {} - try: - from json.decoder import JSONDecodeError - - json_decode_error = JSONDecodeError - except ImportError: - json_decode_error = ValueError - - try: - return util.load_json(response.contents) - except json_decode_error as e: - report_diagnostic_event( - "Ignoring non-json IMDS instance metadata response: %s. " - "Loading non-json IMDS response failed: %s" - % (response.contents, e), - logger_func=LOG.warning, - ) - return {} - - @azure_ds_telemetry_reporter def maybe_remove_ubuntu_network_config_scripts(paths=None): """Remove Azure-specific ubuntu network config for non-primary nics. diff --git a/cloudinit/sources/azure/__init__.py b/cloudinit/sources/azure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cloudinit/sources/azure/imds.py b/cloudinit/sources/azure/imds.py new file mode 100644 index 00000000..54fc9a05 --- /dev/null +++ b/cloudinit/sources/azure/imds.py @@ -0,0 +1,156 @@ +# Copyright (C) 2022 Microsoft Corporation. +# +# This file is part of cloud-init. See LICENSE file for license information. + +import functools +from typing import Dict + +import requests + +from cloudinit import log as logging +from cloudinit import util +from cloudinit.sources.helpers.azure import report_diagnostic_event +from cloudinit.url_helper import UrlError, readurl, retry_on_url_exc + +LOG = logging.getLogger(__name__) + +IMDS_URL = "http://169.254.169.254/metadata" + +_readurl_exception_callback = functools.partial( + retry_on_url_exc, + retry_codes=( + 404, # not found (yet) + 410, # gone / unavailable (yet) + 429, # rate-limited/throttled + 500, # server error + ), + retry_instances=( + requests.ConnectionError, + requests.Timeout, + ), +) + + +def _fetch_url( + url: str, *, log_response: bool = True, retries: int = 10, timeout: int = 2 +) -> bytes: + """Fetch URL from IMDS. + + :raises UrlError: on error fetching metadata. + """ + + try: + response = readurl( + url, + exception_cb=_readurl_exception_callback, + headers={"Metadata": "true"}, + infinite=False, + log_req_resp=log_response, + retries=retries, + timeout=timeout, + ) + except UrlError as error: + report_diagnostic_event( + "Failed to fetch metadata from IMDS: %s" % error, + logger_func=LOG.warning, + ) + raise + + return response.contents + + +def _fetch_metadata( + url: str, +) -> Dict: + """Fetch IMDS metadata. + + :raises UrlError: on error fetching metadata. + :raises ValueError: on error parsing metadata. + """ + metadata = _fetch_url(url) + + try: + return util.load_json(metadata) + except ValueError as error: + report_diagnostic_event( + "Failed to parse metadata from IMDS: %s" % error, + logger_func=LOG.warning, + ) + raise + + +def fetch_metadata_with_api_fallback() -> Dict: + """Fetch extended metadata, falling back to non-extended as required. + + :raises UrlError: on error fetching metadata. + :raises ValueError: on error parsing metadata. + """ + try: + url = IMDS_URL + "/instance?api-version=2021-08-01&extended=true" + return _fetch_metadata(url) + except UrlError as error: + if error.code == 400: + report_diagnostic_event( + "Falling back to IMDS api-version: 2019-06-01", + logger_func=LOG.warning, + ) + url = IMDS_URL + "/instance?api-version=2019-06-01" + return _fetch_metadata(url) + raise + + +def fetch_reprovision_data() -> bytes: + """Fetch extended metadata, falling back to non-extended as required. + + :raises UrlError: on error. + """ + url = IMDS_URL + "/reprovisiondata?api-version=2019-06-01" + + logging_threshold = 1 + poll_counter = 0 + + def exception_callback(msg, exception): + nonlocal logging_threshold + nonlocal poll_counter + + poll_counter += 1 + if not isinstance(exception, UrlError): + report_diagnostic_event( + "Polling IMDS failed with unexpected exception: %r" + % (exception), + logger_func=LOG.warning, + ) + return False + + log = True + retry = False + if exception.code in (404, 410): + retry = True + if poll_counter >= logging_threshold: + # Exponential back-off on logging. + logging_threshold *= 2 + else: + log = False + + if log: + report_diagnostic_event( + "Polling IMDS failed with exception: %r count: %d" + % (exception, poll_counter), + logger_func=LOG.info, + ) + return retry + + response = readurl( + url, + exception_cb=exception_callback, + headers={"Metadata": "true"}, + infinite=True, + log_req_resp=False, + timeout=2, + ) + + report_diagnostic_event( + f"Polled IMDS {poll_counter+1} time(s)", + logger_func=LOG.debug, + ) + return response.contents diff --git a/tests/unittests/sources/azure/test_imds.py b/tests/unittests/sources/azure/test_imds.py new file mode 100644 index 00000000..b5a72645 --- /dev/null +++ b/tests/unittests/sources/azure/test_imds.py @@ -0,0 +1,491 @@ +# This file is part of cloud-init. See LICENSE file for license information. + +import json +import logging +import math +from unittest import mock + +import pytest +import requests + +from cloudinit.sources.azure import imds +from cloudinit.url_helper import UrlError + +MOCKPATH = "cloudinit.sources.azure.imds." + + +@pytest.fixture +def mock_readurl(): + with mock.patch(MOCKPATH + "readurl", autospec=True) as m: + yield m + + +@pytest.fixture +def mock_requests_session_request(): + with mock.patch("requests.Session.request", autospec=True) as m: + yield m + + +@pytest.fixture +def mock_url_helper_time_sleep(): + with mock.patch("cloudinit.url_helper.time.sleep", autospec=True) as m: + yield m + + +def fake_http_error_for_code(status_code: int): + response_failure = requests.Response() + response_failure.status_code = status_code + return requests.exceptions.HTTPError( + "fake error", + response=response_failure, + ) + + +class TestFetchMetadataWithApiFallback: + default_url = ( + "http://169.254.169.254/metadata/instance?" + "api-version=2021-08-01&extended=true" + ) + fallback_url = ( + "http://169.254.169.254/metadata/instance?api-version=2019-06-01" + ) + headers = {"Metadata": "true"} + retries = 10 + timeout = 2 + + def test_basic( + self, + caplog, + mock_readurl, + ): + fake_md = {"foo": {"bar": []}} + mock_readurl.side_effect = [ + mock.Mock(contents=json.dumps(fake_md).encode()), + ] + + md = imds.fetch_metadata_with_api_fallback() + + assert md == fake_md + assert mock_readurl.mock_calls == [ + mock.call( + self.default_url, + timeout=self.timeout, + headers=self.headers, + retries=self.retries, + exception_cb=imds._readurl_exception_callback, + infinite=False, + log_req_resp=True, + ), + ] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [] + + def test_basic_fallback( + self, + caplog, + mock_readurl, + ): + fake_md = {"foo": {"bar": []}} + mock_readurl.side_effect = [ + UrlError("No IMDS version", code=400), + mock.Mock(contents=json.dumps(fake_md).encode()), + ] + + md = imds.fetch_metadata_with_api_fallback() + + assert md == fake_md + assert mock_readurl.mock_calls == [ + mock.call( + self.default_url, + timeout=self.timeout, + headers=self.headers, + retries=self.retries, + exception_cb=imds._readurl_exception_callback, + infinite=False, + log_req_resp=True, + ), + mock.call( + self.fallback_url, + timeout=self.timeout, + headers=self.headers, + retries=self.retries, + exception_cb=imds._readurl_exception_callback, + infinite=False, + log_req_resp=True, + ), + ] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [ + "Failed to fetch metadata from IMDS: No IMDS version", + "Falling back to IMDS api-version: 2019-06-01", + ] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(404), + fake_http_error_for_code(410), + fake_http_error_for_code(429), + fake_http_error_for_code(500), + requests.ConnectionError("Fake connection error"), + requests.Timeout("Fake connection timeout"), + ], + ) + def test_will_retry_errors( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + ): + fake_md = {"foo": {"bar": []}} + mock_requests_session_request.side_effect = [ + error, + mock.Mock(content=json.dumps(fake_md)), + ] + + md = imds.fetch_metadata_with_api_fallback() + + assert md == fake_md + assert len(mock_requests_session_request.mock_calls) == 2 + assert mock_url_helper_time_sleep.mock_calls == [mock.call(1)] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [] + + def test_will_retry_errors_on_fallback( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + ): + error = fake_http_error_for_code(400) + fake_md = {"foo": {"bar": []}} + mock_requests_session_request.side_effect = [ + error, + fake_http_error_for_code(429), + mock.Mock(content=json.dumps(fake_md)), + ] + + md = imds.fetch_metadata_with_api_fallback() + + assert md == fake_md + assert len(mock_requests_session_request.mock_calls) == 3 + assert mock_url_helper_time_sleep.mock_calls == [mock.call(1)] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [ + "Failed to fetch metadata from IMDS: fake error", + "Falling back to IMDS api-version: 2019-06-01", + ] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(404), + fake_http_error_for_code(410), + fake_http_error_for_code(429), + fake_http_error_for_code(500), + requests.ConnectionError("Fake connection error"), + requests.Timeout("Fake connection timeout"), + ], + ) + def test_retry_until_failure( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + ): + mock_requests_session_request.side_effect = [error] * (11) + + with pytest.raises(UrlError) as exc_info: + imds.fetch_metadata_with_api_fallback() + + assert exc_info.value.cause == error + assert len(mock_requests_session_request.mock_calls) == ( + self.retries + 1 + ) + assert ( + mock_url_helper_time_sleep.mock_calls + == [mock.call(1)] * self.retries + ) + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [f"Failed to fetch metadata from IMDS: {error!s}"] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(403), + fake_http_error_for_code(501), + ], + ) + def test_will_not_retry_errors( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + ): + fake_md = {"foo": {"bar": []}} + mock_requests_session_request.side_effect = [ + error, + mock.Mock(content=json.dumps(fake_md)), + ] + + with pytest.raises(UrlError) as exc_info: + imds.fetch_metadata_with_api_fallback() + + assert exc_info.value.cause == error + assert len(mock_requests_session_request.mock_calls) == 1 + assert mock_url_helper_time_sleep.mock_calls == [] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [f"Failed to fetch metadata from IMDS: {error!s}"] + + def test_non_json_repsonse( + self, + caplog, + mock_readurl, + ): + mock_readurl.side_effect = [ + mock.Mock(contents=b"bad data"), + ] + + with pytest.raises(ValueError): + imds.fetch_metadata_with_api_fallback() + + assert mock_readurl.mock_calls == [ + mock.call( + self.default_url, + timeout=self.timeout, + headers=self.headers, + retries=self.retries, + exception_cb=imds._readurl_exception_callback, + infinite=False, + log_req_resp=True, + ), + ] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [ + ( + "Failed to parse metadata from IMDS: " + "Expecting value: line 1 column 1 (char 0)" + ) + ] + + +class TestFetchReprovisionData: + url = ( + "http://169.254.169.254/metadata/" + "reprovisiondata?api-version=2019-06-01" + ) + headers = {"Metadata": "true"} + timeout = 2 + + def test_basic( + self, + caplog, + mock_readurl, + ): + content = b"ovf content" + mock_readurl.side_effect = [ + mock.Mock(contents=content), + ] + + ovf = imds.fetch_reprovision_data() + + assert ovf == content + assert mock_readurl.mock_calls == [ + mock.call( + self.url, + timeout=self.timeout, + headers=self.headers, + exception_cb=mock.ANY, + infinite=True, + log_req_resp=False, + ), + ] + + assert caplog.record_tuples == [ + ( + "cloudinit.sources.azure.imds", + logging.DEBUG, + "Polled IMDS 1 time(s)", + ) + ] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(404), + fake_http_error_for_code(410), + ], + ) + @pytest.mark.parametrize("failures", [1, 5, 100, 1000]) + def test_will_retry_errors( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + failures, + ): + content = b"ovf content" + mock_requests_session_request.side_effect = [error] * failures + [ + mock.Mock(content=content), + ] + + ovf = imds.fetch_reprovision_data() + + assert ovf == content + assert len(mock_requests_session_request.mock_calls) == failures + 1 + assert ( + mock_url_helper_time_sleep.mock_calls == [mock.call(1)] * failures + ) + + wrapped_error = UrlError( + error, + code=error.response.status_code, + headers=error.response.headers, + url=self.url, + ) + backoff_logs = [ + ( + "cloudinit.sources.azure.imds", + logging.INFO, + "Polling IMDS failed with exception: " + f"{wrapped_error!r} count: {i}", + ) + for i in range(1, failures + 1) + if i == 1 or math.log2(i).is_integer() + ] + assert caplog.record_tuples == backoff_logs + [ + ( + "cloudinit.url_helper", + logging.DEBUG, + mock.ANY, + ), + ( + "cloudinit.sources.azure.imds", + logging.DEBUG, + f"Polled IMDS {failures+1} time(s)", + ), + ] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(404), + fake_http_error_for_code(410), + ], + ) + @pytest.mark.parametrize("failures", [1, 5, 100, 1000]) + @pytest.mark.parametrize( + "terminal_error", + [ + requests.ConnectionError("Fake connection error"), + requests.Timeout("Fake connection timeout"), + ], + ) + def test_retry_until_failure( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + failures, + terminal_error, + ): + mock_requests_session_request.side_effect = [error] * failures + [ + terminal_error + ] + + with pytest.raises(UrlError) as exc_info: + imds.fetch_reprovision_data() + + assert exc_info.value.cause == terminal_error + assert len(mock_requests_session_request.mock_calls) == (failures + 1) + assert ( + mock_url_helper_time_sleep.mock_calls == [mock.call(1)] * failures + ) + + wrapped_error = UrlError( + error, + code=error.response.status_code, + headers=error.response.headers, + url=self.url, + ) + + backoff_logs = [ + ( + "cloudinit.sources.azure.imds", + logging.INFO, + "Polling IMDS failed with exception: " + f"{wrapped_error!r} count: {i}", + ) + for i in range(1, failures + 1) + if i == 1 or math.log2(i).is_integer() + ] + assert caplog.record_tuples == backoff_logs + [ + ( + "cloudinit.sources.azure.imds", + logging.INFO, + "Polling IMDS failed with exception: " + f"{exc_info.value!r} count: {failures+1}", + ), + ] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(403), + fake_http_error_for_code(501), + ], + ) + def test_will_not_retry_errors( + self, + caplog, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + ): + fake_md = {"foo": {"bar": []}} + mock_requests_session_request.side_effect = [ + error, + mock.Mock(content=json.dumps(fake_md)), + ] + + with pytest.raises(UrlError) as exc_info: + imds.fetch_reprovision_data() + + assert exc_info.value.cause == error + assert len(mock_requests_session_request.mock_calls) == 1 + assert mock_url_helper_time_sleep.mock_calls == [] + + assert caplog.record_tuples == [ + ( + "cloudinit.sources.azure.imds", + logging.INFO, + "Polling IMDS failed with exception: " + f"{exc_info.value!r} count: 1", + ), + ] diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py index 6f98cb27..b5fe2672 100644 --- a/tests/unittests/sources/test_azure.py +++ b/tests/unittests/sources/test_azure.py @@ -3,7 +3,6 @@ import copy import crypt import json -import logging import os import stat import xml.etree.ElementTree as ET @@ -11,13 +10,13 @@ from pathlib import Path import pytest import requests -import responses from cloudinit import distros, helpers, subp, url_helper from cloudinit.net import dhcp from cloudinit.sources import UNSET from cloudinit.sources import DataSourceAzure as dsaz from cloudinit.sources import InvalidMetaDataException +from cloudinit.sources.azure import imds from cloudinit.sources.helpers import netlink from cloudinit.util import ( MountFailedError, @@ -27,11 +26,9 @@ from cloudinit.util import ( load_json, write_file, ) -from cloudinit.version import version_string as vs from tests.unittests.helpers import ( CiTestCase, ExitStack, - ResponsesTestCase, mock, populate_dir, resourceLocation, @@ -205,7 +202,7 @@ def mock_os_path_isfile(): @pytest.fixture def mock_readurl(): - with mock.patch(MOCKPATH + "readurl", autospec=True) as m: + with mock.patch(MOCKPATH + "imds.readurl", autospec=True) as m: yield m @@ -215,12 +212,6 @@ def mock_report_diagnostic_event(): yield m -@pytest.fixture -def mock_requests_session_request(): - with mock.patch("requests.Session.request", autospec=True) as m: - yield m - - @pytest.fixture def mock_sleep(): with mock.patch( @@ -236,12 +227,6 @@ def mock_subp_subp(): yield m -@pytest.fixture -def mock_url_helper_time_sleep(): - with mock.patch("cloudinit.url_helper.time.sleep", autospec=True) as m: - yield m - - @pytest.fixture def mock_util_ensure_dir(): with mock.patch( @@ -852,176 +837,6 @@ class TestNetworkConfig: assert azure_ds.network_config == self.fallback_config -class TestGetMetadataFromIMDS(ResponsesTestCase): - - with_logs = True - - def setUp(self): - super(TestGetMetadataFromIMDS, self).setUp() - self.network_md_url = "{}/instance?api-version=2019-06-01".format( - dsaz.IMDS_URL - ) - - @mock.patch(MOCKPATH + "readurl", autospec=True) - def test_get_metadata_uses_instance_url(self, m_readurl): - """Make sure readurl is called with the correct url when accessing - metadata""" - m_readurl.return_value = url_helper.StringResponse( - json.dumps(IMDS_NETWORK_METADATA).encode("utf-8") - ) - - dsaz.get_metadata_from_imds(retries=3, md_type=dsaz.MetadataType.ALL) - m_readurl.assert_called_with( - "http://169.254.169.254/metadata/instance?api-version=2019-06-01", - exception_cb=mock.ANY, - headers=mock.ANY, - retries=mock.ANY, - timeout=mock.ANY, - infinite=False, - ) - - @mock.patch(MOCKPATH + "readurl", autospec=True) - def test_get_network_metadata_uses_network_url(self, m_readurl): - """Make sure readurl is called with the correct url when accessing - network metadata""" - m_readurl.return_value = url_helper.StringResponse( - json.dumps(IMDS_NETWORK_METADATA).encode("utf-8") - ) - - dsaz.get_metadata_from_imds( - retries=3, md_type=dsaz.MetadataType.NETWORK - ) - m_readurl.assert_called_with( - "http://169.254.169.254/metadata/instance/network?api-version=" - "2019-06-01", - exception_cb=mock.ANY, - headers=mock.ANY, - retries=mock.ANY, - timeout=mock.ANY, - infinite=False, - ) - - @mock.patch(MOCKPATH + "readurl", autospec=True) - @mock.patch(MOCKPATH + "EphemeralDHCPv4", autospec=True) - def test_get_default_metadata_uses_instance_url(self, m_dhcp, m_readurl): - """Make sure readurl is called with the correct url when accessing - metadata""" - m_readurl.return_value = url_helper.StringResponse( - json.dumps(IMDS_NETWORK_METADATA).encode("utf-8") - ) - - dsaz.get_metadata_from_imds(retries=3) - m_readurl.assert_called_with( - "http://169.254.169.254/metadata/instance?api-version=2019-06-01", - exception_cb=mock.ANY, - headers=mock.ANY, - retries=mock.ANY, - timeout=mock.ANY, - infinite=False, - ) - - @mock.patch(MOCKPATH + "readurl", autospec=True) - def test_get_metadata_uses_extended_url(self, m_readurl): - """Make sure readurl is called with the correct url when accessing - metadata""" - m_readurl.return_value = url_helper.StringResponse( - json.dumps(IMDS_NETWORK_METADATA).encode("utf-8") - ) - - dsaz.get_metadata_from_imds( - retries=3, - md_type=dsaz.MetadataType.ALL, - api_version="2021-08-01", - ) - m_readurl.assert_called_with( - "http://169.254.169.254/metadata/instance?api-version=" - "2021-08-01&extended=true", - exception_cb=mock.ANY, - headers=mock.ANY, - retries=mock.ANY, - timeout=mock.ANY, - infinite=False, - ) - - @mock.patch(MOCKPATH + "readurl", autospec=True) - def test_get_metadata_performs_dhcp_when_network_is_down(self, m_readurl): - """Perform DHCP setup when nic is not up.""" - m_readurl.return_value = url_helper.StringResponse( - json.dumps(NETWORK_METADATA).encode("utf-8") - ) - - self.assertEqual( - NETWORK_METADATA, dsaz.get_metadata_from_imds(retries=2) - ) - - self.assertIn( - "Crawl of Azure Instance Metadata Service (IMDS) took", # log_time - self.logs.getvalue(), - ) - - m_readurl.assert_called_with( - self.network_md_url, - exception_cb=mock.ANY, - headers={"Metadata": "true"}, - retries=2, - timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, - infinite=False, - ) - - @mock.patch("cloudinit.url_helper.time.sleep") - def test_get_metadata_from_imds_empty_when_no_imds_present(self, m_sleep): - """Return empty dict when IMDS network metadata is absent.""" - # Workaround https://github.com/getsentry/responses/pull/166 - # url path can be reverted to "/instance?api-version=2019-12-01" - response = requests.Response() - response.status_code = 404 - self.responses.add( - responses.GET, - dsaz.IMDS_URL + "/instance", - body=requests.HTTPError("...", response=response), - status=404, - ) - - self.assertEqual( - {}, - dsaz.get_metadata_from_imds(retries=2, api_version="2019-12-01"), - ) - - self.assertEqual([mock.call(1), mock.call(1)], m_sleep.call_args_list) - self.assertIn( - "Crawl of Azure Instance Metadata Service (IMDS) took", # log_time - self.logs.getvalue(), - ) - - @mock.patch("requests.Session.request") - @mock.patch("cloudinit.url_helper.time.sleep") - def test_get_metadata_from_imds_retries_on_timeout( - self, m_sleep, m_request - ): - """Retry IMDS network metadata on timeout errors.""" - - self.attempt = 0 - m_request.side_effect = requests.Timeout("Fake Connection Timeout") - - def retry_callback(request, uri, headers): - self.attempt += 1 - raise requests.Timeout("Fake connection timeout") - - self.responses.add( - responses.GET, - dsaz.IMDS_URL + "instance?api-version=2017-12-01", - body=retry_callback, - ) - - self.assertEqual({}, dsaz.get_metadata_from_imds(retries=3)) - - self.assertEqual([mock.call(1)] * 3, m_sleep.call_args_list) - self.assertIn( - "Crawl of Azure Instance Metadata Service (IMDS) took", # log_time - self.logs.getvalue(), - ) - - class TestAzureDataSource(CiTestCase): with_logs = True @@ -1053,10 +868,10 @@ class TestAzureDataSource(CiTestCase): self.m_dhcp.return_value.lease = {} self.m_dhcp.return_value.iface = "eth4" - self.m_get_metadata_from_imds = self.patches.enter_context( + self.m_fetch = self.patches.enter_context( mock.patch.object( - dsaz, - "get_metadata_from_imds", + dsaz.imds, + "fetch_metadata_with_api_fallback", mock.MagicMock(return_value=NETWORK_METADATA), ) ) @@ -1369,7 +1184,7 @@ scbus-1 on xpt0 bus 0 data, write_ovf_to_data_dir=True, write_ovf_to_seed_dir=False ) - self.m_get_metadata_from_imds.return_value = {} + self.m_fetch.return_value = {} with mock.patch(MOCKPATH + "util.mount_cb") as m_mount_cb: m_mount_cb.side_effect = [ MountFailedError("fail"), @@ -1506,7 +1321,7 @@ scbus-1 on xpt0 bus 0 data = {"ovfcontent": ovfenv, "sys_cfg": {}} dsrc = self._get_ds(data) dsrc.crawl_metadata() - self.assertEqual(1, self.m_get_metadata_from_imds.call_count) + self.assertEqual(1, self.m_fetch.call_count) @mock.patch("cloudinit.sources.DataSourceAzure.util.write_file") @mock.patch( @@ -1523,7 +1338,7 @@ scbus-1 on xpt0 bus 0 dsrc = self._get_ds(data) poll_imds_func.return_value = ovfenv dsrc.crawl_metadata() - self.assertEqual(2, self.m_get_metadata_from_imds.call_count) + self.assertEqual(2, self.m_fetch.call_count) @mock.patch("cloudinit.sources.DataSourceAzure.util.write_file") @mock.patch( @@ -1574,9 +1389,11 @@ scbus-1 on xpt0 bus 0 "cloudinit.sources.DataSourceAzure.DataSourceAzure._report_ready", return_value=True, ) - @mock.patch("cloudinit.sources.DataSourceAzure.readurl") + @mock.patch( + "cloudinit.sources.DataSourceAzure.imds.fetch_reprovision_data" + ) def test_crawl_metadata_on_reprovision_reports_ready_using_lease( - self, m_readurl, m_report_ready, m_media_switch, m_write + self, m_fetch_reprovision_data, m_report_ready, m_media_switch, m_write ): """If reprovisioning, report ready using the obtained lease""" ovfenv = construct_ovf_env(preprovisioned_vm=True) @@ -1595,8 +1412,8 @@ scbus-1 on xpt0 bus 0 m_media_switch.return_value = None reprovision_ovfenv = construct_ovf_env() - m_readurl.return_value = url_helper.StringResponse( - reprovision_ovfenv.encode("utf-8") + m_fetch_reprovision_data.return_value = reprovision_ovfenv.encode( + "utf-8" ) dsrc.crawl_metadata() @@ -1678,7 +1495,7 @@ scbus-1 on xpt0 bus 0 third_intf["ipv4"]["ipAddress"][0]["privateIpAddress"] = "10.0.2.6" imds_data["network"]["interface"].append(third_intf) - self.m_get_metadata_from_imds.return_value = imds_data + self.m_fetch.return_value = imds_data dsrc = self._get_ds(data) dsrc.get_data() self.assertEqual(expected_network_config, dsrc.network_config) @@ -1704,7 +1521,7 @@ scbus-1 on xpt0 bus 0 } imds_data = copy.deepcopy(NETWORK_METADATA) imds_data["network"]["interface"].append(SECONDARY_INTERFACE_NO_IP) - self.m_get_metadata_from_imds.return_value = imds_data + self.m_fetch.return_value = imds_data dsrc = self._get_ds(data) dsrc.get_data() self.assertEqual(expected_network_config, dsrc.network_config) @@ -2261,13 +2078,12 @@ scbus-1 on xpt0 bus 0 @mock.patch( "cloudinit.sources.helpers.azure.OpenSSLManager.parse_certificates" ) - @mock.patch(MOCKPATH + "get_metadata_from_imds") def test_get_public_ssh_keys_with_no_openssh_format( - self, m_get_metadata_from_imds, m_parse_certificates + self, m_parse_certificates ): imds_data = copy.deepcopy(NETWORK_METADATA) imds_data["compute"]["publicKeys"][0]["keyData"] = "no-openssh-format" - m_get_metadata_from_imds.return_value = imds_data + self.m_fetch.return_value = imds_data sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { "ovfcontent": construct_ovf_env(), @@ -2280,9 +2096,8 @@ scbus-1 on xpt0 bus 0 self.assertEqual(ssh_keys, []) self.assertEqual(m_parse_certificates.call_count, 0) - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_get_public_ssh_keys_without_imds(self, m_get_metadata_from_imds): - m_get_metadata_from_imds.return_value = dict() + def test_get_public_ssh_keys_without_imds(self): + self.m_fetch.return_value = dict() sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { "ovfcontent": construct_ovf_env(), @@ -2295,67 +2110,7 @@ scbus-1 on xpt0 bus 0 ssh_keys = dsrc.get_public_ssh_keys() self.assertEqual(ssh_keys, ["key2"]) - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_imds_api_version_wanted_nonexistent( - self, m_get_metadata_from_imds - ): - def get_metadata_from_imds_side_eff(*args, **kwargs): - if kwargs["api_version"] == dsaz.IMDS_VER_WANT: - raise url_helper.UrlError("No IMDS version", code=400) - return NETWORK_METADATA - - m_get_metadata_from_imds.side_effect = get_metadata_from_imds_side_eff - sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} - data = { - "ovfcontent": construct_ovf_env(), - "sys_cfg": sys_cfg, - } - dsrc = self._get_ds(data) - dsrc.get_data() - self.assertIsNotNone(dsrc.metadata) - - assert m_get_metadata_from_imds.mock_calls == [ - mock.call( - retries=10, - md_type=dsaz.MetadataType.ALL, - api_version="2021-08-01", - exc_cb=mock.ANY, - infinite=False, - ), - mock.call( - retries=10, - md_type=dsaz.MetadataType.ALL, - api_version="2019-06-01", - exc_cb=mock.ANY, - infinite=False, - ), - ] - - @mock.patch( - MOCKPATH + "get_metadata_from_imds", return_value=NETWORK_METADATA - ) - def test_imds_api_version_wanted_exists(self, m_get_metadata_from_imds): - sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} - data = { - "ovfcontent": construct_ovf_env(), - "sys_cfg": sys_cfg, - } - dsrc = self._get_ds(data) - dsrc.get_data() - self.assertIsNotNone(dsrc.metadata) - - assert m_get_metadata_from_imds.mock_calls == [ - mock.call( - retries=10, - md_type=dsaz.MetadataType.ALL, - api_version="2021-08-01", - exc_cb=mock.ANY, - infinite=False, - ) - ] - - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_hostname_from_imds(self, m_get_metadata_from_imds): + def test_hostname_from_imds(self): sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { "ovfcontent": construct_ovf_env(), @@ -2367,13 +2122,12 @@ scbus-1 on xpt0 bus 0 computerName="hostname1", disablePasswordAuthentication="true", ) - m_get_metadata_from_imds.return_value = imds_data_with_os_profile + self.m_fetch.return_value = imds_data_with_os_profile dsrc = self._get_ds(data) dsrc.get_data() self.assertEqual(dsrc.metadata["local-hostname"], "hostname1") - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_username_from_imds(self, m_get_metadata_from_imds): + def test_username_from_imds(self): sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { "ovfcontent": construct_ovf_env(), @@ -2385,15 +2139,14 @@ scbus-1 on xpt0 bus 0 computerName="hostname1", disablePasswordAuthentication="true", ) - m_get_metadata_from_imds.return_value = imds_data_with_os_profile + self.m_fetch.return_value = imds_data_with_os_profile dsrc = self._get_ds(data) dsrc.get_data() self.assertEqual( dsrc.cfg["system_info"]["default_user"]["name"], "username1" ) - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_disable_password_from_imds(self, m_get_metadata_from_imds): + def test_disable_password_from_imds(self): sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { "ovfcontent": construct_ovf_env(), @@ -2405,13 +2158,12 @@ scbus-1 on xpt0 bus 0 computerName="hostname1", disablePasswordAuthentication="true", ) - m_get_metadata_from_imds.return_value = imds_data_with_os_profile + self.m_fetch.return_value = imds_data_with_os_profile dsrc = self._get_ds(data) dsrc.get_data() self.assertTrue(dsrc.metadata["disable_password"]) - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_userdata_from_imds(self, m_get_metadata_from_imds): + def test_userdata_from_imds(self): sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { "ovfcontent": construct_ovf_env(), @@ -2425,16 +2177,13 @@ scbus-1 on xpt0 bus 0 disablePasswordAuthentication="true", ) imds_data["compute"]["userData"] = b64e(userdata) - m_get_metadata_from_imds.return_value = imds_data + self.m_fetch.return_value = imds_data dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) self.assertEqual(dsrc.userdata_raw, userdata.encode("utf-8")) - @mock.patch(MOCKPATH + "get_metadata_from_imds") - def test_userdata_from_imds_with_customdata_from_OVF( - self, m_get_metadata_from_imds - ): + def test_userdata_from_imds_with_customdata_from_OVF(self): userdataOVF = "userdataOVF" sys_cfg = {"datasource": {"Azure": {"apply_network_config": True}}} data = { @@ -2450,7 +2199,7 @@ scbus-1 on xpt0 bus 0 disablePasswordAuthentication="true", ) imds_data["compute"]["userData"] = b64e(userdataImds) - m_get_metadata_from_imds.return_value = imds_data + self.m_fetch.return_value = imds_data dsrc = self._get_ds(data) ret = dsrc.get_data() self.assertTrue(ret) @@ -3046,7 +2795,7 @@ class TestPreprovisioningHotAttachNics(CiTestCase): @mock.patch(MOCKPATH + "DataSourceAzure._report_ready") @mock.patch(MOCKPATH + "DataSourceAzure.wait_for_link_up") @mock.patch("cloudinit.sources.helpers.netlink.wait_for_nic_attach_event") - @mock.patch(MOCKPATH + "DataSourceAzure.get_imds_data_with_api_fallback") + @mock.patch(MOCKPATH + "imds.fetch_metadata_with_api_fallback") @mock.patch(MOCKPATH + "EphemeralDHCPv4", autospec=True) @mock.patch(MOCKPATH + "DataSourceAzure._wait_for_nic_detach") @mock.patch("os.path.isfile") @@ -3226,7 +2975,7 @@ class TestPreprovisioningHotAttachNics(CiTestCase): @mock.patch( "cloudinit.sources.helpers.netlink.wait_for_media_disconnect_connect" ) -@mock.patch("requests.Session.request") +@mock.patch(MOCKPATH + "imds.fetch_reprovision_data") @mock.patch(MOCKPATH + "DataSourceAzure._report_ready", return_value=True) class TestPreprovisioningPollIMDS(CiTestCase): def setUp(self): @@ -3240,13 +2989,17 @@ class TestPreprovisioningPollIMDS(CiTestCase): def test_poll_imds_re_dhcp_on_timeout( self, m_report_ready, - m_request, + m_fetch_reprovisiondata, m_media_switch, m_dhcp, m_net, m_fallback, ): """The poll_imds will retry DHCP on IMDS timeout.""" + m_fetch_reprovisiondata.side_effect = [ + url_helper.UrlError(requests.Timeout("Fake connection timeout")), + b"ovf data", + ] report_file = self.tmp_path("report_marker", self.tmp) lease = { "interface": "eth9", @@ -3260,23 +3013,6 @@ class TestPreprovisioningPollIMDS(CiTestCase): dhcp_ctx = mock.MagicMock(lease=lease) dhcp_ctx.obtain_lease.return_value = lease - self.tries = 0 - - def fake_timeout_once(**kwargs): - self.tries += 1 - if self.tries == 1: - raise requests.Timeout("Fake connection timeout") - elif self.tries in (2, 3): - response = requests.Response() - response.status_code = 404 if self.tries == 2 else 410 - raise requests.exceptions.HTTPError( - "fake {}".format(response.status_code), response=response - ) - # Third try should succeed and stop retries or redhcp - return mock.MagicMock(status_code=200, text="good", content="good") - - m_request.side_effect = fake_timeout_once - dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) with mock.patch.object( dsa, "_reported_ready_marker_file", report_file @@ -3286,14 +3022,14 @@ class TestPreprovisioningPollIMDS(CiTestCase): assert m_report_ready.mock_calls == [mock.call()] self.assertEqual(3, m_dhcp.call_count, "Expected 3 DHCP calls") - self.assertEqual(4, self.tries, "Expected 4 total reads from IMDS") + assert m_fetch_reprovisiondata.call_count == 2 @mock.patch("os.path.isfile") def test_poll_imds_skips_dhcp_if_ctx_present( self, m_isfile, report_ready_func, - fake_resp, + m_fetch_reprovisiondata, m_media_switch, m_dhcp, m_net, @@ -3322,7 +3058,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): m_ephemeral_dhcpv4, m_isfile, report_ready_func, - m_request, + m_fetch_reprovisiondata, m_media_switch, m_dhcp, m_net, @@ -3333,17 +3069,15 @@ class TestPreprovisioningPollIMDS(CiTestCase): polling for reprovisiondata. Note that if this ctx is set when _poll_imds is called, then it is not expected to be waiting for media_disconnect_connect either.""" - - tries = 0 - - def fake_timeout_once(**kwargs): - nonlocal tries - tries += 1 - if tries == 1: - raise requests.Timeout("Fake connection timeout") - return mock.MagicMock(status_code=200, text="good", content="good") - - m_request.side_effect = fake_timeout_once + m_fetch_reprovisiondata.side_effect = [ + url_helper.UrlError( + requests.ConnectionError( + "Failed to establish a new connection: " + "[Errno 101] Network is unreachable" + ) + ), + b"ovf data", + ] report_file = self.tmp_path("report_marker", self.tmp) m_isfile.return_value = True distro = mock.MagicMock() @@ -3358,12 +3092,12 @@ class TestPreprovisioningPollIMDS(CiTestCase): self.assertEqual(1, m_dhcp_ctx.clean_network.call_count) self.assertEqual(1, m_ephemeral_dhcpv4.call_count) self.assertEqual(0, m_media_switch.call_count) - self.assertEqual(2, m_request.call_count) + self.assertEqual(2, m_fetch_reprovisiondata.call_count) def test_does_not_poll_imds_report_ready_when_marker_file_exists( self, m_report_ready, - m_request, + m_fetch_reprovisiondata, m_media_switch, m_dhcp, m_net, @@ -3390,10 +3124,12 @@ class TestPreprovisioningPollIMDS(CiTestCase): dsa._poll_imds() self.assertEqual(m_report_ready.call_count, 0) + @mock.patch(MOCKPATH + "imds.fetch_metadata_with_api_fallback") def test_poll_imds_report_ready_success_writes_marker_file( self, + m_fetch, m_report_ready, - m_request, + m_fetch_reprovisiondata, m_media_switch, m_dhcp, m_net, @@ -3426,7 +3162,7 @@ class TestPreprovisioningPollIMDS(CiTestCase): def test_poll_imds_report_ready_failure_raises_exc_and_doesnt_write_marker( self, m_report_ready, - m_request, + m_fetch_reprovisiondata, m_media_switch, m_dhcp, m_net, @@ -3466,7 +3202,9 @@ class TestPreprovisioningPollIMDS(CiTestCase): ) @mock.patch("cloudinit.net.ephemeral.EphemeralIPv4Network", autospec=True) @mock.patch("cloudinit.net.ephemeral.maybe_perform_dhcp_discovery") -@mock.patch("requests.Session.request") +@mock.patch( + MOCKPATH + "imds.fetch_reprovision_data", side_effect=[b"ovf data"] +) class TestAzureDataSourcePreprovisioning(CiTestCase): def setUp(self): super(TestAzureDataSourcePreprovisioning, self).setUp() @@ -3476,7 +3214,7 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): dsaz.BUILTIN_DS_CONFIG["data_dir"] = self.waagent_d def test_poll_imds_returns_ovf_env( - self, m_request, m_dhcp, m_net, m_media_switch + self, m_fetch_reprovisiondata, m_dhcp, m_net, m_media_switch ): """The _poll_imds method should return the ovf_env.xml.""" m_media_switch.return_value = None @@ -3488,30 +3226,8 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): "subnet-mask": "255.255.255.0", } ] - url = "http://{0}/metadata/reprovisiondata?api-version=2019-06-01" - host = "169.254.169.254" - full_url = url.format(host) - m_request.return_value = mock.MagicMock( - status_code=200, text="ovf", content="ovf" - ) dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) self.assertTrue(len(dsa._poll_imds()) > 0) - self.assertEqual( - m_request.call_args_list, - [ - mock.call( - allow_redirects=True, - headers={ - "Metadata": "true", - "User-Agent": "Cloud-Init/%s" % vs(), - }, - method="GET", - timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, - url=full_url, - stream=False, - ) - ], - ) self.assertEqual(m_dhcp.call_count, 2) m_net.assert_any_call( broadcast="192.168.2.255", @@ -3524,7 +3240,7 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): self.assertEqual(m_net.call_count, 2) def test__reprovision_calls__poll_imds( - self, m_request, m_dhcp, m_net, m_media_switch + self, m_fetch_reprovisiondata, m_dhcp, m_net, m_media_switch ): """The _reprovision method should call poll IMDS.""" m_media_switch.return_value = None @@ -3537,33 +3253,14 @@ class TestAzureDataSourcePreprovisioning(CiTestCase): "unknown-245": "624c3620", } ] - url = "http://{0}/metadata/reprovisiondata?api-version=2019-06-01" - host = "169.254.169.254" - full_url = url.format(host) hostname = "myhost" username = "myuser" content = construct_ovf_env(username=username, hostname=hostname) - m_request.return_value = mock.MagicMock( - status_code=200, text=content, content=content - ) + m_fetch_reprovisiondata.side_effect = [content] dsa = dsaz.DataSourceAzure({}, distro=mock.Mock(), paths=self.paths) md, _ud, cfg, _d = dsa._reprovision() self.assertEqual(md["local-hostname"], hostname) self.assertEqual(cfg["system_info"]["default_user"]["name"], username) - self.assertIn( - mock.call( - allow_redirects=True, - headers={ - "Metadata": "true", - "User-Agent": "Cloud-Init/%s" % vs(), - }, - method="GET", - timeout=dsaz.IMDS_TIMEOUT_IN_SECONDS, - url=full_url, - stream=False, - ), - m_request.call_args_list, - ) self.assertEqual(m_dhcp.call_count, 2) m_net.assert_any_call( broadcast="192.168.2.255", @@ -3908,187 +3605,6 @@ def fake_http_error_for_code(status_code: int): ) -@pytest.mark.parametrize( - "md_type,expected_url", - [ - ( - dsaz.MetadataType.ALL, - "http://169.254.169.254/metadata/instance?" - "api-version=2021-08-01&extended=true", - ), - ( - dsaz.MetadataType.NETWORK, - "http://169.254.169.254/metadata/instance/network?" - "api-version=2021-08-01", - ), - ( - dsaz.MetadataType.REPROVISION_DATA, - "http://169.254.169.254/metadata/reprovisiondata?" - "api-version=2021-08-01", - ), - ], -) -class TestIMDS: - def test_basic_scenarios( - self, azure_ds, caplog, mock_readurl, md_type, expected_url - ): - fake_md = {"foo": {"bar": []}} - mock_readurl.side_effect = [ - mock.MagicMock(contents=json.dumps(fake_md).encode()), - ] - - md = azure_ds.get_imds_data_with_api_fallback( - retries=5, - md_type=md_type, - ) - - assert md == fake_md - assert mock_readurl.mock_calls == [ - mock.call( - expected_url, - timeout=2, - headers={"Metadata": "true"}, - retries=5, - exception_cb=dsaz.imds_readurl_exception_callback, - infinite=False, - ), - ] - - warnings = [ - x.message for x in caplog.records if x.levelno == logging.WARNING - ] - assert warnings == [] - - @pytest.mark.parametrize( - "error", - [ - fake_http_error_for_code(404), - fake_http_error_for_code(410), - fake_http_error_for_code(429), - fake_http_error_for_code(500), - requests.ConnectionError("Fake connection error"), - requests.Timeout("Fake connection timeout"), - ], - ) - def test_will_retry_errors( - self, - azure_ds, - caplog, - md_type, - expected_url, - mock_requests_session_request, - mock_url_helper_time_sleep, - error, - ): - fake_md = {"foo": {"bar": []}} - mock_requests_session_request.side_effect = [ - error, - mock.Mock(content=json.dumps(fake_md)), - ] - - md = azure_ds.get_imds_data_with_api_fallback( - retries=5, - md_type=md_type, - ) - - assert md == fake_md - assert len(mock_requests_session_request.mock_calls) == 2 - assert mock_url_helper_time_sleep.mock_calls == [mock.call(1)] - - warnings = [ - x.message for x in caplog.records if x.levelno == logging.WARNING - ] - assert warnings == [] - - @pytest.mark.parametrize("retries", [0, 1, 5, 10]) - @pytest.mark.parametrize( - "error", - [ - fake_http_error_for_code(404), - fake_http_error_for_code(410), - fake_http_error_for_code(429), - fake_http_error_for_code(500), - requests.ConnectionError("Fake connection error"), - requests.Timeout("Fake connection timeout"), - ], - ) - def test_retry_until_failure( - self, - azure_ds, - caplog, - md_type, - expected_url, - mock_requests_session_request, - mock_url_helper_time_sleep, - error, - retries, - ): - mock_requests_session_request.side_effect = [error] * (retries + 1) - - assert ( - azure_ds.get_imds_data_with_api_fallback( - retries=retries, - md_type=md_type, - ) - == {} - ) - - assert len(mock_requests_session_request.mock_calls) == (retries + 1) - assert ( - mock_url_helper_time_sleep.mock_calls == [mock.call(1)] * retries - ) - - warnings = [ - x.message for x in caplog.records if x.levelno == logging.WARNING - ] - assert warnings == [ - "Ignoring IMDS instance metadata. " - "Get metadata from IMDS failed: %s" % error - ] - - @pytest.mark.parametrize( - "error", - [ - fake_http_error_for_code(403), - fake_http_error_for_code(501), - ], - ) - def test_will_not_retry_errors( - self, - azure_ds, - caplog, - md_type, - expected_url, - mock_requests_session_request, - mock_url_helper_time_sleep, - error, - ): - fake_md = {"foo": {"bar": []}} - mock_requests_session_request.side_effect = [ - error, - mock.Mock(content=json.dumps(fake_md)), - ] - - assert ( - azure_ds.get_imds_data_with_api_fallback( - retries=5, - md_type=md_type, - ) - == {} - ) - - assert len(mock_requests_session_request.mock_calls) == 1 - assert mock_url_helper_time_sleep.mock_calls == [] - - warnings = [ - x.message for x in caplog.records if x.levelno == logging.WARNING - ] - assert warnings == [ - "Ignoring IMDS instance metadata. " - "Get metadata from IMDS failed: %s" % error - ] - - class TestInstanceId: def test_metadata(self, azure_ds, mock_dmi_read_dmi_data): azure_ds.metadata = {"instance-id": "test-id"} @@ -4193,8 +3709,9 @@ class TestProvisioning: timeout=2, headers={"Metadata": "true"}, retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, + exception_cb=imds._readurl_exception_callback, infinite=False, + log_req_resp=True, ), ] @@ -4252,29 +3769,31 @@ class TestProvisioning: mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/reprovisiondata?" "api-version=2019-06-01", - timeout=2, - headers={"Metadata": "true"}, exception_cb=mock.ANY, - infinite=True, + headers={"Metadata": "true"}, log_req_resp=False, + infinite=True, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), ] @@ -4355,38 +3874,41 @@ class TestProvisioning: mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), mock.call( - "http://169.254.169.254/metadata/instance/network?" - "api-version=2021-08-01", - timeout=2, + "http://169.254.169.254/metadata/instance?" + "api-version=2021-08-01&extended=true", + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=0, - exception_cb=mock.ANY, - infinite=True, + infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/reprovisiondata?" "api-version=2019-06-01", - timeout=2, - headers={"Metadata": "true"}, exception_cb=mock.ANY, - infinite=True, + headers={"Metadata": "true"}, log_req_resp=False, + infinite=True, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), ] @@ -4503,38 +4025,41 @@ class TestProvisioning: mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), mock.call( - "http://169.254.169.254/metadata/instance/network?" - "api-version=2021-08-01", - timeout=2, + "http://169.254.169.254/metadata/instance?" + "api-version=2021-08-01&extended=true", + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=0, - exception_cb=mock.ANY, - infinite=True, + infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/reprovisiondata?" "api-version=2019-06-01", - timeout=2, - headers={"Metadata": "true"}, exception_cb=mock.ANY, + headers={"Metadata": "true"}, infinite=True, log_req_resp=False, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), ] @@ -4609,29 +4134,31 @@ class TestProvisioning: mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/reprovisiondata?" "api-version=2019-06-01", - timeout=2, - headers={"Metadata": "true"}, exception_cb=mock.ANY, + headers={"Metadata": "true"}, infinite=True, log_req_resp=False, + timeout=2, ), mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, + log_req_resp=True, + retries=10, + timeout=2, ), ] @@ -4688,12 +4215,13 @@ class TestProvisioning: mock.call( "http://169.254.169.254/metadata/instance?" "api-version=2021-08-01&extended=true", - timeout=2, + exception_cb=imds._readurl_exception_callback, headers={"Metadata": "true"}, - retries=10, - exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, - ) + log_req_resp=True, + retries=10, + timeout=2, + ), ] assert self.mock_subp_subp.mock_calls == [] -- cgit v1.2.1