From ddc51c4ace0caa0787715801b9df42e65c790d46 Mon Sep 17 00:00:00 2001 From: Bar Shaul <88437685+barshaul@users.noreply.github.com> Date: Thu, 23 Dec 2021 11:42:30 +0200 Subject: Support for specifying error types with retry (#1817) --- redis/client.py | 21 ++++++--- redis/connection.py | 27 ++++++++++-- redis/retry.py | 8 ++++ tests/test_retry.py | 125 +++++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 169 insertions(+), 12 deletions(-) diff --git a/redis/client.py b/redis/client.py index c7aa17b..0236f20 100755 --- a/redis/client.py +++ b/redis/client.py @@ -869,6 +869,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): errors=None, decode_responses=False, retry_on_timeout=False, + retry_on_error=[], ssl=False, ssl_keyfile=None, ssl_certfile=None, @@ -887,8 +888,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): ): """ Initialize a new Redis client. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ if not connection_pool: if charset is not None: @@ -905,7 +908,8 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): ) ) encoding_errors = errors - + if retry_on_timeout is True: + retry_on_error.append(TimeoutError) kwargs = { "db": db, "username": username, @@ -914,7 +918,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): "encoding": encoding, "encoding_errors": encoding_errors, "decode_responses": decode_responses, - "retry_on_timeout": retry_on_timeout, + "retry_on_error": retry_on_error, "retry": copy.deepcopy(retry), "max_connections": max_connections, "health_check_interval": health_check_interval, @@ -1146,11 +1150,14 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): def _disconnect_raise(self, conn, error): """ Close the connection and raise an exception - if retry_on_timeout is not set or the error - is not a TimeoutError + if retry_on_error is not set or the error + is not one of the specified error types """ conn.disconnect() - if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): raise error # COMMAND EXECUTION AND PROTOCOL PARSING diff --git a/redis/connection.py b/redis/connection.py index 3fe8543..a349a0f 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -513,6 +513,7 @@ class Connection: socket_keepalive_options=None, socket_type=0, retry_on_timeout=False, + retry_on_error=[], encoding="utf-8", encoding_errors="strict", decode_responses=False, @@ -526,8 +527,10 @@ class Connection: ): """ Initialize a new Connection. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ self.pid = os.getpid() self.host = host @@ -543,11 +546,17 @@ class Connection: self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_timeout: + # Add TimeoutError to the errors list to retry on + retry_on_error.append(TimeoutError) + self.retry_on_error = retry_on_error + if retry_on_error: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors + self.retry.update_supported_erros(retry_on_error) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval @@ -969,6 +978,7 @@ class UnixDomainSocketConnection(Connection): encoding_errors="strict", decode_responses=False, retry_on_timeout=False, + retry_on_error=[], parser_class=DefaultParser, socket_read_size=65536, health_check_interval=0, @@ -978,8 +988,10 @@ class UnixDomainSocketConnection(Connection): ): """ Initialize a new UnixDomainSocketConnection. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ self.pid = os.getpid() self.path = path @@ -990,11 +1002,17 @@ class UnixDomainSocketConnection(Connection): self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout if retry_on_timeout: + # Add TimeoutError to the errors list to retry on + retry_on_error.append(TimeoutError) + self.retry_on_error = retry_on_error + if self.retry_on_error: if retry is None: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors + self.retry.update_supported_erros(retry_on_error) else: self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval @@ -1052,6 +1070,7 @@ URL_QUERY_ARGUMENT_PARSERS = { "socket_connect_timeout": float, "socket_keepalive": to_bool, "retry_on_timeout": to_bool, + "retry_on_error": list, "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, diff --git a/redis/retry.py b/redis/retry.py index 75504c7..6147fbd 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -19,6 +19,14 @@ class Retry: self._retries = retries self._supported_errors = supported_errors + def update_supported_erros(self, specified_errors: list): + """ + Updates the supported errors with the specified error types + """ + self._supported_errors = tuple( + set(self._supported_errors + tuple(specified_errors)) + ) + def call_with_retry(self, do, fail): """ Execute an operation that might fail and returns its result, or diff --git a/tests/test_retry.py b/tests/test_retry.py index c4650bc..0094787 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,10 +1,20 @@ +from unittest.mock import patch + import pytest from redis.backoff import NoBackoff +from redis.client import Redis from redis.connection import Connection, UnixDomainSocketConnection -from redis.exceptions import ConnectionError +from redis.exceptions import ( + BusyLoadingError, + ConnectionError, + ReadOnlyError, + TimeoutError, +) from redis.retry import Retry +from .conftest import _get_client + class BackoffMock: def __init__(self): @@ -39,6 +49,37 @@ class TestConnectionConstructorWithRetry: assert isinstance(c.retry, Retry) assert c.retry._retries == retries + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_error(self, Class): + c = Class(retry_on_error=[ReadOnlyError]) + assert c.retry_on_error == [ReadOnlyError] + assert isinstance(c.retry, Retry) + assert c.retry._retries == 1 + + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_error_empty_value(self, Class): + c = Class(retry_on_error=[]) + assert c.retry_on_error == [] + assert isinstance(c.retry, Retry) + assert c.retry._retries == 0 + + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_error_and_timeout(self, Class): + c = Class( + retry_on_error=[ReadOnlyError, BusyLoadingError], retry_on_timeout=True + ) + assert c.retry_on_error == [ReadOnlyError, BusyLoadingError, TimeoutError] + assert isinstance(c.retry, Retry) + assert c.retry._retries == 1 + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_error_retry(self, Class, retries): + c = Class(retry_on_error=[ReadOnlyError], retry=Retry(NoBackoff(), retries)) + assert c.retry_on_error == [ReadOnlyError] + assert isinstance(c.retry, Retry) + assert c.retry._retries == retries + class TestRetry: "Test that Retry calls backoff and retries the expected number of times" @@ -65,3 +106,85 @@ class TestRetry: assert self.actual_failures == 1 + retries assert backoff.reset_calls == 1 assert backoff.calls == retries + + +@pytest.mark.onlynoncluster +class TestRedisClientRetry: + "Test the standalone Redis client behavior with retries" + + def test_client_retry_on_error_with_success(self, request): + with patch.object(Redis, "parse_response") as parse_response: + + def mock_parse_response(connection, *args, **options): + def ok_response(connection, *args, **options): + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise ReadOnlyError() + + parse_response.side_effect = mock_parse_response + r = _get_client(Redis, request, retry_on_error=[ReadOnlyError]) + assert r.get("foo") == "MOCK_OK" + assert parse_response.call_count == 2 + + def test_client_retry_on_error_raise(self, request): + with patch.object(Redis, "parse_response") as parse_response: + parse_response.side_effect = BusyLoadingError() + retries = 3 + r = _get_client( + Redis, + request, + retry_on_error=[ReadOnlyError, BusyLoadingError], + retry=Retry(NoBackoff(), retries), + ) + with pytest.raises(BusyLoadingError): + try: + r.get("foo") + finally: + assert parse_response.call_count == retries + 1 + + def test_client_retry_on_error_different_error_raised(self, request): + with patch.object(Redis, "parse_response") as parse_response: + parse_response.side_effect = TimeoutError() + retries = 3 + r = _get_client( + Redis, + request, + retry_on_error=[ReadOnlyError], + retry=Retry(NoBackoff(), retries), + ) + with pytest.raises(TimeoutError): + try: + r.get("foo") + finally: + assert parse_response.call_count == 1 + + def test_client_retry_on_error_and_timeout(self, request): + with patch.object(Redis, "parse_response") as parse_response: + parse_response.side_effect = TimeoutError() + retries = 3 + r = _get_client( + Redis, + request, + retry_on_error=[ReadOnlyError], + retry_on_timeout=True, + retry=Retry(NoBackoff(), retries), + ) + with pytest.raises(TimeoutError): + try: + r.get("foo") + finally: + assert parse_response.call_count == retries + 1 + + def test_client_retry_on_timeout(self, request): + with patch.object(Redis, "parse_response") as parse_response: + parse_response.side_effect = TimeoutError() + retries = 3 + r = _get_client( + Redis, request, retry_on_timeout=True, retry=Retry(NoBackoff(), retries) + ) + with pytest.raises(TimeoutError): + try: + r.get("foo") + finally: + assert parse_response.call_count == retries + 1 -- cgit v1.2.1