From 3748a8b36d5c765f5d21c6d20b041fa1876021ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 7 May 2023 19:33:14 +0000 Subject: Add RedisCluster.remap_host_port, Update tests for CWE 404 (#2706) * Use provided redis address. Bind to IPv4 * Add missing "await" and perform the correct test for pipe eimpty * Wait for a send event, rather than rely on sleep time. Excpect cancel errors. * set delay to 0 except for operation we want to cancel This speeds up the unit tests considerably by eliminating unnecessary delay. * Release resources in test * Fix cluster test to use address_remap and multiple proxies. * Use context manager to manage DelayProxy * Mark failing pipeline tests * lint * Use a common "master_host" test fixture --- tests/conftest.py | 2 +- tests/test_asyncio/conftest.py | 8 - tests/test_asyncio/test_cluster.py | 20 +- tests/test_asyncio/test_connection_pool.py | 10 +- tests/test_asyncio/test_cwe_404.py | 319 +++++++++++++++++++---------- tests/test_asyncio/test_sentinel.py | 2 +- tests/test_cluster.py | 21 +- 7 files changed, 227 insertions(+), 155 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 27dcc74..4cd4c3c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -441,7 +441,7 @@ def mock_cluster_resp_slaves(request, **kwargs): def master_host(request): url = request.config.getoption("--redis-url") parts = urlparse(url) - yield parts.hostname, parts.port + return parts.hostname, (parts.port or 6379) @pytest.fixture() diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6982cc8..121a13b 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,7 +1,6 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager from typing import Union -from urllib.parse import urlparse import pytest import pytest_asyncio @@ -209,13 +208,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): return _gen_cluster_mock_resp(r, response) -@pytest_asyncio.fixture(scope="session") -def master_host(request): - url = request.config.getoption("--redis-url") - parts = urlparse(url) - return parts.hostname - - async def wait_for_command( client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None ): diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 6d0aba7..2d6099f 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -102,18 +102,6 @@ class NodeProxy: await writer.drain() -@pytest.fixture -def redis_addr(request): - redis_url = request.config.getoption("--redis-url") - scheme, netloc = urlparse(redis_url)[:2] - assert scheme == "redis" - if ":" in netloc: - host, port = netloc.split(":") - return host, int(port) - else: - return netloc, 6379 - - @pytest_asyncio.fixture() async def slowlog(r: RedisCluster) -> None: """ @@ -874,7 +862,7 @@ class TestRedisClusterObj: # Rollback to the old default node r.replace_default_node(curr_default_node) - async def test_address_remap(self, create_redis, redis_addr): + async def test_address_remap(self, create_redis, master_host): """Test that we can create a rediscluster object with a host-port remapper and map connections through proxy objects """ @@ -882,7 +870,8 @@ class TestRedisClusterObj: # we remap the first n nodes offset = 1000 n = 6 - ports = [redis_addr[1] + i for i in range(n)] + hostname, master_port = master_host + ports = [master_port + i for i in range(n)] def address_remap(address): # remap first three nodes to our local proxy @@ -895,8 +884,7 @@ class TestRedisClusterObj: # create the proxies proxies = [ - NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) - for port in ports + NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports ] await asyncio.gather(*[p.start() for p in proxies]) try: diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index d1e52bd..92499e2 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -136,14 +136,14 @@ class TestConnectionPool: assert connection.kwargs == connection_kwargs async def test_multiple_connections(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") c2 = await pool.get_connection("_") assert c1 != c2 async def test_max_connections(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool( max_connections=2, connection_kwargs=connection_kwargs ) as pool: @@ -153,7 +153,7 @@ class TestConnectionPool: await pool.get_connection("_") async def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") await pool.release(c1) @@ -237,7 +237,7 @@ class TestBlockingConnectionPool: async def test_connection_pool_blocks_until_timeout(self, master_host): """When out of connections, block for timeout seconds, then raise""" - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs ) as pool: @@ -270,7 +270,7 @@ class TestBlockingConnectionPool: assert asyncio.get_running_loop().time() - start >= 0.1 async def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {"host": master_host} + connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: c1 = await pool.get_connection("_") await pool.release(c1) diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index dc62df6..d3a0666 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -1,147 +1,252 @@ import asyncio -import sys +import contextlib import pytest from redis.asyncio import Redis from redis.asyncio.cluster import RedisCluster - - -async def pipe( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name="" -): - while True: - data = await reader.read(1000) - if not data: - break - await asyncio.sleep(delay) - writer.write(data) - await writer.drain() +from redis.asyncio.connection import async_timeout class DelayProxy: - def __init__(self, addr, redis_addr, delay: float): + def __init__(self, addr, redis_addr, delay: float = 0.0): self.addr = addr self.redis_addr = redis_addr self.delay = delay + self.send_event = asyncio.Event() + self.server = None + self.task = None + + async def __aenter__(self): + await self.start() + return self + + async def __aexit__(self, *args): + await self.stop() async def start(self): - self.server = await asyncio.start_server(self.handle, *self.addr) - self.ROUTINE = asyncio.create_task(self.server.serve_forever()) + # test that we can connect to redis + async with async_timeout(2): + _, redis_writer = await asyncio.open_connection(*self.redis_addr) + redis_writer.close() + self.server = await asyncio.start_server( + self.handle, *self.addr, reuse_address=True + ) + self.task = asyncio.create_task(self.server.serve_forever()) + + @contextlib.contextmanager + def set_delay(self, delay: float = 0.0): + """ + Allow to override the delay for parts of tests which aren't time dependent, + to speed up execution. + """ + old_delay = self.delay + self.delay = delay + try: + yield + finally: + self.delay = old_delay async def handle(self, reader, writer): # establish connection to redis redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) - pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:")) - pipe2 = asyncio.create_task( - pipe(redis_reader, writer, self.delay, "from redis:") - ) - await asyncio.gather(pipe1, pipe2) + try: + pipe1 = asyncio.create_task( + self.pipe(reader, redis_writer, "to redis:", self.send_event) + ) + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:")) + await asyncio.gather(pipe1, pipe2) + finally: + redis_writer.close() async def stop(self): # clean up enough so that we can reuse the looper - self.ROUTINE.cancel() + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass loop = self.server.get_loop() await loop.shutdown_asyncgens() + async def pipe( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + name="", + event: asyncio.Event = None, + ): + while True: + data = await reader.read(1000) + if not data: + break + # print(f"{name} read {len(data)} delay {self.delay}") + if event: + event.set() + await asyncio.sleep(self.delay) + writer.write(data) + await writer.drain() + @pytest.mark.onlynoncluster @pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) -async def test_standalone(delay): +async def test_standalone(delay, master_host): # create a tcp socket proxy that relays data to Redis and back, # inserting 0.1 seconds of delay - dp = DelayProxy( - addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2 - ) - await dp.start() - - for b in [True, False]: - # note that we connect to proxy, rather than to Redis directly - async with Redis(host="localhost", port=5380, single_connection_client=b) as r: - - await r.set("foo", "foo") - await r.set("bar", "bar") - - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(delay) - t.cancel() - try: - await t - sys.stderr.write("try again, we did not cancel the task in time\n") - except asyncio.CancelledError: - sys.stderr.write( - "canceled task, connection is left open with unread response\n" - ) - - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" - - await dp.stop() - - + async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: + + for b in [True, False]: + # note that we connect to proxy, rather than to Redis directly + async with Redis( + host="127.0.0.1", port=5380, single_connection_client=b + ) as r: + + await r.set("foo", "foo") + await r.set("bar", "bar") + + async def op(r): + with dp.set_delay(delay * 2): + return await r.get( + "foo" + ) # <-- this is the operation we want to cancel + + dp.send_event.clear() + t = asyncio.create_task(op(r)) + # Wait until the task has sent, and then some, to make sure it has + # settled on the read. + await dp.send_event.wait() + await asyncio.sleep(0.01) # a little extra time for prudence + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # make sure that our previous request, cancelled while waiting for + # a repsponse, didn't leave the connection open andin a bad state + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" + + +@pytest.mark.xfail(reason="cancel does not cause disconnect") @pytest.mark.onlynoncluster @pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2]) -async def test_standalone_pipeline(delay): - dp = DelayProxy( - addr=("localhost", 5380), redis_addr=("localhost", 6379), delay=delay * 2 - ) - await dp.start() - for b in [True, False]: - async with Redis(host="localhost", port=5380, single_connection_client=b) as r: - await r.set("foo", "foo") - await r.set("bar", "bar") - - pipe = r.pipeline() - - pipe2 = r.pipeline() - pipe2.get("bar") - pipe2.ping() - pipe2.get("foo") - - t = asyncio.create_task(pipe.get("foo").execute()) - await asyncio.sleep(delay) - t.cancel() - - pipe.get("bar") - pipe.ping() - pipe.get("foo") - pipe.reset() - - assert await pipe.execute() is None - - # validating that the pipeline can be used as it could previously - pipe.get("bar") - pipe.ping() - pipe.get("foo") - assert await pipe.execute() == [b"bar", True, b"foo"] - assert await pipe2.execute() == [b"bar", True, b"foo"] - - await dp.stop() +async def test_standalone_pipeline(delay, master_host): + async with DelayProxy(addr=("127.0.0.1", 5380), redis_addr=master_host) as dp: + for b in [True, False]: + async with Redis( + host="127.0.0.1", port=5380, single_connection_client=b + ) as r: + await r.set("foo", "foo") + await r.set("bar", "bar") + + pipe = r.pipeline() + + pipe2 = r.pipeline() + pipe2.get("bar") + pipe2.ping() + pipe2.get("foo") + + async def op(pipe): + with dp.set_delay(delay * 2): + return await pipe.get( + "foo" + ).execute() # <-- this is the operation we want to cancel + + dp.send_event.clear() + t = asyncio.create_task(op(pipe)) + # wait until task has settled on the read + await dp.send_event.wait() + await asyncio.sleep(0.01) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + # we have now cancelled the pieline in the middle of a request, + # make sure that the connection is still usable + pipe.get("bar") + pipe.ping() + pipe.get("foo") + await pipe.reset() + + # check that the pipeline is empty after reset + assert await pipe.execute() == [] + + # validating that the pipeline can be used as it could previously + pipe.get("bar") + pipe.ping() + pipe.get("foo") + assert await pipe.execute() == [b"bar", True, b"foo"] + assert await pipe2.execute() == [b"bar", True, b"foo"] @pytest.mark.onlycluster -async def test_cluster(request): +async def test_cluster(master_host): + + delay = 0.1 + cluster_port = 6372 + remap_base = 7372 + n_nodes = 6 + hostname, _ = master_host + + def remap(address): + host, port = address + return host, remap_base + port - cluster_port + + proxies = [] + for i in range(n_nodes): + port = cluster_port + i + remapped = remap_base + i + forward_addr = hostname, port + proxy = DelayProxy(addr=("127.0.0.1", remapped), redis_addr=forward_addr) + proxies.append(proxy) + + def all_clear(): + for p in proxies: + p.send_event.clear() + + async def wait_for_send(): + asyncio.wait( + [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + ) - dp = DelayProxy(addr=("localhost", 5381), redis_addr=("localhost", 6372), delay=0.1) - await dp.start() + @contextlib.contextmanager + def set_delay(delay: float): + with contextlib.ExitStack() as stack: + for p in proxies: + stack.enter_context(p.set_delay(delay)) + yield + + async with contextlib.AsyncExitStack() as stack: + for p in proxies: + await stack.enter_async_context(p) + + with contextlib.closing( + RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", address_remap=remap + ) + ) as r: + await r.initialize() + await r.set("foo", "foo") + await r.set("bar", "bar") - r = RedisCluster.from_url("redis://localhost:5381") - await r.initialize() - await r.set("foo", "foo") - await r.set("bar", "bar") + async def op(r): + with set_delay(delay): + return await r.get("foo") - t = asyncio.create_task(r.get("foo")) - await asyncio.sleep(0.050) - t.cancel() - try: - await t - except asyncio.CancelledError: - pytest.fail("connection is left open with unread response") + all_clear() + t = asyncio.create_task(op(r)) + # Wait for whichever DelayProxy gets the request first + await wait_for_send() + await asyncio.sleep(0.01) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t - assert await r.get("bar") == b"bar" - assert await r.ping() - assert await r.get("foo") == b"foo" + # try a number of requests to excercise all the connections + async def doit(): + assert await r.get("bar") == b"bar" + assert await r.ping() + assert await r.get("foo") == b"foo" - await dp.stop() + await asyncio.gather(*[doit() for _ in range(10)]) diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py index 5a0533b..7866056 100644 --- a/tests/test_asyncio/test_sentinel.py +++ b/tests/test_asyncio/test_sentinel.py @@ -15,7 +15,7 @@ from redis.asyncio.sentinel import ( @pytest_asyncio.fixture(scope="module") def master_ip(master_host): - yield socket.gethostbyname(master_host) + yield socket.gethostbyname(master_host[0]) class SentinelTestClient: diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 1f037c9..8371cc5 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -8,7 +8,6 @@ import warnings from queue import LifoQueue, Queue from time import sleep from unittest.mock import DEFAULT, Mock, call, patch -from urllib.parse import urlparse import pytest @@ -125,18 +124,6 @@ class NodeProxy: self.server.shutdown() -@pytest.fixture -def redis_addr(request): - redis_url = request.config.getoption("--redis-url") - scheme, netloc = urlparse(redis_url)[:2] - assert scheme == "redis" - if ":" in netloc: - host, port = netloc.split(":") - return host, int(port) - else: - return netloc, 6379 - - @pytest.fixture() def slowlog(request, r): """ @@ -907,7 +894,7 @@ class TestRedisClusterObj: assert "myself" not in nodes.get(curr_default_node.name).get("flags") assert r.get_default_node() != curr_default_node - def test_address_remap(self, request, redis_addr): + def test_address_remap(self, request, master_host): """Test that we can create a rediscluster object with a host-port remapper and map connections through proxy objects """ @@ -915,7 +902,8 @@ class TestRedisClusterObj: # we remap the first n nodes offset = 1000 n = 6 - ports = [redis_addr[1] + i for i in range(n)] + hostname, master_port = master_host + ports = [master_port + i for i in range(n)] def address_remap(address): # remap first three nodes to our local proxy @@ -928,8 +916,7 @@ class TestRedisClusterObj: # create the proxies proxies = [ - NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) - for port in ports + NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports ] for p in proxies: p.start() -- cgit v1.2.1