diff options
Diffstat (limited to 'tests/test_asyncio/test_cluster.py')
-rw-r--r-- | tests/test_asyncio/test_cluster.py | 110 |
1 files changed, 109 insertions, 1 deletions
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 13e5e26..6d0aba7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -11,7 +11,7 @@ import pytest_asyncio from _pytest.fixtures import FixtureRequest from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster -from redis.asyncio.connection import Connection, SSLConnection +from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff @@ -49,6 +49,71 @@ default_cluster_slots = [ ] +class NodeProxy: + """A class to proxy a node connection to a different port""" + + def __init__(self, addr, redis_addr): + self.addr = addr + self.redis_addr = redis_addr + self.send_event = asyncio.Event() + self.server = None + self.task = None + self.n_connections = 0 + + async def start(self): + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + async def handle(self, reader, writer): + # establish connection to redis + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) + try: + self.n_connections += 1 + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() + + async def aclose(self): + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + await self.server.wait_closed() + + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ): + while True: + data = await reader.read(1000) + if not data: + break + writer.write(data) + await writer.drain() + + +@pytest.fixture +def redis_addr(request): + redis_url = request.config.getoption("--redis-url") + scheme, netloc = urlparse(redis_url)[:2] + assert scheme == "redis" + if ":" in netloc: + host, port = netloc.split(":") + return host, int(port) + else: + return netloc, 6379 + + @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -809,6 +874,49 @@ class TestRedisClusterObj: # Rollback to the old default node r.replace_default_node(curr_default_node) + async def test_address_remap(self, create_redis, redis_addr): + """Test that we can create a rediscluster object with + a host-port remapper and map connections through proxy objects + """ + + # we remap the first n nodes + offset = 1000 + n = 6 + ports = [redis_addr[1] + i for i in range(n)] + + def address_remap(address): + # remap first three nodes to our local proxy + # old = host, port + host, port = address + if int(port) in ports: + host, port = "127.0.0.1", int(port) + offset + # print(f"{old} {host, port}") + return host, port + + # create the proxies + proxies = [ + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) + for port in ports + ] + await asyncio.gather(*[p.start() for p in proxies]) + try: + # create cluster: + r = await create_redis( + cls=RedisCluster, flushdb=False, address_remap=address_remap + ) + try: + assert await r.ping() is True + assert await r.set("byte_string", b"giraffe") + assert await r.get("byte_string") == b"giraffe" + finally: + await r.close() + finally: + await asyncio.gather(*[p.aclose() for p in proxies]) + + # verify that the proxies were indeed used + n_used = sum((1 if p.n_connections else 0) for p in proxies) + assert n_used > 1 + class TestClusterRedisCommands: """ |