summaryrefslogtreecommitdiff
path: root/redis/asyncio/cluster.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/asyncio/cluster.py')
-rw-r--r--redis/asyncio/cluster.py165
1 files changed, 86 insertions, 79 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 3405a49..b53f848 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -468,9 +468,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
return key_slot(k)
async def _determine_nodes(
- self, *args: Any, node_flag: Optional[str] = None
+ self, command: str, *args: Any, node_flag: Optional[str] = None
) -> List["ClusterNode"]:
- command = args[0]
if not node_flag:
# get the nodes group for this command if it was predefined
node_flag = self.command_flags.get(command)
@@ -495,16 +494,15 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
# get the node that holds the key's slot
return [
self.nodes_manager.get_node_from_slot(
- await self._determine_slot(*args),
+ await self._determine_slot(command, *args),
self.read_from_replicas and command in READ_COMMANDS,
)
]
- async def _determine_slot(self, *args: Any) -> int:
- command = args[0]
+ async def _determine_slot(self, command: str, *args: Any) -> int:
if self.command_flags.get(command) == SLOT_ID:
# The command contains the slot ID
- return int(args[1])
+ return int(args[0])
# Get the keys in the command
@@ -516,19 +514,17 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
# - fix: https://github.com/redis/redis/pull/9733
if command in ("EVAL", "EVALSHA"):
# command syntax: EVAL "script body" num_keys ...
- if len(args) <= 2:
- raise RedisClusterException(f"Invalid args in command: {args}")
- num_actual_keys = args[2]
- eval_keys = args[3 : 3 + num_actual_keys]
+ if len(args) < 2:
+ raise RedisClusterException(
+ f"Invalid args in command: {command, *args}"
+ )
+ keys = args[2 : 2 + args[1]]
# if there are 0 keys, that means the script can be run on any node
# so we can just return a random slot
- if not eval_keys:
+ if not keys:
return random.randrange(0, REDIS_CLUSTER_HASH_SLOTS)
- keys = eval_keys
else:
- keys = await self.commands_parser.get_keys(
- self.nodes_manager.default_node, *args
- )
+ keys = await self.commands_parser.get_keys(command, *args)
if not keys:
# FCALL can call a function with 0 keys, that means the function
# can be run on any node so we can just return a random slot
@@ -848,13 +844,13 @@ class ClusterNode:
self._free.append(connection)
return self._free.popleft()
- else:
- if len(self._connections) < self.max_connections:
- connection = self.connection_class(**self.connection_kwargs)
- self._connections.append(connection)
- return connection
- else:
- raise ConnectionError("Too many connections")
+
+ if len(self._connections) < self.max_connections:
+ connection = self.connection_class(**self.connection_kwargs)
+ self._connections.append(connection)
+ return connection
+
+ raise ConnectionError("Too many connections")
async def parse_response(
self, connection: Connection, command: str, **kwargs: Any
@@ -872,10 +868,10 @@ class ClusterNode:
raise
# Return response
- try:
+ if command in self.response_callbacks:
return self.response_callbacks[command](response, **kwargs)
- except KeyError:
- return response
+
+ return response
async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
@@ -891,7 +887,7 @@ class ClusterNode:
# Release connection
self._free.append(connection)
- async def execute_pipeline(self) -> None:
+ async def execute_pipeline(self) -> bool:
# Acquire connection
connection = self.acquire_connection()
@@ -901,17 +897,20 @@ class ClusterNode:
)
# Read responses
- try:
- for cmd in self._command_stack:
- try:
- cmd.result = await self.parse_response(
- connection, cmd.args[0], **cmd.kwargs
- )
- except Exception as e:
- cmd.result = e
- finally:
- # Release connection
- self._free.append(connection)
+ ret = False
+ for cmd in self._command_stack:
+ try:
+ cmd.result = await self.parse_response(
+ connection, cmd.args[0], **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+ ret = True
+
+ # Release connection
+ self._free.append(connection)
+
+ return ret
class NodesManager:
@@ -1257,6 +1256,13 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
def __await__(self) -> Generator[Any, None, "ClusterPipeline"]:
return self.initialize().__await__()
+ def __enter__(self) -> "ClusterPipeline":
+ self._command_stack = []
+ return self
+
+ def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None:
+ self._command_stack = []
+
def __bool__(self) -> bool:
return bool(self._command_stack)
@@ -1310,6 +1316,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
try:
return await self._execute(
+ self._client,
self._command_stack,
raise_on_error=raise_on_error,
allow_redirections=allow_redirections,
@@ -1331,60 +1338,60 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
async def _execute(
self,
+ client: "RedisCluster",
stack: List["PipelineCommand"],
raise_on_error: bool = True,
allow_redirections: bool = True,
) -> List[Any]:
- client = self._client
+ todo = [
+ cmd for cmd in stack if not cmd.result or isinstance(cmd.result, Exception)
+ ]
+
nodes = {}
- for cmd in stack:
- if not cmd.result or isinstance(cmd.result, Exception):
- target_nodes = await client._determine_nodes(*cmd.args)
- if not target_nodes:
- raise RedisClusterException(
- f"No targets were found to execute {cmd.args} command on"
- )
- if len(target_nodes) > 1:
- raise RedisClusterException(
- f"Too many targets for command {cmd.args}"
- )
+ for cmd in todo:
+ target_nodes = await client._determine_nodes(*cmd.args)
+ if not target_nodes:
+ raise RedisClusterException(
+ f"No targets were found to execute {cmd.args} command on"
+ )
+ if len(target_nodes) > 1:
+ raise RedisClusterException(f"Too many targets for command {cmd.args}")
- node = target_nodes[0]
- if node.name not in nodes:
- nodes[node.name] = node
- node._command_stack = []
- node._command_stack.append(cmd)
+ node = target_nodes[0]
+ if node.name not in nodes:
+ nodes[node.name] = node
+ node._command_stack = []
+ node._command_stack.append(cmd)
- await asyncio.gather(
+ errors = await asyncio.gather(
*(asyncio.ensure_future(node.execute_pipeline()) for node in nodes.values())
)
- if allow_redirections:
- # send each errored command individually
- for cmd in stack:
- if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
- try:
- cmd.result = await client.execute_command(
- *cmd.args, **cmd.kwargs
+ if any(errors):
+ if allow_redirections:
+ # send each errored command individually
+ for cmd in todo:
+ if isinstance(cmd.result, (TryAgainError, MovedError, AskError)):
+ try:
+ cmd.result = await client.execute_command(
+ *cmd.args, **cmd.kwargs
+ )
+ except Exception as e:
+ cmd.result = e
+
+ if raise_on_error:
+ for cmd in todo:
+ result = cmd.result
+ if isinstance(result, Exception):
+ command = " ".join(map(safe_str, cmd.args))
+ msg = (
+ f"Command # {cmd.position + 1} ({command}) of pipeline "
+ f"caused error: {result.args}"
)
- except Exception as e:
- cmd.result = e
-
- responses = [cmd.result for cmd in stack]
-
- if raise_on_error:
- for cmd in stack:
- result = cmd.result
- if isinstance(result, Exception):
- command = " ".join(map(safe_str, cmd.args))
- msg = (
- f"Command # {cmd.position + 1} ({command}) of pipeline "
- f"caused error: {result.args}"
- )
- result.args = (msg,) + result.args[1:]
- raise result
+ result.args = (msg,) + result.args[1:]
+ raise result
- return responses
+ return [cmd.result for cmd in stack]
def _split_command_across_slots(
self, command: str, *keys: KeyT