summaryrefslogtreecommitdiff
path: root/redis/asyncio/cluster.py
diff options
context:
space:
mode:
authorUtkarsh Gupta <utkarshgupta137@gmail.com>2022-05-30 19:04:06 +0530
committerGitHub <noreply@github.com>2022-05-30 16:34:06 +0300
commitf704281cf4c1f735c06a13946fcea42fa939e3a5 (patch)
tree85a7affc680058b54d1df30d65e1f97c44c08847 /redis/asyncio/cluster.py
parent48079083a7f6ac1bdd948c03175f9ffd42aa1f6b (diff)
downloadredis-py-f704281cf4c1f735c06a13946fcea42fa939e3a5.tar.gz
async_cluster: add/update typing (#2195)
* async_cluster: add/update typing * async_cluster: update cleanup_kwargs with kwargs from async Connection * async_cluster: properly remove old nodes
Diffstat (limited to 'redis/asyncio/cluster.py')
-rw-r--r--redis/asyncio/cluster.py162
1 files changed, 91 insertions, 71 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index fa91872..c4e01d6 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -3,11 +3,12 @@ import collections
import random
import socket
import warnings
-from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
+from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union
-from redis.asyncio.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
+from redis.asyncio.client import ResponseCallbackT
from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url
from redis.asyncio.parser import CommandsParser
+from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
PRIMARY,
READ_COMMANDS,
@@ -15,7 +16,6 @@ from redis.cluster import (
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
- cleanup_kwargs,
get_node_name,
parse_cluster_slots,
)
@@ -41,9 +41,36 @@ from redis.typing import EncodableT, KeyT
from redis.utils import dict_merge, str_if_bytes
TargetNodesT = TypeVar(
- "TargetNodesT", "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
+ "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
)
+CONNECTION_ALLOWED_KEYS = (
+ "client_name",
+ "db",
+ "decode_responses",
+ "encoder_class",
+ "encoding",
+ "encoding_errors",
+ "health_check_interval",
+ "parser_class",
+ "password",
+ "redis_connect_func",
+ "retry",
+ "retry_on_timeout",
+ "socket_connect_timeout",
+ "socket_keepalive",
+ "socket_keepalive_options",
+ "socket_read_size",
+ "socket_timeout",
+ "socket_type",
+ "username",
+)
+
+
+def cleanup_kwargs(**kwargs: Any) -> Dict[str, Any]:
+ """Remove unsupported or disabled keys from kwargs."""
+ return {k: v for k, v in kwargs.items() if k in CONNECTION_ALLOWED_KEYS}
+
class ClusterParser(DefaultParser):
EXCEPTION_CLASSES = dict_merge(
@@ -131,7 +158,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
"""
@classmethod
- def from_url(cls, url: str, **kwargs) -> "RedisCluster":
+ def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
"""
Return a Redis client object configured from the given URL.
@@ -201,7 +228,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
cluster_error_retry_attempts: int = 3,
reinitialize_steps: int = 10,
url: Optional[str] = None,
- **kwargs,
+ **kwargs: Any,
) -> None:
if not startup_nodes:
startup_nodes = []
@@ -212,7 +239,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
"Argument 'db' is not possible to use in cluster mode"
)
- # Get the startup node/s
+ # Get the startup node(s)
if url:
url_options = parse_url(url)
if "path" in url_options:
@@ -247,34 +274,34 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
self.connection_kwargs = kwargs = cleanup_kwargs(**kwargs)
self.response_callbacks = kwargs[
"response_callbacks"
- ] = self.__class__.RESPONSE_CALLBACKS
+ ] = self.__class__.RESPONSE_CALLBACKS.copy()
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
+ self.nodes_manager = NodesManager(
+ startup_nodes=startup_nodes,
+ require_full_coverage=require_full_coverage,
+ **self.connection_kwargs,
+ )
self.encoder = Encoder(
kwargs.get("encoding", "utf-8"),
kwargs.get("encoding_errors", "strict"),
kwargs.get("decode_responses", False),
)
self.cluster_error_retry_attempts = cluster_error_retry_attempts
- self.command_flags = self.__class__.COMMAND_FLAGS.copy()
- self.node_flags = self.__class__.NODE_FLAGS.copy()
self.read_from_replicas = read_from_replicas
- self.reinitialize_counter = 0
self.reinitialize_steps = reinitialize_steps
- self.nodes_manager = NodesManager(
- startup_nodes=startup_nodes,
- require_full_coverage=require_full_coverage,
- **self.connection_kwargs,
- )
- self.result_callbacks = self.__class__.RESULT_CALLBACKS
+ self.reinitialize_counter = 0
+ self.commands_parser = CommandsParser()
+ self.node_flags = self.__class__.NODE_FLAGS.copy()
+ self.command_flags = self.__class__.COMMAND_FLAGS.copy()
+ self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
self.result_callbacks[
"CLUSTER SLOTS"
] = lambda cmd, res, **kwargs: parse_cluster_slots(
list(res.values())[0], **kwargs
)
- self.commands_parser = CommandsParser()
self._initialize = True
self._lock = asyncio.Lock()
@@ -310,16 +337,14 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
await self.close()
- def __await__(self):
+ def __await__(self) -> Generator[Any, None, "RedisCluster"]:
return self.initialize().__await__()
_DEL_MESSAGE = "Unclosed RedisCluster client"
- def __del__(self, _warnings=warnings):
+ def __del__(self) -> None:
if hasattr(self, "_initialize") and not self._initialize:
- _warnings.warn(
- f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self
- )
+ warnings.warn(f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self)
try:
context = {"client": self, "message": self._DEL_MESSAGE}
# TODO: Change to get_running_loop() when dropping support for py3.6
@@ -408,7 +433,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
self.nodes_manager.default_node = node
- def set_response_callback(self, command: KeyT, callback: Callable) -> None:
+ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
"""Set a custom response callback."""
self.response_callbacks[command] = callback
@@ -430,7 +455,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
return key_slot(k)
async def _determine_nodes(
- self, *args, node_flag: Optional[str] = None
+ self, *args: Any, node_flag: Optional[str] = None
) -> List["ClusterNode"]:
command = args[0]
if not node_flag:
@@ -462,11 +487,11 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
)
]
- async def _determine_slot(self, *args) -> int:
+ async def _determine_slot(self, *args: Any) -> int:
command = args[0]
if self.command_flags.get(command) == SLOT_ID:
# The command contains the slot ID
- return args[1]
+ return int(args[1])
# Get the keys in the command
@@ -516,14 +541,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
return slots.pop()
- def _is_node_flag(
- self, target_nodes: Union[List["ClusterNode"], "ClusterNode", str]
- ) -> bool:
+ def _is_node_flag(self, target_nodes: Any) -> bool:
return isinstance(target_nodes, str) and target_nodes in self.node_flags
- def _parse_target_nodes(
- self, target_nodes: Union[List["ClusterNode"], "ClusterNode"]
- ) -> List["ClusterNode"]:
+ def _parse_target_nodes(self, target_nodes: Any) -> List["ClusterNode"]:
if isinstance(target_nodes, list):
nodes = target_nodes
elif isinstance(target_nodes, ClusterNode):
@@ -533,7 +554,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
# Supports dictionaries of the format {node_name: node}.
# It enables to execute commands with multi nodes as follows:
# rc.cluster_save_config(rc.get_primaries())
- nodes = target_nodes.values()
+ nodes = list(target_nodes.values())
else:
raise TypeError(
"target_nodes type can be one of the following: "
@@ -543,7 +564,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
)
return nodes
- async def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs) -> Any:
+ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
"""
Execute a raw command on the appropriate cluster node or target_nodes.
@@ -562,7 +583,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
can't be mapped to a slot
"""
command = args[0]
- target_nodes_specified = target_nodes = exception = None
+ target_nodes = []
+ target_nodes_specified = False
retry_attempts = self.cluster_error_retry_attempts
passed_targets = kwargs.pop("target_nodes", None)
@@ -571,7 +593,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
target_nodes_specified = True
retry_attempts = 1
- for _ in range(0, retry_attempts):
+ for _ in range(retry_attempts):
if self._initialize:
await self.initialize()
try:
@@ -622,9 +644,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
raise exception
async def _execute_command(
- self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs
+ self, target_node: "ClusterNode", *args: Union[KeyT, EncodableT], **kwargs: Any
) -> Any:
- redirect_addr = asking = moved = None
+ asking = moved = False
+ redirect_addr = None
ttl = self.RedisClusterRequestTTL
connection_error_retry_counter = 0
@@ -725,8 +748,8 @@ class ClusterNode:
server_type: Optional[str] = None,
max_connections: int = 2 ** 31,
connection_class: Type[Connection] = Connection,
- response_callbacks: Dict = RedisCluster.RESPONSE_CALLBACKS,
- **connection_kwargs,
+ response_callbacks: Dict[str, Any] = RedisCluster.RESPONSE_CALLBACKS,
+ **connection_kwargs: Any,
) -> None:
if host == "localhost":
host = socket.gethostbyname(host)
@@ -743,8 +766,8 @@ class ClusterNode:
self.connection_kwargs = connection_kwargs
self.response_callbacks = response_callbacks
- self._connections = []
- self._free = collections.deque(maxlen=self.max_connections)
+ self._connections: List[Connection] = []
+ self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
def __repr__(self) -> str:
return (
@@ -752,15 +775,15 @@ class ClusterNode:
f"name={self.name}, server_type={self.server_type}]"
)
- def __eq__(self, obj: "ClusterNode") -> bool:
+ def __eq__(self, obj: Any) -> bool:
return isinstance(obj, ClusterNode) and obj.name == self.name
_DEL_MESSAGE = "Unclosed ClusterNode object"
- def __del__(self, _warnings=warnings):
+ def __del__(self) -> None:
for connection in self._connections:
if connection.is_connected:
- _warnings.warn(
+ warnings.warn(
f"{self._DEL_MESSAGE} {self!r}", ResourceWarning, source=self
)
try:
@@ -783,7 +806,7 @@ class ClusterNode:
if exc:
raise exc
- async def execute_command(self, *args, **kwargs) -> Any:
+ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
connection = None
if self._free:
@@ -829,11 +852,11 @@ class ClusterNode:
class NodesManager:
__slots__ = (
"_moved_exception",
- "_require_full_coverage",
"connection_kwargs",
"default_node",
"nodes_cache",
"read_load_balancer",
+ "require_full_coverage",
"slots_cache",
"startup_nodes",
)
@@ -842,23 +865,24 @@ class NodesManager:
self,
startup_nodes: List["ClusterNode"],
require_full_coverage: bool = False,
- **kwargs,
+ **kwargs: Any,
) -> None:
- self.nodes_cache = {}
- self.slots_cache = {}
self.startup_nodes = {node.name: node for node in startup_nodes}
- self.default_node = None
- self._require_full_coverage = require_full_coverage
- self._moved_exception = None
+ self.require_full_coverage = require_full_coverage
self.connection_kwargs = kwargs
+
+ self.default_node: "ClusterNode" = None
+ self.nodes_cache: Dict[str, "ClusterNode"] = {}
+ self.slots_cache: Dict[int, List["ClusterNode"]] = {}
self.read_load_balancer = LoadBalancer()
+ self._moved_exception: MovedError = None
def get_node(
self,
host: Optional[str] = None,
port: Optional[int] = None,
node_name: Optional[str] = None,
- ) -> "ClusterNode":
+ ) -> Optional["ClusterNode"]:
if host and port:
# the user passed host and port
if host == "localhost":
@@ -877,20 +901,18 @@ class NodesManager:
self,
old: Dict[str, "ClusterNode"],
new: Dict[str, "ClusterNode"],
- remove_old=False,
+ remove_old: bool = False,
) -> None:
- tasks = []
if remove_old:
- tasks = [
- asyncio.ensure_future(node.disconnect())
- for name, node in old.items()
- if name not in new
- ]
+ for name in list(old.keys()):
+ if name not in new:
+ asyncio.ensure_future(old.pop(name).disconnect())
+
for name, node in new.items():
if name in old:
if old[name] is node:
continue
- tasks.append(asyncio.ensure_future(old[name].disconnect()))
+ asyncio.ensure_future(old[name].disconnect())
old[name] = node
def _update_moved_slots(self) -> None:
@@ -949,7 +971,7 @@ class NodesManager:
except (IndexError, TypeError):
raise SlotNotCoveredError(
f'Slot "{slot}" not covered by the cluster. '
- f'"require_full_coverage={self._require_full_coverage}"'
+ f'"require_full_coverage={self.require_full_coverage}"'
)
def get_nodes_by_server_type(self, server_type: str) -> List["ClusterNode"]:
@@ -961,8 +983,8 @@ class NodesManager:
async def initialize(self) -> None:
self.read_load_balancer.reset()
- tmp_nodes_cache = {}
- tmp_slots = {}
+ tmp_nodes_cache: Dict[str, "ClusterNode"] = {}
+ tmp_slots: Dict[int, List["ClusterNode"]] = {}
disagreements = []
startup_nodes_reachable = False
fully_covered = False
@@ -975,9 +997,7 @@ class NodesManager:
raise RedisClusterException(
"Cluster mode is not enabled on this node"
)
- cluster_slots = str_if_bytes(
- await startup_node.execute_command("CLUSTER SLOTS")
- )
+ cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
startup_nodes_reachable = True
except (ConnectionError, TimeoutError):
continue
@@ -1069,7 +1089,7 @@ class NodesManager:
# Validate if all slots are covered or if we should try next startup node
fully_covered = True
- for i in range(0, REDIS_CLUSTER_HASH_SLOTS):
+ for i in range(REDIS_CLUSTER_HASH_SLOTS):
if i not in tmp_slots:
fully_covered = False
break
@@ -1083,7 +1103,7 @@ class NodesManager:
)
# Check if the slots are not fully covered
- if not fully_covered and self._require_full_coverage:
+ if not fully_covered and self.require_full_coverage:
# Despite the requirement that the slots be covered, there
# isn't a full coverage
raise RedisClusterException(