summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorUtkarsh Gupta <utkarshgupta137@gmail.com>2022-07-27 16:35:35 +0530
committerGitHub <noreply@github.com>2022-07-27 14:05:35 +0300
commitf665bd306dc843cec3e8fa01d6f4061385d1812e (patch)
treedcbc11c7f5474a08f876fa265fdf069480b7e3d2
parent3c4d96bcfa1758a2ffd7b1d913166f6f7ca107a5 (diff)
downloadredis-py-f665bd306dc843cec3e8fa01d6f4061385d1812e.tar.gz
async_cluster: fix max_connections/ssl & improve args (#2217)
* async_cluster: fix max_connections/ssl & improve args - set proper connection_class if ssl = True - pass max_connections/connection_class to ClusterNode - recreate startup_nodes to properly initialize - pass parser_class to Connection instead of changing it in on_connect - only pass redis_connect_func if read_from_replicas = True - add connection_error_retry_attempts parameter - skip is_connected check in acquire_connection as it is already checked in send_packed_command BREAKING: - RedisCluster args except host & port are kw-only now - RedisCluster will no longer accept unknown arguments - RedisCluster will no longer accept url as an argument. Use RedisCluster.from_url - RedisCluster.require_full_coverage defaults to True - ClusterNode args except host, port, & server_type are kw-only now * async_cluster: remove kw-only requirement from client Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
-rw-r--r--redis/asyncio/cluster.py377
-rw-r--r--redis/cluster.py4
-rw-r--r--redis/exceptions.py4
-rw-r--r--tests/test_asyncio/test_cluster.py210
4 files changed, 380 insertions, 215 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 3fe3ebc..df0c17d 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -17,7 +17,13 @@ from typing import (
)
from redis.asyncio.client import ResponseCallbackT
-from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url
+from redis.asyncio.connection import (
+ Connection,
+ DefaultParser,
+ Encoder,
+ SSLConnection,
+ parse_url,
+)
from redis.asyncio.parser import CommandsParser
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
@@ -42,6 +48,7 @@ from redis.exceptions import (
ConnectionError,
DataError,
MasterDownError,
+ MaxConnectionsError,
MovedError,
RedisClusterException,
ResponseError,
@@ -56,44 +63,17 @@ TargetNodesT = TypeVar(
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
)
-CONNECTION_ALLOWED_KEYS = (
- "client_name",
- "db",
- "decode_responses",
- "encoder_class",
- "encoding",
- "encoding_errors",
- "health_check_interval",
- "parser_class",
- "password",
- "redis_connect_func",
- "retry",
- "retry_on_timeout",
- "socket_connect_timeout",
- "socket_keepalive",
- "socket_keepalive_options",
- "socket_read_size",
- "socket_timeout",
- "socket_type",
- "username",
-)
-
-
-def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]:
- """Remove unsupported or disabled keys from kwargs."""
- return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS}
-
class ClusterParser(DefaultParser):
EXCEPTION_CLASSES = dict_merge(
DefaultParser.EXCEPTION_CLASSES,
{
"ASK": AskError,
- "TRYAGAIN": TryAgainError,
- "MOVED": MovedError,
"CLUSTERDOWN": ClusterDownError,
"CROSSSLOT": ClusterCrossSlotError,
"MASTERDOWN": MasterDownError,
+ "MOVED": MovedError,
+ "TRYAGAIN": TryAgainError,
},
)
@@ -104,7 +84,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
Pass one of parameters:
- - `url`
- `host` & `port`
- `startup_nodes`
@@ -128,9 +107,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
| Port used if **host** is provided
:param startup_nodes:
| :class:`~.ClusterNode` to used as a startup node
- :param cluster_error_retry_attempts:
- | Retry command execution attempts when encountering :class:`~.ClusterDownError`
- or :class:`~.ConnectionError`
:param require_full_coverage:
| When set to ``False``: the client will not require a full coverage of the
slots. However, if not all slots are covered, and at least one node has
@@ -141,6 +117,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
thrown.
| See:
https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters
+ :param read_from_replicas:
+ | Enable read from replicas in READONLY mode. You can read possibly stale data.
+ When set to true, read commands will be assigned between the primary and
+ its replications in a Round-Robin manner.
:param reinitialize_steps:
| Specifies the number of MOVED errors that need to occur before reinitializing
the whole cluster topology. If a MOVED error occurs and the cluster does not
@@ -149,23 +129,27 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1.
To avoid reinitializing the cluster on moved errors, set reinitialize_steps to
0.
- :param read_from_replicas:
- | Enable read from replicas in READONLY mode. You can read possibly stale data.
- When set to true, read commands will be assigned between the primary and
- its replications in a Round-Robin manner.
- :param url:
- | See :meth:`.from_url`
- :param kwargs:
- | Extra arguments that will be passed to the
- :class:`~redis.asyncio.connection.Connection` instances when created
+ :param cluster_error_retry_attempts:
+ | Number of times to retry before raising an error when :class:`~.TimeoutError`
+ or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered
+ :param connection_error_retry_attempts:
+ | Number of times to retry before reinitializing when :class:`~.TimeoutError`
+ or :class:`~.ConnectionError` are encountered
+ :param max_connections:
+ | Maximum number of connections per node. If there are no free connections & the
+ maximum number of connections are already created, a
+ :class:`~.MaxConnectionsError` is raised. This error may be retried as defined
+ by :attr:`connection_error_retry_attempts`
+
+ | Rest of the arguments will be passed to the
+ :class:`~redis.asyncio.connection.Connection` instances when created
:raises RedisClusterException:
- if any arguments are invalid. Eg:
+ if any arguments are invalid or unknown. Eg:
- - db kwarg
- - db != 0 in url
- - unix socket connection
- - none of host & url & startup_nodes were provided
+ - `db` != 0 or None
+ - `path` argument for unix socket connection
+ - none of the `host`/`port` & `startup_nodes` were provided
"""
@@ -178,7 +162,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
redis://[[username]:[password]]@localhost:6379/0
rediss://[[username]:[password]]@localhost:6379/0
- unix://[[username]:[password]]@/path/to/socket.sock?db=0
Three URL schemes are supported:
@@ -186,32 +169,22 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
<https://www.iana.org/assignments/uri-schemes/prov/redis>
- `rediss://` creates a SSL wrapped TCP socket connection. See more at:
<https://www.iana.org/assignments/uri-schemes/prov/rediss>
- - ``unix://``: creates a Unix Domain Socket connection.
-
- The username, password, hostname, path and all querystring values
- are passed through urllib.parse.unquote in order to replace any
- percent-encoded values with their corresponding characters.
- There are several ways to specify a database number. The first value
- found will be used:
-
- 1. A ``db`` querystring option, e.g. redis://localhost?db=0
- 2. If using the redis:// or rediss:// schemes, the path argument
- of the url, e.g. redis://localhost/0
- 3. A ``db`` keyword argument to this function.
-
- If none of these options are specified, the default db=0 is used.
-
- All querystring options are cast to their appropriate Python types.
- Boolean arguments can be specified with string values "True"/"False"
- or "Yes"/"No". Values that cannot be properly cast cause a
- ``ValueError`` to be raised. Once parsed, the querystring arguments and
- keyword arguments are passed to :class:`~redis.asyncio.connection.Connection`
- when created. In the case of conflicting arguments, querystring
- arguments always win.
+ The username, password, hostname, path and all querystring values are passed
+ through ``urllib.parse.unquote`` in order to replace any percent-encoded values
+ with their corresponding characters.
+ All querystring options are cast to their appropriate Python types. Boolean
+ arguments can be specified with string values "True"/"False" or "Yes"/"No".
+ Values that cannot be properly cast cause a ``ValueError`` to be raised. Once
+ parsed, the querystring arguments and keyword arguments are passed to
+ :class:`~redis.asyncio.connection.Connection` when created.
+ In the case of conflicting arguments, querystring arguments are used.
"""
- return cls(url=url, **kwargs)
+ kwargs.update(parse_url(url))
+ if kwargs.pop("connection_class", None) is SSLConnection:
+ kwargs["ssl"] = True
+ return cls(**kwargs)
__slots__ = (
"_initialize",
@@ -219,6 +192,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
"cluster_error_retry_attempts",
"command_flags",
"commands_parser",
+ "connection_error_retry_attempts",
"connection_kwargs",
"encoder",
"node_flags",
@@ -233,87 +207,131 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
def __init__(
self,
host: Optional[str] = None,
- port: int = 6379,
+ port: Union[str, int] = 6379,
+ # Cluster related kwargs
startup_nodes: Optional[List["ClusterNode"]] = None,
- require_full_coverage: bool = False,
+ require_full_coverage: bool = True,
read_from_replicas: bool = False,
- cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 10,
- url: Optional[str] = None,
- **kwargs: Any,
+ cluster_error_retry_attempts: int = 3,
+ connection_error_retry_attempts: int = 5,
+ max_connections: int = 2**31,
+ # Client related kwargs
+ db: Union[str, int] = 0,
+ path: Optional[str] = None,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ client_name: Optional[str] = None,
+ # Encoding related kwargs
+ encoding: str = "utf-8",
+ encoding_errors: str = "strict",
+ decode_responses: bool = False,
+ # Connection related kwargs
+ health_check_interval: float = 0,
+ socket_connect_timeout: Optional[float] = None,
+ socket_keepalive: bool = False,
+ socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
+ socket_timeout: Optional[float] = None,
+ # SSL related kwargs
+ ssl: bool = False,
+ ssl_ca_certs: Optional[str] = None,
+ ssl_ca_data: Optional[str] = None,
+ ssl_cert_reqs: str = "required",
+ ssl_certfile: Optional[str] = None,
+ ssl_check_hostname: bool = False,
+ ssl_keyfile: Optional[str] = None,
) -> None:
- if not startup_nodes:
- startup_nodes = []
+ if db:
+ raise RedisClusterException(
+ "Argument 'db' must be 0 or None in cluster mode"
+ )
- if "db" in kwargs:
- # Argument 'db' is not possible to use in cluster mode
+ if path:
raise RedisClusterException(
- "Argument 'db' is not possible to use in cluster mode"
+ "Unix domain socket is not supported in cluster mode"
)
- # Get the startup node(s)
- if url:
- url_options = parse_url(url)
- if "path" in url_options:
- raise RedisClusterException(
- "RedisCluster does not currently support Unix Domain "
- "Socket connections"
- )
- if "db" in url_options and url_options["db"] != 0:
- # Argument 'db' is not possible to use in cluster mode
- raise RedisClusterException(
- "A ``db`` querystring option can only be 0 in cluster mode"
- )
- kwargs.update(url_options)
- host = kwargs.get("host")
- port = kwargs.get("port", port)
- elif (not host or not port) and not startup_nodes:
- # No startup node was provided
+ if (not host or not port) and not startup_nodes:
raise RedisClusterException(
- "RedisCluster requires at least one node to discover the "
- "cluster. Please provide one of the followings:\n"
- "1. host and port, for example:\n"
- " RedisCluster(host='localhost', port=6379)\n"
- "2. list of startup nodes, for example:\n"
- " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379),"
- " ClusterNode('localhost', 6378)])"
+ "RedisCluster requires at least one node to discover the cluster.\n"
+ "Please provide one of the following or use RedisCluster.from_url:\n"
+ ' - host and port: RedisCluster(host="localhost", port=6379)\n'
+ " - startup_nodes: RedisCluster(startup_nodes=["
+ 'ClusterNode("localhost", 6379), ClusterNode("localhost", 6380)])'
+ )
+
+ kwargs: Dict[str, Any] = {
+ "max_connections": max_connections,
+ "connection_class": Connection,
+ "parser_class": ClusterParser,
+ # Client related kwargs
+ "username": username,
+ "password": password,
+ "client_name": client_name,
+ # Encoding related kwargs
+ "encoding": encoding,
+ "encoding_errors": encoding_errors,
+ "decode_responses": decode_responses,
+ # Connection related kwargs
+ "health_check_interval": health_check_interval,
+ "socket_connect_timeout": socket_connect_timeout,
+ "socket_keepalive": socket_keepalive,
+ "socket_keepalive_options": socket_keepalive_options,
+ "socket_timeout": socket_timeout,
+ }
+
+ if ssl:
+ # SSL related kwargs
+ kwargs.update(
+ {
+ "connection_class": SSLConnection,
+ "ssl_ca_certs": ssl_ca_certs,
+ "ssl_ca_data": ssl_ca_data,
+ "ssl_cert_reqs": ssl_cert_reqs,
+ "ssl_certfile": ssl_certfile,
+ "ssl_check_hostname": ssl_check_hostname,
+ "ssl_keyfile": ssl_keyfile,
+ }
)
- # Update the connection arguments
- # Whenever a new connection is established, RedisCluster's on_connect
- # method should be run
- kwargs["redis_connect_func"] = self.on_connect
- self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs)
- self.response_callbacks = kwargs[
- "response_callbacks"
- ] = self.__class__.RESPONSE_CALLBACKS.copy()
+ if read_from_replicas:
+ # Call our on_connect function to configure READONLY mode
+ kwargs["redis_connect_func"] = self.on_connect
+
+ kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy()
+ self.connection_kwargs = kwargs
+
+ if startup_nodes:
+ passed_nodes = []
+ for node in startup_nodes:
+ passed_nodes.append(
+ ClusterNode(node.host, node.port, **self.connection_kwargs)
+ )
+ startup_nodes = passed_nodes
+ else:
+ startup_nodes = []
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
- self.nodes_manager = NodesManager(
- startup_nodes=startup_nodes,
- require_full_coverage=require_full_coverage,
- **self.connection_kwargs,
- )
- self.encoder = Encoder(
- kwargs.get("encoding", "utf-8"),
- kwargs.get("encoding_errors", "strict"),
- kwargs.get("decode_responses", False),
- )
- self.cluster_error_retry_attempts = cluster_error_retry_attempts
+ self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
+ self.cluster_error_retry_attempts = cluster_error_retry_attempts
+ self.connection_error_retry_attempts = connection_error_retry_attempts
self.reinitialize_counter = 0
self.commands_parser = CommandsParser()
self.node_flags = self.__class__.NODE_FLAGS.copy()
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
+ self.response_callbacks = kwargs["response_callbacks"]
self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
self.result_callbacks[
"CLUSTER SLOTS"
] = lambda cmd, res, **kwargs: parse_cluster_slots(
list(res.values())[0], **kwargs
)
+
self._initialize = True
self._lock = asyncio.Lock()
@@ -365,18 +383,16 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
...
async def on_connect(self, connection: Connection) -> None:
- connection.set_parser(ClusterParser)
await connection.on_connect()
- if self.read_from_replicas:
- # Sending READONLY command to server to configure connection as
- # readonly. Since each cluster node may change its server type due
- # to a failover, we should establish a READONLY connection
- # regardless of the server type. If this is a primary connection,
- # READONLY would not affect executing write commands.
- await connection.send_command("READONLY")
- if str_if_bytes(await connection.read_response_without_lock()) != "OK":
- raise ConnectionError("READONLY command failed")
+ # Sending READONLY command to server to configure connection as
+ # readonly. Since each cluster node may change its server type due
+ # to a failover, we should establish a READONLY connection
+ # regardless of the server type. If this is a primary connection,
+ # READONLY would not affect executing write commands.
+ await connection.send_command("READONLY")
+ if str_if_bytes(await connection.read_response_without_lock()) != "OK":
+ raise ConnectionError("READONLY command failed")
def get_nodes(self) -> List["ClusterNode"]:
"""Get all nodes of the cluster."""
@@ -436,12 +452,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
slot_cache = self.nodes_manager.slots_cache.get(slot)
if not slot_cache:
raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.')
- if replica and len(self.nodes_manager.slots_cache[slot]) < 2:
- return None
- elif replica:
+
+ if replica:
+ if len(self.nodes_manager.slots_cache[slot]) < 2:
+ return None
node_idx = 1
else:
- # primary
node_idx = 0
return slot_cache[node_idx]
@@ -638,14 +654,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
command, dict(zip(keys, values)), **kwargs
)
return dict(zip(keys, values))
- except BaseException as e:
+ except Exception as e:
if type(e) in self.__class__.ERRORS_ALLOW_RETRY:
# The nodes and slots cache were reinitialized.
# Try again with the new cluster setup.
exception = e
else:
# All other errors should be raised.
- raise e
+ raise
# If it fails the configured number of times then raise exception back
# to caller of this method
@@ -678,19 +694,30 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
return await target_node.execute_command(*args, **kwargs)
except BusyLoadingError:
raise
- except (ConnectionError, TimeoutError):
- # Give the node 0.25 seconds to get back up and retry again
- # with same node and configuration. After 5 attempts then try
- # to reinitialize the cluster and see if the nodes
- # configuration has changed or not
+ except (ConnectionError, TimeoutError) as e:
+ # Give the node 0.25 seconds to get back up and retry again with the
+ # same node and configuration. After the defined number of attempts, try
+ # to reinitialize the cluster and try again.
connection_error_retry_counter += 1
- if connection_error_retry_counter < 5:
+ if (
+ connection_error_retry_counter
+ < self.connection_error_retry_attempts
+ ):
await asyncio.sleep(0.25)
else:
+ if isinstance(e, MaxConnectionsError):
+ raise
# Hard force of reinitialize of the node/slots setup
# and try again with the new setup
await self.close()
raise
+ except ClusterDownError:
+ # ClusterDownError can occur during a failover and to get
+ # self-healed, we will try to reinitialize the cluster layout
+ # and retry executing the command
+ await self.close()
+ await asyncio.sleep(0.25)
+ raise
except MovedError as e:
# First, we will try to patch the slots/nodes cache with the
# redirected node output and try again. If MovedError exceeds
@@ -711,19 +738,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
else:
self.nodes_manager._moved_exception = e
moved = True
- except TryAgainError:
- if ttl < self.RedisClusterRequestTTL / 2:
- await asyncio.sleep(0.05)
except AskError as e:
redirect_addr = get_node_name(host=e.host, port=e.port)
asking = True
- except ClusterDownError:
- # ClusterDownError can occur during a failover and to get
- # self-healed, we will try to reinitialize the cluster layout
- # and retry executing the command
- await asyncio.sleep(0.25)
- await self.close()
- raise
+ except TryAgainError:
+ if ttl < self.RedisClusterRequestTTL / 2:
+ await asyncio.sleep(0.05)
raise ClusterError("TTL exhausted.")
@@ -770,8 +790,9 @@ class ClusterNode:
def __init__(
self,
host: str,
- port: int,
+ port: Union[str, int],
server_type: Optional[str] = None,
+ *,
max_connections: int = 2**31,
connection_class: Type[Connection] = Connection,
**connection_kwargs: Any,
@@ -789,9 +810,7 @@ class ClusterNode:
self.max_connections = max_connections
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
- self.response_callbacks = connection_kwargs.pop(
- "response_callbacks", RedisCluster.RESPONSE_CALLBACKS
- )
+ self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
self._connections: List[Connection] = []
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
@@ -834,21 +853,15 @@ class ClusterNode:
raise exc
def acquire_connection(self) -> Connection:
- if self._free:
- for _ in range(len(self._free)):
- connection = self._free.popleft()
- if connection.is_connected:
- return connection
- self._free.append(connection)
-
+ try:
return self._free.popleft()
+ except IndexError:
+ if len(self._connections) < self.max_connections:
+ connection = self.connection_class(**self.connection_kwargs)
+ self._connections.append(connection)
+ return connection
- if len(self._connections) < self.max_connections:
- connection = self.connection_class(**self.connection_kwargs)
- self._connections.append(connection)
- return connection
-
- raise ConnectionError("Too many connections")
+ raise MaxConnectionsError()
async def parse_response(
self, connection: Connection, command: str, **kwargs: Any
@@ -926,12 +939,12 @@ class NodesManager:
def __init__(
self,
startup_nodes: List["ClusterNode"],
- require_full_coverage: bool = False,
- **kwargs: Any,
+ require_full_coverage: bool,
+ connection_kwargs: Dict[str, Any],
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
- self.connection_kwargs = kwargs
+ self.connection_kwargs = connection_kwargs
self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1050,6 +1063,7 @@ class NodesManager:
disagreements = []
startup_nodes_reachable = False
fully_covered = False
+ exception = None
for startup_node in self.startup_nodes.values():
try:
# Make sure cluster mode is enabled on this node
@@ -1061,7 +1075,8 @@ class NodesManager:
)
cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
startup_nodes_reachable = True
- except (ConnectionError, TimeoutError):
+ except (ConnectionError, TimeoutError) as e:
+ exception = e
continue
except ResponseError as e:
# Isn't a cluster connection, so it won't parse these
@@ -1162,7 +1177,7 @@ class NodesManager:
raise RedisClusterException(
"Redis Cluster cannot be connected. Please provide at least "
"one reachable node. "
- )
+ ) from exception
# Check if the slots are not fully covered
if not fully_covered and self.require_full_coverage:
@@ -1327,7 +1342,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
await asyncio.sleep(0.25)
else:
# All other errors should be raised.
- raise e
+ raise
# If it fails the configured number of times then raise an exception
raise exception
diff --git a/redis/cluster.py b/redis/cluster.py
index b2d4f3b..b05cf30 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -5,7 +5,7 @@ import sys
import threading
import time
from collections import OrderedDict
-from typing import Any, Callable, Dict, Tuple
+from typing import Any, Callable, Dict, Tuple, Union
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands
@@ -38,7 +38,7 @@ from redis.utils import (
)
-def get_node_name(host: str, port: int) -> str:
+def get_node_name(host: str, port: Union[str, int]) -> str:
return f"{host}:{port}"
diff --git a/redis/exceptions.py b/redis/exceptions.py
index d18b354..8a8bf42 100644
--- a/redis/exceptions.py
+++ b/redis/exceptions.py
@@ -199,3 +199,7 @@ class SlotNotCoveredError(RedisClusterException):
"""
pass
+
+
+class MaxConnectionsError(ConnectionError):
+ ...
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 8766cbf..1365e4d 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -1,9 +1,11 @@
import asyncio
import binascii
import datetime
+import os
import sys
import warnings
-from typing import Any, Callable, Dict, List, Optional, Type, Union
+from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union
+from urllib.parse import urlparse
import pytest
@@ -16,10 +18,10 @@ if sys.version_info[0:2] == (3, 6):
else:
import pytest_asyncio
-from _pytest.fixtures import FixtureRequest, SubRequest
+from _pytest.fixtures import FixtureRequest
-from redis.asyncio import Connection, RedisCluster
-from redis.asyncio.cluster import ClusterNode, NodesManager
+from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
+from redis.asyncio.connection import Connection, SSLConnection
from redis.asyncio.parser import CommandsParser
from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
@@ -28,6 +30,7 @@ from redis.exceptions import (
ClusterDownError,
ConnectionError,
DataError,
+ MaxConnectionsError,
MovedError,
NoPermissionError,
RedisClusterException,
@@ -41,6 +44,9 @@ from tests.conftest import (
skip_unless_arch_bits,
)
+pytestmark = pytest.mark.onlycluster
+
+
default_host = "127.0.0.1"
default_port = 7000
default_cluster_slots = [
@@ -50,7 +56,7 @@ default_cluster_slots = [
@pytest_asyncio.fixture()
-async def slowlog(request: SubRequest, r: RedisCluster) -> None:
+async def slowlog(r: RedisCluster) -> None:
"""
Set the slowlog threshold to 0, and the
max length to 128. This will force every
@@ -146,7 +152,7 @@ def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster:
async def moved_redirection_helper(
- request: FixtureRequest, create_redis: Callable, failover: bool = False
+ create_redis: Callable[..., RedisCluster], failover: bool = False
) -> None:
"""
Test that the client handles MOVED response after a failover.
@@ -202,7 +208,6 @@ async def moved_redirection_helper(
assert prev_primary.server_type == REPLICA
-@pytest.mark.onlycluster
class TestRedisClusterObj:
"""
Tests for the RedisCluster class
@@ -237,10 +242,18 @@ class TestRedisClusterObj:
await cluster.close()
- startup_nodes = [ClusterNode("127.0.0.1", 16379)]
- async with RedisCluster(startup_nodes=startup_nodes) as rc:
+ startup_node = ClusterNode("127.0.0.1", 16379)
+ async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc:
assert await rc.set("A", 1)
assert await rc.get("A") == b"1"
+ assert all(
+ [
+ name == "test"
+ for name in (
+ await rc.client_getname(target_nodes=rc.ALL_NODES)
+ ).values()
+ ]
+ )
async def test_empty_startup_nodes(self) -> None:
"""
@@ -253,18 +266,43 @@ class TestRedisClusterObj:
"RedisCluster requires at least one node to discover the " "cluster"
), str_if_bytes(ex.value)
- async def test_from_url(self, r: RedisCluster) -> None:
- redis_url = f"redis://{default_host}:{default_port}/0"
- with mock.patch.object(RedisCluster, "from_url") as from_url:
+ async def test_from_url(self, request: FixtureRequest) -> None:
+ url = request.config.getoption("--redis-url")
- async def from_url_mocked(_url, **_kwargs):
- return await get_mocked_redis_client(url=_url, **_kwargs)
+ async with RedisCluster.from_url(url) as rc:
+ await rc.set("a", 1)
+ await rc.get("a") == 1
- from_url.side_effect = from_url_mocked
- cluster = await RedisCluster.from_url(redis_url)
- assert cluster.get_node(host=default_host, port=default_port) is not None
+ rc = RedisCluster.from_url("rediss://localhost:16379")
+ assert rc.connection_kwargs["connection_class"] is SSLConnection
- await cluster.close()
+ async def test_max_connections(
+ self, create_redis: Callable[..., RedisCluster]
+ ) -> None:
+ rc = await create_redis(cls=RedisCluster, max_connections=10)
+ for node in rc.get_nodes():
+ assert node.max_connections == 10
+
+ with mock.patch.object(
+ Connection, "read_response_without_lock"
+ ) as read_response_without_lock:
+
+ async def read_response_without_lock_mocked(
+ *args: Any, **kwargs: Any
+ ) -> None:
+ await asyncio.sleep(10)
+
+ read_response_without_lock.side_effect = read_response_without_lock_mocked
+
+ with pytest.raises(MaxConnectionsError):
+ await asyncio.gather(
+ *(
+ rc.ping(target_nodes=RedisCluster.DEFAULT_NODE)
+ for _ in range(11)
+ )
+ )
+
+ await rc.close()
async def test_execute_command_errors(self, r: RedisCluster) -> None:
"""
@@ -373,23 +411,23 @@ class TestRedisClusterObj:
assert await r.execute_command("SET", "foo", "bar") == "MOCK_OK"
async def test_moved_redirection(
- self, request: FixtureRequest, create_redis: Callable
+ self, create_redis: Callable[..., RedisCluster]
) -> None:
"""
Test that the client handles MOVED response.
"""
- await moved_redirection_helper(request, create_redis, failover=False)
+ await moved_redirection_helper(create_redis, failover=False)
async def test_moved_redirection_after_failover(
- self, request: FixtureRequest, create_redis: Callable
+ self, create_redis: Callable[..., RedisCluster]
) -> None:
"""
Test that the client handles MOVED response after a failover.
"""
- await moved_redirection_helper(request, create_redis, failover=True)
+ await moved_redirection_helper(create_redis, failover=True)
async def test_refresh_using_specific_nodes(
- self, request: FixtureRequest, create_redis: Callable
+ self, create_redis: Callable[..., RedisCluster]
) -> None:
"""
Test making calls on specific nodes when the cluster has failed over to
@@ -691,7 +729,6 @@ class TestRedisClusterObj:
await rc.close()
-@pytest.mark.onlycluster
class TestClusterRedisCommands:
"""
Tests for RedisCluster unique commands
@@ -1918,7 +1955,7 @@ class TestClusterRedisCommands:
@skip_if_server_version_lt("6.0.0")
@skip_if_redis_enterprise()
async def test_acl_log(
- self, r: RedisCluster, request: FixtureRequest, create_redis: Callable
+ self, r: RedisCluster, create_redis: Callable[..., RedisCluster]
) -> None:
key = "{cache}:"
node = r.get_node_from_key(key)
@@ -1963,7 +2000,6 @@ class TestClusterRedisCommands:
await user_client.close()
-@pytest.mark.onlycluster
class TestNodesManager:
"""
Tests for the NodesManager class
@@ -2095,7 +2131,7 @@ class TestNodesManager:
specified
"""
with pytest.raises(RedisClusterException):
- await NodesManager([]).initialize()
+ await NodesManager([], False, {}).initialize()
async def test_wrong_startup_nodes_type(self) -> None:
"""
@@ -2103,11 +2139,9 @@ class TestNodesManager:
fail
"""
with pytest.raises(RedisClusterException):
- await NodesManager({}).initialize()
+ await NodesManager({}, False, {}).initialize()
- async def test_init_slots_cache_slots_collision(
- self, request: FixtureRequest
- ) -> None:
+ async def test_init_slots_cache_slots_collision(self) -> None:
"""
Test that if 2 nodes do not agree on the same slots setup it should
raise an error. In this test both nodes will say that the first
@@ -2236,7 +2270,6 @@ class TestNodesManager:
assert rc.get_node(host=default_host, port=7002) is not None
-@pytest.mark.onlycluster
class TestClusterPipeline:
"""Tests for the ClusterPipeline class."""
@@ -2484,3 +2517,116 @@ class TestClusterPipeline:
*(self.test_multi_key_operation_with_a_single_slot(r) for i in range(100)),
*(self.test_multi_key_operation_with_multi_slots(r) for i in range(100)),
)
+
+
+@pytest.mark.ssl
+class TestSSL:
+ """
+ Tests for SSL connections.
+
+ This relies on the --redis-ssl-url for building the client and connecting to the
+ appropriate port.
+ """
+
+ ROOT = os.path.join(os.path.dirname(__file__), "../..")
+ CERT_DIR = os.path.abspath(os.path.join(ROOT, "docker", "stunnel", "keys"))
+ if not os.path.isdir(CERT_DIR): # github actions package validation case
+ CERT_DIR = os.path.abspath(
+ os.path.join(ROOT, "..", "docker", "stunnel", "keys")
+ )
+ if not os.path.isdir(CERT_DIR):
+ raise IOError(f"No SSL certificates found. They should be in {CERT_DIR}")
+
+ SERVER_CERT = os.path.join(CERT_DIR, "server-cert.pem")
+ SERVER_KEY = os.path.join(CERT_DIR, "server-key.pem")
+
+ @pytest_asyncio.fixture()
+ def create_client(self, request: FixtureRequest) -> Callable[..., RedisCluster]:
+ ssl_url = request.config.option.redis_ssl_url
+ ssl_host, ssl_port = urlparse(ssl_url)[1].split(":")
+
+ async def _create_client(mocked: bool = True, **kwargs: Any) -> RedisCluster:
+ if mocked:
+ with mock.patch.object(
+ ClusterNode, "execute_command", autospec=True
+ ) as execute_command_mock:
+
+ async def execute_command(self, *args, **kwargs):
+ if args[0] == "INFO":
+ return {"cluster_enabled": True}
+ if args[0] == "CLUSTER SLOTS":
+ return [[0, 16383, [ssl_host, ssl_port, "ssl_node"]]]
+ if args[0] == "COMMAND":
+ return {
+ "ping": {
+ "name": "ping",
+ "arity": -1,
+ "flags": ["stale", "fast"],
+ "first_key_pos": 0,
+ "last_key_pos": 0,
+ "step_count": 0,
+ }
+ }
+ raise NotImplementedError()
+
+ execute_command_mock.side_effect = execute_command
+
+ rc = await RedisCluster(host=ssl_host, port=ssl_port, **kwargs)
+
+ assert len(rc.get_nodes()) == 1
+ node = rc.get_default_node()
+ assert node.port == int(ssl_port)
+ return rc
+
+ return await RedisCluster(host=ssl_host, port=ssl_port, **kwargs)
+
+ return _create_client
+
+ async def test_ssl_connection_without_ssl(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ with pytest.raises(RedisClusterException) as e:
+ await create_client(mocked=False, ssl=False)
+ e = e.value.__cause__
+ assert "Connection closed by server" in str(e)
+
+ async def test_ssl_with_invalid_cert(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ with pytest.raises(RedisClusterException) as e:
+ await create_client(mocked=False, ssl=True)
+ e = e.value.__cause__.__context__
+ assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e)
+
+ async def test_ssl_connection(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ async with await create_client(ssl=True, ssl_cert_reqs="none") as rc:
+ assert await rc.ping()
+
+ async def test_validating_self_signed_certificate(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ async with await create_client(
+ ssl=True,
+ ssl_ca_certs=self.SERVER_CERT,
+ ssl_cert_reqs="required",
+ ssl_certfile=self.SERVER_CERT,
+ ssl_keyfile=self.SERVER_KEY,
+ ) as rc:
+ assert await rc.ping()
+
+ async def test_validating_self_signed_string_certificate(
+ self, create_client: Callable[..., Awaitable[RedisCluster]]
+ ) -> None:
+ with open(self.SERVER_CERT) as f:
+ cert_data = f.read()
+
+ async with await create_client(
+ ssl=True,
+ ssl_ca_data=cert_data,
+ ssl_cert_reqs="required",
+ ssl_certfile=self.SERVER_CERT,
+ ssl_keyfile=self.SERVER_KEY,
+ ) as rc:
+ assert await rc.ping()