summaryrefslogtreecommitdiff
path: root/tests/test_asyncio/test_cluster.py
diff options
context:
space:
mode:
authorUtkarsh Gupta <utkarshgupta137@gmail.com>2022-07-27 16:35:35 +0530
committerGitHub <noreply@github.com>2022-07-27 14:05:35 +0300
commitf665bd306dc843cec3e8fa01d6f4061385d1812e (patch)
treedcbc11c7f5474a08f876fa265fdf069480b7e3d2 /tests/test_asyncio/test_cluster.py
parent3c4d96bcfa1758a2ffd7b1d913166f6f7ca107a5 (diff)
downloadredis-py-f665bd306dc843cec3e8fa01d6f4061385d1812e.tar.gz
async_cluster: fix max_connections/ssl & improve args (#2217)
* async_cluster: fix max_connections/ssl & improve args - set proper connection_class if ssl = True - pass max_connections/connection_class to ClusterNode - recreate startup_nodes to properly initialize - pass parser_class to Connection instead of changing it in on_connect - only pass redis_connect_func if read_from_replicas = True - add connection_error_retry_attempts parameter - skip is_connected check in acquire_connection as it is already checked in send_packed_command BREAKING: - RedisCluster args except host & port are kw-only now - RedisCluster will no longer accept unknown arguments - RedisCluster will no longer accept url as an argument. Use RedisCluster.from_url - RedisCluster.require_full_coverage defaults to True - ClusterNode args except host, port, & server_type are kw-only now * async_cluster: remove kw-only requirement from client Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
Diffstat (limited to 'tests/test_asyncio/test_cluster.py')
-rw-r--r--tests/test_asyncio/test_cluster.py210
1 files changed, 178 insertions, 32 deletions
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 8766cbf..1365e4d 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -1,9 +1,11 @@
import asyncio
import binascii
import datetime
+import os
import sys
import warnings
-from typing import Any, Callable, Dict, List, Optional, Type, Union
+from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
+from urllib.parse import urlparse
import pytest
@@ -16,10 +18,10 @@ if sys.version_info[0:2] == (3, 6):
else:
import pytest_asyncio
-from _pytest.fixtures import FixtureRequest, SubRequest
+from _pytest.fixtures import FixtureRequest
-from redis.asyncio import Connection, RedisCluster
-from redis.asyncio.cluster import ClusterNode, NodesManager
+from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
+from redis.asyncio.connection import Connection, SSLConnection
from redis.asyncio.parser import CommandsParser
from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
@@ -28,6 +30,7 @@ from redis.exceptions import (
ClusterDownError,
ConnectionError,
DataError,
+ MaxConnectionsError,
MovedError,
NoPermissionError,
RedisClusterException,
@@ -41,6 +44,9 @@ from tests.conftest import (
skip_unless_arch_bits,
)
+pytestmark = pytest.mark.onlycluster
+
+
default_host = "127.0.0.1"
default_port = 7000
default_cluster_slots = [
@@ -50,7 +56,7 @@ default_cluster_slots = [
@pytest_asyncio.fixture()
-async def slowlog(request: SubRequest, r: RedisCluster) -> None:
+async def slowlog(r: RedisCluster) -> None:
"""
Set the slowlog threshold to 0, and the
max length to 128. This will force every
@@ -146,7 +152,7 @@ def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster:
async def moved_redirection_helper(
- request: FixtureRequest, create_redis: Callable, failover: bool = False
+ create_redis: Callable[..., RedisCluster], failover: bool = False
) -> None:
"""
Test that the client handles MOVED response after a failover.
@@ -202,7 +208,6 @@ async def moved_redirection_helper(
assert prev_primary.server_type == REPLICA
-@pytest.mark.onlycluster
class TestRedisClusterObj:
"""
Tests for the RedisCluster class
@@ -237,10 +242,18 @@ class TestRedisClusterObj:
await cluster.close()
- startup_nodes = [ClusterNode("127.0.0.1", 16379)]
- async with RedisCluster(startup_nodes=startup_nodes) as rc:
+ startup_node = ClusterNode("127.0.0.1", 16379)
+ async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc:
assert await rc.set("A", 1)
assert await rc.get("A") == b"1"
+ assert all(
+ [
+ name == "test"
+ for name in (
+ await rc.client_getname(target_nodes=rc.ALL_NODES)
+ ).values()
+ ]
+ )
async def test_empty_startup_nodes(self) -> None:
"""
@@ -253,18 +266,43 @@ class TestRedisClusterObj:
"RedisCluster requires at least one node to discover the " "cluster"
), str_if_bytes(ex.value)
- async def test_from_url(self, r: RedisCluster) -> None:
- redis_url = f"redis://{default_host}:{default_port}/0"
- with mock.patch.object(RedisCluster, "from_url") as from_url:
+ async def test_from_url(self, request: FixtureRequest) -> None:
+ url = request.config.getoption("--redis-url")
- async def from_url_mocked(_url, **_kwargs):
- return await get_mocked_redis_client(url=_url, **_kwargs)
+ async with RedisCluster.from_url(url) as rc:
+ await rc.set("a", 1)
+ await rc.get("a") == 1
- from_url.side_effect = from_url_mocked
- cluster = await RedisCluster.from_url(redis_url)
- assert cluster.get_node(host=default_host, port=default_port) is not None
+ rc = RedisCluster.from_url("rediss://localhost:16379")
+ assert rc.connection_kwargs["connection_class"] is SSLConnection
- await cluster.close()
+ async def test_max_connections(
+ self, create_redis: Callable[..., RedisCluster]
+ ) -> None:
+ rc = await create_redis(cls=RedisCluster, max_connections=10)
+ for node in rc.get_nodes():
+ assert node.max_connections == 10
+
+ with mock.patch.object(
+ Connection, "read_response_without_lock"
+ ) as read_response_without_lock:
+
+ async def read_response_without_lock_mocked(
+ *args: Any, **kwargs: Any
+ ) -> None:
+ await asyncio.sleep(10)
+
+ read_response_without_lock.side_effect = read_response_without_lock_mocked
+
+ with pytest.raises(MaxConnectionsError):
+ await asyncio.gather(
+ *(
+ rc.ping(target_nodes=RedisCluster.DEFAULT_NODE)
+ for _ in range(11)
+ )
+ )
+
+ await rc.close()
async def test_execute_command_errors(self, r: RedisCluster) -> None:
"""
@@ -373,23 +411,23 @@ class TestRedisClusterObj:
assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK"
async def test_moved_redirection(
- self, request: FixtureRequest, create_redis: Callable
+ self, create_redis: Callable[..., RedisCluster]
) -> None:
"""
Test that the client handles MOVED response.
"""
- await moved_redirection_helper(request, create_redis, failover=False)
+ await moved_redirection_helper(create_redis, failover=False)
async def test_moved_redirection_after_failover(
- self, request: FixtureRequest, create_redis: Callable
+ self, create_redis: Callable[..., RedisCluster]
) -> None:
"""
Test that the client handles MOVED response after a failover.
"""
- await moved_redirection_helper(request, create_redis, failover=True)
+ await moved_redirection_helper(create_redis, failover=True)
async def test_refresh_using_specific_nodes(
- self, request: FixtureRequest, create_redis: Callable
+ self, create_redis: Callable[..., RedisCluster]
) -> None:
"""
Test making calls on specific nodes when the cluster has failed over to
@@ -691,7 +729,6 @@ class TestRedisClusterObj:
await rc.close()
-@pytest.mark.onlycluster
class TestClusterRedisCommands:
"""
Tests for RedisCluster unique commands
@@ -1918,7 +1955,7 @@ class TestClusterRedisCommands:
@skip_if_server_version_lt("6.0.0")
@skip_if_redis_enterprise()
async def test_acl_log(
- self, r: RedisCluster, request: FixtureRequest, create_redis: Callable
+ self, r: RedisCluster, create_redis: Callable[..., RedisCluster]
) -> None:
key = "{cache}:"
node = r.get_node_from_key(key)
@@ -1963,7 +2000,6 @@ class TestClusterRedisCommands:
await user_client.close()
-@pytest.mark.onlycluster
class TestNodesManager:
"""
Tests for the NodesManager class
@@ -2095,7 +2131,7 @@ class TestNodesManager:
specified
"""
with pytest.raises(RedisClusterException):
- await NodesManager([]).initialize()
+ await NodesManager([], False, {}).initialize()
async def test_wrong_startup_nodes_type(self) -> None:
"""
@@ -2103,11 +2139,9 @@ class TestNodesManager:
fail
"""
with pytest.raises(RedisClusterException):
- await NodesManager({}).initialize()
+ await NodesManager({}, False, {}).initialize()
- async def test_init_slots_cache_slots_collision(
- self, request: FixtureRequest
- ) -> None:
+ async def test_init_slots_cache_slots_collision(self) -> None:
"""
Test that if 2 nodes do not agree on the same slots setup it should
raise an error. In this test both nodes will say that the first
@@ -2236,7 +2270,6 @@ class TestNodesManager:
assert rc.get_node(host=default_host, port=7002) is not None
-@pytest.mark.onlycluster
class TestClusterPipeline:
"""Tests for the ClusterPipeline class."""
@@ -2484,3 +2517,116 @@ class TestClusterPipeline:
*(self.test_multi_key_operation_with_a_single_slot(r) for i in range(100)),
*(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)),
)
+
+
+@pytest.mark.ssl
+class TestSSL:
+ """
+ Tests for SSL connections.
+
+ This relies on the --redis-ssl-url for building the client and connecting to the
+ appropriate port.
+ """
+
+ ROOT = os.path.join(os.path.dirname(__file__), "../..")
+ CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys"))
+ if not os.path.isdir(CERT_DIR): # github actions package validation case
+ CERT_DIR = os.path.abspath(
+ os.path.join(ROOT, "..", "docker", "stunnel", "keys")
+ )
+ if not os.path.isdir(CERT_DIR):
+ raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}")
+
+ SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem")
+ SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem")
+
+ @pytest_asyncio.fixture()
+ def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]:
+ ssl_url = request.config.option.redis_ssl_url
+ ssl_host, ssl_port = urlparse(ssl_url)[1].split(":")
+
+ async def _create_client(mocked: bool = True, **kwargs: Any) -> RedisCluster:
+ if mocked:
+ with mock.patch.object(
+ ClusterNode, "execute_command", autospec=True
+ ) as execute_command_mock:
+
+ async def execute_command(self, *args, **kwargs):
+ if args[0] == "INFO":
+ return {"cluster_enabled": True}
+ if args[0] == "CLUSTER SLOTS":
+ return [[0, 16383, [ssl_host, ssl_port, "ssl_node"]]]
+ if args[0] == "COMMAND":
+ return {
+ "ping": {
+ "name": "ping",
+ "arity": -1,
+ "flags": ["stale", "fast"],
+ "first_key_pos": 0,
+ "last_key_pos": 0,
+ "step_count": 0,
+ }
+ }
+ raise NotImplementedError()
+
+ execute_command_mock.side_effect = execute_command
+
+ rc = await RedisCluster(host=ssl_host, port=ssl_port, **kwargs)
+
+ assert len(rc.get_nodes()) == 1
+ node = rc.get_default_node()
+ assert node.port == int(ssl_port)
+ return rc
+
+ return await RedisCluster(host=ssl_host, port=ssl_port, **kwargs)
+
+ return _create_client
+
+ async def test_ssl_connection_without_ssl(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ with pytest.raises(RedisClusterException) as e:
+ await create_client(mocked=False, ssl=False)
+ e = e.value.__cause__
+ assert "Connection closed by server" in str(e)
+
+ async def test_ssl_with_invalid_cert(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ with pytest.raises(RedisClusterException) as e:
+ await create_client(mocked=False, ssl=True)
+ e = e.value.__cause__.__context__
+ assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e)
+
+ async def test_ssl_connection(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ async with await create_client(ssl=True, ssl_cert_reqs="none") as rc:
+ assert await rc.ping()
+
+ async def test_validating_self_signed_certificate(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ async with await create_client(
+ ssl=True,
+ ssl_ca_certs=self.SERVER_CERT,
+ ssl_cert_reqs="required",
+ ssl_certfile=self.SERVER_CERT,
+ ssl_keyfile=self.SERVER_KEY,
+ ) as rc:
+ assert await rc.ping()
+
+ async def test_validating_self_signed_string_certificate(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ with open(self.SERVER_CERT) as f:
+ cert_data = f.read()
+
+ async with await create_client(
+ ssl=True,
+ ssl_ca_data=cert_data,
+ ssl_cert_reqs="required",
+ ssl_certfile=self.SERVER_CERT,
+ ssl_keyfile=self.SERVER_KEY,
+ ) as rc:
+ assert await rc.ping()