summaryrefslogtreecommitdiff
path: root/tests/test_cluster.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cluster.py')
-rw-r--r--tests/test_cluster.py100
1 files changed, 100 insertions, 0 deletions
diff --git a/tests/test_cluster.py b/tests/test_cluster.py
index 5652673..d18fbbb 100644
--- a/tests/test_cluster.py
+++ b/tests/test_cluster.py
@@ -7,6 +7,7 @@ from unittest.mock import DEFAULT, Mock, call, patch
import pytest
from redis import Redis
+from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
from redis.cluster import (
PRIMARY,
REDIS_CLUSTER_HASH_SLOTS,
@@ -31,6 +32,7 @@ from redis.exceptions import (
ResponseError,
TimeoutError,
)
+from redis.retry import Retry
from redis.utils import str_if_bytes
from tests.test_pubsub import wait_for_message
@@ -358,6 +360,60 @@ class TestRedisClusterObj:
assert r.execute_command("SET", "foo", "bar") == "MOCK_OK"
+ def test_handling_cluster_failover_to_a_replica(self, r):
+ # Set the key we'll test for
+ key = "key"
+ r.set("key", "value")
+ primary = r.get_node_from_key(key, replica=False)
+ assert str_if_bytes(r.get("key")) == "value"
+ # Get the current output of cluster slots
+ cluster_slots = primary.redis_connection.execute_command("CLUSTER SLOTS")
+ replica_host = ""
+ replica_port = 0
+ # Replace one of the replicas to be the new primary based on the
+ # cluster slots output
+ for slot_range in cluster_slots:
+ primary_port = slot_range[2][1]
+ if primary_port == primary.port:
+ if len(slot_range) <= 3:
+ # cluster doesn't have a replica, return
+ return
+ replica_host = str_if_bytes(slot_range[3][0])
+ replica_port = slot_range[3][1]
+ # replace replica and primary in the cluster slots output
+ tmp_node = slot_range[2]
+ slot_range[2] = slot_range[3]
+ slot_range[3] = tmp_node
+ break
+
+ def raise_connection_error():
+ raise ConnectionError("error")
+
+ def mock_execute_command(*_args, **_kwargs):
+ if _args[0] == "CLUSTER SLOTS":
+ return cluster_slots
+ else:
+ raise Exception("Failed to mock cluster slots")
+
+ # Mock connection error for the current primary
+ mock_node_resp_func(primary, raise_connection_error)
+ primary.redis_connection.set_retry(Retry(NoBackoff(), 1))
+
+ # Mock the cluster slots response for all other nodes
+ redis_mock_node = Mock()
+ redis_mock_node.execute_command.side_effect = mock_execute_command
+ # Mock response value for all other commands
+ redis_mock_node.parse_response.return_value = "MOCK_OK"
+ for node in r.get_nodes():
+ if node.port != primary.port:
+ node.redis_connection = redis_mock_node
+
+ assert r.get(key) == "MOCK_OK"
+ new_primary = r.get_node_from_key(key, replica=False)
+ assert new_primary.host == replica_host
+ assert new_primary.port == replica_port
+ assert r.get_node(primary.host, primary.port).server_type == REPLICA
+
def test_moved_redirection(self, request):
"""
Test that the client handles MOVED response.
@@ -691,6 +747,50 @@ class TestRedisClusterObj:
cur_node = r.get_node(node_name=node_name)
assert conn == r.get_redis_connection(cur_node)
+ def test_cluster_get_set_retry_object(self, request):
+ retry = Retry(NoBackoff(), 2)
+ r = _get_client(RedisCluster, request, retry=retry)
+ assert r.get_retry()._retries == retry._retries
+ assert isinstance(r.get_retry()._backoff, NoBackoff)
+ for node in r.get_nodes():
+ assert node.redis_connection.get_retry()._retries == retry._retries
+ assert isinstance(node.redis_connection.get_retry()._backoff, NoBackoff)
+ rand_node = r.get_random_node()
+ existing_conn = rand_node.redis_connection.connection_pool.get_connection("_")
+ # Change retry policy
+ new_retry = Retry(ExponentialBackoff(), 3)
+ r.set_retry(new_retry)
+ assert r.get_retry()._retries == new_retry._retries
+ assert isinstance(r.get_retry()._backoff, ExponentialBackoff)
+ for node in r.get_nodes():
+ assert node.redis_connection.get_retry()._retries == new_retry._retries
+ assert isinstance(
+ node.redis_connection.get_retry()._backoff, ExponentialBackoff
+ )
+ assert existing_conn.retry._retries == new_retry._retries
+ new_conn = rand_node.redis_connection.connection_pool.get_connection("_")
+ assert new_conn.retry._retries == new_retry._retries
+
+ def test_cluster_retry_object(self, r) -> None:
+ # Test default retry
+ retry = r.get_connection_kwargs().get("retry")
+ assert isinstance(retry, Retry)
+ assert retry._retries == 0
+ assert isinstance(retry._backoff, type(default_backoff()))
+ node1 = r.get_node("127.0.0.1", 16379).redis_connection
+ node2 = r.get_node("127.0.0.1", 16380).redis_connection
+ assert node1.get_retry()._retries == node2.get_retry()._retries
+
+ # Test custom retry
+ retry = Retry(ExponentialBackoff(10, 5), 5)
+ rc_custom_retry = RedisCluster("127.0.0.1", 16379, retry=retry)
+ assert (
+ rc_custom_retry.get_node("127.0.0.1", 16379)
+ .redis_connection.get_retry()
+ ._retries
+ == retry._retries
+ )
+
@pytest.mark.onlycluster
class TestClusterRedisCommands: