diff options
Diffstat (limited to 'redis/asyncio/cluster.py')
-rw-r--r-- | redis/asyncio/cluster.py | 328 |
1 files changed, 312 insertions, 16 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 39aa536..3405a49 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,19 +3,32 @@ import collections import random import socket import warnings -from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union +from typing import ( + Any, + Deque, + Dict, + Generator, + List, + Mapping, + Optional, + Type, + TypeVar, + Union, +) from redis.asyncio.client import ResponseCallbackT from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url from redis.asyncio.parser import CommandsParser from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( + PIPELINE_BLOCKED_COMMANDS, PRIMARY, READ_COMMANDS, REPLICA, SLOT_ID, AbstractRedisCluster, LoadBalancer, + block_pipeline_command, get_node_name, parse_cluster_slots, ) @@ -37,8 +50,8 @@ from redis.exceptions import ( TimeoutError, TryAgainError, ) -from redis.typing import EncodableT, KeyT -from redis.utils import dict_merge, str_if_bytes +from redis.typing import AnyKeyT, EncodableT, KeyT +from redis.utils import dict_merge, safe_str, str_if_bytes TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -719,6 +732,24 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand raise ClusterError("TTL exhausted.") + def pipeline( + self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None + ) -> "ClusterPipeline": + """ + Create & return a new :class:`~.ClusterPipeline` object. + + Cluster implementation of pipeline does not support transaction or shard_hint. + + :raises RedisClusterException: if transaction or shard_hint are truthy values + """ + if shard_hint: + raise RedisClusterException("shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return ClusterPipeline(self) + class ClusterNode: """ @@ -729,6 +760,7 @@ class ClusterNode: """ __slots__ = ( + "_command_stack", "_connections", "_free", "connection_class", @@ -768,6 +800,7 @@ class ClusterNode: self._connections: List[Connection] = [] self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + self._command_stack: List["PipelineCommand"] = [] def __repr__(self) -> str: return ( @@ -806,27 +839,26 @@ class ClusterNode: if exc: raise exc - async def execute_command(self, *args: Any, **kwargs: Any) -> Any: - # Acquire connection - connection = None + def acquire_connection(self) -> Connection: if self._free: for _ in range(len(self._free)): connection = self._free.popleft() if connection.is_connected: - break + return connection self._free.append(connection) - else: - connection = self._free.popleft() + + return self._free.popleft() else: if len(self._connections) < self.max_connections: connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) + return connection else: raise ConnectionError("Too many connections") - # Execute command - command = connection.pack_command(*args) - await connection.send_packed_command(command, False) + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: try: if NEVER_DECODE in kwargs: response = await connection.read_response_without_lock( @@ -838,16 +870,49 @@ class ClusterNode: if EMPTY_RESPONSE in kwargs: return kwargs[EMPTY_RESPONSE] raise - finally: - # Release connection - self._free.append(connection) # Return response try: - return self.response_callbacks[args[0]](response, **kwargs) + return self.response_callbacks[command](response, **kwargs) except KeyError: return response + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: + # Acquire connection + connection = self.acquire_connection() + + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) + + # Read response + try: + return await self.parse_response(connection, args[0], **kwargs) + finally: + # Release connection + self._free.append(connection) + + async def execute_pipeline(self) -> None: + # Acquire connection + connection = self.acquire_connection() + + # Execute command + await connection.send_packed_command( + connection.pack_commands(cmd.args for cmd in self._command_stack), False + ) + + # Read responses + try: + for cmd in self._command_stack: + try: + cmd.result = await self.parse_response( + connection, cmd.args[0], **cmd.kwargs + ) + except Exception as e: + cmd.result = e + finally: + # Release connection + self._free.append(connection) + class NodesManager: __slots__ = ( @@ -1131,3 +1196,234 @@ class NodesManager: for node in getattr(self, attr).values() ) ) + + +class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new ClusterPipeline object. + + Usage:: + + result = await ( + rc.pipeline() + .set("A", 1) + .get("A") + .hset("K", "F", "V") + .hgetall("K") + .mset_nonatomic({"A": 2, "B": 3}) + .get("A") + .get("B") + .delete("A", "B", "K") + .execute() + ) + # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] + + Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which + are split across multiple nodes, you'll get multiple results for them in the array. + + Retryable errors: + - :class:`~.ClusterDownError` + - :class:`~.ConnectionError` + - :class:`~.TimeoutError` + + Redirection errors: + - :class:`~.TryAgainError` + - :class:`~.MovedError` + - :class:`~.AskError` + + :param client: + | Existing :class:`~.RedisCluster` client + """ + + __slots__ = ("_command_stack", "_client") + + def __init__(self, client: RedisCluster) -> None: + self._client = client + + self._command_stack: List["PipelineCommand"] = [] + + async def initialize(self) -> "ClusterPipeline": + if self._client._initialize: + await self._client.initialize() + self._command_stack = [] + return self + + async def __aenter__(self) -> "ClusterPipeline": + return await self.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_stack = [] + + def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: + return self.initialize().__await__() + + def __bool__(self) -> bool: + return bool(self._command_stack) + + def __len__(self) -> int: + return len(self._command_stack) + + def execute_command( + self, *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> "ClusterPipeline": + """ + Append a raw command to the pipeline. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + """ + self._command_stack.append( + PipelineCommand(len(self._command_stack), *args, **kwargs) + ) + return self + + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + """ + Execute the pipeline. + + It will retry the commands as specified by :attr:`cluster_error_retry_attempts` + & then raise an exception. + + :param raise_on_error: + | Raise the first error if there are any errors + :param allow_redirections: + | Whether to retry each failed command individually in case of redirection + errors + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot + """ + if not self._command_stack: + return [] + + try: + for _ in range(self._client.cluster_error_retry_attempts): + if self._client._initialize: + await self._client.initialize() + + try: + return await self._execute( + self._command_stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except BaseException as e: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # Try again with the new cluster setup. + exception = e + await self._client.close() + await asyncio.sleep(0.25) + else: + # All other errors should be raised. + raise e + + # If it fails the configured number of times then raise an exception + raise exception + finally: + self._command_stack = [] + + async def _execute( + self, + stack: List["PipelineCommand"], + raise_on_error: bool = True, + allow_redirections: bool = True, + ) -> List[Any]: + client = self._client + nodes = {} + for cmd in stack: + if not cmd.result or isinstance(cmd.result, Exception): + target_nodes = await client._determine_nodes(*cmd.args) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {cmd.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException( + f"Too many targets for command {cmd.args}" + ) + + node = target_nodes[0] + if node.name not in nodes: + nodes[node.name] = node + node._command_stack = [] + node._command_stack.append(cmd) + + await asyncio.gather( + *(asyncio.ensure_future(node.execute_pipeline()) for node in nodes.values()) + ) + + if allow_redirections: + # send each errored command individually + for cmd in stack: + if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): + try: + cmd.result = await client.execute_command( + *cmd.args, **cmd.kwargs + ) + except Exception as e: + cmd.result = e + + responses = [cmd.result for cmd in stack] + + if raise_on_error: + for cmd in stack: + result = cmd.result + if isinstance(result, Exception): + command = " ".join(map(safe_str, cmd.args)) + msg = ( + f"Command # {cmd.position + 1} ({command}) of pipeline " + f"caused error: {result.args}" + ) + result.args = (msg,) + result.args[1:] + raise result + + return responses + + def _split_command_across_slots( + self, command: str, *keys: KeyT + ) -> "ClusterPipeline": + for slot_keys in self._client._partition_keys_by_slot(keys).values(): + self.execute_command(command, *slot_keys) + + return self + + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + encoder = self._client.encoder + + slots_pairs = {} + for pair in mapping.items(): + slot = key_slot(encoder.encode(pair[0])) + slots_pairs.setdefault(slot, []).extend(pair) + + for pairs in slots_pairs.values(): + self.execute_command("MSET", *pairs) + + return self + + +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + if command == "mset_nonatomic": + continue + + setattr(ClusterPipeline, command, block_pipeline_command(command)) + + +class PipelineCommand: + def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.position = position + self.result: Union[Any, Exception] = None + + def __repr__(self) -> str: + return f"[{self.position}] {self.args} ({self.kwargs})" |