From bac33d4a92892ca7982b461347151bff5a661f0d Mon Sep 17 00:00:00 2001 From: Utkarsh Gupta Date: Mon, 30 May 2022 21:45:45 +0530 Subject: async_cluster: add pipeline support (#2199) Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- docs/connections.rst | 5 + redis/asyncio/cluster.py | 328 ++++++++++++++++++++++++++++++++++-- redis/cluster.py | 114 +++++++++---- tests/test_asyncio/test_cluster.py | 254 +++++++++++++++++++++++++++- tests/test_asyncio/test_pipeline.py | 7 +- 5 files changed, 652 insertions(+), 56 deletions(-) diff --git a/docs/connections.rst b/docs/connections.rst index e4b82cd..b481689 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -76,6 +76,11 @@ ClusterNode (Async) .. autoclass:: redis.asyncio.cluster.ClusterNode :members: +ClusterPipeline (Async) +=================== +.. autoclass:: redis.asyncio.cluster.ClusterPipeline + :members: + Connection ********** diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 39aa536..3405a49 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -3,19 +3,32 @@ import collections import random import socket import warnings -from typing import Any, Deque, Dict, Generator, List, Optional, Type, TypeVar, Union +from typing import ( + Any, + Deque, + Dict, + Generator, + List, + Mapping, + Optional, + Type, + TypeVar, + Union, +) from redis.asyncio.client import ResponseCallbackT from redis.asyncio.connection import Connection, DefaultParser, Encoder, parse_url from redis.asyncio.parser import CommandsParser from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( + PIPELINE_BLOCKED_COMMANDS, PRIMARY, READ_COMMANDS, REPLICA, SLOT_ID, AbstractRedisCluster, LoadBalancer, + block_pipeline_command, get_node_name, parse_cluster_slots, ) @@ -37,8 +50,8 @@ from redis.exceptions import ( TimeoutError, TryAgainError, ) -from redis.typing import EncodableT, KeyT -from redis.utils import dict_merge, str_if_bytes +from redis.typing import AnyKeyT, EncodableT, KeyT +from redis.utils import dict_merge, safe_str, str_if_bytes TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -719,6 +732,24 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand raise ClusterError("TTL exhausted.") + def pipeline( + self, transaction: Optional[Any] = None, shard_hint: Optional[Any] = None + ) -> "ClusterPipeline": + """ + Create & return a new :class:`~.ClusterPipeline` object. + + Cluster implementation of pipeline does not support transaction or shard_hint. + + :raises RedisClusterException: if transaction or shard_hint are truthy values + """ + if shard_hint: + raise RedisClusterException("shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return ClusterPipeline(self) + class ClusterNode: """ @@ -729,6 +760,7 @@ class ClusterNode: """ __slots__ = ( + "_command_stack", "_connections", "_free", "connection_class", @@ -768,6 +800,7 @@ class ClusterNode: self._connections: List[Connection] = [] self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) + self._command_stack: List["PipelineCommand"] = [] def __repr__(self) -> str: return ( @@ -806,27 +839,26 @@ class ClusterNode: if exc: raise exc - async def execute_command(self, *args: Any, **kwargs: Any) -> Any: - # Acquire connection - connection = None + def acquire_connection(self) -> Connection: if self._free: for _ in range(len(self._free)): connection = self._free.popleft() if connection.is_connected: - break + return connection self._free.append(connection) - else: - connection = self._free.popleft() + + return self._free.popleft() else: if len(self._connections) < self.max_connections: connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) + return connection else: raise ConnectionError("Too many connections") - # Execute command - command = connection.pack_command(*args) - await connection.send_packed_command(command, False) + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: try: if NEVER_DECODE in kwargs: response = await connection.read_response_without_lock( @@ -838,16 +870,49 @@ class ClusterNode: if EMPTY_RESPONSE in kwargs: return kwargs[EMPTY_RESPONSE] raise - finally: - # Release connection - self._free.append(connection) # Return response try: - return self.response_callbacks[args[0]](response, **kwargs) + return self.response_callbacks[command](response, **kwargs) except KeyError: return response + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: + # Acquire connection + connection = self.acquire_connection() + + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) + + # Read response + try: + return await self.parse_response(connection, args[0], **kwargs) + finally: + # Release connection + self._free.append(connection) + + async def execute_pipeline(self) -> None: + # Acquire connection + connection = self.acquire_connection() + + # Execute command + await connection.send_packed_command( + connection.pack_commands(cmd.args for cmd in self._command_stack), False + ) + + # Read responses + try: + for cmd in self._command_stack: + try: + cmd.result = await self.parse_response( + connection, cmd.args[0], **cmd.kwargs + ) + except Exception as e: + cmd.result = e + finally: + # Release connection + self._free.append(connection) + class NodesManager: __slots__ = ( @@ -1131,3 +1196,234 @@ class NodesManager: for node in getattr(self, attr).values() ) ) + + +class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): + """ + Create a new ClusterPipeline object. + + Usage:: + + result = await ( + rc.pipeline() + .set("A", 1) + .get("A") + .hset("K", "F", "V") + .hgetall("K") + .mset_nonatomic({"A": 2, "B": 3}) + .get("A") + .get("B") + .delete("A", "B", "K") + .execute() + ) + # result = [True, "1", 1, {"F": "V"}, True, True, "2", "3", 1, 1, 1] + + Note: For commands `DELETE`, `EXISTS`, `TOUCH`, `UNLINK`, `mset_nonatomic`, which + are split across multiple nodes, you'll get multiple results for them in the array. + + Retryable errors: + - :class:`~.ClusterDownError` + - :class:`~.ConnectionError` + - :class:`~.TimeoutError` + + Redirection errors: + - :class:`~.TryAgainError` + - :class:`~.MovedError` + - :class:`~.AskError` + + :param client: + | Existing :class:`~.RedisCluster` client + """ + + __slots__ = ("_command_stack", "_client") + + def __init__(self, client: RedisCluster) -> None: + self._client = client + + self._command_stack: List["PipelineCommand"] = [] + + async def initialize(self) -> "ClusterPipeline": + if self._client._initialize: + await self._client.initialize() + self._command_stack = [] + return self + + async def __aenter__(self) -> "ClusterPipeline": + return await self.initialize() + + async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_stack = [] + + def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: + return self.initialize().__await__() + + def __bool__(self) -> bool: + return bool(self._command_stack) + + def __len__(self) -> int: + return len(self._command_stack) + + def execute_command( + self, *args: Union[KeyT, EncodableT], **kwargs: Any + ) -> "ClusterPipeline": + """ + Append a raw command to the pipeline. + + :param args: + | Raw command args + :param kwargs: + + - target_nodes: :attr:`NODE_FLAGS` or :class:`~.ClusterNode` + or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] + - Rest of the kwargs are passed to the Redis connection + """ + self._command_stack.append( + PipelineCommand(len(self._command_stack), *args, **kwargs) + ) + return self + + async def execute( + self, raise_on_error: bool = True, allow_redirections: bool = True + ) -> List[Any]: + """ + Execute the pipeline. + + It will retry the commands as specified by :attr:`cluster_error_retry_attempts` + & then raise an exception. + + :param raise_on_error: + | Raise the first error if there are any errors + :param allow_redirections: + | Whether to retry each failed command individually in case of redirection + errors + + :raises RedisClusterException: if target_nodes is not provided & the command + can't be mapped to a slot + """ + if not self._command_stack: + return [] + + try: + for _ in range(self._client.cluster_error_retry_attempts): + if self._client._initialize: + await self._client.initialize() + + try: + return await self._execute( + self._command_stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except BaseException as e: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # Try again with the new cluster setup. + exception = e + await self._client.close() + await asyncio.sleep(0.25) + else: + # All other errors should be raised. + raise e + + # If it fails the configured number of times then raise an exception + raise exception + finally: + self._command_stack = [] + + async def _execute( + self, + stack: List["PipelineCommand"], + raise_on_error: bool = True, + allow_redirections: bool = True, + ) -> List[Any]: + client = self._client + nodes = {} + for cmd in stack: + if not cmd.result or isinstance(cmd.result, Exception): + target_nodes = await client._determine_nodes(*cmd.args) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {cmd.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException( + f"Too many targets for command {cmd.args}" + ) + + node = target_nodes[0] + if node.name not in nodes: + nodes[node.name] = node + node._command_stack = [] + node._command_stack.append(cmd) + + await asyncio.gather( + *(asyncio.ensure_future(node.execute_pipeline()) for node in nodes.values()) + ) + + if allow_redirections: + # send each errored command individually + for cmd in stack: + if isinstance(cmd.result, (TryAgainError, MovedError, AskError)): + try: + cmd.result = await client.execute_command( + *cmd.args, **cmd.kwargs + ) + except Exception as e: + cmd.result = e + + responses = [cmd.result for cmd in stack] + + if raise_on_error: + for cmd in stack: + result = cmd.result + if isinstance(result, Exception): + command = " ".join(map(safe_str, cmd.args)) + msg = ( + f"Command # {cmd.position + 1} ({command}) of pipeline " + f"caused error: {result.args}" + ) + result.args = (msg,) + result.args[1:] + raise result + + return responses + + def _split_command_across_slots( + self, command: str, *keys: KeyT + ) -> "ClusterPipeline": + for slot_keys in self._client._partition_keys_by_slot(keys).values(): + self.execute_command(command, *slot_keys) + + return self + + def mset_nonatomic( + self, mapping: Mapping[AnyKeyT, EncodableT] + ) -> "ClusterPipeline": + encoder = self._client.encoder + + slots_pairs = {} + for pair in mapping.items(): + slot = key_slot(encoder.encode(pair[0])) + slots_pairs.setdefault(slot, []).extend(pair) + + for pairs in slots_pairs.values(): + self.execute_command("MSET", *pairs) + + return self + + +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + if command == "mset_nonatomic": + continue + + setattr(ClusterPipeline, command, block_pipeline_command(command)) + + +class PipelineCommand: + def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.position = position + self.result: Union[Any, Exception] = None + + def __repr__(self) -> str: + return f"[{self.position}] {self.args} ({self.kwargs})" diff --git a/redis/cluster.py b/redis/cluster.py index 46a96a6..2e31063 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,7 +6,7 @@ import sys import threading import time from collections import OrderedDict -from typing import Any, Dict, Tuple +from typing import Any, Callable, Dict, Tuple from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan from redis.commands import CommandsParser, RedisClusterCommands @@ -2130,7 +2130,7 @@ class ClusterPipeline(RedisCluster): return self.execute_command("DEL", names[0]) -def block_pipeline_command(func): +def block_pipeline_command(name: str) -> Callable[..., Any]: """ Prints error because some pipelined commands should be blocked when running in cluster-mode @@ -2138,7 +2138,7 @@ def block_pipeline_command(func): def inner(*args, **kwargs): raise RedisClusterException( - f"ERROR: Calling pipelined function {func.__name__} is blocked " + f"ERROR: Calling pipelined function {name} is blocked " f"when running redis in cluster mode..." ) @@ -2146,39 +2146,81 @@ def block_pipeline_command(func): # Blocked pipeline commands -ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) -ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) -ClusterPipeline.client_getname = block_pipeline_command(RedisCluster.client_getname) -ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) -ClusterPipeline.client_setname = block_pipeline_command(RedisCluster.client_setname) -ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) -ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) -ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) -ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb) -ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys) -ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget) -ClusterPipeline.move = block_pipeline_command(RedisCluster.move) -ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset) -ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx) -ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge) -ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount) -ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping) -ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish) -ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey) -ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename) -ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx) -ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush) -ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan) -ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff) -ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore) -ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter) -ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore) -ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove) -ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) -ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) -ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) -ClusterPipeline.readwrite = block_pipeline_command(RedisCluster.readwrite) -ClusterPipeline.readonly = block_pipeline_command(RedisCluster.readonly) +PIPELINE_BLOCKED_COMMANDS = ( + "BGREWRITEAOF", + "BGSAVE", + "BITOP", + "BRPOPLPUSH", + "CLIENT GETNAME", + "CLIENT KILL", + "CLIENT LIST", + "CLIENT SETNAME", + "CLIENT", + "CONFIG GET", + "CONFIG RESETSTAT", + "CONFIG REWRITE", + "CONFIG SET", + "CONFIG", + "DBSIZE", + "ECHO", + "EVALSHA", + "FLUSHALL", + "FLUSHDB", + "INFO", + "KEYS", + "LASTSAVE", + "MGET", + "MGET NONATOMIC", + "MOVE", + "MSET", + "MSET NONATOMIC", + "MSETNX", + "PFCOUNT", + "PFMERGE", + "PING", + "PUBLISH", + "RANDOMKEY", + "READONLY", + "READWRITE", + "RENAME", + "RENAMENX", + "RPOPLPUSH", + "SAVE", + "SCAN", + "SCRIPT EXISTS", + "SCRIPT FLUSH", + "SCRIPT KILL", + "SCRIPT LOAD", + "SCRIPT", + "SDIFF", + "SDIFFSTORE", + "SENTINEL GET MASTER ADDR BY NAME", + "SENTINEL MASTER", + "SENTINEL MASTERS", + "SENTINEL MONITOR", + "SENTINEL REMOVE", + "SENTINEL SENTINELS", + "SENTINEL SET", + "SENTINEL SLAVES", + "SENTINEL", + "SHUTDOWN", + "SINTER", + "SINTERSTORE", + "SLAVEOF", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "SLOWLOG", + "SMOVE", + "SORT", + "SUNION", + "SUNIONSTORE", + "TIME", +) +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + + setattr(ClusterPipeline, command, block_pipeline_command(command)) class PipelineCommand: diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 123adc8..0c676cb 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -19,7 +19,7 @@ from _pytest.fixtures import FixtureRequest, SubRequest from redis.asyncio import Connection, RedisCluster from redis.asyncio.cluster import ClusterNode, NodesManager from redis.asyncio.parser import CommandsParser -from redis.cluster import PRIMARY, REPLICA, get_node_name +from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -129,6 +129,16 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: return node +def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: + connection = mock.AsyncMock() + connection.is_connected = True + connection.read_response_without_lock.side_effect = exc + while node._free: + node._free.pop() + node._free.append(connection) + return node + + def mock_all_nodes_resp(rc: RedisCluster, response: Any) -> RedisCluster: for node in rc.get_nodes(): mock_node_resp(node, response) @@ -2218,3 +2228,245 @@ class TestNodesManager: async with RedisCluster(startup_nodes=[node_1, node_2]) as rc: assert rc.get_node(host=default_host, port=7001) is not None assert rc.get_node(host=default_host, port=7002) is not None + + +@pytest.mark.onlycluster +class TestClusterPipeline: + """Tests for the ClusterPipeline class.""" + + async def test_blocked_arguments(self, r: RedisCluster) -> None: + """Test handling for blocked pipeline arguments.""" + with pytest.raises(RedisClusterException) as ex: + r.pipeline(transaction=True) + + assert str(ex.value) == "transaction is deprecated in cluster mode" + + with pytest.raises(RedisClusterException) as ex: + r.pipeline(shard_hint=True) + + assert str(ex.value) == "shard_hint is deprecated in cluster mode" + + async def test_blocked_methods(self, r: RedisCluster) -> None: + """Test handling for blocked pipeline commands.""" + pipeline = r.pipeline() + for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + if command == "mset_nonatomic": + continue + + with pytest.raises(RedisClusterException) as exc: + getattr(pipeline, command)() + + assert str(exc.value) == ( + f"ERROR: Calling pipelined function {command} is blocked " + "when running redis in cluster mode..." + ) + + async def test_empty_stack(self, r: RedisCluster) -> None: + """If a pipeline is executed with no commands it should return a empty list.""" + p = r.pipeline() + result = await p.execute() + assert result == [] + + async def test_redis_cluster_pipeline(self, r: RedisCluster) -> None: + """Test that we can use a pipeline with the RedisCluster class""" + result = await ( + r.pipeline() + .set("A", 1) + .get("A") + .hset("K", "F", "V") + .hgetall("K") + .mset_nonatomic({"A": 2, "B": 3}) + .get("A") + .get("B") + .delete("A", "B", "K") + .execute() + ) + assert result == [True, b"1", 1, {b"F": b"V"}, True, True, b"2", b"3", 1, 1, 1] + + async def test_multi_key_operation_with_a_single_slot( + self, r: RedisCluster + ) -> None: + """Test multi key operation with a single slot.""" + pipe = r.pipeline() + pipe.set("a{foo}", 1) + pipe.set("b{foo}", 2) + pipe.set("c{foo}", 3) + pipe.get("a{foo}") + pipe.get("b{foo}") + pipe.get("c{foo}") + + res = await pipe.execute() + assert res == [True, True, True, b"1", b"2", b"3"] + + async def test_multi_key_operation_with_multi_slots(self, r: RedisCluster) -> None: + """Test multi key operation with more than one slot.""" + pipe = r.pipeline() + pipe.set("a{foo}", 1) + pipe.set("b{foo}", 2) + pipe.set("c{foo}", 3) + pipe.set("bar", 4) + pipe.set("bazz", 5) + pipe.get("a{foo}") + pipe.get("b{foo}") + pipe.get("c{foo}") + pipe.get("bar") + pipe.get("bazz") + res = await pipe.execute() + assert res == [True, True, True, True, True, b"1", b"2", b"3", b"4", b"5"] + + async def test_cluster_down_error(self, r: RedisCluster) -> None: + """ + Test that the pipeline retries cluster_error_retry_attempts times before raising + an error. + """ + key = "foo" + node = r.get_node_from_key(key, False) + + parse_response_orig = node.parse_response + with mock.patch.object( + ClusterNode, "parse_response", autospec=True + ) as parse_response_mock: + + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: + if command == "GET": + raise ClusterDownError("error") + return await parse_response_orig(connection, command, **kwargs) + + parse_response_mock.side_effect = parse_response + + # For each ClusterDownError, we launch 4 commands: INFO, CLUSTER SLOTS, + # COMMAND, GET. Before any errors, the first 3 commands are already run + async with r.pipeline() as pipe: + with pytest.raises(ClusterDownError): + await pipe.get(key).execute() + + assert ( + node.parse_response.await_count + == 4 * r.cluster_error_retry_attempts - 3 + ) + + async def test_connection_error_not_raised(self, r: RedisCluster) -> None: + """Test ConnectionError handling with raise_on_error=False.""" + key = "foo" + node = r.get_node_from_key(key, False) + + parse_response_orig = node.parse_response + with mock.patch.object( + ClusterNode, "parse_response", autospec=True + ) as parse_response_mock: + + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: + if command == "GET": + raise ConnectionError("error") + return await parse_response_orig(connection, command, **kwargs) + + parse_response_mock.side_effect = parse_response + + async with r.pipeline() as pipe: + res = await pipe.get(key).get(key).execute(raise_on_error=False) + assert node.parse_response.await_count + assert isinstance(res[0], ConnectionError) + + async def test_connection_error_raised(self, r: RedisCluster) -> None: + """Test ConnectionError handling with raise_on_error=True.""" + key = "foo" + node = r.get_node_from_key(key, False) + + parse_response_orig = node.parse_response + with mock.patch.object( + ClusterNode, "parse_response", autospec=True + ) as parse_response_mock: + + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: + if command == "GET": + raise ConnectionError("error") + return await parse_response_orig(connection, command, **kwargs) + + parse_response_mock.side_effect = parse_response + + async with r.pipeline() as pipe: + with pytest.raises(ConnectionError): + await pipe.get(key).get(key).execute(raise_on_error=True) + + async def test_asking_error(self, r: RedisCluster) -> None: + """Test AskError handling.""" + key = "foo" + first_node = r.get_node_from_key(key, False) + ask_node = None + for node in r.get_nodes(): + if node != first_node: + ask_node = node + break + ask_msg = f"{r.keyslot(key)} {ask_node.host}:{ask_node.port}" + + async with r.pipeline() as pipe: + mock_node_resp_exc(first_node, AskError(ask_msg)) + mock_node_resp(ask_node, "MOCK_OK") + res = await pipe.get(key).execute() + assert first_node._free.pop().read_response_without_lock.await_count + assert ask_node._free.pop().read_response_without_lock.await_count + assert res == ["MOCK_OK"] + + async def test_moved_redirection_on_slave_with_default( + self, r: RedisCluster + ) -> None: + """Test MovedError handling.""" + key = "foo" + await r.set("foo", "bar") + # set read_from_replicas to True + r.read_from_replicas = True + primary = r.get_node_from_key(key, False) + moved_error = f"{r.keyslot(key)} {primary.host}:{primary.port}" + + parse_response_orig = primary.parse_response + with mock.patch.object( + ClusterNode, "parse_response", autospec=True + ) as parse_response_mock: + + async def parse_response( + self, connection: Connection, command: str, **kwargs: Any + ) -> Any: + if ( + command == "GET" + and self.host != primary.host + and self.port != primary.port + ): + raise MovedError(moved_error) + + return await parse_response_orig(connection, command, **kwargs) + + parse_response_mock.side_effect = parse_response + + async with r.pipeline() as readwrite_pipe: + assert r.reinitialize_counter == 0 + readwrite_pipe.get(key).get(key) + assert r.reinitialize_counter == 0 + assert await readwrite_pipe.execute() == [b"bar", b"bar"] + + async def test_readonly_pipeline_from_readonly_client( + self, r: RedisCluster + ) -> None: + """Test that the pipeline uses replicas for read_from_replicas clients.""" + # Create a cluster with reading from replications + r.read_from_replicas = True + key = "bar" + await r.set(key, "foo") + + async with r.pipeline() as pipe: + mock_all_nodes_resp(r, "MOCK_OK") + assert await pipe.get(key).get(key).execute() == ["MOCK_OK", "MOCK_OK"] + slot_nodes = r.nodes_manager.slots_cache[r.keyslot(key)] + executed_on_replica = False + for node in slot_nodes: + if node.server_type == REPLICA: + if node._free.pop().read_response_without_lock.await_count: + executed_on_replica = True + break + assert executed_on_replica diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index 50a1051..dfeb664 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -8,8 +8,8 @@ from .conftest import wait_for_command pytestmark = pytest.mark.asyncio -@pytest.mark.onlynoncluster class TestPipeline: + @pytest.mark.onlynoncluster async def test_pipeline_is_true(self, r): """Ensure pipeline instances are not false-y""" async with r.pipeline() as pipe: @@ -52,7 +52,6 @@ class TestPipeline: await pipe.execute() assert len(pipe) == 0 - @pytest.mark.onlynoncluster async def test_pipeline_no_transaction(self, r): async with r.pipeline(transaction=False) as pipe: pipe.set("a", "a1").set("b", "b1").set("c", "c1") @@ -61,6 +60,7 @@ class TestPipeline: assert await r.get("b") == b"b1" assert await r.get("c") == b"c1" + @pytest.mark.onlynoncluster async def test_pipeline_no_transaction_watch(self, r): await r.set("a", 0) @@ -72,6 +72,7 @@ class TestPipeline: pipe.set("a", int(a) + 1) assert await pipe.execute() == [True] + @pytest.mark.onlynoncluster async def test_pipeline_no_transaction_watch_failure(self, r): await r.set("a", 0) @@ -375,7 +376,7 @@ class TestPipeline: async def test_pipeline_get(self, r): await r.set("a", "a1") async with r.pipeline() as pipe: - await pipe.get("a") + pipe.get("a") assert await pipe.execute() == [b"a1"] @pytest.mark.onlynoncluster -- cgit v1.2.1