summaryrefslogtreecommitdiff
path: root/tests/test_asyncio/test_cluster.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_asyncio/test_cluster.py')
-rw-r--r--tests/test_asyncio/test_cluster.py110
1 files changed, 109 insertions, 1 deletions
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 13e5e26..6d0aba7 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -11,7 +11,7 @@ import pytest_asyncio
from _pytest.fixtures import FixtureRequest
from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
-from redis.asyncio.connection import Connection, SSLConnection
+from redis.asyncio.connection import Connection, SSLConnection, async_timeout
from redis.asyncio.parser import CommandsParser
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
@@ -49,6 +49,71 @@ default_cluster_slots = [
]
+class NodeProxy:
+ """A class to proxy a node connection to a different port"""
+
+ def __init__(self, addr, redis_addr):
+ self.addr = addr
+ self.redis_addr = redis_addr
+ self.send_event = asyncio.Event()
+ self.server = None
+ self.task = None
+ self.n_connections = 0
+
+ async def start(self):
+ # test that we can connect to redis
+ async with async_timeout(2):
+ _, redis_writer = await asyncio.open_connection(*self.redis_addr)
+ redis_writer.close()
+ self.server = await asyncio.start_server(
+ self.handle, *self.addr, reuse_address=True
+ )
+ self.task = asyncio.create_task(self.server.serve_forever())
+
+ async def handle(self, reader, writer):
+ # establish connection to redis
+ redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
+ try:
+ self.n_connections += 1
+ pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
+ pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
+ await asyncio.gather(pipe1, pipe2)
+ finally:
+ redis_writer.close()
+
+ async def aclose(self):
+ self.task.cancel()
+ try:
+ await self.task
+ except asyncio.CancelledError:
+ pass
+ await self.server.wait_closed()
+
+ async def pipe(
+ self,
+ reader: asyncio.StreamReader,
+ writer: asyncio.StreamWriter,
+ ):
+ while True:
+ data = await reader.read(1000)
+ if not data:
+ break
+ writer.write(data)
+ await writer.drain()
+
+
+@pytest.fixture
+def redis_addr(request):
+ redis_url = request.config.getoption("--redis-url")
+ scheme, netloc = urlparse(redis_url)[:2]
+ assert scheme == "redis"
+ if ":" in netloc:
+ host, port = netloc.split(":")
+ return host, int(port)
+ else:
+ return netloc, 6379
+
+
@pytest_asyncio.fixture()
async def slowlog(r: RedisCluster) -> None:
"""
@@ -809,6 +874,49 @@ class TestRedisClusterObj:
# Rollback to the old default node
r.replace_default_node(curr_default_node)
+ async def test_address_remap(self, create_redis, redis_addr):
+ """Test that we can create a rediscluster object with
+ a host-port remapper and map connections through proxy objects
+ """
+
+ # we remap the first n nodes
+ offset = 1000
+ n = 6
+ ports = [redis_addr[1] + i for i in range(n)]
+
+ def address_remap(address):
+ # remap first three nodes to our local proxy
+ # old = host, port
+ host, port = address
+ if int(port) in ports:
+ host, port = "127.0.0.1", int(port) + offset
+ # print(f"{old} {host, port}")
+ return host, port
+
+ # create the proxies
+ proxies = [
+ NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
+ for port in ports
+ ]
+ await asyncio.gather(*[p.start() for p in proxies])
+ try:
+ # create cluster:
+ r = await create_redis(
+ cls=RedisCluster, flushdb=False, address_remap=address_remap
+ )
+ try:
+ assert await r.ping() is True
+ assert await r.set("byte_string", b"giraffe")
+ assert await r.get("byte_string") == b"giraffe"
+ finally:
+ await r.close()
+ finally:
+ await asyncio.gather(*[p.aclose() for p in proxies])
+
+ # verify that the proxies were indeed used
+ n_used = sum((1 if p.n_connections else 0) for p in proxies)
+ assert n_used > 1
+
class TestClusterRedisCommands:
"""