diff options
author | Utkarsh Gupta <utkarshgupta137@gmail.com> | 2022-07-27 16:35:35 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-07-27 14:05:35 +0300 |
commit | f665bd306dc843cec3e8fa01d6f4061385d1812e (patch) | |
tree | dcbc11c7f5474a08f876fa265fdf069480b7e3d2 /tests/test_asyncio/test_cluster.py | |
parent | 3c4d96bcfa1758a2ffd7b1d913166f6f7ca107a5 (diff) | |
download | redis-py-f665bd306dc843cec3e8fa01d6f4061385d1812e.tar.gz |
async_cluster: fix max_connections/ssl & improve args (#2217)
* async_cluster: fix max_connections/ssl & improve args
- set proper connection_class if ssl = True
- pass max_connections/connection_class to ClusterNode
- recreate startup_nodes to properly initialize
- pass parser_class to Connection instead of changing it in on_connect
- only pass redis_connect_func if read_from_replicas = True
- add connection_error_retry_attempts parameter
- skip is_connected check in acquire_connection as it is already checked in send_packed_command
BREAKING:
- RedisCluster args except host & port are kw-only now
- RedisCluster will no longer accept unknown arguments
- RedisCluster will no longer accept url as an argument. Use RedisCluster.from_url
- RedisCluster.require_full_coverage defaults to True
- ClusterNode args except host, port, & server_type are kw-only now
* async_cluster: remove kw-only requirement from client
Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
Diffstat (limited to 'tests/test_asyncio/test_cluster.py')
-rw-r--r-- | tests/test_asyncio/test_cluster.py | 210 |
1 files changed, 178 insertions, 32 deletions
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 8766cbf..1365e4d 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -1,9 +1,11 @@ import asyncio import binascii import datetime +import os import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union +from urllib.parse import urlparse import pytest @@ -16,10 +18,10 @@ if sys.version_info[0:2] == (3, 6): else: import pytest_asyncio -from _pytest.fixtures import FixtureRequest, SubRequest +from _pytest.fixtures import FixtureRequest -from redis.asyncio import Connection, RedisCluster -from redis.asyncio.cluster import ClusterNode, NodesManager +from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster +from redis.asyncio.connection import Connection, SSLConnection from redis.asyncio.parser import CommandsParser from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot @@ -28,6 +30,7 @@ from redis.exceptions import ( ClusterDownError, ConnectionError, DataError, + MaxConnectionsError, MovedError, NoPermissionError, RedisClusterException, @@ -41,6 +44,9 @@ from tests.conftest import ( skip_unless_arch_bits, ) +pytestmark = pytest.mark.onlycluster + + default_host = "127.0.0.1" default_port = 7000 default_cluster_slots = [ @@ -50,7 +56,7 @@ default_cluster_slots = [ @pytest_asyncio.fixture() -async def slowlog(request: SubRequest, r: RedisCluster) -> None: +async def slowlog(r: RedisCluster) -> None: """ Set the slowlog threshold to 0, and the max length to 128. This will force every @@ -146,7 +152,7 @@ def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster: async def moved_redirection_helper( - request: FixtureRequest, create_redis: Callable, failover: bool = False + create_redis: Callable[..., RedisCluster], failover: bool = False ) -> None: """ Test that the client handles MOVED response after a failover. @@ -202,7 +208,6 @@ async def moved_redirection_helper( assert prev_primary.server_type == REPLICA -@pytest.mark.onlycluster class TestRedisClusterObj: """ Tests for the RedisCluster class @@ -237,10 +242,18 @@ class TestRedisClusterObj: await cluster.close() - startup_nodes = [ClusterNode("127.0.0.1", 16379)] - async with RedisCluster(startup_nodes=startup_nodes) as rc: + startup_node = ClusterNode("127.0.0.1", 16379) + async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc: assert await rc.set("A", 1) assert await rc.get("A") == b"1" + assert all( + [ + name == "test" + for name in ( + await rc.client_getname(target_nodes=rc.ALL_NODES) + ).values() + ] + ) async def test_empty_startup_nodes(self) -> None: """ @@ -253,18 +266,43 @@ class TestRedisClusterObj: "RedisCluster requires at least one node to discover the " "cluster" ), str_if_bytes(ex.value) - async def test_from_url(self, r: RedisCluster) -> None: - redis_url = f"redis://{default_host}:{default_port}/0" - with mock.patch.object(RedisCluster, "from_url") as from_url: + async def test_from_url(self, request: FixtureRequest) -> None: + url = request.config.getoption("--redis-url") - async def from_url_mocked(_url, **_kwargs): - return await get_mocked_redis_client(url=_url, **_kwargs) + async with RedisCluster.from_url(url) as rc: + await rc.set("a", 1) + await rc.get("a") == 1 - from_url.side_effect = from_url_mocked - cluster = await RedisCluster.from_url(redis_url) - assert cluster.get_node(host=default_host, port=default_port) is not None + rc = RedisCluster.from_url("rediss://localhost:16379") + assert rc.connection_kwargs["connection_class"] is SSLConnection - await cluster.close() + async def test_max_connections( + self, create_redis: Callable[..., RedisCluster] + ) -> None: + rc = await create_redis(cls=RedisCluster, max_connections=10) + for node in rc.get_nodes(): + assert node.max_connections == 10 + + with mock.patch.object( + Connection, "read_response_without_lock" + ) as read_response_without_lock: + + async def read_response_without_lock_mocked( + *args: Any, **kwargs: Any + ) -> None: + await asyncio.sleep(10) + + read_response_without_lock.side_effect = read_response_without_lock_mocked + + with pytest.raises(MaxConnectionsError): + await asyncio.gather( + *( + rc.ping(target_nodes=RedisCluster.DEFAULT_NODE) + for _ in range(11) + ) + ) + + await rc.close() async def test_execute_command_errors(self, r: RedisCluster) -> None: """ @@ -373,23 +411,23 @@ class TestRedisClusterObj: assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK" async def test_moved_redirection( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test that the client handles MOVED response. """ - await moved_redirection_helper(request, create_redis, failover=False) + await moved_redirection_helper(create_redis, failover=False) async def test_moved_redirection_after_failover( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test that the client handles MOVED response after a failover. """ - await moved_redirection_helper(request, create_redis, failover=True) + await moved_redirection_helper(create_redis, failover=True) async def test_refresh_using_specific_nodes( - self, request: FixtureRequest, create_redis: Callable + self, create_redis: Callable[..., RedisCluster] ) -> None: """ Test making calls on specific nodes when the cluster has failed over to @@ -691,7 +729,6 @@ class TestRedisClusterObj: await rc.close() -@pytest.mark.onlycluster class TestClusterRedisCommands: """ Tests for RedisCluster unique commands @@ -1918,7 +1955,7 @@ class TestClusterRedisCommands: @skip_if_server_version_lt("6.0.0") @skip_if_redis_enterprise() async def test_acl_log( - self, r: RedisCluster, request: FixtureRequest, create_redis: Callable + self, r: RedisCluster, create_redis: Callable[..., RedisCluster] ) -> None: key = "{cache}:" node = r.get_node_from_key(key) @@ -1963,7 +2000,6 @@ class TestClusterRedisCommands: await user_client.close() -@pytest.mark.onlycluster class TestNodesManager: """ Tests for the NodesManager class @@ -2095,7 +2131,7 @@ class TestNodesManager: specified """ with pytest.raises(RedisClusterException): - await NodesManager([]).initialize() + await NodesManager([], False, {}).initialize() async def test_wrong_startup_nodes_type(self) -> None: """ @@ -2103,11 +2139,9 @@ class TestNodesManager: fail """ with pytest.raises(RedisClusterException): - await NodesManager({}).initialize() + await NodesManager({}, False, {}).initialize() - async def test_init_slots_cache_slots_collision( - self, request: FixtureRequest - ) -> None: + async def test_init_slots_cache_slots_collision(self) -> None: """ Test that if 2 nodes do not agree on the same slots setup it should raise an error. In this test both nodes will say that the first @@ -2236,7 +2270,6 @@ class TestNodesManager: assert rc.get_node(host=default_host, port=7002) is not None -@pytest.mark.onlycluster class TestClusterPipeline: """Tests for the ClusterPipeline class.""" @@ -2484,3 +2517,116 @@ class TestClusterPipeline: *(self.test_multi_key_operation_with_a_single_slot(r) for i in range(100)), *(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)), ) + + +@pytest.mark.ssl +class TestSSL: + """ + Tests for SSL connections. + + This relies on the --redis-ssl-url for building the client and connecting to the + appropriate port. + """ + + ROOT = os.path.join(os.path.dirname(__file__), "../..") + CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys")) + if not os.path.isdir(CERT_DIR): # github actions package validation case + CERT_DIR = os.path.abspath( + os.path.join(ROOT, "..", "docker", "stunnel", "keys") + ) + if not os.path.isdir(CERT_DIR): + raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}") + + SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem") + SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem") + + @pytest_asyncio.fixture() + def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]: + ssl_url = request.config.option.redis_ssl_url + ssl_host, ssl_port = urlparse(ssl_url)[1].split(":") + + async def _create_client(mocked: bool = True, **kwargs: Any) -> RedisCluster: + if mocked: + with mock.patch.object( + ClusterNode, "execute_command", autospec=True + ) as execute_command_mock: + + async def execute_command(self, *args, **kwargs): + if args[0] == "INFO": + return {"cluster_enabled": True} + if args[0] == "CLUSTER SLOTS": + return [[0, 16383, [ssl_host, ssl_port, "ssl_node"]]] + if args[0] == "COMMAND": + return { + "ping": { + "name": "ping", + "arity": -1, + "flags": ["stale", "fast"], + "first_key_pos": 0, + "last_key_pos": 0, + "step_count": 0, + } + } + raise NotImplementedError() + + execute_command_mock.side_effect = execute_command + + rc = await RedisCluster(host=ssl_host, port=ssl_port, **kwargs) + + assert len(rc.get_nodes()) == 1 + node = rc.get_default_node() + assert node.port == int(ssl_port) + return rc + + return await RedisCluster(host=ssl_host, port=ssl_port, **kwargs) + + return _create_client + + async def test_ssl_connection_without_ssl( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with pytest.raises(RedisClusterException) as e: + await create_client(mocked=False, ssl=False) + e = e.value.__cause__ + assert "Connection closed by server" in str(e) + + async def test_ssl_with_invalid_cert( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with pytest.raises(RedisClusterException) as e: + await create_client(mocked=False, ssl=True) + e = e.value.__cause__.__context__ + assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e) + + async def test_ssl_connection( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client(ssl=True, ssl_cert_reqs="none") as rc: + assert await rc.ping() + + async def test_validating_self_signed_certificate( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + async with await create_client( + ssl=True, + ssl_ca_certs=self.SERVER_CERT, + ssl_cert_reqs="required", + ssl_certfile=self.SERVER_CERT, + ssl_keyfile=self.SERVER_KEY, + ) as rc: + assert await rc.ping() + + async def test_validating_self_signed_string_certificate( + self, create_client: Callable[..., Awaitable[RedisCluster]] + ) -> None: + with open(self.SERVER_CERT) as f: + cert_data = f.read() + + async with await create_client( + ssl=True, + ssl_ca_data=cert_data, + ssl_cert_reqs="required", + ssl_certfile=self.SERVER_CERT, + ssl_keyfile=self.SERVER_KEY, + ) as rc: + assert await rc.ping() |