summaryrefslogtreecommitdiff
path: root/redis/asyncio/cluster.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/asyncio/cluster.py')
-rw-r--r--redis/asyncio/cluster.py328
1 files changed, 312 insertions, 16 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 39aa536..3405a49 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -3,19 +3,32 @@ import collections
import random
import socket
import warnings
-from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union
+from typing import (
+ Any,
+ Deque,
+ Dict,
+ Generator,
+ List,
+ Mapping,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+)
from redis.asyncio.client import ResponseCallbackT
from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url
from redis.asyncio.parser import CommandsParser
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
+ PIPELINE_BLOCKED_COMMANDS,
PRIMARY,
READ_COMMANDS,
REPLICA,
SLOT_ID,
AbstractRedisCluster,
LoadBalancer,
+ block_pipeline_command,
get_node_name,
parse_cluster_slots,
)
@@ -37,8 +50,8 @@ from redis.exceptions import (
TimeoutError,
TryAgainError,
)
-from redis.typing import EncodableT, KeyT
-from redis.utils import dict_merge, str_if_bytes
+from redis.typing import AnyKeyT, EncodableT, KeyT
+from redis.utils import dict_merge, safe_str, str_if_bytes
TargetNodesT = TypeVar(
"TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"]
@@ -719,6 +732,24 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
raise ClusterError("TTL exhausted.")
+ def pipeline(
+ self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None
+ ) -> "ClusterPipeline":
+ """
+ Create & return a new :class:`~.ClusterPipeline` object.
+
+ Cluster implementation of pipeline does not support transaction or shard_hint.
+
+ :raises RedisClusterException: if transaction or shard_hint are truthy values
+ """
+ if shard_hint:
+ raise RedisClusterException("shard_hint is deprecated in cluster mode")
+
+ if transaction:
+ raise RedisClusterException("transaction is deprecated in cluster mode")
+
+ return ClusterPipeline(self)
+
class ClusterNode:
"""
@@ -729,6 +760,7 @@ class ClusterNode:
"""
__slots__ = (
+ "_command_stack",
"_connections",
"_free",
"connection_class",
@@ -768,6 +800,7 @@ class ClusterNode:
self._connections: List[Connection] = []
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
+ self._command_stack: List["PipelineCommand"] = []
def __repr__(self) -> str:
return (
@@ -806,27 +839,26 @@ class ClusterNode:
if exc:
raise exc
- async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
- # Acquire connection
- connection = None
+ def acquire_connection(self) -> Connection:
if self._free:
for _ in range(len(self._free)):
connection = self._free.popleft()
if connection.is_connected:
- break
+ return connection
self._free.append(connection)
- else:
- connection = self._free.popleft()
+
+ return self._free.popleft()
else:
if len(self._connections) < self.max_connections:
connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
+ return connection
else:
raise ConnectionError("Too many connections")
- # Execute command
- command = connection.pack_command(*args)
- await connection.send_packed_command(command, False)
+ async def parse_response(
+ self, connection: Connection, command: str, **kwargs: Any
+ ) -> Any:
try:
if NEVER_DECODE in kwargs:
response = await connection.read_response_without_lock(
@@ -838,16 +870,49 @@ class ClusterNode:
if EMPTY_RESPONSE in kwargs:
return kwargs[EMPTY_RESPONSE]
raise
- finally:
- # Release connection
- self._free.append(connection)
# Return response
try:
- return self.response_callbacks[args[0]](response, **kwargs)
+ return self.response_callbacks[command](response, **kwargs)
except KeyError:
return response
+ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
+ # Acquire connection
+ connection = self.acquire_connection()
+
+ # Execute command
+ await connection.send_packed_command(connection.pack_command(*args), False)
+
+ # Read response
+ try:
+ return await self.parse_response(connection, args[0], **kwargs)
+ finally:
+ # Release connection
+ self._free.append(connection)
+
+ async def execute_pipeline(self) -> None:
+ # Acquire connection
+ connection = self.acquire_connection()
+
+ # Execute command
+ await connection.send_packed_command(
+ connection.pack_commands(cmd.args for cmd in self._command_stack), False
+ )
+
+ # Read responses
+ try:
+ for cmd in self._command_stack:
+ try:
+ cmd.result = await self.parse_response(
+ connection, cmd.args[0], **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+ finally:
+ # Release connection
+ self._free.append(connection)
+
class NodesManager:
__slots__ = (
@@ -1131,3 +1196,234 @@ class NodesManager:
for node in getattr(self, attr).values()
)
)
+
+
+class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
+ """
+ Create a new ClusterPipeline object.
+
+ Usage::
+
+ result = await (
+ rc.pipeline()
+ .set("A", 1)
+ .get("A")
+ .hset("K", "F", "V")
+ .hgetall("K")
+ .mset_nonatomic({"A": 2, "B": 3})
+ .get("A")
+ .get("B")
+ .delete("A", "B", "K")
+ .execute()
+ )
+ # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1]
+
+ Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which
+ are split across multiple nodes, you'll get multiple results for them in the array.
+
+ Retryable errors:
+ - :class:`~.ClusterDownError`
+ - :class:`~.ConnectionError`
+ - :class:`~.TimeoutError`
+
+ Redirection errors:
+ - :class:`~.TryAgainError`
+ - :class:`~.MovedError`
+ - :class:`~.AskError`
+
+ :param client:
+ | Existing :class:`~.RedisCluster` client
+ """
+
+ __slots__ = ("_command_stack", "_client")
+
+ def __init__(self, client: RedisCluster) -> None:
+ self._client = client
+
+ self._command_stack: List["PipelineCommand"] = []
+
+ async def initialize(self) -> "ClusterPipeline":
+ if self._client._initialize:
+ await self._client.initialize()
+ self._command_stack = []
+ return self
+
+ async def __aenter__(self) -> "ClusterPipeline":
+ return await self.initialize()
+
+ async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
+ self._command_stack = []
+
+ def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
+ return self.initialize().__await__()
+
+ def __bool__(self) -> bool:
+ return bool(self._command_stack)
+
+ def __len__(self) -> int:
+ return len(self._command_stack)
+
+ def execute_command(
+ self, *args: Union[KeyT, EncodableT], **kwargs: Any
+ ) -> "ClusterPipeline":
+ """
+ Append a raw command to the pipeline.
+
+ :param args:
+ | Raw command args
+ :param kwargs:
+
+ - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode`
+ or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
+ - Rest of the kwargs are passed to the Redis connection
+ """
+ self._command_stack.append(
+ PipelineCommand(len(self._command_stack), *args, **kwargs)
+ )
+ return self
+
+ async def execute(
+ self, raise_on_error: bool = True, allow_redirections: bool = True
+ ) -> List[Any]:
+ """
+ Execute the pipeline.
+
+ It will retry the commands as specified by :attr:`cluster_error_retry_attempts`
+ & then raise an exception.
+
+ :param raise_on_error:
+ | Raise the first error if there are any errors
+ :param allow_redirections:
+ | Whether to retry each failed command individually in case of redirection
+ errors
+
+ :raises RedisClusterException: if target_nodes is not provided & the command
+ can't be mapped to a slot
+ """
+ if not self._command_stack:
+ return []
+
+ try:
+ for _ in range(self._client.cluster_error_retry_attempts):
+ if self._client._initialize:
+ await self._client.initialize()
+
+ try:
+ return await self._execute(
+ self._command_stack,
+ raise_on_error=raise_on_error,
+ allow_redirections=allow_redirections,
+ )
+ except BaseException as e:
+ if type(e) in self.__class__.ERRORS_ALLOW_RETRY:
+ # Try again with the new cluster setup.
+ exception = e
+ await self._client.close()
+ await asyncio.sleep(0.25)
+ else:
+ # All other errors should be raised.
+ raise e
+
+ # If it fails the configured number of times then raise an exception
+ raise exception
+ finally:
+ self._command_stack = []
+
+ async def _execute(
+ self,
+ stack: List["PipelineCommand"],
+ raise_on_error: bool = True,
+ allow_redirections: bool = True,
+ ) -> List[Any]:
+ client = self._client
+ nodes = {}
+ for cmd in stack:
+ if not cmd.result or isinstance(cmd.result, Exception):
+ target_nodes = await client._determine_nodes(*cmd.args)
+ if not target_nodes:
+ raise RedisClusterException(
+ f"No targets were found to execute {cmd.args} command on"
+ )
+ if len(target_nodes) > 1:
+ raise RedisClusterException(
+ f"Too many targets for command {cmd.args}"
+ )
+
+ node = target_nodes[0]
+ if node.name not in nodes:
+ nodes[node.name] = node
+ node._command_stack = []
+ node._command_stack.append(cmd)
+
+ await asyncio.gather(
+ *(asyncio.ensure_future(node.execute_pipeline()) for node in nodes.values())
+ )
+
+ if allow_redirections:
+ # send each errored command individually
+ for cmd in stack:
+ if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
+ try:
+ cmd.result = await client.execute_command(
+ *cmd.args, **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+
+ responses = [cmd.result for cmd in stack]
+
+ if raise_on_error:
+ for cmd in stack:
+ result = cmd.result
+ if isinstance(result, Exception):
+ command = " ".join(map(safe_str, cmd.args))
+ msg = (
+ f"Command # {cmd.position + 1} ({command}) of pipeline "
+ f"caused error: {result.args}"
+ )
+ result.args = (msg,) + result.args[1:]
+ raise result
+
+ return responses
+
+ def _split_command_across_slots(
+ self, command: str, *keys: KeyT
+ ) -> "ClusterPipeline":
+ for slot_keys in self._client._partition_keys_by_slot(keys).values():
+ self.execute_command(command, *slot_keys)
+
+ return self
+
+ def mset_nonatomic(
+ self, mapping: Mapping[AnyKeyT, EncodableT]
+ ) -> "ClusterPipeline":
+ encoder = self._client.encoder
+
+ slots_pairs = {}
+ for pair in mapping.items():
+ slot = key_slot(encoder.encode(pair[0]))
+ slots_pairs.setdefault(slot, []).extend(pair)
+
+ for pairs in slots_pairs.values():
+ self.execute_command("MSET", *pairs)
+
+ return self
+
+
+for command in PIPELINE_BLOCKED_COMMANDS:
+ command = command.replace(" ", "_").lower()
+ if command == "mset_nonatomic":
+ continue
+
+ setattr(ClusterPipeline, command, block_pipeline_command(command))
+
+
+class PipelineCommand:
+ def __init__(self, position: int, *args: Any, **kwargs: Any) -> None:
+ self.args = args
+ self.kwargs = kwargs
+ self.position = position
+ self.result: Union[Any, Exception] = None
+
+ def __repr__(self) -> str:
+ return f"[{self.position}] {self.args} ({self.kwargs})"