diff options
author | Utkarsh Gupta <utkarshgupta137@gmail.com> | 2022-05-08 17:34:20 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-08 15:04:20 +0300 |
commit | 061d97abe21d3a8ce9738330cabf771dd05c8dc1 (patch) | |
tree | e64d64c5917312a65304f4d630775d08adb38bfa /redis/asyncio | |
parent | c25be04d6468163d31908774ed358d3fd6bc0a39 (diff) | |
download | redis-py-061d97abe21d3a8ce9738330cabf771dd05c8dc1.tar.gz |
Add Async RedisCluster (#2099)
* Copy Cluster Client, Commands, Commands Parser, Tests for asyncio
* Async Cluster Tests: Async/Await
* Add Async RedisCluster
* cluster: use ERRORS_ALLOW_RETRY from self.__class__
* async_cluster: rework redis_connection, initialize, & close
- move redis_connection from NodesManager to ClusterNode & handle all related logic in ClusterNode class
- use Locks while initializing or closing
- in case of error, close connections instead of instantly reinitializing
- create ResourceWarning instead of manually deleting client object
- use asyncio.gather to run commands/initialize/close in parallel
- inline single use functions
- fix test_acl_log for py3.6
* async_cluster: add types
* async_cluster: add docs
* docs: update sphinx & add sphinx_autodoc_typehints
* async_cluster: move TargetNodesT to cluster module
* async_cluster/commands: inherit commands from sync class if possible
* async_cluster: add benchmark script with aredis & aioredis-cluster
* async_cluster: remove logging
* async_cluster: inline functions
* async_cluster: manage Connection instead of Redis Client
* async_cluster/commands: optimize parser
* async_cluster: use ensure_future & generators for gather
* async_conn: optimize
* async_cluster: optimize determine_slot
* async_cluster: optimize determine_nodes
* async_cluster/parser: optimize _get_moveable_keys
* async_cluster: inlined check_slots_coverage
* async_cluster: update docstrings
* async_cluster: add concurrent test & use read_response/_update_moved_slots without lock
Co-authored-by: Chayim <chayim@users.noreply.github.com>
Diffstat (limited to 'redis/asyncio')
-rw-r--r-- | redis/asyncio/__init__.py | 4 | ||||
-rw-r--r-- | redis/asyncio/client.py | 12 | ||||
-rw-r--r-- | redis/asyncio/cluster.py | 1113 | ||||
-rw-r--r-- | redis/asyncio/connection.py | 140 | ||||
-rw-r--r-- | redis/asyncio/parser.py | 95 |
5 files changed, 1299 insertions, 65 deletions
diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index c655c7d..598791a 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -1,4 +1,5 @@ from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.cluster import RedisCluster from redis.asyncio.connection import ( BlockingConnectionPool, Connection, @@ -6,6 +7,7 @@ from redis.asyncio.connection import ( SSLConnection, UnixDomainSocketConnection, ) +from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -35,6 +37,7 @@ __all__ = [ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", + "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", @@ -44,6 +47,7 @@ __all__ = [ "PubSubError", "ReadOnlyError", "Redis", + "RedisCluster", "RedisError", "ResponseError", "Sentinel", diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 8dde96e..6db5489 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -172,6 +172,7 @@ class Redis( username: Optional[str] = None, retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, + redis_connect_func=None, ): """ Initialize a new Redis client. @@ -200,6 +201,7 @@ class Redis( "max_connections": max_connections, "health_check_interval": health_check_interval, "client_name": client_name, + "redis_connect_func": redis_connect_func, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -263,11 +265,7 @@ class Redis( """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def load_external_module( - self, - funcname, - func, - ): + def load_external_module(self, funcname, func): """ This function can be used to add externally defined redis modules, and their namespaces to the redis client. @@ -426,9 +424,7 @@ class Redis( def __del__(self, _warnings: Any = warnings) -> None: if self.connection is not None: _warnings.warn( - f"Unclosed client session {self!r}", - ResourceWarning, - source=self, + f"Unclosed client session {self!r}", ResourceWarning, source=self ) context = {"client": self, "message": self._DEL_MESSAGE} asyncio.get_event_loop().call_exception_handler(context) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py new file mode 100644 index 0000000..10a5675 --- /dev/null +++ b/redis/asyncio/cluster.py @@ -0,0 +1,1113 @@ +import asyncio +import collections +import random +import socket +import warnings +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union + +from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis +from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url +from redis.asyncio.parser import CommandsParser +from redis.cluster import ( + PRIMARY, + READ_COMMANDS, + REPLICA, + SLOT_ID, + AbstractRedisCluster, + LoadBalancer, + cleanup_kwargs, + get_node_name, + parse_cluster_slots, +) +from redis.commands import AsyncRedisClusterCommands +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.exceptions import ( + AskError, + BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + ConnectionError, + DataError, + MasterDownError, + MovedError, + RedisClusterException, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.typing import EncodableT, KeyT +from redis.utils import dict_merge, str_if_bytes + +TargetNodesT = TypeVar( + "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] +) + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + }, + ) + + +class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new RedisCluster client. + + Pass one of parameters: + + - `url` + - `host` & `port` + - `startup_nodes` + + | Use ``await`` :meth:`initialize` to find cluster nodes & create connections. + | Use ``await`` :meth:`close` to disconnect connections & close client. + + Many commands support the target_nodes kwarg. It can be one of the + :attr:`NODE_FLAGS`: + + - :attr:`PRIMARIES` + - :attr:`REPLICAS` + - :attr:`ALL_NODES` + - :attr:`RANDOM` + - :attr:`DEFAULT_NODE` + + Note: This client is not thread/process/fork safe. + + :param host: + | Can be used to point to a startup node + :param port: + | Port used if **host** is provided + :param startup_nodes: + | :class:`~.ClusterNode` to used as a startup node + :param cluster_error_retry_attempts: + | Retry command execution attempts when encountering :class:`~.ClusterDownError` + or :class:`~.ConnectionError` + :param require_full_coverage: + | When set to ``False``: the client will not require a full coverage of the + slots. However, if not all slots are covered, and at least one node has + ``cluster-require-full-coverage`` set to ``yes``, the server will throw a + :class:`~.ClusterDownError` for some key-based commands. + | When set to ``True``: all slots must be covered to construct the cluster + client. If not all slots are covered, :class:`~.RedisClusterException` will be + thrown. + | See: + https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters + :param reinitialize_steps: + | Specifies the number of MOVED errors that need to occur before reinitializing + the whole cluster topology. If a MOVED error occurs and the cluster does not + need to be reinitialized on this current error handling, only the MOVED slot + will be patched with the redirected node. + To reinitialize the cluster on every MOVED error, set reinitialize_steps to 1. + To avoid reinitializing the cluster on moved errors, set reinitialize_steps to + 0. + :param read_from_replicas: + | Enable read from replicas in READONLY mode. You can read possibly stale data. + When set to true, read commands will be assigned between the primary and + its replications in a Round-Robin manner. + :param url: + | See :meth:`.from_url` + :param kwargs: + | Extra arguments that will be passed to the + :class:`~redis.asyncio.connection.Connection` instances when created + + :raises RedisClusterException: + if any arguments are invalid. Eg: + + - db kwarg + - db != 0 in url + - unix socket connection + - none of host & url & startup_nodes were provided + + """ + + @classmethod + def from_url(cls, url: str, **kwargs) -> "RedisCluster": + """ + Return a Redis client object configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/redis> + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/rediss> + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments and + keyword arguments are passed to :class:`~redis.asyncio.connection.Connection` + when created. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) + + __slots__ = ( + "_initialize", + "_lock", + "cluster_error_retry_attempts", + "command_flags", + "commands_parser", + "connection_kwargs", + "encoder", + "node_flags", + "nodes_manager", + "read_from_replicas", + "reinitialize_counter", + "reinitialize_steps", + "response_callbacks", + "result_callbacks", + ) + + def __init__( + self, + host: Optional[str] = None, + port: int = 6379, + startup_nodes: Optional[List["ClusterNode"]] = None, + require_full_coverage: bool = False, + read_from_replicas: bool = False, + cluster_error_retry_attempts: int = 3, + reinitialize_steps: int = 10, + url: Optional[str] = None, + **kwargs, + ) -> None: + if not startup_nodes: + startup_nodes = [] + + if "db" in kwargs: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "Argument 'db' is not possible to use in cluster mode" + ) + + # Get the startup node/s + if url: + url_options = parse_url(url) + if "path" in url_options: + raise RedisClusterException( + "RedisCluster does not currently support Unix Domain " + "Socket connections" + ) + if "db" in url_options and url_options["db"] != 0: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "A ``db`` querystring option can only be 0 in cluster mode" + ) + kwargs.update(url_options) + host = kwargs.get("host") + port = kwargs.get("port", port) + elif (not host or not port) and not startup_nodes: + # No startup node was provided + raise RedisClusterException( + "RedisCluster requires at least one node to discover the " + "cluster. Please provide one of the followings:\n" + "1. host and port, for example:\n" + " RedisCluster(host='localhost', port=6379)\n" + "2. list of startup nodes, for example:\n" + " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," + " ClusterNode('localhost', 6378)])" + ) + + # Update the connection arguments + # Whenever a new connection is established, RedisCluster's on_connect + # method should be run + kwargs["redis_connect_func"] = self.on_connect + self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) + self.response_callbacks = kwargs[ + "response_callbacks" + ] = self.__class__.RESPONSE_CALLBACKS + if host and port: + startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.read_from_replicas = read_from_replicas + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + require_full_coverage=require_full_coverage, + **self.connection_kwargs, + ) + + self.result_callbacks = self.__class__.RESULT_CALLBACKS + self.result_callbacks[ + "CLUSTER SLOTS" + ] = lambda cmd, res, **kwargs: parse_cluster_slots( + list(res.values())[0], **kwargs + ) + self.commands_parser = CommandsParser() + self._initialize = True + self._lock = asyncio.Lock() + + async def initialize(self) -> "RedisCluster": + """Get all nodes from startup nodes & creates connections if not initialized.""" + if self._initialize: + async with self._lock: + if self._initialize: + self._initialize = False + try: + await self.nodes_manager.initialize() + await self.commands_parser.initialize( + self.nodes_manager.default_node + ) + except BaseException: + self._initialize = True + await self.nodes_manager.close() + await self.nodes_manager.close("startup_nodes") + raise + return self + + async def close(self) -> None: + """Close all connections & client if initialized.""" + if not self._initialize: + async with self._lock: + if not self._initialize: + self._initialize = True + await self.nodes_manager.close() + + async def __aenter__(self) -> "RedisCluster": + return await self.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + await self.close() + + def __await__(self): + return self.initialize().__await__() + + _DEL_MESSAGE = "Unclosed RedisCluster client" + + def __del__(self, _warnings=warnings): + if hasattr(self, "_initialize") and not self._initialize: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... + + async def on_connect(self, connection: Connection) -> None: + connection.set_parser(ClusterParser) + await connection.on_connect() + + if self.read_from_replicas: + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + await connection.send_command("READONLY") + if str_if_bytes(await connection.read_response_without_lock()) != "OK": + raise ConnectionError("READONLY command failed") + + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> Optional["ClusterNode"]: + """Get node by (host, port) or node_name.""" + return self.nodes_manager.get_node(host, port, node_name) + + def get_primaries(self) -> List["ClusterNode"]: + """Get the primary nodes of the cluster.""" + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self) -> List["ClusterNode"]: + """Get the replica nodes of the cluster.""" + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self) -> "ClusterNode": + """Get a random node of the cluster.""" + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_nodes(self) -> List["ClusterNode"]: + """Get all nodes of the cluster.""" + return list(self.nodes_manager.nodes_cache.values()) + + def get_node_from_key( + self, key: str, replica: bool = False + ) -> Optional["ClusterNode"]: + """ + Get the cluster node corresponding to the provided key. + + :param key: + :param replica: + | Indicates if a replica should be returned + None will returned if no replica holds this key + + :raises SlotNotCoveredError: if the key is not covered by any slot. + """ + slot = self.keyslot(key) + slot_cache = self.nodes_manager.slots_cache.get(slot) + if not slot_cache: + raise SlotNotCoveredError(f'Slot "{slot}" is not covered by the cluster.') + if replica and len(self.nodes_manager.slots_cache[slot]) < 2: + return None + elif replica: + node_idx = 1 + else: + # primary + node_idx = 0 + + return slot_cache[node_idx] + + def get_default_node(self) -> "ClusterNode": + """Get the default node of the client.""" + return self.nodes_manager.default_node + + def set_default_node(self, node: "ClusterNode") -> None: + """ + Set the default node of the client. + + :raises DataError: if None is passed or node does not exist in cluster. + """ + if not node or not self.get_node(node_name=node.name): + raise DataError("The requested node does not exist in the cluster.") + + self.nodes_manager.default_node = node + + def set_response_callback(self, command: KeyT, callback: Callable) -> None: + """Set a custom response callback.""" + self.response_callbacks[command] = callback + + def get_encoder(self) -> Encoder: + """Get the encoder object of the client.""" + return self.encoder + + def get_connection_kwargs(self) -> Dict[str, Optional[Any]]: + """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" + return self.connection_kwargs + + def keyslot(self, key: EncodableT) -> int: + """ + Find the keyslot for a given key. + + See: https://redis.io/docs/manual/scaling/#redis-cluster-data-sharding + """ + k = self.encoder.encode(key) + return key_slot(k) + + async def _determine_nodes( + self, *args, node_flag: Optional[str] = None + ) -> List["ClusterNode"]: + command = args[0] + if not node_flag: + # get the nodes group for this command if it was predefined + node_flag = self.command_flags.get(command) + + if node_flag in self.node_flags: + if node_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] + if node_flag == self.__class__.PRIMARIES: + # return all primaries + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + if node_flag == self.__class__.REPLICAS: + # return all replicas + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + if node_flag == self.__class__.ALL_NODES: + # return all nodes + return list(self.nodes_manager.nodes_cache.values()) + if node_flag == self.__class__.RANDOM: + # return a random node + return [random.choice(list(self.nodes_manager.nodes_cache.values()))] + + # get the node that holds the key's slot + return [ + self.nodes_manager.get_node_from_slot( + await self._determine_slot(*args), + self.read_from_replicas and command in READ_COMMANDS, + ) + ] + + async def _determine_slot(self, *args) -> int: + command = args[0] + if self.command_flags.get(command) == SLOT_ID: + # The command contains the slot ID + return args[1] + + # Get the keys in the command + + # EVAL and EVALSHA are common enough that it's wasteful to go to the + # redis server to parse the keys. Besides, there is a bug in redis<7.0 + # where `self._get_command_keys()` fails anyway. So, we special case + # EVAL/EVALSHA. + # - issue: https://github.com/redis/redis/issues/9493 + # - fix: https://github.com/redis/redis/pull/9733 + if command in ("EVAL", "EVALSHA"): + # command syntax: EVAL "script body" num_keys ... + if len(args) <= 2: + raise RedisClusterException(f"Invalid args in command: {args}") + num_actual_keys = args[2] + eval_keys = args[3 : 3 + num_actual_keys] + # if there are 0 keys, that means the script can be run on any node + # so we can just return a random slot + if not eval_keys: + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + keys = eval_keys + else: + keys = await self.commands_parser.get_keys( + self.nodes_manager.default_node, *args + ) + if not keys: + # FCALL can call a function with 0 keys, that means the function + # can be run on any node so we can just return a random slot + if command in ("FCALL", "FCALL_RO"): + return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS) + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + f"target nodes.\nCommand: {args}" + ) + + # single key command + if len(keys) == 1: + return self.keyslot(keys[0]) + + # multi-key command; we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException( + f"{command} - all keys must map to the same key slot" + ) + + return slots.pop() + + def _is_node_flag( + self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] + ) -> bool: + return isinstance(target_nodes, str) and target_nodes in self.node_flags + + def _parse_target_nodes( + self, target_nodes: Union[List["ClusterNode"], "ClusterNode"] + ) -> List["ClusterNode"]: + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError( + "target_nodes type can be one of the following: " + "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list<ClusterNode>, or dict<any, ClusterNode>. " + f"The passed type is {type(target_nodes)}" + ) + return nodes + + async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any: + """ + Execute a raw command on the appropriate cluster node or target_nodes. + + It will retry the command as specified by :attr:`cluster_error_retry_attempts` & + then raise an exception. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot + """ + command = args[0] + target_nodes_specified = target_nodes = exception = None + retry_attempts = self.cluster_error_retry_attempts + + passed_targets = kwargs.pop("target_nodes", None) + if passed_targets and not self._is_node_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + target_nodes_specified = True + retry_attempts = 1 + + for _ in range(0, retry_attempts): + if self._initialize: + await self.initialize() + try: + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = await self._determine_nodes( + *args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {args} command on" + ) + + if len(target_nodes) == 1: + # Return the processed result + ret = await self._execute_command(target_nodes[0], *args, **kwargs) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, {target_nodes[0].name: ret}, **kwargs + ) + return ret + else: + keys = [node.name for node in target_nodes] + values = await asyncio.gather( + *( + asyncio.ensure_future( + self._execute_command(node, *args, **kwargs) + ) + for node in target_nodes + ) + ) + if command in self.result_callbacks: + return self.result_callbacks[command]( + command, dict(zip(keys, values)), **kwargs + ) + return dict(zip(keys, values)) + except BaseException as e: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # The nodes and slots cache were reinitialized. + # Try again with the new cluster setup. + exception = e + else: + # All other errors should be raised. + raise e + + # If it fails the configured number of times then raise exception back + # to caller of this method + raise exception + + async def _execute_command( + self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs + ) -> Any: + redirect_addr = asking = moved = None + ttl = self.RedisClusterRequestTTL + connection_error_retry_counter = 0 + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + await target_node.execute_command("ASKING") + asking = False + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = await self._determine_slot(*args) + target_node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and args[0] in READ_COMMANDS + ) + moved = False + + return await target_node.execute_command(*args, **kwargs) + except BusyLoadingError: + raise + except (ConnectionError, TimeoutError): + # Give the node 0.25 seconds to get back up and retry again + # with same node and configuration. After 5 attempts then try + # to reinitialize the cluster and see if the nodes + # configuration has changed or not + connection_error_retry_counter += 1 + if connection_error_retry_counter < 5: + await asyncio.sleep(0.25) + else: + # Hard force of reinitialize of the node/slots setup + # and try again with the new setup + await self.close() + raise + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when + # the same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + self.reinitialize_counter += 1 + if ( + self.reinitialize_steps + and self.reinitialize_counter % self.reinitialize_steps == 0 + ): + await self.close() + # Reset the counter + self.reinitialize_counter = 0 + else: + self.nodes_manager._moved_exception = e + moved = True + except TryAgainError: + if ttl < self.RedisClusterRequestTTL / 2: + await asyncio.sleep(0.05) + except AskError as e: + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except ClusterDownError: + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + await asyncio.sleep(0.25) + await self.close() + raise + + raise ClusterError("TTL exhausted.") + + +class ClusterNode: + """ + Create a new ClusterNode. + + Each ClusterNode manages multiple :class:`~redis.asyncio.connection.Connection` + objects for the (host, port). + """ + + __slots__ = ( + "_connections", + "_free", + "connection_class", + "connection_kwargs", + "host", + "max_connections", + "name", + "port", + "response_callbacks", + "server_type", + ) + + def __init__( + self, + host: str, + port: int, + server_type: Optional[str] = None, + max_connections: int = 2 ** 31, + connection_class: Type[Connection] = Connection, + response_callbacks: Dict = None, + **connection_kwargs, + ) -> None: + if host == "localhost": + host = socket.gethostbyname(host) + + connection_kwargs["host"] = host + connection_kwargs["port"] = port + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + + self.max_connections = max_connections + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.response_callbacks = response_callbacks + + self._connections = [] + self._free = collections.deque(maxlen=self.max_connections) + + def __repr__(self) -> str: + return ( + f"[host={self.host}, port={self.port}, " + f"name={self.name}, server_type={self.server_type}]" + ) + + def __eq__(self, obj: "ClusterNode") -> bool: + return isinstance(obj, ClusterNode) and obj.name == self.name + + _DEL_MESSAGE = "Unclosed ClusterNode object" + + def __del__(self, _warnings=warnings): + for connection in self._connections: + if connection.is_connected: + _warnings.warn( + f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self + ) + try: + context = {"client": self, "message": self._DEL_MESSAGE} + # TODO: Change to get_running_loop() when dropping support for py3.6 + asyncio.get_event_loop().call_exception_handler(context) + except RuntimeError: + ... + break + + async def disconnect(self) -> None: + ret = await asyncio.gather( + *( + asyncio.ensure_future(connection.disconnect()) + for connection in self._connections + ), + return_exceptions=True, + ) + exc = next((res for res in ret if isinstance(res, Exception)), None) + if exc: + raise exc + + async def execute_command(self, *args, **kwargs) -> Any: + # Acquire connection + connection = None + if self._free: + for _ in range(len(self._free)): + connection = self._free.popleft() + if connection.is_connected: + break + self._free.append(connection) + else: + connection = self._free.popleft() + else: + if len(self._connections) < self.max_connections: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + else: + raise ConnectionError("Too many connections") + + # Execute command + command = connection.pack_command(*args) + await connection.send_packed_command(command, False) + try: + if NEVER_DECODE in kwargs: + response = await connection.read_response_without_lock( + disable_decoding=True + ) + else: + response = await connection.read_response_without_lock() + except ResponseError: + if EMPTY_RESPONSE in kwargs: + return kwargs[EMPTY_RESPONSE] + raise + finally: + # Release connection + self._free.append(connection) + + # Return response + try: + return self.response_callbacks[args[0]](response, **kwargs) + except KeyError: + return response + + +class NodesManager: + __slots__ = ( + "_moved_exception", + "_require_full_coverage", + "connection_kwargs", + "default_node", + "nodes_cache", + "read_load_balancer", + "slots_cache", + "startup_nodes", + ) + + def __init__( + self, + startup_nodes: List["ClusterNode"], + require_full_coverage: bool = False, + **kwargs, + ) -> None: + self.nodes_cache = {} + self.slots_cache = {} + self.startup_nodes = {node.name: node for node in startup_nodes} + self.default_node = None + self._require_full_coverage = require_full_coverage + self._moved_exception = None + self.connection_kwargs = kwargs + self.read_load_balancer = LoadBalancer() + + def get_node( + self, + host: Optional[str] = None, + port: Optional[int] = None, + node_name: Optional[str] = None, + ) -> "ClusterNode": + if host and port: + # the user passed host and port + if host == "localhost": + host = socket.gethostbyname(host) + return self.nodes_cache.get(get_node_name(host=host, port=port)) + elif node_name: + return self.nodes_cache.get(node_name) + else: + raise DataError( + "get_node requires one of the following: " + "1. node name " + "2. host and port" + ) + + def set_nodes( + self, + old: Dict[str, "ClusterNode"], + new: Dict[str, "ClusterNode"], + remove_old=False, + ) -> None: + tasks = [] + if remove_old: + tasks = [ + asyncio.ensure_future(node.disconnect()) + for name, node in old.items() + if name not in new + ] + for name, node in new.items(): + if name in old: + if old[name] is node: + continue + tasks.append(asyncio.ensure_future(old[name].disconnect())) + old[name] = node + + def _update_moved_slots(self) -> None: + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node: + # The node already exists + if redirected_node.server_type != PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode( + e.host, e.port, PRIMARY, **self.connection_kwargs + ) + self.set_nodes(self.nodes_cache, {redirected_node.name: redirected_node}) + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot( + self, slot: int, read_from_replicas: bool = False + ) -> "ClusterNode": + if self._moved_exception: + self._update_moved_slots() + + try: + if read_from_replicas: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot]) + ) + return self.slots_cache[slot][node_idx] + return self.slots_cache[slot][0] + except (IndexError, TypeError): + raise SlotNotCoveredError( + f'Slot "{slot}" not covered by the cluster. ' + f'"require_full_coverage={self._require_full_coverage}"' + ) + + def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + async def initialize(self) -> None: + self.read_load_balancer.reset() + tmp_nodes_cache = {} + tmp_slots = {} + disagreements = [] + startup_nodes_reachable = False + fully_covered = False + for startup_node in self.startup_nodes.values(): + try: + # Make sure cluster mode is enabled on this node + if not (await startup_node.execute_command("INFO")).get( + "cluster_enabled" + ): + raise RedisClusterException( + "Cluster mode is not enabled on this node" + ) + cluster_slots = str_if_bytes( + await startup_node.execute_command("CLUSTER SLOTS") + ) + startup_nodes_reachable = True + except (ConnectionError, TimeoutError): + continue + except ResponseError as e: + # Isn't a cluster connection, so it won't parse these + # exceptions automatically + message = e.__str__() + if "CLUSTERDOWN" in message or "MASTERDOWN" in message: + continue + else: + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + f"server: {startup_node}. error: {message}" + ) + except Exception as e: + message = e.__str__() + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + f"server {startup_node.name}. error: {message}" + ) + + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if ( + len(cluster_slots) == 1 + and not cluster_slots[0][2][0] + and len(self.startup_nodes) == 1 + ): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + for i in range(2, len(slot)): + slot[i] = [str_if_bytes(val) for val in slot[i]] + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if not target_node: + target_node = ClusterNode( + host, port, PRIMARY, **self.connection_kwargs + ) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port) + ) + if not target_replica_node: + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + tmp_slot = tmp_slots[i][0] + if tmp_slot.name != target_node.name: + disagreements.append( + f"{tmp_slot.name} vs {target_node.name} on slot: {i}" + ) + + if len(disagreements) > 5: + raise RedisClusterException( + f"startup_nodes could not agree on a valid " + f'slots cache: {", ".join(disagreements)}' + ) + + # Validate if all slots are covered or if we should try next startup node + fully_covered = True + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in tmp_slots: + fully_covered = False + break + if fully_covered: + break + + if not startup_nodes_reachable: + raise RedisClusterException( + "Redis Cluster cannot be connected. Please provide at least " + "one reachable node. " + ) + + # Check if the slots are not fully covered + if not fully_covered and self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + f"All slots are not covered after query all startup_nodes. " + f"{len(tmp_slots)} of {REDIS_CLUSTER_HASH_SLOTS} " + f"covered..." + ) + + # Set the tmp variables to the real variables + self.slots_cache = tmp_slots + self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) + # Populate the startup nodes with all discovered nodes + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) + + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] + # If initialize was called after a MovedError, clear it + self._moved_exception = None + + async def close(self, attr: str = "nodes_cache") -> None: + self.default_node = None + await asyncio.gather( + *( + asyncio.ensure_future(node.disconnect()) + for node in getattr(self, attr).values() + ) + ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 91961ba..9de2d46 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -24,7 +24,6 @@ from typing import ( Type, TypeVar, Union, - cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -110,32 +109,32 @@ class Encoder: def encode(self, value: EncodableT) -> EncodedT: """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, str): + return value.encode(self.encoding, self.encoding_errors) if isinstance(value, (bytes, memoryview)): return value - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) if isinstance(value, (int, float)): + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) return repr(value).encode() - if not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ # type: ignore[unreachable] - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - return value.encode(self.encoding, self.encoding_errors) + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) def decode(self, value: EncodableT, force=False) -> EncodableT: """Return a unicode string from the bytes-like representation""" if self.decode_responses or force: - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) if isinstance(value, bytes): return value.decode(self.encoding, self.encoding_errors) + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) return value @@ -336,7 +335,7 @@ class SocketBuffer: def close(self): try: self.purge() - self._buffer.close() # type: ignore[union-attr] + self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -466,7 +465,7 @@ class HiredisParser(BaseParser): self._next_response = False async def can_read(self, timeout: float): - if not self._reader: + if not self._stream or not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._next_response is False: @@ -480,14 +479,14 @@ class HiredisParser(BaseParser): timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ): - if self._stream is None or self._reader is None: - raise RedisError("Parser already closed.") - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - async with async_timeout.timeout(timeout): + if timeout is None: buffer = await self._stream.read(self._read_size) - if not isinstance(buffer, bytes) or len(buffer) == 0: + else: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None self._reader.feed(buffer) # data was read from the socket and added to the buffer. @@ -516,9 +515,6 @@ class HiredisParser(BaseParser): self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - response: Union[ - EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] - ] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -541,8 +537,7 @@ class HiredisParser(BaseParser): and isinstance(response[0], ConnectionError) ): raise response[0] - # cast as there won't be a ConnectionError here. - return cast(Union[EncodableT, List[EncodableT]], response) + return response DefaultParser: Type[Union[PythonParser, HiredisParser]] @@ -637,7 +632,7 @@ class Connection: self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_timeout: - if retry is None: + if not retry: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable @@ -681,7 +676,7 @@ class Connection: @property def is_connected(self): - return bool(self._reader and self._writer) + return self._reader and self._writer def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -713,7 +708,7 @@ class Connection: raise ConnectionError(exc) from exc try: - if self.redis_connect_func is None: + if not self.redis_connect_func: # Use the default on_connect function await self.on_connect() else: @@ -745,7 +740,7 @@ class Connection: self._reader = reader self._writer = writer sock = writer.transport.get_extra_info("socket") - if sock is not None: + if sock: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: # TCP_KEEPALIVE @@ -856,32 +851,29 @@ class Connection: await self.retry.call_with_retry(self._send_ping, self._ping_failed) async def _send_packed_command(self, command: Iterable[bytes]) -> None: - if self._writer is None: - raise RedisError("Connection already closed.") - self._writer.writelines(command) await self._writer.drain() async def send_packed_command( - self, - command: Union[bytes, str, Iterable[bytes]], - check_health: bool = True, - ): - """Send an already packed command to the Redis server""" - if not self._writer: + self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True + ) -> None: + if not self.is_connected: await self.connect() - # guard against health check recursion - if check_health: + elif check_health: await self.check_health() + try: if isinstance(command, str): command = command.encode() if isinstance(command, bytes): command = [command] - await asyncio.wait_for( - self._send_packed_command(command), - self.socket_timeout, - ) + if self.socket_timeout: + await asyncio.wait_for( + self._send_packed_command(command), self.socket_timeout + ) + else: + self._writer.writelines(command) + await self._writer.drain() except asyncio.TimeoutError: await self.disconnect() raise TimeoutError("Timeout writing to socket") from None @@ -901,8 +893,6 @@ class Connection: async def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" - if not self.is_connected: - await self.connect() await self.send_packed_command( self.pack_command(*args), check_health=kwargs.get("check_health", True) ) @@ -923,10 +913,50 @@ class Connection: """Read the response from a previously sent command""" try: async with self._lock: + if self.socket_timeout: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + + async def read_response_without_lock(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + if self.socket_timeout: async with async_timeout.timeout(self.socket_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) except asyncio.TimeoutError: await self.disconnect() raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") @@ -1182,10 +1212,7 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] self._lock = asyncio.Lock() def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: - pieces = [ - ("path", self.path), - ("db", self.db), - ] + pieces = [("path", self.path), ("db", self.db)] if self.client_name: pieces.append(("client_name", self.client_name)) return pieces @@ -1254,12 +1281,11 @@ def parse_url(url: str) -> ConnectKwargs: parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: - # We can't type this. - kwargs[name] = parser(value) # type: ignore[misc] + kwargs[name] = parser(value) except (TypeError, ValueError): raise ValueError(f"Invalid value for `{name}` in connection URL.") else: - kwargs[name] = value # type: ignore[misc] + kwargs[name] = value if parsed.username: kwargs["username"] = unquote(parsed.username) diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py new file mode 100644 index 0000000..273fe03 --- /dev/null +++ b/redis/asyncio/parser.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING, List, Optional, Union + +from redis.exceptions import RedisError, ResponseError + +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + + +class CommandsParser: + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands",) + + def __init__(self) -> None: + self.commands = {} + + async def initialize(self, r: "ClusterNode") -> None: + commands = await r.execute_command("COMMAND") + for cmd, command in commands.items(): + if "movablekeys" in command["flags"]: + commands[cmd] = -1 + elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: + commands[cmd] = 0 + elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: + commands[cmd] = 1 + self.commands = {cmd.upper(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys( + self, redis_conn: "ClusterNode", *args + ) -> Optional[Union[List[str], List[bytes]]]: + if len(args) < 2: + # The command has no keys in it + return None + + try: + command = self.commands[args[0]] + except KeyError: + # try to split the command name and to take only the main command + # e.g. 'memory' for 'memory usage' + args = args[0].split() + list(args[1:]) + cmd_name = args[0] + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name.upper()} command doesn't exist in Redis commands" + ) + + command = self.commands[cmd_name] + + if command == 1: + return [args[1]] + if command == 0: + return None + if command == -1: + return await self._get_moveable_keys(redis_conn, *args) + + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) + last_key_pos + return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] + + async def _get_moveable_keys( + self, redis_conn: "ClusterNode", *args + ) -> Optional[List[str]]: + try: + keys = await redis_conn.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys |