summaryrefslogtreecommitdiff
path: root/tests/test_asyncio/test_connection_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_asyncio/test_connection_pool.py')
-rw-r--r--tests/test_asyncio/test_connection_pool.py160
1 files changed, 87 insertions, 73 deletions
diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py
index 6c56558..c8eb918 100644
--- a/tests/test_asyncio/test_connection_pool.py
+++ b/tests/test_asyncio/test_connection_pool.py
@@ -7,6 +7,8 @@ import pytest
if sys.version_info[0:2] == (3, 6):
import pytest as pytest_asyncio
+
+ pytestmark = pytest.mark.asyncio
else:
import pytest_asyncio
@@ -15,10 +17,9 @@ from redis.asyncio.connection import Connection, to_bool
from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt
from .compat import mock
+from .conftest import asynccontextmanager
from .test_pubsub import wait_for_message
-pytestmark = pytest.mark.asyncio
-
@pytest.mark.onlynoncluster
class TestRedisAutoReleaseConnectionPool:
@@ -114,7 +115,8 @@ class DummyConnection(Connection):
class TestConnectionPool:
- def get_pool(
+ @asynccontextmanager
+ async def get_pool(
self,
connection_kwargs=None,
max_connections=None,
@@ -126,71 +128,77 @@ class TestConnectionPool:
max_connections=max_connections,
**connection_kwargs,
)
- return pool
+ try:
+ yield pool
+ finally:
+ await pool.disconnect(inuse_connections=True)
async def test_connection_creation(self):
connection_kwargs = {"foo": "bar", "biz": "baz"}
- pool = self.get_pool(
+ async with self.get_pool(
connection_kwargs=connection_kwargs, connection_class=DummyConnection
- )
- connection = await pool.get_connection("_")
- assert isinstance(connection, DummyConnection)
- assert connection.kwargs == connection_kwargs
+ ) as pool:
+ connection = await pool.get_connection("_")
+ assert isinstance(connection, DummyConnection)
+ assert connection.kwargs == connection_kwargs
async def test_multiple_connections(self, master_host):
connection_kwargs = {"host": master_host}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
- c1 = await pool.get_connection("_")
- c2 = await pool.get_connection("_")
- assert c1 != c2
+ async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
+ c1 = await pool.get_connection("_")
+ c2 = await pool.get_connection("_")
+ assert c1 != c2
async def test_max_connections(self, master_host):
connection_kwargs = {"host": master_host}
- pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs)
- await pool.get_connection("_")
- await pool.get_connection("_")
- with pytest.raises(redis.ConnectionError):
+ async with self.get_pool(
+ max_connections=2, connection_kwargs=connection_kwargs
+ ) as pool:
+ await pool.get_connection("_")
await pool.get_connection("_")
+ with pytest.raises(redis.ConnectionError):
+ await pool.get_connection("_")
async def test_reuse_previously_released_connection(self, master_host):
connection_kwargs = {"host": master_host}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
- c1 = await pool.get_connection("_")
- await pool.release(c1)
- c2 = await pool.get_connection("_")
- assert c1 == c2
+ async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
+ c1 = await pool.get_connection("_")
+ await pool.release(c1)
+ c2 = await pool.get_connection("_")
+ assert c1 == c2
- def test_repr_contains_db_info_tcp(self):
+ async def test_repr_contains_db_info_tcp(self):
connection_kwargs = {
"host": "localhost",
"port": 6379,
"db": 1,
"client_name": "test-client",
}
- pool = self.get_pool(
+ async with self.get_pool(
connection_kwargs=connection_kwargs, connection_class=redis.Connection
- )
- expected = (
- "ConnectionPool<Connection<"
- "host=localhost,port=6379,db=1,client_name=test-client>>"
- )
- assert repr(pool) == expected
+ ) as pool:
+ expected = (
+ "ConnectionPool<Connection<"
+ "host=localhost,port=6379,db=1,client_name=test-client>>"
+ )
+ assert repr(pool) == expected
- def test_repr_contains_db_info_unix(self):
+ async def test_repr_contains_db_info_unix(self):
connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"}
- pool = self.get_pool(
+ async with self.get_pool(
connection_kwargs=connection_kwargs,
connection_class=redis.UnixDomainSocketConnection,
- )
- expected = (
- "ConnectionPool<UnixDomainSocketConnection<"
- "path=/abc,db=1,client_name=test-client>>"
- )
- assert repr(pool) == expected
+ ) as pool:
+ expected = (
+ "ConnectionPool<UnixDomainSocketConnection<"
+ "path=/abc,db=1,client_name=test-client>>"
+ )
+ assert repr(pool) == expected
class TestBlockingConnectionPool:
- def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
+ @asynccontextmanager
+ async def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
connection_kwargs = connection_kwargs or {}
pool = redis.BlockingConnectionPool(
connection_class=DummyConnection,
@@ -198,7 +206,10 @@ class TestBlockingConnectionPool:
timeout=timeout,
**connection_kwargs,
)
- return pool
+ try:
+ yield pool
+ finally:
+ await pool.disconnect(inuse_connections=True)
async def test_connection_creation(self, master_host):
connection_kwargs = {
@@ -207,10 +218,10 @@ class TestBlockingConnectionPool:
"host": master_host[0],
"port": master_host[1],
}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
- connection = await pool.get_connection("_")
- assert isinstance(connection, DummyConnection)
- assert connection.kwargs == connection_kwargs
+ async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
+ connection = await pool.get_connection("_")
+ assert isinstance(connection, DummyConnection)
+ assert connection.kwargs == connection_kwargs
async def test_disconnect(self, master_host):
"""A regression test for #1047"""
@@ -220,30 +231,31 @@ class TestBlockingConnectionPool:
"host": master_host[0],
"port": master_host[1],
}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
- await pool.get_connection("_")
- await pool.disconnect()
+ async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
+ await pool.get_connection("_")
+ await pool.disconnect()
async def test_multiple_connections(self, master_host):
connection_kwargs = {"host": master_host[0], "port": master_host[1]}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
- c1 = await pool.get_connection("_")
- c2 = await pool.get_connection("_")
- assert c1 != c2
+ async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
+ c1 = await pool.get_connection("_")
+ c2 = await pool.get_connection("_")
+ assert c1 != c2
async def test_connection_pool_blocks_until_timeout(self, master_host):
"""When out of connections, block for timeout seconds, then raise"""
connection_kwargs = {"host": master_host}
- pool = self.get_pool(
+ async with self.get_pool(
max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs
- )
- await pool.get_connection("_")
+ ) as pool:
+ c1 = await pool.get_connection("_")
- start = asyncio.get_event_loop().time()
- with pytest.raises(redis.ConnectionError):
- await pool.get_connection("_")
- # we should have waited at least 0.1 seconds
- assert asyncio.get_event_loop().time() - start >= 0.1
+ start = asyncio.get_event_loop().time()
+ with pytest.raises(redis.ConnectionError):
+ await pool.get_connection("_")
+ # we should have waited at least 0.1 seconds
+ assert asyncio.get_event_loop().time() - start >= 0.1
+ await c1.disconnect()
async def test_connection_pool_blocks_until_conn_available(self, master_host):
"""
@@ -251,26 +263,26 @@ class TestBlockingConnectionPool:
to the pool
"""
connection_kwargs = {"host": master_host[0], "port": master_host[1]}
- pool = self.get_pool(
+ async with self.get_pool(
max_connections=1, timeout=2, connection_kwargs=connection_kwargs
- )
- c1 = await pool.get_connection("_")
+ ) as pool:
+ c1 = await pool.get_connection("_")
- async def target():
- await asyncio.sleep(0.1)
- await pool.release(c1)
+ async def target():
+ await asyncio.sleep(0.1)
+ await pool.release(c1)
- start = asyncio.get_event_loop().time()
- await asyncio.gather(target(), pool.get_connection("_"))
- assert asyncio.get_event_loop().time() - start >= 0.1
+ start = asyncio.get_event_loop().time()
+ await asyncio.gather(target(), pool.get_connection("_"))
+ assert asyncio.get_event_loop().time() - start >= 0.1
async def test_reuse_previously_released_connection(self, master_host):
connection_kwargs = {"host": master_host}
- pool = self.get_pool(connection_kwargs=connection_kwargs)
- c1 = await pool.get_connection("_")
- await pool.release(c1)
- c2 = await pool.get_connection("_")
- assert c1 == c2
+ async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
+ c1 = await pool.get_connection("_")
+ await pool.release(c1)
+ c2 = await pool.get_connection("_")
+ assert c1 == c2
def test_repr_contains_db_info_tcp(self):
pool = redis.ConnectionPool(
@@ -689,6 +701,8 @@ class TestHealthCheck:
if r.connection:
await r.get("foo")
next_health_check = r.connection.next_health_check
+ # ensure that the event loop's `time()` advances a bit
+ await asyncio.sleep(0.001)
await r.get("foo")
assert next_health_check < r.connection.next_health_check