From f704281cf4c1f735c06a13946fcea42fa939e3a5 Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Mon, 30 May 2022 19:04:06 +0530 Subject: async_cluster: add/update typing (#2195) * async_cluster: add/update typing * async_cluster: update cleanup_kwargs with kwargs from async Connection * async_cluster: properly remove old nodes --- redis/asyncio/cluster.py | 162 +++++++++++++++++++++---------------- redis/asyncio/connection.py | 8 +- redis/asyncio/parser.py | 14 ++-- redis/cluster.py | 15 ++-- redis/utils.py | 5 +- tests/test_asyncio/test_cluster.py | 35 +++----- 6 files changed, 123 insertions(+), 116 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index fa91872..c4e01d6 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,11 +3,12 @@ import collections import random import socket import warnings -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union -from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis +from redis.asyncio.client import ResponseCallbackT from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url from redis.asyncio.parser import CommandsParser +from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( PRIMARY, READ_COMMANDS, @@ -15,7 +16,6 @@ from redis.cluster import ( SLOT_ID, AbstractRedisCluster, LoadBalancer, - cleanup_kwargs, get_node_name, parse_cluster_slots, ) @@ -41,9 +41,36 @@ from redis.typing import EncodableT, KeyT from redis.utils import dict_merge, str_if_bytes TargetNodesT = TypeVar( - "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] + "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) +CONNECTION_ALLOWED_KEYS = ( + "client_name", + "db", + "decode_responses", + "encoder_class", + "encoding", + "encoding_errors", + "health_check_interval", + "parser_class", + "password", + "redis_connect_func", + "retry", + "retry_on_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_read_size", + "socket_timeout", + "socket_type", + "username", +) + + +def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]: + """Remove unsupported or disabled keys from kwargs.""" + return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS} + class ClusterParser(DefaultParser): EXCEPTION_CLASSES = dict_merge( @@ -131,7 +158,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand """ @classmethod - def from_url(cls, url: str, **kwargs) -> "RedisCluster": + def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster": """ Return a Redis client object configured from the given URL. @@ -201,7 +228,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 10, url: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> None: if not startup_nodes: startup_nodes = [] @@ -212,7 +239,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "Argument 'db' is not possible to use in cluster mode" ) - # Get the startup node/s + # Get the startup node(s) if url: url_options = parse_url(url) if "path" in url_options: @@ -247,34 +274,34 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs) self.response_callbacks = kwargs[ "response_callbacks" - ] = self.__class__.RESPONSE_CALLBACKS + ] = self.__class__.RESPONSE_CALLBACKS.copy() if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + require_full_coverage=require_full_coverage, + **self.connection_kwargs, + ) self.encoder = Encoder( kwargs.get("encoding", "utf-8"), kwargs.get("encoding_errors", "strict"), kwargs.get("decode_responses", False), ) self.cluster_error_retry_attempts = cluster_error_retry_attempts - self.command_flags = self.__class__.COMMAND_FLAGS.copy() - self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas - self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps - self.nodes_manager = NodesManager( - startup_nodes=startup_nodes, - require_full_coverage=require_full_coverage, - **self.connection_kwargs, - ) - self.result_callbacks = self.__class__.RESULT_CALLBACKS + self.reinitialize_counter = 0 + self.commands_parser = CommandsParser() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() self.result_callbacks[ "CLUSTER SLOTS" ] = lambda cmd, res, **kwargs: parse_cluster_slots( list(res.values())[0], **kwargs ) - self.commands_parser = CommandsParser() self._initialize = True self._lock = asyncio.Lock() @@ -310,16 +337,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: await self.close() - def __await__(self): + def __await__(self) -> Generator[Any, None, "RedisCluster"]: return self.initialize().__await__() _DEL_MESSAGE = "Unclosed RedisCluster client" - def __del__(self, _warnings=warnings): + def __del__(self) -> None: if hasattr(self, "_initialize") and not self._initialize: - _warnings.warn( - f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self - ) + warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self) try: context = {"client": self, "message": self._DEL_MESSAGE} # TODO: Change to get_running_loop() when dropping support for py3.6 @@ -408,7 +433,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand self.nodes_manager.default_node = node - def set_response_callback(self, command: KeyT, callback: Callable) -> None: + def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: """Set a custom response callback.""" self.response_callbacks[command] = callback @@ -430,7 +455,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand return key_slot(k) async def _determine_nodes( - self, *args, node_flag: Optional[str] = None + self, *args: Any, node_flag: Optional[str] = None ) -> List["ClusterNode"]: command = args[0] if not node_flag: @@ -462,11 +487,11 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ) ] - async def _determine_slot(self, *args) -> int: + async def _determine_slot(self, *args: Any) -> int: command = args[0] if self.command_flags.get(command) == SLOT_ID: # The command contains the slot ID - return args[1] + return int(args[1]) # Get the keys in the command @@ -516,14 +541,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand return slots.pop() - def _is_node_flag( - self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str] - ) -> bool: + def _is_node_flag(self, target_nodes: Any) -> bool: return isinstance(target_nodes, str) and target_nodes in self.node_flags - def _parse_target_nodes( - self, target_nodes: Union[List["ClusterNode"], "ClusterNode"] - ) -> List["ClusterNode"]: + def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]: if isinstance(target_nodes, list): nodes = target_nodes elif isinstance(target_nodes, ClusterNode): @@ -533,7 +554,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand # Supports dictionaries of the format {node_name: node}. # It enables to execute commands with multi nodes as follows: # rc.cluster_save_config(rc.get_primaries()) - nodes = target_nodes.values() + nodes = list(target_nodes.values()) else: raise TypeError( "target_nodes type can be one of the following: " @@ -543,7 +564,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ) return nodes - async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any: + async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any: """ Execute a raw command on the appropriate cluster node or target_nodes. @@ -562,7 +583,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand can't be mapped to a slot """ command = args[0] - target_nodes_specified = target_nodes = exception = None + target_nodes = [] + target_nodes_specified = False retry_attempts = self.cluster_error_retry_attempts passed_targets = kwargs.pop("target_nodes", None) @@ -571,7 +593,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand target_nodes_specified = True retry_attempts = 1 - for _ in range(0, retry_attempts): + for _ in range(retry_attempts): if self._initialize: await self.initialize() try: @@ -622,9 +644,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand raise exception async def _execute_command( - self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs + self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any ) -> Any: - redirect_addr = asking = moved = None + asking = moved = False + redirect_addr = None ttl = self.RedisClusterRequestTTL connection_error_retry_counter = 0 @@ -725,8 +748,8 @@ class ClusterNode: server_type: Optional[str] = None, max_connections: int = 2 ** 31, connection_class: Type[Connection] = Connection, - response_callbacks: Dict = RedisCluster.RESPONSE_CALLBACKS, - **connection_kwargs, + response_callbacks: Dict[str, Any] = RedisCluster.RESPONSE_CALLBACKS, + **connection_kwargs: Any, ) -> None: if host == "localhost": host = socket.gethostbyname(host) @@ -743,8 +766,8 @@ class ClusterNode: self.connection_kwargs = connection_kwargs self.response_callbacks = response_callbacks - self._connections = [] - self._free = collections.deque(maxlen=self.max_connections) + self._connections: List[Connection] = [] + self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) def __repr__(self) -> str: return ( @@ -752,15 +775,15 @@ class ClusterNode: f"name={self.name}, server_type={self.server_type}]" ) - def __eq__(self, obj: "ClusterNode") -> bool: + def __eq__(self, obj: Any) -> bool: return isinstance(obj, ClusterNode) and obj.name == self.name _DEL_MESSAGE = "Unclosed ClusterNode object" - def __del__(self, _warnings=warnings): + def __del__(self) -> None: for connection in self._connections: if connection.is_connected: - _warnings.warn( + warnings.warn( f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self ) try: @@ -783,7 +806,7 @@ class ClusterNode: if exc: raise exc - async def execute_command(self, *args, **kwargs) -> Any: + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = None if self._free: @@ -829,11 +852,11 @@ class ClusterNode: class NodesManager: __slots__ = ( "_moved_exception", - "_require_full_coverage", "connection_kwargs", "default_node", "nodes_cache", "read_load_balancer", + "require_full_coverage", "slots_cache", "startup_nodes", ) @@ -842,23 +865,24 @@ class NodesManager: self, startup_nodes: List["ClusterNode"], require_full_coverage: bool = False, - **kwargs, + **kwargs: Any, ) -> None: - self.nodes_cache = {} - self.slots_cache = {} self.startup_nodes = {node.name: node for node in startup_nodes} - self.default_node = None - self._require_full_coverage = require_full_coverage - self._moved_exception = None + self.require_full_coverage = require_full_coverage self.connection_kwargs = kwargs + + self.default_node: "ClusterNode" = None + self.nodes_cache: Dict[str, "ClusterNode"] = {} + self.slots_cache: Dict[int, List["ClusterNode"]] = {} self.read_load_balancer = LoadBalancer() + self._moved_exception: MovedError = None def get_node( self, host: Optional[str] = None, port: Optional[int] = None, node_name: Optional[str] = None, - ) -> "ClusterNode": + ) -> Optional["ClusterNode"]: if host and port: # the user passed host and port if host == "localhost": @@ -877,20 +901,18 @@ class NodesManager: self, old: Dict[str, "ClusterNode"], new: Dict[str, "ClusterNode"], - remove_old=False, + remove_old: bool = False, ) -> None: - tasks = [] if remove_old: - tasks = [ - asyncio.ensure_future(node.disconnect()) - for name, node in old.items() - if name not in new - ] + for name in list(old.keys()): + if name not in new: + asyncio.ensure_future(old.pop(name).disconnect()) + for name, node in new.items(): if name in old: if old[name] is node: continue - tasks.append(asyncio.ensure_future(old[name].disconnect())) + asyncio.ensure_future(old[name].disconnect()) old[name] = node def _update_moved_slots(self) -> None: @@ -949,7 +971,7 @@ class NodesManager: except (IndexError, TypeError): raise SlotNotCoveredError( f'Slot "{slot}" not covered by the cluster. ' - f'"require_full_coverage={self._require_full_coverage}"' + f'"require_full_coverage={self.require_full_coverage}"' ) def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]: @@ -961,8 +983,8 @@ class NodesManager: async def initialize(self) -> None: self.read_load_balancer.reset() - tmp_nodes_cache = {} - tmp_slots = {} + tmp_nodes_cache: Dict[str, "ClusterNode"] = {} + tmp_slots: Dict[int, List["ClusterNode"]] = {} disagreements = [] startup_nodes_reachable = False fully_covered = False @@ -975,9 +997,7 @@ class NodesManager: raise RedisClusterException( "Cluster mode is not enabled on this node" ) - cluster_slots = str_if_bytes( - await startup_node.execute_command("CLUSTER SLOTS") - ) + cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True except (ConnectionError, TimeoutError): continue @@ -1069,7 +1089,7 @@ class NodesManager: # Validate if all slots are covered or if we should try next startup node fully_covered = True - for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + for i in range(REDIS_CLUSTER_HASH_SLOTS): if i not in tmp_slots: fully_covered = False break @@ -1083,7 +1103,7 @@ class NodesManager: ) # Check if the slots are not fully covered - if not fully_covered and self._require_full_coverage: + if not fully_covered and self.require_full_coverage: # Despite the requirement that the slots be covered, there # isn't a full coverage raise RedisClusterException( diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 9de2d46..d9b0974 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -684,7 +684,7 @@ class Connection: def clear_connect_callbacks(self): self._connect_callbacks = [] - def set_parser(self, parser_class): + def set_parser(self, parser_class: Type[BaseParser]) -> None: """ Creates a new instance of parser_class with socket size: _socket_read_size and assigns it to the parser for the connection @@ -766,7 +766,7 @@ class Connection: f"{exception.args[0]}." ) - async def on_connect(self): + async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) @@ -807,7 +807,7 @@ class Connection: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") - async def disconnect(self): + async def disconnect(self) -> None: """Disconnects from the Redis server""" try: async with async_timeout.timeout(self.socket_connect_timeout): @@ -891,7 +891,7 @@ class Connection: await self.disconnect() raise - async def send_command(self, *args, **kwargs): + async def send_command(self, *args: Any, **kwargs: Any) -> None: """Pack and send a command to the Redis server""" await self.send_packed_command( self.pack_command(*args), check_health=kwargs.get("check_health", True) diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py index 273fe03..6286351 100644 --- a/redis/asyncio/parser.py +++ b/redis/asyncio/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from redis.exceptions import RedisError, ResponseError @@ -25,7 +25,7 @@ class CommandsParser: __slots__ = ("commands",) def __init__(self) -> None: - self.commands = {} + self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} async def initialize(self, r: "ClusterNode") -> None: commands = await r.execute_command("COMMAND") @@ -42,8 +42,8 @@ class CommandsParser: # our logic to use COMMAND INFO changes to determine the key positions # https://github.com/redis/redis/pull/8324 async def get_keys( - self, redis_conn: "ClusterNode", *args - ) -> Optional[Union[List[str], List[bytes]]]: + self, redis_conn: "ClusterNode", *args: Any + ) -> Optional[Tuple[str, ...]]: if len(args) < 2: # The command has no keys in it return None @@ -67,7 +67,7 @@ class CommandsParser: command = self.commands[cmd_name] if command == 1: - return [args[1]] + return (args[1],) if command == 0: return None if command == -1: @@ -79,8 +79,8 @@ class CommandsParser: return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] async def _get_moveable_keys( - self, redis_conn: "ClusterNode", *args - ) -> Optional[List[str]]: + self, redis_conn: "ClusterNode", *args: Any + ) -> Optional[Tuple[str, ...]]: try: keys = await redis_conn.execute_command("COMMAND GETKEYS", *args) except ResponseError as e: diff --git a/redis/cluster.py b/redis/cluster.py index 0b9c543..46a96a6 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,6 +6,7 @@ import sys import threading import time from collections import OrderedDict +from typing import Any, Dict, Tuple from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan from redis.commands import CommandsParser, RedisClusterCommands @@ -40,7 +41,7 @@ from redis.utils import ( log = logging.getLogger(__name__) -def get_node_name(host, port): +def get_node_name(host: str, port: int) -> str: return f"{host}:{port}" @@ -74,10 +75,12 @@ def parse_pubsub_numsub(command, res, **options): return ret_numsub -def parse_cluster_slots(resp, **options): +def parse_cluster_slots( + resp: Any, **options: Any +) -> Dict[Tuple[int, int], Dict[str, Any]]: current_host = options.get("current_host", "") - def fix_server(*args): + def fix_server(*args: Any) -> Tuple[str, Any]: return str_if_bytes(args[0]) or current_host, args[1] slots = {} @@ -1248,17 +1251,17 @@ class LoadBalancer: Round-Robin Load Balancing """ - def __init__(self, start_index=0): + def __init__(self, start_index: int = 0) -> None: self.primary_to_idx = {} self.start_index = start_index - def get_server_index(self, primary, list_size): + def get_server_index(self, primary: str, list_size: int) -> int: server_index = self.primary_to_idx.setdefault(primary, self.start_index) # Update the index self.primary_to_idx[primary] = (server_index + 1) % list_size return server_index - def reset(self): + def reset(self) -> None: self.primary_to_idx.clear() diff --git a/redis/utils.py b/redis/utils.py index 9ab75f2..0c34e1e 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from typing import Any, Dict, Mapping, Union try: import hiredis # noqa @@ -34,7 +35,7 @@ def pipeline(redis_obj): p.execute() -def str_if_bytes(value): +def str_if_bytes(value: Union[str, bytes]) -> str: return ( value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value ) @@ -44,7 +45,7 @@ def safe_str(value): return str(str_if_bytes(value)) -def dict_merge(*dicts): +def dict_merge(*dicts: Mapping[str, Any]) -> Dict[str, Any]: """ Merge all provided dicts into 1 dict. *dicts : `dict` diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 6c28ce3..123adc8 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -3,6 +3,7 @@ import binascii import datetime import sys import warnings +from typing import Any, Callable, Dict, List, Optional, Type, Union import pytest @@ -13,21 +14,13 @@ if sys.version_info[0:2] == (3, 6): else: import pytest_asyncio -from typing import Callable, Dict, List, Optional, Type, Union - from _pytest.fixtures import FixtureRequest, SubRequest from redis.asyncio import Connection, RedisCluster -from redis.asyncio.cluster import ( - PRIMARY, - REDIS_CLUSTER_HASH_SLOTS, - REPLICA, - ClusterNode, - NodesManager, - get_node_name, -) +from redis.asyncio.cluster import ClusterNode, NodesManager from redis.asyncio.parser import CommandsParser -from redis.crc import key_slot +from redis.cluster import PRIMARY, REPLICA, get_node_name +from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, ClusterDownError, @@ -109,7 +102,7 @@ async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: CommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: - def cmd_init_mock(self, r): + def cmd_init_mock(self, r: ClusterNode) -> None: self.commands = { "GET": { "name": "get", @@ -126,12 +119,7 @@ async def get_mocked_redis_client(*args, **kwargs) -> RedisCluster: return await RedisCluster(*args, **kwargs) -def mock_node_resp( - node: ClusterNode, - response: Union[ - List[List[Union[int, List[Union[str, int]]]]], List[bytes], str, int - ], -) -> ClusterNode: +def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: connection = mock.AsyncMock() connection.is_connected = True connection.read_response_without_lock.return_value = response @@ -141,12 +129,7 @@ def mock_node_resp( return node -def mock_all_nodes_resp( - rc: RedisCluster, - response: Union[ - List[List[Union[int, List[Union[str, int]]]]], List[bytes], int, str - ], -) -> RedisCluster: +def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster: for node in rc.get_nodes(): mock_node_resp(node, response) return rc @@ -461,7 +444,7 @@ class TestRedisClusterObj: CommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: - def cmd_init_mock(self, r): + def cmd_init_mock(self, r: ClusterNode) -> None: self.commands = { "GET": { "name": "get", @@ -2217,7 +2200,7 @@ class TestNodesManager: CommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: - def cmd_init_mock(self, r): + def cmd_init_mock(self, r: ClusterNode) -> None: self.commands = { "GET": { "name": "get", -- cgit v1.2.1