diff options
author | Utkarsh Gupta <utkarshgupta137@gmail.com> | 2022-05-30 19:04:06 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-30 16:34:06 +0300 |
commit | f704281cf4c1f735c06a13946fcea42fa939e3a5 (patch) | |
tree | 85a7affc680058b54d1df30d65e1f97c44c08847 /redis/asyncio/cluster.py | |
parent | 48079083a7f6ac1bdd948c03175f9ffd42aa1f6b (diff) | |
download | redis-py-f704281cf4c1f735c06a13946fcea42fa939e3a5.tar.gz |
async_cluster: add/update typing (#2195)
* async_cluster: add/update typing
* async_cluster: update cleanup_kwargs with kwargs from async Connection
* async_cluster: properly remove old nodes
Diffstat (limited to 'redis/asyncio/cluster.py')
-rw-r--r-- | redis/asyncio/cluster.py | 162 |
1 files changed, 91 insertions, 71 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index fa91872..c4e01d6 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,11 +3,12 @@ import collections import random import socket import warnings -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union -from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis +from redis.asyncio.client import ResponseCallbackT from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url from redis.asyncio.parser import CommandsParser +from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( PRIMARY, READ_COMMANDS, @@ -15,7 +16,6 @@ from redis.cluster import ( SLOT_ID, AbstractRedisCluster, LoadBalancer, - cleanup_kwargs, get_node_name, parse_cluster_slots, ) @@ -41,9 +41,36 @@ from redis.typing import EncodableT, KeyT from redis.utils import dict_merge, str_if_bytes TargetNodesT = TypeVar( - "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] + "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) +CONNECTION_ALLOWED_KEYS = ( + "client_name", + "db", + "decode_responses", + "encoder_class", + "encoding", + "encoding_errors", + "health_check_interval", + "parser_class", + "password", + "redis_connect_func", + "retry", + "retry_on_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_read_size", + "socket_timeout", + "socket_type", + "username", +) + + +def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]: + """Remove unsupported or disabled keys from kwargs.""" + return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS} + class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( @@ -131,7 +158,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand """ @classmethod - def from_url(cls, url: str, **kwargs) -> "RedisCluster": + def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": """ Return a Redis client object configured from the given URL. @@ -201,7 +228,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 10, url: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> None: if not startup_nodes: startup_nodes = [] @@ -212,7 +239,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "Argument 'db' is not possible to use in cluster mode" ) - # Get the startup node/s + # Get the startup node(s) if url: url_options = parse_url(url) if "path" in url_options: @@ -247,34 +274,34 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) self.response_callbacks = kwargs[ "response_callbacks" - ] = self.__class__.RESPONSE_CALLBACKS + ] = self.__class__.RESPONSE_CALLBACKS.copy() if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + require_full_coverage=require_full_coverage, + **self.connection_kwargs, + ) self.encoder = Encoder( kwargs.get("encoding", "utf-8"), kwargs.get("encoding_errors", "strict"), kwargs.get("decode_responses", False), ) self.cluster_error_retry_attempts = cluster_error_retry_attempts - self.command_flags = self.__class__.COMMAND_FLAGS.copy() - self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas - self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps - self.nodes_manager = NodesManager( - startup_nodes=startup_nodes, - require_full_coverage=require_full_coverage, - **self.connection_kwargs, - ) - self.result_callbacks = self.__class__.RESULT_CALLBACKS + self.reinitialize_counter = 0 + self.commands_parser = CommandsParser() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() self.result_callbacks[ "CLUSTER SLOTS" ] = lambda cmd, res, **kwargs: parse_cluster_slots( list(res.values())[0], **kwargs ) - self.commands_parser = CommandsParser() self._initialize = True self._lock = asyncio.Lock() @@ -310,16 +337,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: await self.close() - def __await__(self): + def __await__(self) -> Generator[Any, None, "RedisCluster"]: return self.initialize().__await__() _DEL_MESSAGE = "Unclosed RedisCluster client" - def __del__(self, _warnings=warnings): + def __del__(self) -> None: if hasattr(self, "_initialize") and not self._initialize: - _warnings.warn( - f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self - ) + warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) try: context = {"client": self, "message": self._DEL_MESSAGE} # TODO: Change to get_running_loop() when dropping support for py3.6 @@ -408,7 +433,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand self.nodes_manager.default_node = node - def set_response_callback(self, command: KeyT, callback: Callable) -> None: + def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: """Set a custom response callback.""" self.response_callbacks[command] = callback @@ -430,7 +455,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand return key_slot(k) async def _determine_nodes( - self, *args, node_flag: Optional[str] = None + self, *args: Any, node_flag: Optional[str] = None ) -> List["ClusterNode"]: command = args[0] if not node_flag: @@ -462,11 +487,11 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ) ] - async def _determine_slot(self, *args) -> int: + async def _determine_slot(self, *args: Any) -> int: command = args[0] if self.command_flags.get(command) == SLOT_ID: # The command contains the slot ID - return args[1] + return int(args[1]) # Get the keys in the command @@ -516,14 +541,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand return slots.pop() - def _is_node_flag( - self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] - ) -> bool: + def _is_node_flag(self, target_nodes: Any) -> bool: return isinstance(target_nodes, str) and target_nodes in self.node_flags - def _parse_target_nodes( - self, target_nodes: Union[List["ClusterNode"], "ClusterNode"] - ) -> List["ClusterNode"]: + def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: if isinstance(target_nodes, list): nodes = target_nodes elif isinstance(target_nodes, ClusterNode): @@ -533,7 +554,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand # Supports dictionaries of the format {node_name: node}. # It enables to execute commands with multi nodes as follows: # rc.cluster_save_config(rc.get_primaries()) - nodes = target_nodes.values() + nodes = list(target_nodes.values()) else: raise TypeError( "target_nodes type can be one of the following: " @@ -543,7 +564,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ) return nodes - async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any: + async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: """ Execute a raw command on the appropriate cluster node or target_nodes. @@ -562,7 +583,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand can't be mapped to a slot """ command = args[0] - target_nodes_specified = target_nodes = exception = None + target_nodes = [] + target_nodes_specified = False retry_attempts = self.cluster_error_retry_attempts passed_targets = kwargs.pop("target_nodes", None) @@ -571,7 +593,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand target_nodes_specified = True retry_attempts = 1 - for _ in range(0, retry_attempts): + for _ in range(retry_attempts): if self._initialize: await self.initialize() try: @@ -622,9 +644,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand raise exception async def _execute_command( - self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs + self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any ) -> Any: - redirect_addr = asking = moved = None + asking = moved = False + redirect_addr = None ttl = self.RedisClusterRequestTTL connection_error_retry_counter = 0 @@ -725,8 +748,8 @@ class ClusterNode: server_type: Optional[str] = None, max_connections: int = 2 ** 31, connection_class: Type[Connection] = Connection, - response_callbacks: Dict = RedisCluster.RESPONSE_CALLBACKS, - **connection_kwargs, + response_callbacks: Dict[str, Any] = RedisCluster.RESPONSE_CALLBACKS, + **connection_kwargs: Any, ) -> None: if host == "localhost": host = socket.gethostbyname(host) @@ -743,8 +766,8 @@ class ClusterNode: self.connection_kwargs = connection_kwargs self.response_callbacks = response_callbacks - self._connections = [] - self._free = collections.deque(maxlen=self.max_connections) + self._connections: List[Connection] = [] + self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) def __repr__(self) -> str: return ( @@ -752,15 +775,15 @@ class ClusterNode: f"name={self.name}, server_type={self.server_type}]" ) - def __eq__(self, obj: "ClusterNode") -> bool: + def __eq__(self, obj: Any) -> bool: return isinstance(obj, ClusterNode) and obj.name == self.name _DEL_MESSAGE = "Unclosed ClusterNode object" - def __del__(self, _warnings=warnings): + def __del__(self) -> None: for connection in self._connections: if connection.is_connected: - _warnings.warn( + warnings.warn( f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self ) try: @@ -783,7 +806,7 @@ class ClusterNode: if exc: raise exc - async def execute_command(self, *args, **kwargs) -> Any: + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = None if self._free: @@ -829,11 +852,11 @@ class ClusterNode: class NodesManager: __slots__ = ( "_moved_exception", - "_require_full_coverage", "connection_kwargs", "default_node", "nodes_cache", "read_load_balancer", + "require_full_coverage", "slots_cache", "startup_nodes", ) @@ -842,23 +865,24 @@ class NodesManager: self, startup_nodes: List["ClusterNode"], require_full_coverage: bool = False, - **kwargs, + **kwargs: Any, ) -> None: - self.nodes_cache = {} - self.slots_cache = {} self.startup_nodes = {node.name: node for node in startup_nodes} - self.default_node = None - self._require_full_coverage = require_full_coverage - self._moved_exception = None + self.require_full_coverage = require_full_coverage self.connection_kwargs = kwargs + + self.default_node: "ClusterNode" = None + self.nodes_cache: Dict[str, "ClusterNode"] = {} + self.slots_cache: Dict[int, List["ClusterNode"]] = {} self.read_load_balancer = LoadBalancer() + self._moved_exception: MovedError = None def get_node( self, host: Optional[str] = None, port: Optional[int] = None, node_name: Optional[str] = None, - ) -> "ClusterNode": + ) -> Optional["ClusterNode"]: if host and port: # the user passed host and port if host == "localhost": @@ -877,20 +901,18 @@ class NodesManager: self, old: Dict[str, "ClusterNode"], new: Dict[str, "ClusterNode"], - remove_old=False, + remove_old: bool = False, ) -> None: - tasks = [] if remove_old: - tasks = [ - asyncio.ensure_future(node.disconnect()) - for name, node in old.items() - if name not in new - ] + for name in list(old.keys()): + if name not in new: + asyncio.ensure_future(old.pop(name).disconnect()) + for name, node in new.items(): if name in old: if old[name] is node: continue - tasks.append(asyncio.ensure_future(old[name].disconnect())) + asyncio.ensure_future(old[name].disconnect()) old[name] = node def _update_moved_slots(self) -> None: @@ -949,7 +971,7 @@ class NodesManager: except (IndexError, TypeError): raise SlotNotCoveredError( f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' + f'"require_full_coverage={self.require_full_coverage}"' ) def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: @@ -961,8 +983,8 @@ class NodesManager: async def initialize(self) -> None: self.read_load_balancer.reset() - tmp_nodes_cache = {} - tmp_slots = {} + tmp_nodes_cache: Dict[str, "ClusterNode"] = {} + tmp_slots: Dict[int, List["ClusterNode"]] = {} disagreements = [] startup_nodes_reachable = False fully_covered = False @@ -975,9 +997,7 @@ class NodesManager: raise RedisClusterException( "Cluster mode is not enabled on this node" ) - cluster_slots = str_if_bytes( - await startup_node.execute_command("CLUSTER SLOTS") - ) + cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True except (ConnectionError, TimeoutError): continue @@ -1069,7 +1089,7 @@ class NodesManager: # Validate if all slots are covered or if we should try next startup node fully_covered = True - for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + for i in range(REDIS_CLUSTER_HASH_SLOTS): if i not in tmp_slots: fully_covered = False break @@ -1083,7 +1103,7 @@ class NodesManager: ) # Check if the slots are not fully covered - if not fully_covered and self._require_full_coverage: + if not fully_covered and self.require_full_coverage: # Despite the requirement that the slots be covered, there # isn't a full coverage raise RedisClusterException( |