From e19a76c58f2a998d86e51c5a2a0f1db37563efce Mon Sep 17 00:00:00 2001 From: nbraun-amazon <85549956+nbraun-amazon@users.noreply.github.com> Date: Wed, 18 Aug 2021 12:06:09 +0300 Subject: Add retry mechanism with backoff (#1494) --- redis/backoff.py | 105 +++++++++++++++++++++++++++++++ redis/client.py | 177 ++++++++++++++++++++++++++++++++-------------------- redis/connection.py | 55 ++++++++++------ redis/retry.py | 40 ++++++++++++ tests/conftest.py | 3 + tests/test_retry.py | 66 ++++++++++++++++++++ 6 files changed, 360 insertions(+), 86 deletions(-) create mode 100644 redis/backoff.py create mode 100644 redis/retry.py create mode 100644 tests/test_retry.py diff --git a/redis/backoff.py b/redis/backoff.py new file mode 100644 index 0000000..9162778 --- /dev/null +++ b/redis/backoff.py @@ -0,0 +1,105 @@ +from abc import ABC, abstractmethod +import random + + +class AbstractBackoff(ABC): + """Backoff interface""" + + def reset(self): + """ + Reset internal state before an operation. + `reset` is called once at the beginning of + every call to `Retry.call_with_retry` + """ + pass + + @abstractmethod + def compute(self, failures): + """Compute backoff in seconds upon failure""" + pass + + +class ConstantBackoff(AbstractBackoff): + """Constant backoff upon failure""" + + def __init__(self, backoff): + """`backoff`: backoff time in seconds""" + self._backoff = backoff + + def compute(self, failures): + return self._backoff + + +class NoBackoff(ConstantBackoff): + """No backoff upon failure""" + + def __init__(self): + super().__init__(0) + + +class ExponentialBackoff(AbstractBackoff): + """Exponential backoff upon failure""" + + def __init__(self, cap, base): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures): + return min(self._cap, self._base * 2 ** failures) + + +class FullJitterBackoff(AbstractBackoff): + """Full jitter backoff upon failure""" + + def __init__(self, cap, base): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures): + return random.uniform(0, min(self._cap, self._base * 2 ** failures)) + + +class EqualJitterBackoff(AbstractBackoff): + """Equal jitter backoff upon failure""" + + def __init__(self, cap, base): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + + def compute(self, failures): + temp = min(self._cap, self._base * 2 ** failures) / 2 + return temp + random.uniform(0, temp) + + +class DecorrelatedJitterBackoff(AbstractBackoff): + """Decorrelated jitter backoff upon failure""" + + def __init__(self, cap, base): + """ + `cap`: maximum backoff time in seconds + `base`: base backoff time in seconds + """ + self._cap = cap + self._base = base + self._previous_backoff = 0 + + def reset(self): + self._previous_backoff = 0 + + def compute(self, failures): + max_backoff = max(self._base, self._previous_backoff * 3) + temp = random.uniform(self._base, max_backoff) + self._previous_backoff = min(self._cap, temp) + return self._previous_backoff diff --git a/redis/client.py b/redis/client.py index 741c2d0..ab9246d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,4 +1,5 @@ from itertools import chain +import copy import datetime import hashlib import re @@ -758,7 +759,13 @@ class Redis(Commands, object): ssl_cert_reqs='required', ssl_ca_certs=None, ssl_check_hostname=False, max_connections=None, single_connection_client=False, - health_check_interval=0, client_name=None, username=None): + health_check_interval=0, client_name=None, username=None, + retry=None): + """ + Initialize a new Redis client. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ if not connection_pool: if charset is not None: warnings.warn(DeprecationWarning( @@ -778,6 +785,7 @@ class Redis(Commands, object): 'encoding_errors': encoding_errors, 'decode_responses': decode_responses, 'retry_on_timeout': retry_on_timeout, + 'retry': copy.deepcopy(retry), 'max_connections': max_connections, 'health_check_interval': health_check_interval, 'client_name': client_name @@ -940,21 +948,41 @@ class Redis(Commands, object): self.connection = None self.connection_pool.release(conn) + def _send_command_parse_response(self, + conn, + command_name, + *args, + **options): + """ + Send a command and parse the response + """ + conn.send_command(*args) + return self.parse_response(conn, command_name, **options) + + def _disconnect_raise(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError + """ + conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): "Execute a command and return a parsed response" pool = self.connection_pool command_name = args[0] conn = self.connection or pool.get_connection(command_name, **options) + try: - conn.send_command(*args) - return self.parse_response(conn, command_name, **options) - except (ConnectionError, TimeoutError) as e: - conn.disconnect() - if not (conn.retry_on_timeout and isinstance(e, TimeoutError)): - raise - conn.send_command(*args) - return self.parse_response(conn, command_name, **options) + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response(conn, + command_name, + *args, + **options), + lambda error: self._disconnect_raise(conn, error)) finally: if not self.connection: pool.release(conn) @@ -1142,24 +1170,31 @@ class PubSub: kwargs = {'check_health': not self.subscribed} self._execute(connection, connection.send_command, *args, **kwargs) - def _execute(self, connection, command, *args, **kwargs): - try: - return command(*args, **kwargs) - except (ConnectionError, TimeoutError) as e: - connection.disconnect() - if not (connection.retry_on_timeout and - isinstance(e, TimeoutError)): - raise - # Connect manually here. If the Redis server is down, this will - # fail and raise a ConnectionError as desired. - connection.connect() - # the ``on_connect`` callback should haven been called by the - # connection to resubscribe us to any channels and patterns we were - # previously listening to - return command(*args, **kwargs) + def _disconnect_raise_connect(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + conn.connect() + + def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error)) def parse_response(self, block=True, timeout=0): - "Parse the response from a publish/subscribe command" + """Parse the response from a publish/subscribe command""" conn = self.connection if conn is None: raise RuntimeError( @@ -1499,6 +1534,27 @@ class Pipeline(Redis): return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) + def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + self.reset() + raise WatchError("A ConnectionError occurred on while " + "watching one or more keys") + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + self.reset() + raise + def immediate_execute_command(self, *args, **options): """ Execute a command immediately, but don't auto-retry on a @@ -1513,33 +1569,13 @@ class Pipeline(Redis): conn = self.connection_pool.get_connection(command_name, self.shard_hint) self.connection = conn - try: - conn.send_command(*args) - return self.parse_response(conn, command_name, **options) - except (ConnectionError, TimeoutError) as e: - conn.disconnect() - # if we were already watching a variable, the watch is no longer - # valid since this connection has died. raise a WatchError, which - # indicates the user should retry this transaction. - if self.watching: - self.reset() - raise WatchError("A ConnectionError occurred on while " - "watching one or more keys") - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn.retry_on_timeout and isinstance(e, TimeoutError)): - self.reset() - raise - - # retry_on_timeout is set, this is a TimeoutError and we are not - # already WATCHing any variables. retry the command. - try: - conn.send_command(*args) - return self.parse_response(conn, command_name, **options) - except (ConnectionError, TimeoutError): - # a subsequent failure should simply be raised - self.reset() - raise + + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response(conn, + command_name, + *args, + **options), + lambda error: self._disconnect_reset_raise(conn, error)) def pipeline_execute_command(self, *args, **options): """ @@ -1672,6 +1708,25 @@ class Pipeline(Redis): if not exist: s.sha = immediate('SCRIPT LOAD', s.script) + def _disconnect_raise_reset(self, conn, error): + """ + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError + """ + conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError("A ConnectionError occurred on while " + "watching one or more keys") + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + self.reset() + raise + def execute(self, raise_on_error=True): "Execute all the commands in the current pipeline" stack = self.command_stack @@ -1693,21 +1748,9 @@ class Pipeline(Redis): self.connection = conn try: - return execute(conn, stack, raise_on_error) - except (ConnectionError, TimeoutError) as e: - conn.disconnect() - # if we were watching a variable, the watch is no longer valid - # since this connection has died. raise a WatchError, which - # indicates the user should retry this transaction. - if self.watching: - raise WatchError("A ConnectionError occurred on while " - "watching one or more keys") - # if retry_on_timeout is not set, or the error is not - # a TimeoutError, raise it - if not (conn.retry_on_timeout and isinstance(e, TimeoutError)): - raise - # retry a TimeoutError when retry_on_timeout is set - return execute(conn, stack, raise_on_error) + return conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error)) finally: self.reset() diff --git a/redis/connection.py b/redis/connection.py index 4a855b3..e47e3c7 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -3,6 +3,7 @@ from itertools import chain from time import time from queue import LifoQueue, Empty, Full from urllib.parse import parse_qs, unquote, urlparse +import copy import errno import io import os @@ -28,6 +29,8 @@ from redis.exceptions import ( ModuleError, ) from redis.utils import HIREDIS_AVAILABLE, str_if_bytes +from redis.backoff import NoBackoff +from redis.retry import Retry try: import ssl @@ -499,7 +502,13 @@ class Connection: socket_type=0, retry_on_timeout=False, encoding='utf-8', encoding_errors='strict', decode_responses=False, parser_class=DefaultParser, socket_read_size=65536, - health_check_interval=0, client_name=None, username=None): + health_check_interval=0, client_name=None, username=None, + retry=None): + """ + Initialize a new Connection. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ self.pid = os.getpid() self.host = host self.port = int(port) @@ -513,6 +522,14 @@ class Connection: self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) @@ -673,23 +690,23 @@ class Connection: pass self._sock = None + def _send_ping(self): + """Send PING, expect PONG in return""" + self.send_command('PING', check_health=False) + if str_if_bytes(self.read_response()) != 'PONG': + raise ConnectionError('Bad response from PING health check') + + def _ping_failed(self, error): + """Function to call when PING fails""" + self.disconnect() + def check_health(self): - "Check the health of the connection with a PING/PONG" + """Check the health of the connection with a PING/PONG""" if self.health_check_interval and time() > self.next_health_check: - try: - self.send_command('PING', check_health=False) - if str_if_bytes(self.read_response()) != 'PONG': - raise ConnectionError( - 'Bad response from PING health check') - except (ConnectionError, TimeoutError): - self.disconnect() - self.send_command('PING', check_health=False) - if str_if_bytes(self.read_response()) != 'PONG': - raise ConnectionError( - 'Bad response from PING health check') + self.retry.call_with_retry(self._send_ping, self._ping_failed) def send_packed_command(self, command, check_health=True): - "Send an already packed command to the Redis server" + """Send an already packed command to the Redis server""" if not self._sock: self.connect() # guard against health check recursion @@ -717,12 +734,12 @@ class Connection: raise def send_command(self, *args, **kwargs): - "Pack and send a command to the Redis server" + """Pack and send a command to the Redis server""" self.send_packed_command(self.pack_command(*args), check_health=kwargs.get('check_health', True)) def can_read(self, timeout=0): - "Poll the socket to see if there's data that can be read." + """Poll the socket to see if there's data that can be read.""" sock = self._sock if not sock: self.connect() @@ -730,7 +747,7 @@ class Connection: return self._parser.can_read(timeout) def read_response(self): - "Read the response from a previously sent command" + """Read the response from a previously sent command""" try: response = self._parser.read_response() except socket.timeout: @@ -753,7 +770,7 @@ class Connection: return response def pack_command(self, *args): - "Pack a series of arguments into the Redis protocol" + """Pack a series of arguments into the Redis protocol""" output = [] # the client might have included 1 or more literal arguments in # the command name, e.g., 'CONFIG GET'. The Redis server expects these @@ -787,7 +804,7 @@ class Connection: return output def pack_commands(self, commands): - "Pack multiple commands into the Redis protocol" + """Pack multiple commands into the Redis protocol""" output = [] pieces = [] buffer_length = 0 diff --git a/redis/retry.py b/redis/retry.py new file mode 100644 index 0000000..cd06a23 --- /dev/null +++ b/redis/retry.py @@ -0,0 +1,40 @@ +from time import sleep + +from redis.exceptions import ConnectionError, TimeoutError + + +class Retry: + """Retry a specific number of times after a failure""" + + def __init__(self, backoff, retries, + supported_errors=(ConnectionError, TimeoutError)): + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors + + def call_with_retry(self, do, fail): + """ + Execute an operation that might fail and returns its result, or + raise the exception that was thrown depending on the `Backoff` object. + `do`: the operation to call. Expects no argument. + `fail`: the failure handler, expects the last error that was thrown + """ + self._backoff.reset() + failures = 0 + while True: + try: + return do() + except self._supported_errors as error: + failures += 1 + fail(error) + if failures > self._retries: + raise error + backoff = self._backoff.compute(failures) + if backoff > 0: + sleep(backoff) diff --git a/tests/conftest.py b/tests/conftest.py index cd4d489..711f9e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from redis.backoff import NoBackoff +from redis.retry import Retry import pytest import random import redis @@ -107,6 +109,7 @@ def r2(request): def _gen_cluster_mock_resp(r, response): connection = Mock() + connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response r.connection = connection return r diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 0000000..24d9683 --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,66 @@ +from redis.backoff import NoBackoff +import pytest + +from redis.exceptions import ConnectionError +from redis.connection import Connection +from redis.retry import Retry + + +class BackoffMock: + def __init__(self): + self.reset_calls = 0 + self.calls = 0 + + def reset(self): + self.reset_calls += 1 + + def compute(self, failures): + self.calls += 1 + return 0 + + +class TestConnectionConstructorWithRetry: + "Test that the Connection constructor properly handles Retry objects" + + @pytest.mark.parametrize("retry_on_timeout", [False, True]) + def test_retry_on_timeout_boolean(self, retry_on_timeout): + c = Connection(retry_on_timeout=retry_on_timeout) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == (1 if retry_on_timeout else 0) + + @pytest.mark.parametrize("retries", range(10)) + def test_retry_on_timeout_retry(self, retries): + retry_on_timeout = retries > 0 + c = Connection(retry_on_timeout=retry_on_timeout, + retry=Retry(NoBackoff(), retries)) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == retries + + +class TestRetry: + "Test that Retry calls backoff and retries the expected number of times" + + def setup_method(self, test_method): + self.actual_attempts = 0 + self.actual_failures = 0 + + def _do(self): + self.actual_attempts += 1 + raise ConnectionError() + + def _fail(self, error): + self.actual_failures += 1 + + @pytest.mark.parametrize("retries", range(10)) + def test_retry(self, retries): + backoff = BackoffMock() + retry = Retry(backoff, retries) + with pytest.raises(ConnectionError): + retry.call_with_retry(self._do, self._fail) + + assert self.actual_attempts == 1 + retries + assert self.actual_failures == 1 + retries + assert backoff.reset_calls == 1 + assert backoff.calls == retries -- cgit v1.2.1