summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBar Shaul <88437685+barshaul@users.noreply.github.com>2021-12-23 11:42:30 +0200
committerGitHub <noreply@github.com>2021-12-23 11:42:30 +0200
commitddc51c4ace0caa0787715801b9df42e65c790d46 (patch)
tree50e3a20a53e68ca6eacd73e7c8e93cbcea850d2f
parent940d9fc428c3dbe320af003befabe812a8d8537b (diff)
downloadredis-py-ddc51c4ace0caa0787715801b9df42e65c790d46.tar.gz
Support for specifying error types with retry (#1817)
-rwxr-xr-xredis/client.py21
-rwxr-xr-xredis/connection.py27
-rw-r--r--redis/retry.py8
-rw-r--r--tests/test_retry.py125
4 files changed, 169 insertions, 12 deletions
diff --git a/redis/client.py b/redis/client.py
index c7aa17b..0236f20 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -869,6 +869,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
errors=None,
decode_responses=False,
retry_on_timeout=False,
+ retry_on_error=[],
ssl=False,
ssl_keyfile=None,
ssl_certfile=None,
@@ -887,8 +888,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
):
"""
Initialize a new Redis client.
- To specify a retry policy, first set `retry_on_timeout` to `True`
- then set `retry` to a valid `Retry` object
+ To specify a retry policy for specific errors, first set
+ `retry_on_error` to a list of the error/s to retry on, then set
+ `retry` to a valid `Retry` object.
+ To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
if not connection_pool:
if charset is not None:
@@ -905,7 +908,8 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
)
)
encoding_errors = errors
-
+ if retry_on_timeout is True:
+ retry_on_error.append(TimeoutError)
kwargs = {
"db": db,
"username": username,
@@ -914,7 +918,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
"encoding": encoding,
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
- "retry_on_timeout": retry_on_timeout,
+ "retry_on_error": retry_on_error,
"retry": copy.deepcopy(retry),
"max_connections": max_connections,
"health_check_interval": health_check_interval,
@@ -1146,11 +1150,14 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
def _disconnect_raise(self, conn, error):
"""
Close the connection and raise an exception
- if retry_on_timeout is not set or the error
- is not a TimeoutError
+ if retry_on_error is not set or the error
+ is not one of the specified error types
"""
conn.disconnect()
- if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ if (
+ conn.retry_on_error is None
+ or isinstance(error, tuple(conn.retry_on_error)) is False
+ ):
raise error
# COMMAND EXECUTION AND PROTOCOL PARSING
diff --git a/redis/connection.py b/redis/connection.py
index 3fe8543..a349a0f 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -513,6 +513,7 @@ class Connection:
socket_keepalive_options=None,
socket_type=0,
retry_on_timeout=False,
+ retry_on_error=[],
encoding="utf-8",
encoding_errors="strict",
decode_responses=False,
@@ -526,8 +527,10 @@ class Connection:
):
"""
Initialize a new Connection.
- To specify a retry policy, first set `retry_on_timeout` to `True`
- then set `retry` to a valid `Retry` object
+ To specify a retry policy for specific errors, first set
+ `retry_on_error` to a list of the error/s to retry on, then set
+ `retry` to a valid `Retry` object.
+ To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
self.pid = os.getpid()
self.host = host
@@ -543,11 +546,17 @@ class Connection:
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_timeout:
+ # Add TimeoutError to the errors list to retry on
+ retry_on_error.append(TimeoutError)
+ self.retry_on_error = retry_on_error
+ if retry_on_error:
if retry is None:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
+ # Update the retry's supported errors with the specified errors
+ self.retry.update_supported_erros(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
@@ -969,6 +978,7 @@ class UnixDomainSocketConnection(Connection):
encoding_errors="strict",
decode_responses=False,
retry_on_timeout=False,
+ retry_on_error=[],
parser_class=DefaultParser,
socket_read_size=65536,
health_check_interval=0,
@@ -978,8 +988,10 @@ class UnixDomainSocketConnection(Connection):
):
"""
Initialize a new UnixDomainSocketConnection.
- To specify a retry policy, first set `retry_on_timeout` to `True`
- then set `retry` to a valid `Retry` object
+ To specify a retry policy for specific errors, first set
+ `retry_on_error` to a list of the error/s to retry on, then set
+ `retry` to a valid `Retry` object.
+ To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
self.pid = os.getpid()
self.path = path
@@ -990,11 +1002,17 @@ class UnixDomainSocketConnection(Connection):
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
if retry_on_timeout:
+ # Add TimeoutError to the errors list to retry on
+ retry_on_error.append(TimeoutError)
+ self.retry_on_error = retry_on_error
+ if self.retry_on_error:
if retry is None:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
self.retry = copy.deepcopy(retry)
+ # Update the retry's supported errors with the specified errors
+ self.retry.update_supported_erros(retry_on_error)
else:
self.retry = Retry(NoBackoff(), 0)
self.health_check_interval = health_check_interval
@@ -1052,6 +1070,7 @@ URL_QUERY_ARGUMENT_PARSERS = {
"socket_connect_timeout": float,
"socket_keepalive": to_bool,
"retry_on_timeout": to_bool,
+ "retry_on_error": list,
"max_connections": int,
"health_check_interval": int,
"ssl_check_hostname": to_bool,
diff --git a/redis/retry.py b/redis/retry.py
index 75504c7..6147fbd 100644
--- a/redis/retry.py
+++ b/redis/retry.py
@@ -19,6 +19,14 @@ class Retry:
self._retries = retries
self._supported_errors = supported_errors
+ def update_supported_erros(self, specified_errors: list):
+ """
+ Updates the supported errors with the specified error types
+ """
+ self._supported_errors = tuple(
+ set(self._supported_errors + tuple(specified_errors))
+ )
+
def call_with_retry(self, do, fail):
"""
Execute an operation that might fail and returns its result, or
diff --git a/tests/test_retry.py b/tests/test_retry.py
index c4650bc..0094787 100644
--- a/tests/test_retry.py
+++ b/tests/test_retry.py
@@ -1,10 +1,20 @@
+from unittest.mock import patch
+
import pytest
from redis.backoff import NoBackoff
+from redis.client import Redis
from redis.connection import Connection, UnixDomainSocketConnection
-from redis.exceptions import ConnectionError
+from redis.exceptions import (
+ BusyLoadingError,
+ ConnectionError,
+ ReadOnlyError,
+ TimeoutError,
+)
from redis.retry import Retry
+from .conftest import _get_client
+
class BackoffMock:
def __init__(self):
@@ -39,6 +49,37 @@ class TestConnectionConstructorWithRetry:
assert isinstance(c.retry, Retry)
assert c.retry._retries == retries
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_error(self, Class):
+ c = Class(retry_on_error=[ReadOnlyError])
+ assert c.retry_on_error == [ReadOnlyError]
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == 1
+
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_error_empty_value(self, Class):
+ c = Class(retry_on_error=[])
+ assert c.retry_on_error == []
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == 0
+
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_error_and_timeout(self, Class):
+ c = Class(
+ retry_on_error=[ReadOnlyError, BusyLoadingError], retry_on_timeout=True
+ )
+ assert c.retry_on_error == [ReadOnlyError, BusyLoadingError, TimeoutError]
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == 1
+
+ @pytest.mark.parametrize("retries", range(10))
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_error_retry(self, Class, retries):
+ c = Class(retry_on_error=[ReadOnlyError], retry=Retry(NoBackoff(), retries))
+ assert c.retry_on_error == [ReadOnlyError]
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == retries
+
class TestRetry:
"Test that Retry calls backoff and retries the expected number of times"
@@ -65,3 +106,85 @@ class TestRetry:
assert self.actual_failures == 1 + retries
assert backoff.reset_calls == 1
assert backoff.calls == retries
+
+
+@pytest.mark.onlynoncluster
+class TestRedisClientRetry:
+ "Test the standalone Redis client behavior with retries"
+
+ def test_client_retry_on_error_with_success(self, request):
+ with patch.object(Redis, "parse_response") as parse_response:
+
+ def mock_parse_response(connection, *args, **options):
+ def ok_response(connection, *args, **options):
+ return "MOCK_OK"
+
+ parse_response.side_effect = ok_response
+ raise ReadOnlyError()
+
+ parse_response.side_effect = mock_parse_response
+ r = _get_client(Redis, request, retry_on_error=[ReadOnlyError])
+ assert r.get("foo") == "MOCK_OK"
+ assert parse_response.call_count == 2
+
+ def test_client_retry_on_error_raise(self, request):
+ with patch.object(Redis, "parse_response") as parse_response:
+ parse_response.side_effect = BusyLoadingError()
+ retries = 3
+ r = _get_client(
+ Redis,
+ request,
+ retry_on_error=[ReadOnlyError, BusyLoadingError],
+ retry=Retry(NoBackoff(), retries),
+ )
+ with pytest.raises(BusyLoadingError):
+ try:
+ r.get("foo")
+ finally:
+ assert parse_response.call_count == retries + 1
+
+ def test_client_retry_on_error_different_error_raised(self, request):
+ with patch.object(Redis, "parse_response") as parse_response:
+ parse_response.side_effect = TimeoutError()
+ retries = 3
+ r = _get_client(
+ Redis,
+ request,
+ retry_on_error=[ReadOnlyError],
+ retry=Retry(NoBackoff(), retries),
+ )
+ with pytest.raises(TimeoutError):
+ try:
+ r.get("foo")
+ finally:
+ assert parse_response.call_count == 1
+
+ def test_client_retry_on_error_and_timeout(self, request):
+ with patch.object(Redis, "parse_response") as parse_response:
+ parse_response.side_effect = TimeoutError()
+ retries = 3
+ r = _get_client(
+ Redis,
+ request,
+ retry_on_error=[ReadOnlyError],
+ retry_on_timeout=True,
+ retry=Retry(NoBackoff(), retries),
+ )
+ with pytest.raises(TimeoutError):
+ try:
+ r.get("foo")
+ finally:
+ assert parse_response.call_count == retries + 1
+
+ def test_client_retry_on_timeout(self, request):
+ with patch.object(Redis, "parse_response") as parse_response:
+ parse_response.side_effect = TimeoutError()
+ retries = 3
+ r = _get_client(
+ Redis, request, retry_on_timeout=True, retry=Retry(NoBackoff(), retries)
+ )
+ with pytest.raises(TimeoutError):
+ try:
+ r.get("foo")
+ finally:
+ assert parse_response.call_count == retries + 1