diff options
Diffstat (limited to 'tests/test_asyncio/test_connection_pool.py')
-rw-r--r-- | tests/test_asyncio/test_connection_pool.py | 160 |
1 files changed, 87 insertions, 73 deletions
diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 6c56558..c8eb918 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -7,6 +7,8 @@ import pytest if sys.version_info[0:2] == (3, 6): import pytest as pytest_asyncio + + pytestmark = pytest.mark.asyncio else: import pytest_asyncio @@ -15,10 +17,9 @@ from redis.asyncio.connection import Connection, to_bool from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt from .compat import mock +from .conftest import asynccontextmanager from .test_pubsub import wait_for_message -pytestmark = pytest.mark.asyncio - @pytest.mark.onlynoncluster class TestRedisAutoReleaseConnectionPool: @@ -114,7 +115,8 @@ class DummyConnection(Connection): class TestConnectionPool: - def get_pool( + @asynccontextmanager + async def get_pool( self, connection_kwargs=None, max_connections=None, @@ -126,71 +128,77 @@ class TestConnectionPool: max_connections=max_connections, **connection_kwargs, ) - return pool + try: + yield pool + finally: + await pool.disconnect(inuse_connections=True) async def test_connection_creation(self): connection_kwargs = {"foo": "bar", "biz": "baz"} - pool = self.get_pool( + async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection - ) - connection = await pool.get_connection("_") - assert isinstance(connection, DummyConnection) - assert connection.kwargs == connection_kwargs + ) as pool: + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") - assert c1 != c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 async def test_max_connections(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) - await pool.get_connection("_") - await pool.get_connection("_") - with pytest.raises(redis.ConnectionError): + async with self.get_pool( + max_connections=2, connection_kwargs=connection_kwargs + ) as pool: + await pool.get_connection("_") await pool.get_connection("_") + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - await pool.release(c1) - c2 = await pool.get_connection("_") - assert c1 == c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 - def test_repr_contains_db_info_tcp(self): + async def test_repr_contains_db_info_tcp(self): connection_kwargs = { "host": "localhost", "port": 6379, "db": 1, "client_name": "test-client", } - pool = self.get_pool( + async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=redis.Connection - ) - expected = ( - "ConnectionPool<Connection<" - "host=localhost,port=6379,db=1,client_name=test-client>>" - ) - assert repr(pool) == expected + ) as pool: + expected = ( + "ConnectionPool<Connection<" + "host=localhost,port=6379,db=1,client_name=test-client>>" + ) + assert repr(pool) == expected - def test_repr_contains_db_info_unix(self): + async def test_repr_contains_db_info_unix(self): connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} - pool = self.get_pool( + async with self.get_pool( connection_kwargs=connection_kwargs, connection_class=redis.UnixDomainSocketConnection, - ) - expected = ( - "ConnectionPool<UnixDomainSocketConnection<" - "path=/abc,db=1,client_name=test-client>>" - ) - assert repr(pool) == expected + ) as pool: + expected = ( + "ConnectionPool<UnixDomainSocketConnection<" + "path=/abc,db=1,client_name=test-client>>" + ) + assert repr(pool) == expected class TestBlockingConnectionPool: - def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + @asynccontextmanager + async def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): connection_kwargs = connection_kwargs or {} pool = redis.BlockingConnectionPool( connection_class=DummyConnection, @@ -198,7 +206,10 @@ class TestBlockingConnectionPool: timeout=timeout, **connection_kwargs, ) - return pool + try: + yield pool + finally: + await pool.disconnect(inuse_connections=True) async def test_connection_creation(self, master_host): connection_kwargs = { @@ -207,10 +218,10 @@ class TestBlockingConnectionPool: "host": master_host[0], "port": master_host[1], } - pool = self.get_pool(connection_kwargs=connection_kwargs) - connection = await pool.get_connection("_") - assert isinstance(connection, DummyConnection) - assert connection.kwargs == connection_kwargs + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs async def test_disconnect(self, master_host): """A regression test for #1047""" @@ -220,30 +231,31 @@ class TestBlockingConnectionPool: "host": master_host[0], "port": master_host[1], } - pool = self.get_pool(connection_kwargs=connection_kwargs) - await pool.get_connection("_") - await pool.disconnect() + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + await pool.get_connection("_") + await pool.disconnect() async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - c2 = await pool.get_connection("_") - assert c1 != c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 async def test_connection_pool_blocks_until_timeout(self, master_host): """When out of connections, block for timeout seconds, then raise""" connection_kwargs = {"host": master_host} - pool = self.get_pool( + async with self.get_pool( max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs - ) - await pool.get_connection("_") + ) as pool: + c1 = await pool.get_connection("_") - start = asyncio.get_event_loop().time() - with pytest.raises(redis.ConnectionError): - await pool.get_connection("_") - # we should have waited at least 0.1 seconds - assert asyncio.get_event_loop().time() - start >= 0.1 + start = asyncio.get_event_loop().time() + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") + # we should have waited at least 0.1 seconds + assert asyncio.get_event_loop().time() - start >= 0.1 + await c1.disconnect() async def test_connection_pool_blocks_until_conn_available(self, master_host): """ @@ -251,26 +263,26 @@ class TestBlockingConnectionPool: to the pool """ connection_kwargs = {"host": master_host[0], "port": master_host[1]} - pool = self.get_pool( + async with self.get_pool( max_connections=1, timeout=2, connection_kwargs=connection_kwargs - ) - c1 = await pool.get_connection("_") + ) as pool: + c1 = await pool.get_connection("_") - async def target(): - await asyncio.sleep(0.1) - await pool.release(c1) + async def target(): + await asyncio.sleep(0.1) + await pool.release(c1) - start = asyncio.get_event_loop().time() - await asyncio.gather(target(), pool.get_connection("_")) - assert asyncio.get_event_loop().time() - start >= 0.1 + start = asyncio.get_event_loop().time() + await asyncio.gather(target(), pool.get_connection("_")) + assert asyncio.get_event_loop().time() - start >= 0.1 async def test_reuse_previously_released_connection(self, master_host): connection_kwargs = {"host": master_host} - pool = self.get_pool(connection_kwargs=connection_kwargs) - c1 = await pool.get_connection("_") - await pool.release(c1) - c2 = await pool.get_connection("_") - assert c1 == c2 + async with self.get_pool(connection_kwargs=connection_kwargs) as pool: + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 def test_repr_contains_db_info_tcp(self): pool = redis.ConnectionPool( @@ -689,6 +701,8 @@ class TestHealthCheck: if r.connection: await r.get("foo") next_health_check = r.connection.next_health_check + # ensure that the event loop's `time()` advances a bit + await asyncio.sleep(0.001) await r.get("foo") assert next_health_check < r.connection.next_health_check |