summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornbraun-amazon <85549956+nbraun-amazon@users.noreply.github.com>2021-08-18 12:06:09 +0300
committerGitHub <noreply@github.com>2021-08-18 12:06:09 +0300
commite19a76c58f2a998d86e51c5a2a0f1db37563efce (patch)
tree876614bb653f6df4006ab64cece4078d0355f067
parentb96af52e012bc002df97c4a82a5e4ad389cea3f3 (diff)
downloadredis-py-e19a76c58f2a998d86e51c5a2a0f1db37563efce.tar.gz
Add retry mechanism with backoff (#1494)
-rw-r--r--redis/backoff.py105
-rwxr-xr-xredis/client.py177
-rwxr-xr-xredis/connection.py55
-rw-r--r--redis/retry.py40
-rw-r--r--tests/conftest.py3
-rw-r--r--tests/test_retry.py66
6 files changed, 360 insertions, 86 deletions
diff --git a/redis/backoff.py b/redis/backoff.py
new file mode 100644
index 0000000..9162778
--- /dev/null
+++ b/redis/backoff.py
@@ -0,0 +1,105 @@
+from abc import ABC, abstractmethod
+import random
+
+
+class AbstractBackoff(ABC):
+ """Backoff interface"""
+
+ def reset(self):
+ """
+ Reset internal state before an operation.
+ `reset` is called once at the beginning of
+ every call to `Retry.call_with_retry`
+ """
+ pass
+
+ @abstractmethod
+ def compute(self, failures):
+ """Compute backoff in seconds upon failure"""
+ pass
+
+
+class ConstantBackoff(AbstractBackoff):
+ """Constant backoff upon failure"""
+
+ def __init__(self, backoff):
+ """`backoff`: backoff time in seconds"""
+ self._backoff = backoff
+
+ def compute(self, failures):
+ return self._backoff
+
+
+class NoBackoff(ConstantBackoff):
+ """No backoff upon failure"""
+
+ def __init__(self):
+ super().__init__(0)
+
+
+class ExponentialBackoff(AbstractBackoff):
+ """Exponential backoff upon failure"""
+
+ def __init__(self, cap, base):
+ """
+ `cap`: maximum backoff time in seconds
+ `base`: base backoff time in seconds
+ """
+ self._cap = cap
+ self._base = base
+
+ def compute(self, failures):
+ return min(self._cap, self._base * 2 ** failures)
+
+
+class FullJitterBackoff(AbstractBackoff):
+ """Full jitter backoff upon failure"""
+
+ def __init__(self, cap, base):
+ """
+ `cap`: maximum backoff time in seconds
+ `base`: base backoff time in seconds
+ """
+ self._cap = cap
+ self._base = base
+
+ def compute(self, failures):
+ return random.uniform(0, min(self._cap, self._base * 2 ** failures))
+
+
+class EqualJitterBackoff(AbstractBackoff):
+ """Equal jitter backoff upon failure"""
+
+ def __init__(self, cap, base):
+ """
+ `cap`: maximum backoff time in seconds
+ `base`: base backoff time in seconds
+ """
+ self._cap = cap
+ self._base = base
+
+ def compute(self, failures):
+ temp = min(self._cap, self._base * 2 ** failures) / 2
+ return temp + random.uniform(0, temp)
+
+
+class DecorrelatedJitterBackoff(AbstractBackoff):
+ """Decorrelated jitter backoff upon failure"""
+
+ def __init__(self, cap, base):
+ """
+ `cap`: maximum backoff time in seconds
+ `base`: base backoff time in seconds
+ """
+ self._cap = cap
+ self._base = base
+ self._previous_backoff = 0
+
+ def reset(self):
+ self._previous_backoff = 0
+
+ def compute(self, failures):
+ max_backoff = max(self._base, self._previous_backoff * 3)
+ temp = random.uniform(self._base, max_backoff)
+ self._previous_backoff = min(self._cap, temp)
+ return self._previous_backoff
diff --git a/redis/client.py b/redis/client.py
index 741c2d0..ab9246d 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -1,4 +1,5 @@
from itertools import chain
+import copy
import datetime
import hashlib
import re
@@ -758,7 +759,13 @@ class Redis(Commands, object):
ssl_cert_reqs='required', ssl_ca_certs=None,
ssl_check_hostname=False,
max_connections=None, single_connection_client=False,
- health_check_interval=0, client_name=None, username=None):
+ health_check_interval=0, client_name=None, username=None,
+ retry=None):
+ """
+ Initialize a new Redis client.
+ To specify a retry policy, first set `retry_on_timeout` to `True`
+ then set `retry` to a valid `Retry` object
+ """
if not connection_pool:
if charset is not None:
warnings.warn(DeprecationWarning(
@@ -778,6 +785,7 @@ class Redis(Commands, object):
'encoding_errors': encoding_errors,
'decode_responses': decode_responses,
'retry_on_timeout': retry_on_timeout,
+ 'retry': copy.deepcopy(retry),
'max_connections': max_connections,
'health_check_interval': health_check_interval,
'client_name': client_name
@@ -940,21 +948,41 @@ class Redis(Commands, object):
self.connection = None
self.connection_pool.release(conn)
+ def _send_command_parse_response(self,
+ conn,
+ command_name,
+ *args,
+ **options):
+ """
+ Send a command and parse the response
+ """
+ conn.send_command(*args)
+ return self.parse_response(conn, command_name, **options)
+
+ def _disconnect_raise(self, conn, error):
+ """
+ Close the connection and raise an exception
+ if retry_on_timeout is not set or the error
+ is not a TimeoutError
+ """
+ conn.disconnect()
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ raise error
+
# COMMAND EXECUTION AND PROTOCOL PARSING
def execute_command(self, *args, **options):
"Execute a command and return a parsed response"
pool = self.connection_pool
command_name = args[0]
conn = self.connection or pool.get_connection(command_name, **options)
+
try:
- conn.send_command(*args)
- return self.parse_response(conn, command_name, **options)
- except (ConnectionError, TimeoutError) as e:
- conn.disconnect()
- if not (conn.retry_on_timeout and isinstance(e, TimeoutError)):
- raise
- conn.send_command(*args)
- return self.parse_response(conn, command_name, **options)
+ return conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(conn,
+ command_name,
+ *args,
+ **options),
+ lambda error: self._disconnect_raise(conn, error))
finally:
if not self.connection:
pool.release(conn)
@@ -1142,24 +1170,31 @@ class PubSub:
kwargs = {'check_health': not self.subscribed}
self._execute(connection, connection.send_command, *args, **kwargs)
- def _execute(self, connection, command, *args, **kwargs):
- try:
- return command(*args, **kwargs)
- except (ConnectionError, TimeoutError) as e:
- connection.disconnect()
- if not (connection.retry_on_timeout and
- isinstance(e, TimeoutError)):
- raise
- # Connect manually here. If the Redis server is down, this will
- # fail and raise a ConnectionError as desired.
- connection.connect()
- # the ``on_connect`` callback should haven been called by the
- # connection to resubscribe us to any channels and patterns we were
- # previously listening to
- return command(*args, **kwargs)
+ def _disconnect_raise_connect(self, conn, error):
+ """
+ Close the connection and raise an exception
+ if retry_on_timeout is not set or the error
+ is not a TimeoutError. Otherwise, try to reconnect
+ """
+ conn.disconnect()
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ raise error
+ conn.connect()
+
+ def _execute(self, conn, command, *args, **kwargs):
+ """
+ Connect manually upon disconnection. If the Redis server is down,
+ this will fail and raise a ConnectionError as desired.
+ After reconnection, the ``on_connect`` callback should have been
+ called by the # connection to resubscribe us to any channels and
+ patterns we were previously listening to
+ """
+ return conn.retry.call_with_retry(
+ lambda: command(*args, **kwargs),
+ lambda error: self._disconnect_raise_connect(conn, error))
def parse_response(self, block=True, timeout=0):
- "Parse the response from a publish/subscribe command"
+ """Parse the response from a publish/subscribe command"""
conn = self.connection
if conn is None:
raise RuntimeError(
@@ -1499,6 +1534,27 @@ class Pipeline(Redis):
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)
+ def _disconnect_reset_raise(self, conn, error):
+ """
+ Close the connection, reset watching state and
+ raise an exception if we were watching,
+ retry_on_timeout is not set,
+ or the error is not a TimeoutError
+ """
+ conn.disconnect()
+ # if we were already watching a variable, the watch is no longer
+ # valid since this connection has died. raise a WatchError, which
+ # indicates the user should retry this transaction.
+ if self.watching:
+ self.reset()
+ raise WatchError("A ConnectionError occurred on while "
+ "watching one or more keys")
+ # if retry_on_timeout is not set, or the error is not
+ # a TimeoutError, raise it
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ self.reset()
+ raise
+
def immediate_execute_command(self, *args, **options):
"""
Execute a command immediately, but don't auto-retry on a
@@ -1513,33 +1569,13 @@ class Pipeline(Redis):
conn = self.connection_pool.get_connection(command_name,
self.shard_hint)
self.connection = conn
- try:
- conn.send_command(*args)
- return self.parse_response(conn, command_name, **options)
- except (ConnectionError, TimeoutError) as e:
- conn.disconnect()
- # if we were already watching a variable, the watch is no longer
- # valid since this connection has died. raise a WatchError, which
- # indicates the user should retry this transaction.
- if self.watching:
- self.reset()
- raise WatchError("A ConnectionError occurred on while "
- "watching one or more keys")
- # if retry_on_timeout is not set, or the error is not
- # a TimeoutError, raise it
- if not (conn.retry_on_timeout and isinstance(e, TimeoutError)):
- self.reset()
- raise
-
- # retry_on_timeout is set, this is a TimeoutError and we are not
- # already WATCHing any variables. retry the command.
- try:
- conn.send_command(*args)
- return self.parse_response(conn, command_name, **options)
- except (ConnectionError, TimeoutError):
- # a subsequent failure should simply be raised
- self.reset()
- raise
+
+ return conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(conn,
+ command_name,
+ *args,
+ **options),
+ lambda error: self._disconnect_reset_raise(conn, error))
def pipeline_execute_command(self, *args, **options):
"""
@@ -1672,6 +1708,25 @@ class Pipeline(Redis):
if not exist:
s.sha = immediate('SCRIPT LOAD', s.script)
+ def _disconnect_raise_reset(self, conn, error):
+ """
+ Close the connection, raise an exception if we were watching,
+ and raise an exception if retry_on_timeout is not set,
+ or the error is not a TimeoutError
+ """
+ conn.disconnect()
+ # if we were watching a variable, the watch is no longer valid
+ # since this connection has died. raise a WatchError, which
+ # indicates the user should retry this transaction.
+ if self.watching:
+ raise WatchError("A ConnectionError occurred on while "
+ "watching one or more keys")
+ # if retry_on_timeout is not set, or the error is not
+ # a TimeoutError, raise it
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ self.reset()
+ raise
+
def execute(self, raise_on_error=True):
"Execute all the commands in the current pipeline"
stack = self.command_stack
@@ -1693,21 +1748,9 @@ class Pipeline(Redis):
self.connection = conn
try:
- return execute(conn, stack, raise_on_error)
- except (ConnectionError, TimeoutError) as e:
- conn.disconnect()
- # if we were watching a variable, the watch is no longer valid
- # since this connection has died. raise a WatchError, which
- # indicates the user should retry this transaction.
- if self.watching:
- raise WatchError("A ConnectionError occurred on while "
- "watching one or more keys")
- # if retry_on_timeout is not set, or the error is not
- # a TimeoutError, raise it
- if not (conn.retry_on_timeout and isinstance(e, TimeoutError)):
- raise
- # retry a TimeoutError when retry_on_timeout is set
- return execute(conn, stack, raise_on_error)
+ return conn.retry.call_with_retry(
+ lambda: execute(conn, stack, raise_on_error),
+ lambda error: self._disconnect_raise_reset(conn, error))
finally:
self.reset()
diff --git a/redis/connection.py b/redis/connection.py
index 4a855b3..e47e3c7 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -3,6 +3,7 @@ from itertools import chain
from time import time
from queue import LifoQueue, Empty, Full
from urllib.parse import parse_qs, unquote, urlparse
+import copy
import errno
import io
import os
@@ -28,6 +29,8 @@ from redis.exceptions import (
ModuleError,
)
from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
+from redis.backoff import NoBackoff
+from redis.retry import Retry
try:
import ssl
@@ -499,7 +502,13 @@ class Connection:
socket_type=0, retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536,
- health_check_interval=0, client_name=None, username=None):
+ health_check_interval=0, client_name=None, username=None,
+ retry=None):
+ """
+ Initialize a new Connection.
+ To specify a retry policy, first set `retry_on_timeout` to `True`
+ then set `retry` to a valid `Retry` object
+ """
self.pid = os.getpid()
self.host = host
self.port = int(port)
@@ -513,6 +522,14 @@ class Connection:
self.socket_keepalive_options = socket_keepalive_options or {}
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
+ if retry_on_timeout:
+ if retry is None:
+ self.retry = Retry(NoBackoff(), 1)
+ else:
+ # deep-copy the Retry object as it is mutable
+ self.retry = copy.deepcopy(retry)
+ else:
+ self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
self.next_health_check = 0
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
@@ -673,23 +690,23 @@ class Connection:
pass
self._sock = None
+ def _send_ping(self):
+ """Send PING, expect PONG in return"""
+ self.send_command('PING', check_health=False)
+ if str_if_bytes(self.read_response()) != 'PONG':
+ raise ConnectionError('Bad response from PING health check')
+
+ def _ping_failed(self, error):
+ """Function to call when PING fails"""
+ self.disconnect()
+
def check_health(self):
- "Check the health of the connection with a PING/PONG"
+ """Check the health of the connection with a PING/PONG"""
if self.health_check_interval and time() > self.next_health_check:
- try:
- self.send_command('PING', check_health=False)
- if str_if_bytes(self.read_response()) != 'PONG':
- raise ConnectionError(
- 'Bad response from PING health check')
- except (ConnectionError, TimeoutError):
- self.disconnect()
- self.send_command('PING', check_health=False)
- if str_if_bytes(self.read_response()) != 'PONG':
- raise ConnectionError(
- 'Bad response from PING health check')
+ self.retry.call_with_retry(self._send_ping, self._ping_failed)
def send_packed_command(self, command, check_health=True):
- "Send an already packed command to the Redis server"
+ """Send an already packed command to the Redis server"""
if not self._sock:
self.connect()
# guard against health check recursion
@@ -717,12 +734,12 @@ class Connection:
raise
def send_command(self, *args, **kwargs):
- "Pack and send a command to the Redis server"
+ """Pack and send a command to the Redis server"""
self.send_packed_command(self.pack_command(*args),
check_health=kwargs.get('check_health', True))
def can_read(self, timeout=0):
- "Poll the socket to see if there's data that can be read."
+ """Poll the socket to see if there's data that can be read."""
sock = self._sock
if not sock:
self.connect()
@@ -730,7 +747,7 @@ class Connection:
return self._parser.can_read(timeout)
def read_response(self):
- "Read the response from a previously sent command"
+ """Read the response from a previously sent command"""
try:
response = self._parser.read_response()
except socket.timeout:
@@ -753,7 +770,7 @@ class Connection:
return response
def pack_command(self, *args):
- "Pack a series of arguments into the Redis protocol"
+ """Pack a series of arguments into the Redis protocol"""
output = []
# the client might have included 1 or more literal arguments in
# the command name, e.g., 'CONFIG GET'. The Redis server expects these
@@ -787,7 +804,7 @@ class Connection:
return output
def pack_commands(self, commands):
- "Pack multiple commands into the Redis protocol"
+ """Pack multiple commands into the Redis protocol"""
output = []
pieces = []
buffer_length = 0
diff --git a/redis/retry.py b/redis/retry.py
new file mode 100644
index 0000000..cd06a23
--- /dev/null
+++ b/redis/retry.py
@@ -0,0 +1,40 @@
+from time import sleep
+
+from redis.exceptions import ConnectionError, TimeoutError
+
+
+class Retry:
+ """Retry a specific number of times after a failure"""
+
+ def __init__(self, backoff, retries,
+ supported_errors=(ConnectionError, TimeoutError)):
+ """
+ Initialize a `Retry` object with a `Backoff` object
+ that retries a maximum of `retries` times.
+ You can specify the types of supported errors which trigger
+ a retry with the `supported_errors` parameter.
+ """
+ self._backoff = backoff
+ self._retries = retries
+ self._supported_errors = supported_errors
+
+ def call_with_retry(self, do, fail):
+ """
+ Execute an operation that might fail and returns its result, or
+ raise the exception that was thrown depending on the `Backoff` object.
+ `do`: the operation to call. Expects no argument.
+ `fail`: the failure handler, expects the last error that was thrown
+ """
+ self._backoff.reset()
+ failures = 0
+ while True:
+ try:
+ return do()
+ except self._supported_errors as error:
+ failures += 1
+ fail(error)
+ if failures > self._retries:
+ raise error
+ backoff = self._backoff.compute(failures)
+ if backoff > 0:
+ sleep(backoff)
diff --git a/tests/conftest.py b/tests/conftest.py
index cd4d489..711f9e5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,5 @@
+from redis.backoff import NoBackoff
+from redis.retry import Retry
import pytest
import random
import redis
@@ -107,6 +109,7 @@ def r2(request):
def _gen_cluster_mock_resp(r, response):
connection = Mock()
+ connection.retry = Retry(NoBackoff(), 0)
connection.read_response.return_value = response
r.connection = connection
return r
diff --git a/tests/test_retry.py b/tests/test_retry.py
new file mode 100644
index 0000000..24d9683
--- /dev/null
+++ b/tests/test_retry.py
@@ -0,0 +1,66 @@
+from redis.backoff import NoBackoff
+import pytest
+
+from redis.exceptions import ConnectionError
+from redis.connection import Connection
+from redis.retry import Retry
+
+
+class BackoffMock:
+ def __init__(self):
+ self.reset_calls = 0
+ self.calls = 0
+
+ def reset(self):
+ self.reset_calls += 1
+
+ def compute(self, failures):
+ self.calls += 1
+ return 0
+
+
+class TestConnectionConstructorWithRetry:
+ "Test that the Connection constructor properly handles Retry objects"
+
+ @pytest.mark.parametrize("retry_on_timeout", [False, True])
+ def test_retry_on_timeout_boolean(self, retry_on_timeout):
+ c = Connection(retry_on_timeout=retry_on_timeout)
+ assert c.retry_on_timeout == retry_on_timeout
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == (1 if retry_on_timeout else 0)
+
+ @pytest.mark.parametrize("retries", range(10))
+ def test_retry_on_timeout_retry(self, retries):
+ retry_on_timeout = retries > 0
+ c = Connection(retry_on_timeout=retry_on_timeout,
+ retry=Retry(NoBackoff(), retries))
+ assert c.retry_on_timeout == retry_on_timeout
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == retries
+
+
+class TestRetry:
+ "Test that Retry calls backoff and retries the expected number of times"
+
+ def setup_method(self, test_method):
+ self.actual_attempts = 0
+ self.actual_failures = 0
+
+ def _do(self):
+ self.actual_attempts += 1
+ raise ConnectionError()
+
+ def _fail(self, error):
+ self.actual_failures += 1
+
+ @pytest.mark.parametrize("retries", range(10))
+ def test_retry(self, retries):
+ backoff = BackoffMock()
+ retry = Retry(backoff, retries)
+ with pytest.raises(ConnectionError):
+ retry.call_with_retry(self._do, self._fail)
+
+ assert self.actual_attempts == 1 + retries
+ assert self.actual_failures == 1 + retries
+ assert backoff.reset_calls == 1
+ assert backoff.calls == retries