summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorAndrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com>2022-02-22 05:29:55 -0500
committerGitHub <noreply@github.com>2022-02-22 12:29:55 +0200
commitd56baeb683fc1935cfa343fa2eeb0fa9bd955283 (patch)
tree47357a74bf1d1428cfbcf0d8b2c781f1f971cf77 /tests
parente3c989d93e914e6502bd5a72f15ded49a135c5be (diff)
downloadredis-py-d56baeb683fc1935cfa343fa2eeb0fa9bd955283.tar.gz
Add Async Support (#1899)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py79
-rw-r--r--tests/test_asyncio/__init__.py0
-rw-r--r--tests/test_asyncio/compat.py6
-rw-r--r--tests/test_asyncio/conftest.py205
-rw-r--r--tests/test_asyncio/test_commands.py3
-rw-r--r--tests/test_asyncio/test_connection.py64
-rw-r--r--tests/test_asyncio/test_connection_pool.py884
-rw-r--r--tests/test_asyncio/test_encoding.py126
-rw-r--r--tests/test_asyncio/test_lock.py242
-rw-r--r--tests/test_asyncio/test_monitor.py67
-rw-r--r--tests/test_asyncio/test_pipeline.py409
-rw-r--r--tests/test_asyncio/test_pubsub.py660
-rw-r--r--tests/test_asyncio/test_retry.py70
-rw-r--r--tests/test_asyncio/test_scripting.py159
-rw-r--r--tests/test_asyncio/test_sentinel.py249
15 files changed, 3216 insertions, 7 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index d9de876..2534ca0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,7 @@
+import argparse
import random
import time
+from typing import Callable, TypeVar
from unittest.mock import Mock
from urllib.parse import urlparse
@@ -21,6 +23,54 @@ default_redis_unstable_url = "redis://localhost:6378"
default_redis_ssl_url = "rediss://localhost:6666"
default_cluster_nodes = 6
+_DecoratedTest = TypeVar("_DecoratedTest", bound="Callable")
+_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest]
+
+
+# Taken from python3.9
+class BooleanOptionalAction(argparse.Action):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=None,
+ type=None,
+ choices=None,
+ required=False,
+ help=None,
+ metavar=None,
+ ):
+
+ _option_strings = []
+ for option_string in option_strings:
+ _option_strings.append(option_string)
+
+ if option_string.startswith("--"):
+ option_string = "--no-" + option_string[2:]
+ _option_strings.append(option_string)
+
+ if help is not None and default is not None:
+ help += f" (default: {default})"
+
+ super().__init__(
+ option_strings=_option_strings,
+ dest=dest,
+ nargs=0,
+ default=default,
+ type=type,
+ choices=choices,
+ required=required,
+ help=help,
+ metavar=metavar,
+ )
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ if option_string in self.option_strings:
+ setattr(namespace, self.dest, not option_string.startswith("--no-"))
+
+ def format_usage(self):
+ return " | ".join(self.option_strings)
+
def pytest_addoption(parser):
parser.addoption(
@@ -62,6 +112,9 @@ def pytest_addoption(parser):
help="Redis unstable (latest version) connection string "
"defaults to %(default)s`",
)
+ parser.addoption(
+ "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop"
+ )
def _get_info(redis_url):
@@ -101,6 +154,18 @@ def pytest_sessionstart(session):
cluster_nodes = session.config.getoption("--redis-cluster-nodes")
wait_for_cluster_creation(redis_url, cluster_nodes)
+ use_uvloop = session.config.getoption("--uvloop")
+
+ if use_uvloop:
+ try:
+ import uvloop
+
+ uvloop.install()
+ except ImportError as e:
+ raise RuntimeError(
+ "Can not import uvloop, make sure it is installed"
+ ) from e
+
def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60):
"""
@@ -133,19 +198,19 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60):
)
-def skip_if_server_version_lt(min_version):
+def skip_if_server_version_lt(min_version: str) -> _TestDecorator:
redis_version = REDIS_INFO["version"]
check = Version(redis_version) < Version(min_version)
return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}")
-def skip_if_server_version_gte(min_version):
+def skip_if_server_version_gte(min_version: str) -> _TestDecorator:
redis_version = REDIS_INFO["version"]
check = Version(redis_version) >= Version(min_version)
return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}")
-def skip_unless_arch_bits(arch_bits):
+def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator:
return pytest.mark.skipif(
REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit"
)
@@ -169,17 +234,17 @@ def skip_ifmodversion_lt(min_version: str, module_name: str):
raise AttributeError(f"No redis module named {module_name}")
-def skip_if_redis_enterprise():
+def skip_if_redis_enterprise() -> _TestDecorator:
check = REDIS_INFO["enterprise"] is True
return pytest.mark.skipif(check, reason="Redis enterprise")
-def skip_ifnot_redis_enterprise():
+def skip_ifnot_redis_enterprise() -> _TestDecorator:
check = REDIS_INFO["enterprise"] is False
return pytest.mark.skipif(check, reason="Not running in redis enterprise")
-def skip_if_nocryptography():
+def skip_if_nocryptography() -> _TestDecorator:
try:
import cryptography # noqa
@@ -188,7 +253,7 @@ def skip_if_nocryptography():
return pytest.mark.skipif(True, reason="No cryptography dependency")
-def skip_if_cryptography():
+def skip_if_cryptography() -> _TestDecorator:
try:
import cryptography # noqa
diff --git a/tests/test_asyncio/__init__.py b/tests/test_asyncio/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/test_asyncio/__init__.py
diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py
new file mode 100644
index 0000000..ced4974
--- /dev/null
+++ b/tests/test_asyncio/compat.py
@@ -0,0 +1,6 @@
+from unittest import mock
+
+try:
+ mock.AsyncMock
+except AttributeError:
+ import mock
diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py
new file mode 100644
index 0000000..0e9c73e
--- /dev/null
+++ b/tests/test_asyncio/conftest.py
@@ -0,0 +1,205 @@
+import asyncio
+import random
+import sys
+from typing import Union
+from urllib.parse import urlparse
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+import pytest
+from packaging.version import Version
+
+import redis.asyncio as redis
+from redis.asyncio.client import Monitor
+from redis.asyncio.connection import (
+ HIREDIS_AVAILABLE,
+ HiredisParser,
+ PythonParser,
+ parse_url,
+)
+from tests.conftest import REDIS_INFO
+
+from .compat import mock
+
+
+async def _get_info(redis_url):
+ client = redis.Redis.from_url(redis_url)
+ info = await client.info()
+ await client.connection_pool.disconnect()
+ return info
+
+
+@pytest_asyncio.fixture(
+ params=[
+ (True, PythonParser),
+ (False, PythonParser),
+ pytest.param(
+ (True, HiredisParser),
+ marks=pytest.mark.skipif(
+ not HIREDIS_AVAILABLE, reason="hiredis is not installed"
+ ),
+ ),
+ pytest.param(
+ (False, HiredisParser),
+ marks=pytest.mark.skipif(
+ not HIREDIS_AVAILABLE, reason="hiredis is not installed"
+ ),
+ ),
+ ],
+ ids=[
+ "single-python-parser",
+ "pool-python-parser",
+ "single-hiredis",
+ "pool-hiredis",
+ ],
+)
+def create_redis(request, event_loop: asyncio.BaseEventLoop):
+ """Wrapper around redis.create_redis."""
+ single_connection, parser_cls = request.param
+
+ async def f(url: str = request.config.getoption("--redis-url"), **kwargs):
+ single = kwargs.pop("single_connection_client", False) or single_connection
+ parser_class = kwargs.pop("parser_class", None) or parser_cls
+ url_options = parse_url(url)
+ url_options.update(kwargs)
+ pool = redis.ConnectionPool(parser_class=parser_class, **url_options)
+ client: redis.Redis = redis.Redis(connection_pool=pool)
+ if single:
+ client = client.client()
+ await client.initialize()
+
+ def teardown():
+ async def ateardown():
+ if "username" in kwargs:
+ return
+ try:
+ await client.flushdb()
+ except redis.ConnectionError:
+ # handle cases where a test disconnected a client
+ # just manually retry the flushdb
+ await client.flushdb()
+ await client.close()
+ await client.connection_pool.disconnect()
+
+ if event_loop.is_running():
+ event_loop.create_task(ateardown())
+ else:
+ event_loop.run_until_complete(ateardown())
+
+ request.addfinalizer(teardown)
+
+ return client
+
+ return f
+
+
+@pytest_asyncio.fixture()
+async def r(create_redis):
+ yield await create_redis()
+
+
+@pytest_asyncio.fixture()
+async def r2(create_redis):
+ """A second client for tests that need multiple"""
+ yield await create_redis()
+
+
+def _gen_cluster_mock_resp(r, response):
+ connection = mock.AsyncMock()
+ connection.read_response.return_value = response
+ r.connection = connection
+ return r
+
+
+@pytest_asyncio.fixture()
+async def mock_cluster_resp_ok(create_redis, **kwargs):
+ r = await create_redis(**kwargs)
+ return _gen_cluster_mock_resp(r, "OK")
+
+
+@pytest_asyncio.fixture()
+async def mock_cluster_resp_int(create_redis, **kwargs):
+ r = await create_redis(**kwargs)
+ return _gen_cluster_mock_resp(r, "2")
+
+
+@pytest_asyncio.fixture()
+async def mock_cluster_resp_info(create_redis, **kwargs):
+ r = await create_redis(**kwargs)
+ response = (
+ "cluster_state:ok\r\ncluster_slots_assigned:16384\r\n"
+ "cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n"
+ "cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n"
+ "cluster_size:3\r\ncluster_current_epoch:7\r\n"
+ "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n"
+ "cluster_stats_messages_received:105653\r\n"
+ )
+ return _gen_cluster_mock_resp(r, response)
+
+
+@pytest_asyncio.fixture()
+async def mock_cluster_resp_nodes(create_redis, **kwargs):
+ r = await create_redis(**kwargs)
+ response = (
+ "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 "
+ "slave aa90da731f673a99617dfe930306549a09f83a6b 0 "
+ "1447836263059 5 connected\n"
+ "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 "
+ "master - 0 1447836264065 0 connected\n"
+ "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 "
+ "myself,master - 0 0 2 connected 5461-10922\n"
+ "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 "
+ "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 "
+ "1447836262556 3 connected\n"
+ "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 "
+ "master - 0 1447836262555 7 connected 0-5460\n"
+ "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 "
+ "master - 0 1447836263562 3 connected 10923-16383\n"
+ "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 "
+ "master,fail - 1447829446956 1447829444948 1 disconnected\n"
+ )
+ return _gen_cluster_mock_resp(r, response)
+
+
+@pytest_asyncio.fixture()
+async def mock_cluster_resp_slaves(create_redis, **kwargs):
+ r = await create_redis(**kwargs)
+ response = (
+ "['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 "
+ "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 "
+ "1447836789290 3 connected']"
+ )
+ return _gen_cluster_mock_resp(r, response)
+
+
+@pytest_asyncio.fixture(scope="session")
+def master_host(request):
+ url = request.config.getoption("--redis-url")
+ parts = urlparse(url)
+ yield parts.hostname
+
+
+async def wait_for_command(
+ client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None
+):
+ # issue a command with a key name that's local to this process.
+ # if we find a command with our key before the command we're waiting
+ # for, something went wrong
+ if key is None:
+ # generate key
+ redis_version = REDIS_INFO["version"]
+ if Version(redis_version) >= Version("5.0.0"):
+ id_str = str(client.client_id())
+ else:
+ id_str = f"{random.randrange(2 ** 32):08x}"
+ key = f"__REDIS-PY-{id_str}__"
+ await client.get(key)
+ while True:
+ monitor_response = await monitor.next_command()
+ if command in monitor_response["command"]:
+ return monitor_response
+ if key in monitor_response["command"]:
+ return None
diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py
new file mode 100644
index 0000000..12855c9
--- /dev/null
+++ b/tests/test_asyncio/test_commands.py
@@ -0,0 +1,3 @@
+"""
+Tests async overrides of commands from their mixins
+"""
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
new file mode 100644
index 0000000..46abec0
--- /dev/null
+++ b/tests/test_asyncio/test_connection.py
@@ -0,0 +1,64 @@
+import asyncio
+import types
+
+import pytest
+
+from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection
+from redis.exceptions import InvalidResponse
+from redis.utils import HIREDIS_AVAILABLE
+from tests.conftest import skip_if_server_version_lt
+
+from .compat import mock
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.onlynoncluster
+@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
+async def test_invalid_response(create_redis):
+ r = await create_redis(single_connection_client=True)
+
+ raw = b"x"
+ readline_mock = mock.AsyncMock(return_value=raw)
+
+ parser: "PythonParser" = r.connection._parser
+ with mock.patch.object(parser._buffer, "readline", readline_mock):
+ with pytest.raises(InvalidResponse) as cm:
+ await parser.read_response()
+ assert str(cm.value) == f"Protocol Error: {raw!r}"
+
+
+@skip_if_server_version_lt("4.0.0")
+@pytest.mark.redismod
+@pytest.mark.onlynoncluster
+async def test_loading_external_modules(modclient):
+ def inner():
+ pass
+
+ modclient.load_external_module("myfuncname", inner)
+ assert getattr(modclient, "myfuncname") == inner
+ assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType)
+
+ # and call it
+ from redis.commands import RedisModuleCommands
+
+ j = RedisModuleCommands.json
+ modclient.load_external_module("sometestfuncname", j)
+
+ # d = {'hello': 'world!'}
+ # mod = j(modclient)
+ # mod.set("fookey", ".", d)
+ # assert mod.get('fookey') == d
+
+
+@pytest.mark.onlynoncluster
+async def test_socket_param_regression(r):
+ """A regression test for issue #1060"""
+ conn = UnixDomainSocketConnection()
+ _ = await conn.disconnect() is True
+
+
+@pytest.mark.onlynoncluster
+async def test_can_run_concurrent_commands(r):
+ assert await r.ping() is True
+ assert all(await asyncio.gather(*(r.ping() for _ in range(10))))
diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py
new file mode 100644
index 0000000..f9dfefd
--- /dev/null
+++ b/tests/test_asyncio/test_connection_pool.py
@@ -0,0 +1,884 @@
+import asyncio
+import os
+import re
+import sys
+
+import pytest
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+import redis.asyncio as redis
+from redis.asyncio.connection import Connection, to_bool
+from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt
+
+from .compat import mock
+from .test_pubsub import wait_for_message
+
+pytestmark = pytest.mark.asyncio
+
+
+class TestRedisAutoReleaseConnectionPool:
+ @pytest_asyncio.fixture
+ async def r(self, create_redis) -> redis.Redis:
+ """This is necessary since r and r2 create ConnectionPools behind the scenes"""
+ r = await create_redis()
+ r.auto_close_connection_pool = True
+ yield r
+
+ @staticmethod
+ def get_total_connected_connections(pool):
+ return len(pool._available_connections) + len(pool._in_use_connections)
+
+ @staticmethod
+ async def create_two_conn(r: redis.Redis):
+ if not r.single_connection_client: # Single already initialized connection
+ r.connection = await r.connection_pool.get_connection("_")
+ return await r.connection_pool.get_connection("_")
+
+ @staticmethod
+ def has_no_connected_connections(pool: redis.ConnectionPool):
+ return not any(
+ x.is_connected
+ for x in pool._available_connections + list(pool._in_use_connections)
+ )
+
+ async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis):
+ new_conn = await self.create_two_conn(r)
+ assert new_conn != r.connection
+ assert self.get_total_connected_connections(r.connection_pool) == 2
+ await r.close()
+ assert self.has_no_connected_connections(r.connection_pool)
+
+ async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis):
+ assert r2.auto_close_connection_pool is False, (
+ "The connection pool should not be disconnected as a manually created "
+ "connection pool was passed in in conftest.py"
+ )
+ new_conn = await self.create_two_conn(r2)
+ assert self.get_total_connected_connections(r2.connection_pool) == 2
+ await r2.close()
+ assert r2.connection_pool._in_use_connections == {new_conn}
+ assert new_conn.is_connected
+ assert len(r2.connection_pool._available_connections) == 1
+ assert r2.connection_pool._available_connections[0].is_connected
+
+ async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis):
+ assert r.auto_close_connection_pool is True, "This is from the class fixture"
+ await self.create_two_conn(r)
+ await r.close()
+ assert self.get_total_connected_connections(r.connection_pool) == 2, (
+ "The connection pool should not be disconnected as a manually created "
+ "connection pool was passed in in conftest.py"
+ )
+ assert self.has_no_connected_connections(r.connection_pool)
+
+ @pytest.mark.parametrize("auto_close_conn_pool", [True, False])
+ async def test_close_override(self, r: redis.Redis, auto_close_conn_pool):
+ r.auto_close_connection_pool = auto_close_conn_pool
+ await self.create_two_conn(r)
+ await r.close(close_connection_pool=True)
+ assert self.has_no_connected_connections(r.connection_pool)
+
+ @pytest.mark.parametrize("auto_close_conn_pool", [True, False])
+ async def test_negate_auto_close_client_pool(
+ self, r: redis.Redis, auto_close_conn_pool
+ ):
+ r.auto_close_connection_pool = auto_close_conn_pool
+ new_conn = await self.create_two_conn(r)
+ await r.close(close_connection_pool=False)
+ assert not self.has_no_connected_connections(r.connection_pool)
+ assert r.connection_pool._in_use_connections == {new_conn}
+ assert r.connection_pool._available_connections[0].is_connected
+ assert self.get_total_connected_connections(r.connection_pool) == 2
+
+
+class DummyConnection(Connection):
+ description_format = "DummyConnection<>"
+
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.pid = os.getpid()
+
+ async def connect(self):
+ pass
+
+ async def disconnect(self):
+ pass
+
+ async def can_read(self, timeout: float = 0):
+ return False
+
+
+@pytest.mark.onlynoncluster
+class TestConnectionPool:
+ def get_pool(
+ self,
+ connection_kwargs=None,
+ max_connections=None,
+ connection_class=redis.Connection,
+ ):
+ connection_kwargs = connection_kwargs or {}
+ pool = redis.ConnectionPool(
+ connection_class=connection_class,
+ max_connections=max_connections,
+ **connection_kwargs,
+ )
+ return pool
+
+ async def test_connection_creation(self):
+ connection_kwargs = {"foo": "bar", "biz": "baz"}
+ pool = self.get_pool(
+ connection_kwargs=connection_kwargs, connection_class=DummyConnection
+ )
+ connection = await pool.get_connection("_")
+ assert isinstance(connection, DummyConnection)
+ assert connection.kwargs == connection_kwargs
+
+ async def test_multiple_connections(self, master_host):
+ connection_kwargs = {"host": master_host}
+ pool = self.get_pool(connection_kwargs=connection_kwargs)
+ c1 = await pool.get_connection("_")
+ c2 = await pool.get_connection("_")
+ assert c1 != c2
+
+ async def test_max_connections(self, master_host):
+ connection_kwargs = {"host": master_host}
+ pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs)
+ await pool.get_connection("_")
+ await pool.get_connection("_")
+ with pytest.raises(redis.ConnectionError):
+ await pool.get_connection("_")
+
+ async def test_reuse_previously_released_connection(self, master_host):
+ connection_kwargs = {"host": master_host}
+ pool = self.get_pool(connection_kwargs=connection_kwargs)
+ c1 = await pool.get_connection("_")
+ await pool.release(c1)
+ c2 = await pool.get_connection("_")
+ assert c1 == c2
+
+ def test_repr_contains_db_info_tcp(self):
+ connection_kwargs = {
+ "host": "localhost",
+ "port": 6379,
+ "db": 1,
+ "client_name": "test-client",
+ }
+ pool = self.get_pool(
+ connection_kwargs=connection_kwargs, connection_class=redis.Connection
+ )
+ expected = (
+ "ConnectionPool<Connection<"
+ "host=localhost,port=6379,db=1,client_name=test-client>>"
+ )
+ assert repr(pool) == expected
+
+ def test_repr_contains_db_info_unix(self):
+ connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"}
+ pool = self.get_pool(
+ connection_kwargs=connection_kwargs,
+ connection_class=redis.UnixDomainSocketConnection,
+ )
+ expected = (
+ "ConnectionPool<UnixDomainSocketConnection<"
+ "path=/abc,db=1,client_name=test-client>>"
+ )
+ assert repr(pool) == expected
+
+
+@pytest.mark.onlynoncluster
+class TestBlockingConnectionPool:
+ def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
+ connection_kwargs = connection_kwargs or {}
+ pool = redis.BlockingConnectionPool(
+ connection_class=DummyConnection,
+ max_connections=max_connections,
+ timeout=timeout,
+ **connection_kwargs,
+ )
+ return pool
+
+ async def test_connection_creation(self, master_host):
+ connection_kwargs = {
+ "foo": "bar",
+ "biz": "baz",
+ "host": master_host[0],
+ "port": master_host[1],
+ }
+ pool = self.get_pool(connection_kwargs=connection_kwargs)
+ connection = await pool.get_connection("_")
+ assert isinstance(connection, DummyConnection)
+ assert connection.kwargs == connection_kwargs
+
+ async def test_disconnect(self, master_host):
+ """A regression test for #1047"""
+ connection_kwargs = {
+ "foo": "bar",
+ "biz": "baz",
+ "host": master_host[0],
+ "port": master_host[1],
+ }
+ pool = self.get_pool(connection_kwargs=connection_kwargs)
+ await pool.get_connection("_")
+ await pool.disconnect()
+
+ async def test_multiple_connections(self, master_host):
+ connection_kwargs = {"host": master_host[0], "port": master_host[1]}
+ pool = self.get_pool(connection_kwargs=connection_kwargs)
+ c1 = await pool.get_connection("_")
+ c2 = await pool.get_connection("_")
+ assert c1 != c2
+
+ async def test_connection_pool_blocks_until_timeout(self, master_host):
+ """When out of connections, block for timeout seconds, then raise"""
+ connection_kwargs = {"host": master_host}
+ pool = self.get_pool(
+ max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs
+ )
+ await pool.get_connection("_")
+
+ start = asyncio.get_event_loop().time()
+ with pytest.raises(redis.ConnectionError):
+ await pool.get_connection("_")
+ # we should have waited at least 0.1 seconds
+ assert asyncio.get_event_loop().time() - start >= 0.1
+
+ async def test_connection_pool_blocks_until_conn_available(self, master_host):
+ """
+ When out of connections, block until another connection is released
+ to the pool
+ """
+ connection_kwargs = {"host": master_host[0], "port": master_host[1]}
+ pool = self.get_pool(
+ max_connections=1, timeout=2, connection_kwargs=connection_kwargs
+ )
+ c1 = await pool.get_connection("_")
+
+ async def target():
+ await asyncio.sleep(0.1)
+ await pool.release(c1)
+
+ start = asyncio.get_event_loop().time()
+ await asyncio.gather(target(), pool.get_connection("_"))
+ assert asyncio.get_event_loop().time() - start >= 0.1
+
+ async def test_reuse_previously_released_connection(self, master_host):
+ connection_kwargs = {"host": master_host}
+ pool = self.get_pool(connection_kwargs=connection_kwargs)
+ c1 = await pool.get_connection("_")
+ await pool.release(c1)
+ c2 = await pool.get_connection("_")
+ assert c1 == c2
+
+ def test_repr_contains_db_info_tcp(self):
+ pool = redis.ConnectionPool(
+ host="localhost", port=6379, client_name="test-client"
+ )
+ expected = (
+ "ConnectionPool<Connection<"
+ "host=localhost,port=6379,db=0,client_name=test-client>>"
+ )
+ assert repr(pool) == expected
+
+ def test_repr_contains_db_info_unix(self):
+ pool = redis.ConnectionPool(
+ connection_class=redis.UnixDomainSocketConnection,
+ path="abc",
+ client_name="test-client",
+ )
+ expected = (
+ "ConnectionPool<UnixDomainSocketConnection<"
+ "path=abc,db=0,client_name=test-client>>"
+ )
+ assert repr(pool) == expected
+
+
+@pytest.mark.onlynoncluster
+class TestConnectionPoolURLParsing:
+ def test_hostname(self):
+ pool = redis.ConnectionPool.from_url("redis://my.host")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "my.host",
+ }
+
+ def test_quoted_hostname(self):
+ pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "my / host +=+",
+ }
+
+ def test_port(self):
+ pool = redis.ConnectionPool.from_url("redis://localhost:6380")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "port": 6380,
+ }
+
+ @skip_if_server_version_lt("6.0.0")
+ def test_username(self):
+ pool = redis.ConnectionPool.from_url("redis://myuser:@localhost")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "username": "myuser",
+ }
+
+ @skip_if_server_version_lt("6.0.0")
+ def test_quoted_username(self):
+ pool = redis.ConnectionPool.from_url(
+ "redis://%2Fmyuser%2F%2B name%3D%24+:@localhost"
+ )
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "username": "/myuser/+ name=$+",
+ }
+
+ def test_password(self):
+ pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "password": "mypassword",
+ }
+
+ def test_quoted_password(self):
+ pool = redis.ConnectionPool.from_url(
+ "redis://:%2Fmypass%2F%2B word%3D%24+@localhost"
+ )
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "password": "/mypass/+ word=$+",
+ }
+
+ @skip_if_server_version_lt("6.0.0")
+ def test_username_and_password(self):
+ pool = redis.ConnectionPool.from_url("redis://myuser:mypass@localhost")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "username": "myuser",
+ "password": "mypass",
+ }
+
+ def test_db_as_argument(self):
+ pool = redis.ConnectionPool.from_url("redis://localhost", db=1)
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "db": 1,
+ }
+
+ def test_db_in_path(self):
+ pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1)
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "db": 2,
+ }
+
+ def test_db_in_querystring(self):
+ pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1)
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "db": 3,
+ }
+
+ def test_extra_typed_querystring_options(self):
+ pool = redis.ConnectionPool.from_url(
+ "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10"
+ "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10"
+ )
+
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {
+ "host": "localhost",
+ "db": 2,
+ "socket_timeout": 20.0,
+ "socket_connect_timeout": 10.0,
+ "retry_on_timeout": True,
+ }
+ assert pool.max_connections == 10
+
+ def test_boolean_parsing(self):
+ for expected, value in (
+ (None, None),
+ (None, ""),
+ (False, 0),
+ (False, "0"),
+ (False, "f"),
+ (False, "F"),
+ (False, "False"),
+ (False, "n"),
+ (False, "N"),
+ (False, "No"),
+ (True, 1),
+ (True, "1"),
+ (True, "y"),
+ (True, "Y"),
+ (True, "Yes"),
+ ):
+ assert expected is to_bool(value)
+
+ def test_client_name_in_querystring(self):
+ pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client")
+ assert pool.connection_kwargs["client_name"] == "test-client"
+
+ def test_invalid_extra_typed_querystring_options(self):
+ with pytest.raises(ValueError):
+ redis.ConnectionPool.from_url(
+ "redis://localhost/2?socket_timeout=_&" "socket_connect_timeout=abc"
+ )
+
+ def test_extra_querystring_options(self):
+ pool = redis.ConnectionPool.from_url("redis://localhost?a=1&b=2")
+ assert pool.connection_class == redis.Connection
+ assert pool.connection_kwargs == {"host": "localhost", "a": "1", "b": "2"}
+
+ def test_calling_from_subclass_returns_correct_instance(self):
+ pool = redis.BlockingConnectionPool.from_url("redis://localhost")
+ assert isinstance(pool, redis.BlockingConnectionPool)
+
+ def test_client_creates_connection_pool(self):
+ r = redis.Redis.from_url("redis://myhost")
+ assert r.connection_pool.connection_class == redis.Connection
+ assert r.connection_pool.connection_kwargs == {
+ "host": "myhost",
+ }
+
+ def test_invalid_scheme_raises_error(self):
+ with pytest.raises(ValueError) as cm:
+ redis.ConnectionPool.from_url("localhost")
+ assert str(cm.value) == (
+ "Redis URL must specify one of the following schemes "
+ "(redis://, rediss://, unix://)"
+ )
+
+
+@pytest.mark.onlynoncluster
+class TestConnectionPoolUnixSocketURLParsing:
+ def test_defaults(self):
+ pool = redis.ConnectionPool.from_url("unix:///socket")
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ }
+
+ @skip_if_server_version_lt("6.0.0")
+ def test_username(self):
+ pool = redis.ConnectionPool.from_url("unix://myuser:@/socket")
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ "username": "myuser",
+ }
+
+ @skip_if_server_version_lt("6.0.0")
+ def test_quoted_username(self):
+ pool = redis.ConnectionPool.from_url(
+ "unix://%2Fmyuser%2F%2B name%3D%24+:@/socket"
+ )
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ "username": "/myuser/+ name=$+",
+ }
+
+ def test_password(self):
+ pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket")
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ "password": "mypassword",
+ }
+
+ def test_quoted_password(self):
+ pool = redis.ConnectionPool.from_url(
+ "unix://:%2Fmypass%2F%2B word%3D%24+@/socket"
+ )
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ "password": "/mypass/+ word=$+",
+ }
+
+ def test_quoted_path(self):
+ pool = redis.ConnectionPool.from_url(
+ "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket"
+ )
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/my/path/to/../+_+=$ocket",
+ "password": "mypassword",
+ }
+
+ def test_db_as_argument(self):
+ pool = redis.ConnectionPool.from_url("unix:///socket", db=1)
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ "db": 1,
+ }
+
+ def test_db_in_querystring(self):
+ pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1)
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {
+ "path": "/socket",
+ "db": 2,
+ }
+
+ def test_client_name_in_querystring(self):
+ pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client")
+ assert pool.connection_kwargs["client_name"] == "test-client"
+
+ def test_extra_querystring_options(self):
+ pool = redis.ConnectionPool.from_url("unix:///socket?a=1&b=2")
+ assert pool.connection_class == redis.UnixDomainSocketConnection
+ assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"}
+
+
+@pytest.mark.onlynoncluster
+class TestSSLConnectionURLParsing:
+ def test_host(self):
+ pool = redis.ConnectionPool.from_url("rediss://my.host")
+ assert pool.connection_class == redis.SSLConnection
+ assert pool.connection_kwargs == {
+ "host": "my.host",
+ }
+
+ def test_cert_reqs_options(self):
+ import ssl
+
+ class DummyConnectionPool(redis.ConnectionPool):
+ def get_connection(self, *args, **kwargs):
+ return self.make_connection()
+
+ pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none")
+ assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE
+
+ pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional")
+ assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL
+
+ pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required")
+ assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED
+
+ pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False")
+ assert pool.get_connection("_").check_hostname is False
+
+ pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True")
+ assert pool.get_connection("_").check_hostname is True
+
+
+@pytest.mark.onlynoncluster
+class TestConnection:
+ async def test_on_connect_error(self):
+ """
+ An error in Connection.on_connect should disconnect from the server
+ see for details: https://github.com/andymccurdy/redis-py/issues/368
+ """
+ # this assumes the Redis server being tested against doesn't have
+ # 9999 databases ;)
+ bad_connection = redis.Redis(db=9999)
+ # an error should be raised on connect
+ with pytest.raises(redis.RedisError):
+ await bad_connection.info()
+ pool = bad_connection.connection_pool
+ assert len(pool._available_connections) == 1
+ assert not pool._available_connections[0]._reader
+
+ @pytest.mark.onlynoncluster
+ @skip_if_server_version_lt("2.8.8")
+ @skip_if_redis_enterprise()
+ async def test_busy_loading_disconnects_socket(self, r):
+ """
+ If Redis raises a LOADING error, the connection should be
+ disconnected and a BusyLoadingError raised
+ """
+ with pytest.raises(redis.BusyLoadingError):
+ await r.execute_command("DEBUG", "ERROR", "LOADING fake message")
+ if r.connection:
+ assert not r.connection._reader
+
+ @pytest.mark.onlynoncluster
+ @skip_if_server_version_lt("2.8.8")
+ @skip_if_redis_enterprise()
+ async def test_busy_loading_from_pipeline_immediate_command(self, r):
+ """
+ BusyLoadingErrors should raise from Pipelines that execute a
+ command immediately, like WATCH does.
+ """
+ pipe = r.pipeline()
+ with pytest.raises(redis.BusyLoadingError):
+ await pipe.immediate_execute_command(
+ "DEBUG", "ERROR", "LOADING fake message"
+ )
+ pool = r.connection_pool
+ assert not pipe.connection
+ assert len(pool._available_connections) == 1
+ assert not pool._available_connections[0]._reader
+
+ @pytest.mark.onlynoncluster
+ @skip_if_server_version_lt("2.8.8")
+ @skip_if_redis_enterprise()
+ async def test_busy_loading_from_pipeline(self, r):
+ """
+ BusyLoadingErrors should be raised from a pipeline execution
+ regardless of the raise_on_error flag.
+ """
+ pipe = r.pipeline()
+ pipe.execute_command("DEBUG", "ERROR", "LOADING fake message")
+ with pytest.raises(redis.BusyLoadingError):
+ await pipe.execute()
+ pool = r.connection_pool
+ assert not pipe.connection
+ assert len(pool._available_connections) == 1
+ assert not pool._available_connections[0]._reader
+
+ @skip_if_server_version_lt("2.8.8")
+ @skip_if_redis_enterprise()
+ async def test_read_only_error(self, r):
+ """READONLY errors get turned in ReadOnlyError exceptions"""
+ with pytest.raises(redis.ReadOnlyError):
+ await r.execute_command("DEBUG", "ERROR", "READONLY blah blah")
+
+ def test_connect_from_url_tcp(self):
+ connection = redis.Redis.from_url("redis://localhost")
+ pool = connection.connection_pool
+
+ assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == (
+ "ConnectionPool",
+ "Connection",
+ "host=localhost,port=6379,db=0",
+ )
+
+ def test_connect_from_url_unix(self):
+ connection = redis.Redis.from_url("unix:///path/to/socket")
+ pool = connection.connection_pool
+
+ assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == (
+ "ConnectionPool",
+ "UnixDomainSocketConnection",
+ "path=/path/to/socket,db=0",
+ )
+
+ @skip_if_redis_enterprise()
+ async def test_connect_no_auth_supplied_when_required(self, r):
+ """
+ AuthenticationError should be raised when the server requires a
+ password but one isn't supplied.
+ """
+ with pytest.raises(redis.AuthenticationError):
+ await r.execute_command(
+ "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set"
+ )
+
+ @skip_if_redis_enterprise()
+ async def test_connect_invalid_password_supplied(self, r):
+ """AuthenticationError should be raised when sending the wrong password"""
+ with pytest.raises(redis.AuthenticationError):
+ await r.execute_command("DEBUG", "ERROR", "ERR invalid password")
+
+
+@pytest.mark.onlynoncluster
+class TestMultiConnectionClient:
+ @pytest_asyncio.fixture()
+ async def r(self, create_redis, server):
+ redis = await create_redis(single_connection_client=False)
+ yield redis
+ await redis.flushall()
+
+
+@pytest.mark.onlynoncluster
+class TestHealthCheck:
+ interval = 60
+
+ @pytest_asyncio.fixture()
+ async def r(self, create_redis):
+ redis = await create_redis(health_check_interval=self.interval)
+ yield redis
+ await redis.flushall()
+
+ def assert_interval_advanced(self, connection):
+ diff = connection.next_health_check - asyncio.get_event_loop().time()
+ assert self.interval > diff > (self.interval - 1)
+
+ async def test_health_check_runs(self, r):
+ if r.connection:
+ r.connection.next_health_check = asyncio.get_event_loop().time() - 1
+ await r.connection.check_health()
+ self.assert_interval_advanced(r.connection)
+
+ async def test_arbitrary_command_invokes_health_check(self, r):
+ # invoke a command to make sure the connection is entirely setup
+ if r.connection:
+ await r.get("foo")
+ r.connection.next_health_check = asyncio.get_event_loop().time()
+ with mock.patch.object(
+ r.connection, "send_command", wraps=r.connection.send_command
+ ) as m:
+ await r.get("foo")
+ m.assert_called_with("PING", check_health=False)
+
+ self.assert_interval_advanced(r.connection)
+
+ async def test_arbitrary_command_advances_next_health_check(self, r):
+ if r.connection:
+ await r.get("foo")
+ next_health_check = r.connection.next_health_check
+ await r.get("foo")
+ assert next_health_check < r.connection.next_health_check
+
+ async def test_health_check_not_invoked_within_interval(self, r):
+ if r.connection:
+ await r.get("foo")
+ with mock.patch.object(
+ r.connection, "send_command", wraps=r.connection.send_command
+ ) as m:
+ await r.get("foo")
+ ping_call_spec = (("PING",), {"check_health": False})
+ assert ping_call_spec not in m.call_args_list
+
+ async def test_health_check_in_pipeline(self, r):
+ async with r.pipeline(transaction=False) as pipe:
+ pipe.connection = await pipe.connection_pool.get_connection("_")
+ pipe.connection.next_health_check = 0
+ with mock.patch.object(
+ pipe.connection, "send_command", wraps=pipe.connection.send_command
+ ) as m:
+ responses = await pipe.set("foo", "bar").get("foo").execute()
+ m.assert_any_call("PING", check_health=False)
+ assert responses == [True, b"bar"]
+
+ async def test_health_check_in_transaction(self, r):
+ async with r.pipeline(transaction=True) as pipe:
+ pipe.connection = await pipe.connection_pool.get_connection("_")
+ pipe.connection.next_health_check = 0
+ with mock.patch.object(
+ pipe.connection, "send_command", wraps=pipe.connection.send_command
+ ) as m:
+ responses = await pipe.set("foo", "bar").get("foo").execute()
+ m.assert_any_call("PING", check_health=False)
+ assert responses == [True, b"bar"]
+
+ async def test_health_check_in_watched_pipeline(self, r):
+ await r.set("foo", "bar")
+ async with r.pipeline(transaction=False) as pipe:
+ pipe.connection = await pipe.connection_pool.get_connection("_")
+ pipe.connection.next_health_check = 0
+ with mock.patch.object(
+ pipe.connection, "send_command", wraps=pipe.connection.send_command
+ ) as m:
+ await pipe.watch("foo")
+ # the health check should be called when watching
+ m.assert_called_with("PING", check_health=False)
+ self.assert_interval_advanced(pipe.connection)
+ assert await pipe.get("foo") == b"bar"
+
+ # reset the mock to clear the call list and schedule another
+ # health check
+ m.reset_mock()
+ pipe.connection.next_health_check = 0
+
+ pipe.multi()
+ responses = await pipe.set("foo", "not-bar").get("foo").execute()
+ assert responses == [True, b"not-bar"]
+ m.assert_any_call("PING", check_health=False)
+
+ async def test_health_check_in_pubsub_before_subscribe(self, r):
+ """A health check happens before the first [p]subscribe"""
+ p = r.pubsub()
+ p.connection = await p.connection_pool.get_connection("_")
+ p.connection.next_health_check = 0
+ with mock.patch.object(
+ p.connection, "send_command", wraps=p.connection.send_command
+ ) as m:
+ assert not p.subscribed
+ await p.subscribe("foo")
+ # the connection is not yet in pubsub mode, so the normal
+ # ping/pong within connection.send_command should check
+ # the health of the connection
+ m.assert_any_call("PING", check_health=False)
+ self.assert_interval_advanced(p.connection)
+
+ subscribe_message = await wait_for_message(p)
+ assert subscribe_message["type"] == "subscribe"
+
+ async def test_health_check_in_pubsub_after_subscribed(self, r):
+ """
+ Pubsub can handle a new subscribe when it's time to check the
+ connection health
+ """
+ p = r.pubsub()
+ p.connection = await p.connection_pool.get_connection("_")
+ p.connection.next_health_check = 0
+ with mock.patch.object(
+ p.connection, "send_command", wraps=p.connection.send_command
+ ) as m:
+ await p.subscribe("foo")
+ subscribe_message = await wait_for_message(p)
+ assert subscribe_message["type"] == "subscribe"
+ self.assert_interval_advanced(p.connection)
+ # because we weren't subscribed when sending the subscribe
+ # message to 'foo', the connection's standard check_health ran
+ # prior to subscribing.
+ m.assert_any_call("PING", check_health=False)
+
+ p.connection.next_health_check = 0
+ m.reset_mock()
+
+ await p.subscribe("bar")
+ # the second subscribe issues exactly only command (the subscribe)
+ # and the health check is not invoked
+ m.assert_called_once_with("SUBSCRIBE", "bar", check_health=False)
+
+ # since no message has been read since the health check was
+ # reset, it should still be 0
+ assert p.connection.next_health_check == 0
+
+ subscribe_message = await wait_for_message(p)
+ assert subscribe_message["type"] == "subscribe"
+ assert await wait_for_message(p) is None
+ # now that the connection is subscribed, the pubsub health
+ # check should have taken over and include the HEALTH_CHECK_MESSAGE
+ m.assert_any_call("PING", p.HEALTH_CHECK_MESSAGE, check_health=False)
+ self.assert_interval_advanced(p.connection)
+
+ async def test_health_check_in_pubsub_poll(self, r):
+ """
+ Polling a pubsub connection that's subscribed will regularly
+ check the connection's health.
+ """
+ p = r.pubsub()
+ p.connection = await p.connection_pool.get_connection("_")
+ with mock.patch.object(
+ p.connection, "send_command", wraps=p.connection.send_command
+ ) as m:
+ await p.subscribe("foo")
+ subscribe_message = await wait_for_message(p)
+ assert subscribe_message["type"] == "subscribe"
+ self.assert_interval_advanced(p.connection)
+
+ # polling the connection before the health check interval
+ # doesn't result in another health check
+ m.reset_mock()
+ next_health_check = p.connection.next_health_check
+ assert await wait_for_message(p) is None
+ assert p.connection.next_health_check == next_health_check
+ m.assert_not_called()
+
+ # reset the health check and poll again
+ # we should not receive a pong message, but the next_health_check
+ # should be advanced
+ p.connection.next_health_check = 0
+ assert await wait_for_message(p) is None
+ m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False)
+ self.assert_interval_advanced(p.connection)
diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py
new file mode 100644
index 0000000..da29837
--- /dev/null
+++ b/tests/test_asyncio/test_encoding.py
@@ -0,0 +1,126 @@
+import sys
+
+import pytest
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+import redis.asyncio as redis
+from redis.exceptions import DataError
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.onlynoncluster
+class TestEncoding:
+ @pytest_asyncio.fixture()
+ async def r(self, create_redis):
+ redis = await create_redis(decode_responses=True)
+ yield redis
+ await redis.flushall()
+
+ @pytest_asyncio.fixture()
+ async def r_no_decode(self, create_redis):
+ redis = await create_redis(decode_responses=False)
+ yield redis
+ await redis.flushall()
+
+ async def test_simple_encoding(self, r_no_decode: redis.Redis):
+ unicode_string = chr(3456) + "abcd" + chr(3421)
+ await r_no_decode.set("unicode-string", unicode_string.encode("utf-8"))
+ cached_val = await r_no_decode.get("unicode-string")
+ assert isinstance(cached_val, bytes)
+ assert unicode_string == cached_val.decode("utf-8")
+
+ async def test_simple_encoding_and_decoding(self, r: redis.Redis):
+ unicode_string = chr(3456) + "abcd" + chr(3421)
+ await r.set("unicode-string", unicode_string)
+ cached_val = await r.get("unicode-string")
+ assert isinstance(cached_val, str)
+ assert unicode_string == cached_val
+
+ async def test_memoryview_encoding(self, r_no_decode: redis.Redis):
+ unicode_string = chr(3456) + "abcd" + chr(3421)
+ unicode_string_view = memoryview(unicode_string.encode("utf-8"))
+ await r_no_decode.set("unicode-string-memoryview", unicode_string_view)
+ cached_val = await r_no_decode.get("unicode-string-memoryview")
+ # The cached value won't be a memoryview because it's a copy from Redis
+ assert isinstance(cached_val, bytes)
+ assert unicode_string == cached_val.decode("utf-8")
+
+ async def test_memoryview_encoding_and_decoding(self, r: redis.Redis):
+ unicode_string = chr(3456) + "abcd" + chr(3421)
+ unicode_string_view = memoryview(unicode_string.encode("utf-8"))
+ await r.set("unicode-string-memoryview", unicode_string_view)
+ cached_val = await r.get("unicode-string-memoryview")
+ assert isinstance(cached_val, str)
+ assert unicode_string == cached_val
+
+ async def test_list_encoding(self, r: redis.Redis):
+ unicode_string = chr(3456) + "abcd" + chr(3421)
+ result = [unicode_string, unicode_string, unicode_string]
+ await r.rpush("a", *result)
+ assert await r.lrange("a", 0, -1) == result
+
+
+@pytest.mark.onlynoncluster
+class TestEncodingErrors:
+ async def test_ignore(self, create_redis):
+ r = await create_redis(
+ decode_responses=True,
+ encoding_errors="ignore",
+ )
+ await r.set("a", b"foo\xff")
+ assert await r.get("a") == "foo"
+
+ async def test_replace(self, create_redis):
+ r = await create_redis(
+ decode_responses=True,
+ encoding_errors="replace",
+ )
+ await r.set("a", b"foo\xff")
+ assert await r.get("a") == "foo\ufffd"
+
+
+@pytest.mark.onlynoncluster
+class TestMemoryviewsAreNotPacked:
+ async def test_memoryviews_are_not_packed(self, r):
+ arg = memoryview(b"some_arg")
+ arg_list = ["SOME_COMMAND", arg]
+ c = r.connection or await r.connection_pool.get_connection("_")
+ cmd = c.pack_command(*arg_list)
+ assert cmd[1] is arg
+ cmds = c.pack_commands([arg_list, arg_list])
+ assert cmds[1] is arg
+ assert cmds[3] is arg
+
+
+class TestCommandsAreNotEncoded:
+ @pytest_asyncio.fixture()
+ async def r(self, create_redis):
+ redis = await create_redis(encoding="utf-16")
+ yield redis
+ await redis.flushall()
+
+ async def test_basic_command(self, r: redis.Redis):
+ await r.set("hello", "world")
+
+
+class TestInvalidUserInput:
+ async def test_boolean_fails(self, r: redis.Redis):
+ with pytest.raises(DataError):
+ await r.set("a", True) # type: ignore
+
+ async def test_none_fails(self, r: redis.Redis):
+ with pytest.raises(DataError):
+ await r.set("a", None) # type: ignore
+
+ async def test_user_type_fails(self, r: redis.Redis):
+ class Foo:
+ def __str__(self):
+ return "Foo"
+
+ with pytest.raises(DataError):
+ await r.set("a", Foo()) # type: ignore
diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py
new file mode 100644
index 0000000..c496718
--- /dev/null
+++ b/tests/test_asyncio/test_lock.py
@@ -0,0 +1,242 @@
+import asyncio
+import sys
+
+import pytest
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+from redis.asyncio.lock import Lock
+from redis.exceptions import LockError, LockNotOwnedError
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.onlynoncluster
+class TestLock:
+ @pytest_asyncio.fixture()
+ async def r_decoded(self, create_redis):
+ redis = await create_redis(decode_responses=True)
+ yield redis
+ await redis.flushall()
+
+ def get_lock(self, redis, *args, **kwargs):
+ kwargs["lock_class"] = Lock
+ return redis.lock(*args, **kwargs)
+
+ async def test_lock(self, r):
+ lock = self.get_lock(r, "foo")
+ assert await lock.acquire(blocking=False)
+ assert await r.get("foo") == lock.local.token
+ assert await r.ttl("foo") == -1
+ await lock.release()
+ assert await r.get("foo") is None
+
+ async def test_lock_token(self, r):
+ lock = self.get_lock(r, "foo")
+ await self._test_lock_token(r, lock)
+
+ async def test_lock_token_thread_local_false(self, r):
+ lock = self.get_lock(r, "foo", thread_local=False)
+ await self._test_lock_token(r, lock)
+
+ async def _test_lock_token(self, r, lock):
+ assert await lock.acquire(blocking=False, token="test")
+ assert await r.get("foo") == b"test"
+ assert lock.local.token == b"test"
+ assert await r.ttl("foo") == -1
+ await lock.release()
+ assert await r.get("foo") is None
+ assert lock.local.token is None
+
+ async def test_locked(self, r):
+ lock = self.get_lock(r, "foo")
+ assert await lock.locked() is False
+ await lock.acquire(blocking=False)
+ assert await lock.locked() is True
+ await lock.release()
+ assert await lock.locked() is False
+
+ async def _test_owned(self, client):
+ lock = self.get_lock(client, "foo")
+ assert await lock.owned() is False
+ await lock.acquire(blocking=False)
+ assert await lock.owned() is True
+ await lock.release()
+ assert await lock.owned() is False
+
+ lock2 = self.get_lock(client, "foo")
+ assert await lock.owned() is False
+ assert await lock2.owned() is False
+ await lock2.acquire(blocking=False)
+ assert await lock.owned() is False
+ assert await lock2.owned() is True
+ await lock2.release()
+ assert await lock.owned() is False
+ assert await lock2.owned() is False
+
+ async def test_owned(self, r):
+ await self._test_owned(r)
+
+ async def test_owned_with_decoded_responses(self, r_decoded):
+ await self._test_owned(r_decoded)
+
+ async def test_competing_locks(self, r):
+ lock1 = self.get_lock(r, "foo")
+ lock2 = self.get_lock(r, "foo")
+ assert await lock1.acquire(blocking=False)
+ assert not await lock2.acquire(blocking=False)
+ await lock1.release()
+ assert await lock2.acquire(blocking=False)
+ assert not await lock1.acquire(blocking=False)
+ await lock2.release()
+
+ async def test_timeout(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ assert await lock.acquire(blocking=False)
+ assert 8 < (await r.ttl("foo")) <= 10
+ await lock.release()
+
+ async def test_float_timeout(self, r):
+ lock = self.get_lock(r, "foo", timeout=9.5)
+ assert await lock.acquire(blocking=False)
+ assert 8 < (await r.pttl("foo")) <= 9500
+ await lock.release()
+
+ async def test_blocking_timeout(self, r, event_loop):
+ lock1 = self.get_lock(r, "foo")
+ assert await lock1.acquire(blocking=False)
+ bt = 0.2
+ sleep = 0.05
+ lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
+ start = event_loop.time()
+ assert not await lock2.acquire()
+ # The elapsed duration should be less than the total blocking_timeout
+ assert bt > (event_loop.time() - start) > bt - sleep
+ await lock1.release()
+
+ async def test_context_manager(self, r):
+ # blocking_timeout prevents a deadlock if the lock can't be acquired
+ # for some reason
+ async with self.get_lock(r, "foo", blocking_timeout=0.2) as lock:
+ assert await r.get("foo") == lock.local.token
+ assert await r.get("foo") is None
+
+ async def test_context_manager_raises_when_locked_not_acquired(self, r):
+ await r.set("foo", "bar")
+ with pytest.raises(LockError):
+ async with self.get_lock(r, "foo", blocking_timeout=0.1):
+ pass
+
+ async def test_high_sleep_small_blocking_timeout(self, r):
+ lock1 = self.get_lock(r, "foo")
+ assert await lock1.acquire(blocking=False)
+ sleep = 60
+ bt = 1
+ lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt)
+ start = asyncio.get_event_loop().time()
+ assert not await lock2.acquire()
+ # the elapsed timed is less than the blocking_timeout as the lock is
+ # unattainable given the sleep/blocking_timeout configuration
+ assert bt > (asyncio.get_event_loop().time() - start)
+ await lock1.release()
+
+ async def test_releasing_unlocked_lock_raises_error(self, r):
+ lock = self.get_lock(r, "foo")
+ with pytest.raises(LockError):
+ await lock.release()
+
+ async def test_releasing_lock_no_longer_owned_raises_error(self, r):
+ lock = self.get_lock(r, "foo")
+ await lock.acquire(blocking=False)
+ # manually change the token
+ await r.set("foo", "a")
+ with pytest.raises(LockNotOwnedError):
+ await lock.release()
+ # even though we errored, the token is still cleared
+ assert lock.local.token is None
+
+ async def test_extend_lock(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ assert await lock.acquire(blocking=False)
+ assert 8000 < (await r.pttl("foo")) <= 10000
+ assert await lock.extend(10)
+ assert 16000 < (await r.pttl("foo")) <= 20000
+ await lock.release()
+
+ async def test_extend_lock_replace_ttl(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ assert await lock.acquire(blocking=False)
+ assert 8000 < (await r.pttl("foo")) <= 10000
+ assert await lock.extend(10, replace_ttl=True)
+ assert 8000 < (await r.pttl("foo")) <= 10000
+ await lock.release()
+
+ async def test_extend_lock_float(self, r):
+ lock = self.get_lock(r, "foo", timeout=10.0)
+ assert await lock.acquire(blocking=False)
+ assert 8000 < (await r.pttl("foo")) <= 10000
+ assert await lock.extend(10.0)
+ assert 16000 < (await r.pttl("foo")) <= 20000
+ await lock.release()
+
+ async def test_extending_unlocked_lock_raises_error(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ with pytest.raises(LockError):
+ await lock.extend(10)
+
+ async def test_extending_lock_with_no_timeout_raises_error(self, r):
+ lock = self.get_lock(r, "foo")
+ assert await lock.acquire(blocking=False)
+ with pytest.raises(LockError):
+ await lock.extend(10)
+ await lock.release()
+
+ async def test_extending_lock_no_longer_owned_raises_error(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ assert await lock.acquire(blocking=False)
+ await r.set("foo", "a")
+ with pytest.raises(LockNotOwnedError):
+ await lock.extend(10)
+
+ async def test_reacquire_lock(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ assert await lock.acquire(blocking=False)
+ assert await r.pexpire("foo", 5000)
+ assert await r.pttl("foo") <= 5000
+ assert await lock.reacquire()
+ assert 8000 < (await r.pttl("foo")) <= 10000
+ await lock.release()
+
+ async def test_reacquiring_unlocked_lock_raises_error(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ with pytest.raises(LockError):
+ await lock.reacquire()
+
+ async def test_reacquiring_lock_with_no_timeout_raises_error(self, r):
+ lock = self.get_lock(r, "foo")
+ assert await lock.acquire(blocking=False)
+ with pytest.raises(LockError):
+ await lock.reacquire()
+ await lock.release()
+
+ async def test_reacquiring_lock_no_longer_owned_raises_error(self, r):
+ lock = self.get_lock(r, "foo", timeout=10)
+ assert await lock.acquire(blocking=False)
+ await r.set("foo", "a")
+ with pytest.raises(LockNotOwnedError):
+ await lock.reacquire()
+
+
+@pytest.mark.onlynoncluster
+class TestLockClassSelection:
+ def test_lock_class_argument(self, r):
+ class MyLock:
+ def __init__(self, *args, **kwargs):
+
+ pass
+
+ lock = r.lock("foo", lock_class=MyLock)
+ assert type(lock) == MyLock
diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py
new file mode 100644
index 0000000..783ba26
--- /dev/null
+++ b/tests/test_asyncio/test_monitor.py
@@ -0,0 +1,67 @@
+import pytest
+
+from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise
+
+from .conftest import wait_for_command
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.onlynoncluster
+class TestMonitor:
+ async def test_wait_command_not_found(self, r):
+ """Make sure the wait_for_command func works when command is not found"""
+ async with r.monitor() as m:
+ response = await wait_for_command(r, m, "nothing")
+ assert response is None
+
+ async def test_response_values(self, r):
+ db = r.connection_pool.connection_kwargs.get("db", 0)
+ async with r.monitor() as m:
+ await r.ping()
+ response = await wait_for_command(r, m, "PING")
+ assert isinstance(response["time"], float)
+ assert response["db"] == db
+ assert response["client_type"] in ("tcp", "unix")
+ assert isinstance(response["client_address"], str)
+ assert isinstance(response["client_port"], str)
+ assert response["command"] == "PING"
+
+ async def test_command_with_quoted_key(self, r):
+ async with r.monitor() as m:
+ await r.get('foo"bar')
+ response = await wait_for_command(r, m, 'GET foo"bar')
+ assert response["command"] == 'GET foo"bar'
+
+ async def test_command_with_binary_data(self, r):
+ async with r.monitor() as m:
+ byte_string = b"foo\x92"
+ await r.get(byte_string)
+ response = await wait_for_command(r, m, "GET foo\\x92")
+ assert response["command"] == "GET foo\\x92"
+
+ async def test_command_with_escaped_data(self, r):
+ async with r.monitor() as m:
+ byte_string = b"foo\\x92"
+ await r.get(byte_string)
+ response = await wait_for_command(r, m, "GET foo\\\\x92")
+ assert response["command"] == "GET foo\\\\x92"
+
+ @skip_if_redis_enterprise()
+ async def test_lua_script(self, r):
+ async with r.monitor() as m:
+ script = 'return redis.call("GET", "foo")'
+ assert await r.eval(script, 0) is None
+ response = await wait_for_command(r, m, "GET foo")
+ assert response["command"] == "GET foo"
+ assert response["client_type"] == "lua"
+ assert response["client_address"] == "lua"
+ assert response["client_port"] == ""
+
+ @skip_ifnot_redis_enterprise()
+ async def test_lua_script_in_enterprise(self, r):
+ async with r.monitor() as m:
+ script = 'return redis.call("GET", "foo")'
+ assert await r.eval(script, 0) is None
+ response = await wait_for_command(r, m, "GET foo")
+ assert response is None
diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py
new file mode 100644
index 0000000..5bb1a8a
--- /dev/null
+++ b/tests/test_asyncio/test_pipeline.py
@@ -0,0 +1,409 @@
+import pytest
+
+import redis
+from tests.conftest import skip_if_server_version_lt
+
+from .conftest import wait_for_command
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.onlynoncluster
+class TestPipeline:
+ async def test_pipeline_is_true(self, r):
+ """Ensure pipeline instances are not false-y"""
+ async with r.pipeline() as pipe:
+ assert pipe
+
+ async def test_pipeline(self, r):
+ async with r.pipeline() as pipe:
+ (
+ pipe.set("a", "a1")
+ .get("a")
+ .zadd("z", {"z1": 1})
+ .zadd("z", {"z2": 4})
+ .zincrby("z", 1, "z1")
+ .zrange("z", 0, 5, withscores=True)
+ )
+ assert await pipe.execute() == [
+ True,
+ b"a1",
+ True,
+ True,
+ 2.0,
+ [(b"z1", 2.0), (b"z2", 4)],
+ ]
+
+ async def test_pipeline_memoryview(self, r):
+ async with r.pipeline() as pipe:
+ (pipe.set("a", memoryview(b"a1")).get("a"))
+ assert await pipe.execute() == [
+ True,
+ b"a1",
+ ]
+
+ async def test_pipeline_length(self, r):
+ async with r.pipeline() as pipe:
+ # Initially empty.
+ assert len(pipe) == 0
+
+ # Fill 'er up!
+ pipe.set("a", "a1").set("b", "b1").set("c", "c1")
+ assert len(pipe) == 3
+
+ # Execute calls reset(), so empty once again.
+ await pipe.execute()
+ assert len(pipe) == 0
+
+ @pytest.mark.onlynoncluster
+ async def test_pipeline_no_transaction(self, r):
+ async with r.pipeline(transaction=False) as pipe:
+ pipe.set("a", "a1").set("b", "b1").set("c", "c1")
+ assert await pipe.execute() == [True, True, True]
+ assert await r.get("a") == b"a1"
+ assert await r.get("b") == b"b1"
+ assert await r.get("c") == b"c1"
+
+ async def test_pipeline_no_transaction_watch(self, r):
+ await r.set("a", 0)
+
+ async with r.pipeline(transaction=False) as pipe:
+ await pipe.watch("a")
+ a = await pipe.get("a")
+
+ pipe.multi()
+ pipe.set("a", int(a) + 1)
+ assert await pipe.execute() == [True]
+
+ async def test_pipeline_no_transaction_watch_failure(self, r):
+ await r.set("a", 0)
+
+ async with r.pipeline(transaction=False) as pipe:
+ await pipe.watch("a")
+ a = await pipe.get("a")
+
+ await r.set("a", "bad")
+
+ pipe.multi()
+ pipe.set("a", int(a) + 1)
+
+ with pytest.raises(redis.WatchError):
+ await pipe.execute()
+
+ assert await r.get("a") == b"bad"
+
+ async def test_exec_error_in_response(self, r):
+ """
+ an invalid pipeline command at exec time adds the exception instance
+ to the list of returned values
+ """
+ await r.set("c", "a")
+ async with r.pipeline() as pipe:
+ pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4)
+ result = await pipe.execute(raise_on_error=False)
+
+ assert result[0]
+ assert await r.get("a") == b"1"
+ assert result[1]
+ assert await r.get("b") == b"2"
+
+ # we can't lpush to a key that's a string value, so this should
+ # be a ResponseError exception
+ assert isinstance(result[2], redis.ResponseError)
+ assert await r.get("c") == b"a"
+
+ # since this isn't a transaction, the other commands after the
+ # error are still executed
+ assert result[3]
+ assert await r.get("d") == b"4"
+
+ # make sure the pipe was restored to a working state
+ assert await pipe.set("z", "zzz").execute() == [True]
+ assert await r.get("z") == b"zzz"
+
+ async def test_exec_error_raised(self, r):
+ await r.set("c", "a")
+ async with r.pipeline() as pipe:
+ pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4)
+ with pytest.raises(redis.ResponseError) as ex:
+ await pipe.execute()
+ assert str(ex.value).startswith(
+ "Command # 3 (LPUSH c 3) of " "pipeline caused error: "
+ )
+
+ # make sure the pipe was restored to a working state
+ assert await pipe.set("z", "zzz").execute() == [True]
+ assert await r.get("z") == b"zzz"
+
+ @pytest.mark.onlynoncluster
+ async def test_transaction_with_empty_error_command(self, r):
+ """
+ Commands with custom EMPTY_ERROR functionality return their default
+ values in the pipeline no matter the raise_on_error preference
+ """
+ for error_switch in (True, False):
+ async with r.pipeline() as pipe:
+ pipe.set("a", 1).mget([]).set("c", 3)
+ result = await pipe.execute(raise_on_error=error_switch)
+
+ assert result[0]
+ assert result[1] == []
+ assert result[2]
+
+ @pytest.mark.onlynoncluster
+ async def test_pipeline_with_empty_error_command(self, r):
+ """
+ Commands with custom EMPTY_ERROR functionality return their default
+ values in the pipeline no matter the raise_on_error preference
+ """
+ for error_switch in (True, False):
+ async with r.pipeline(transaction=False) as pipe:
+ pipe.set("a", 1).mget([]).set("c", 3)
+ result = await pipe.execute(raise_on_error=error_switch)
+
+ assert result[0]
+ assert result[1] == []
+ assert result[2]
+
+ async def test_parse_error_raised(self, r):
+ async with r.pipeline() as pipe:
+ # the zrem is invalid because we don't pass any keys to it
+ pipe.set("a", 1).zrem("b").set("b", 2)
+ with pytest.raises(redis.ResponseError) as ex:
+ await pipe.execute()
+
+ assert str(ex.value).startswith(
+ "Command # 2 (ZREM b) of " "pipeline caused error: "
+ )
+
+ # make sure the pipe was restored to a working state
+ assert await pipe.set("z", "zzz").execute() == [True]
+ assert await r.get("z") == b"zzz"
+
+ @pytest.mark.onlynoncluster
+ async def test_parse_error_raised_transaction(self, r):
+ async with r.pipeline() as pipe:
+ pipe.multi()
+ # the zrem is invalid because we don't pass any keys to it
+ pipe.set("a", 1).zrem("b").set("b", 2)
+ with pytest.raises(redis.ResponseError) as ex:
+ await pipe.execute()
+
+ assert str(ex.value).startswith(
+ "Command # 2 (ZREM b) of " "pipeline caused error: "
+ )
+
+ # make sure the pipe was restored to a working state
+ assert await pipe.set("z", "zzz").execute() == [True]
+ assert await r.get("z") == b"zzz"
+
+ @pytest.mark.onlynoncluster
+ async def test_watch_succeed(self, r):
+ await r.set("a", 1)
+ await r.set("b", 2)
+
+ async with r.pipeline() as pipe:
+ await pipe.watch("a", "b")
+ assert pipe.watching
+ a_value = await pipe.get("a")
+ b_value = await pipe.get("b")
+ assert a_value == b"1"
+ assert b_value == b"2"
+ pipe.multi()
+
+ pipe.set("c", 3)
+ assert await pipe.execute() == [True]
+ assert not pipe.watching
+
+ @pytest.mark.onlynoncluster
+ async def test_watch_failure(self, r):
+ await r.set("a", 1)
+ await r.set("b", 2)
+
+ async with r.pipeline() as pipe:
+ await pipe.watch("a", "b")
+ await r.set("b", 3)
+ pipe.multi()
+ pipe.get("a")
+ with pytest.raises(redis.WatchError):
+ await pipe.execute()
+
+ assert not pipe.watching
+
+ @pytest.mark.onlynoncluster
+ async def test_watch_failure_in_empty_transaction(self, r):
+ await r.set("a", 1)
+ await r.set("b", 2)
+
+ async with r.pipeline() as pipe:
+ await pipe.watch("a", "b")
+ await r.set("b", 3)
+ pipe.multi()
+ with pytest.raises(redis.WatchError):
+ await pipe.execute()
+
+ assert not pipe.watching
+
+ @pytest.mark.onlynoncluster
+ async def test_unwatch(self, r):
+ await r.set("a", 1)
+ await r.set("b", 2)
+
+ async with r.pipeline() as pipe:
+ await pipe.watch("a", "b")
+ await r.set("b", 3)
+ await pipe.unwatch()
+ assert not pipe.watching
+ pipe.get("a")
+ assert await pipe.execute() == [b"1"]
+
+ @pytest.mark.onlynoncluster
+ async def test_watch_exec_no_unwatch(self, r):
+ await r.set("a", 1)
+ await r.set("b", 2)
+
+ async with r.monitor() as m:
+ async with r.pipeline() as pipe:
+ await pipe.watch("a", "b")
+ assert pipe.watching
+ a_value = await pipe.get("a")
+ b_value = await pipe.get("b")
+ assert a_value == b"1"
+ assert b_value == b"2"
+ pipe.multi()
+ pipe.set("c", 3)
+ assert await pipe.execute() == [True]
+ assert not pipe.watching
+
+ unwatch_command = await wait_for_command(r, m, "UNWATCH")
+ assert unwatch_command is None, "should not send UNWATCH"
+
+ @pytest.mark.onlynoncluster
+ async def test_watch_reset_unwatch(self, r):
+ await r.set("a", 1)
+
+ async with r.monitor() as m:
+ async with r.pipeline() as pipe:
+ await pipe.watch("a")
+ assert pipe.watching
+ await pipe.reset()
+ assert not pipe.watching
+
+ unwatch_command = await wait_for_command(r, m, "UNWATCH")
+ assert unwatch_command is not None
+ assert unwatch_command["command"] == "UNWATCH"
+
+ @pytest.mark.onlynoncluster
+ async def test_transaction_callable(self, r):
+ await r.set("a", 1)
+ await r.set("b", 2)
+ has_run = []
+
+ async def my_transaction(pipe):
+ a_value = await pipe.get("a")
+ assert a_value in (b"1", b"2")
+ b_value = await pipe.get("b")
+ assert b_value == b"2"
+
+ # silly run-once code... incr's "a" so WatchError should be raised
+ # forcing this all to run again. this should incr "a" once to "2"
+ if not has_run:
+ await r.incr("a")
+ has_run.append("it has")
+
+ pipe.multi()
+ pipe.set("c", int(a_value) + int(b_value))
+
+ result = await r.transaction(my_transaction, "a", "b")
+ assert result == [True]
+ assert await r.get("c") == b"4"
+
+ @pytest.mark.onlynoncluster
+ async def test_transaction_callable_returns_value_from_callable(self, r):
+ async def callback(pipe):
+ # No need to do anything here since we only want the return value
+ return "a"
+
+ res = await r.transaction(callback, "my-key", value_from_callable=True)
+ assert res == "a"
+
+ async def test_exec_error_in_no_transaction_pipeline(self, r):
+ await r.set("a", 1)
+ async with r.pipeline(transaction=False) as pipe:
+ pipe.llen("a")
+ pipe.expire("a", 100)
+
+ with pytest.raises(redis.ResponseError) as ex:
+ await pipe.execute()
+
+ assert str(ex.value).startswith(
+ "Command # 1 (LLEN a) of " "pipeline caused error: "
+ )
+
+ assert await r.get("a") == b"1"
+
+ async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r):
+ key = chr(3456) + "abcd" + chr(3421)
+ await r.set(key, 1)
+ async with r.pipeline(transaction=False) as pipe:
+ pipe.llen(key)
+ pipe.expire(key, 100)
+
+ with pytest.raises(redis.ResponseError) as ex:
+ await pipe.execute()
+
+ expected = f"Command # 1 (LLEN {key}) of pipeline caused error: "
+ assert str(ex.value).startswith(expected)
+
+ assert await r.get(key) == b"1"
+
+ async def test_pipeline_with_bitfield(self, r):
+ async with r.pipeline() as pipe:
+ pipe.set("a", "1")
+ bf = pipe.bitfield("b")
+ pipe2 = (
+ bf.set("u8", 8, 255)
+ .get("u8", 0)
+ .get("u4", 8) # 1111
+ .get("u4", 12) # 1111
+ .get("u4", 13) # 1110
+ .execute()
+ )
+ pipe.get("a")
+ response = await pipe.execute()
+
+ assert pipe == pipe2
+ assert response == [True, [0, 0, 15, 15, 14], b"1"]
+
+ async def test_pipeline_get(self, r):
+ await r.set("a", "a1")
+ async with r.pipeline() as pipe:
+ await pipe.get("a")
+ assert await pipe.execute() == [b"a1"]
+
+ @pytest.mark.onlynoncluster
+ @skip_if_server_version_lt("2.0.0")
+ async def test_pipeline_discard(self, r):
+
+ # empty pipeline should raise an error
+ async with r.pipeline() as pipe:
+ pipe.set("key", "someval")
+ await pipe.discard()
+ with pytest.raises(redis.exceptions.ResponseError):
+ await pipe.execute()
+
+ # setting a pipeline and discarding should do the same
+ async with r.pipeline() as pipe:
+ pipe.set("key", "someval")
+ pipe.set("someotherkey", "val")
+ response = await pipe.execute()
+ pipe.set("key", "another value!")
+ await pipe.discard()
+ pipe.set("key", "another vae!")
+ with pytest.raises(redis.exceptions.ResponseError):
+ await pipe.execute()
+
+ pipe.set("foo", "bar")
+ response = await pipe.execute()
+ assert response[0]
+ assert await r.get("foo") == b"bar"
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
new file mode 100644
index 0000000..9efcd3c
--- /dev/null
+++ b/tests/test_asyncio/test_pubsub.py
@@ -0,0 +1,660 @@
+import asyncio
+import sys
+from typing import Optional
+
+import pytest
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+import redis.asyncio as redis
+from redis.exceptions import ConnectionError
+from redis.typing import EncodableT
+from tests.conftest import skip_if_server_version_lt
+
+from .compat import mock
+
+pytestmark = pytest.mark.asyncio(forbid_global_loop=True)
+
+
+async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
+ now = asyncio.get_event_loop().time()
+ timeout = now + timeout
+ while now < timeout:
+ message = await pubsub.get_message(
+ ignore_subscribe_messages=ignore_subscribe_messages
+ )
+ if message is not None:
+ return message
+ await asyncio.sleep(0.01)
+ now = asyncio.get_event_loop().time()
+ return None
+
+
+def make_message(
+ type, channel: Optional[str], data: EncodableT, pattern: Optional[str] = None
+):
+ return {
+ "type": type,
+ "pattern": pattern and pattern.encode("utf-8") or None,
+ "channel": channel and channel.encode("utf-8") or None,
+ "data": data.encode("utf-8") if isinstance(data, str) else data,
+ }
+
+
+def make_subscribe_test_data(pubsub, type):
+ if type == "channel":
+ return {
+ "p": pubsub,
+ "sub_type": "subscribe",
+ "unsub_type": "unsubscribe",
+ "sub_func": pubsub.subscribe,
+ "unsub_func": pubsub.unsubscribe,
+ "keys": ["foo", "bar", "uni" + chr(4456) + "code"],
+ }
+ elif type == "pattern":
+ return {
+ "p": pubsub,
+ "sub_type": "psubscribe",
+ "unsub_type": "punsubscribe",
+ "sub_func": pubsub.psubscribe,
+ "unsub_func": pubsub.punsubscribe,
+ "keys": ["f*", "b*", "uni" + chr(4456) + "*"],
+ }
+ assert False, f"invalid subscribe type: {type}"
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubSubscribeUnsubscribe:
+ async def _test_subscribe_unsubscribe(
+ self, p, sub_type, unsub_type, sub_func, unsub_func, keys
+ ):
+ for key in keys:
+ assert await sub_func(key) is None
+
+ # should be a message for each channel/pattern we just subscribed to
+ for i, key in enumerate(keys):
+ assert await wait_for_message(p) == make_message(sub_type, key, i + 1)
+
+ for key in keys:
+ assert await unsub_func(key) is None
+
+ # should be a message for each channel/pattern we just unsubscribed
+ # from
+ for i, key in enumerate(keys):
+ i = len(keys) - 1 - i
+ assert await wait_for_message(p) == make_message(unsub_type, key, i)
+
+ async def test_channel_subscribe_unsubscribe(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "channel")
+ await self._test_subscribe_unsubscribe(**kwargs)
+
+ async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "pattern")
+ await self._test_subscribe_unsubscribe(**kwargs)
+
+ @pytest.mark.onlynoncluster
+ async def _test_resubscribe_on_reconnection(
+ self, p, sub_type, unsub_type, sub_func, unsub_func, keys
+ ):
+
+ for key in keys:
+ assert await sub_func(key) is None
+
+ # should be a message for each channel/pattern we just subscribed to
+ for i, key in enumerate(keys):
+ assert await wait_for_message(p) == make_message(sub_type, key, i + 1)
+
+ # manually disconnect
+ await p.connection.disconnect()
+
+ # calling get_message again reconnects and resubscribes
+ # note, we may not re-subscribe to channels in exactly the same order
+ # so we have to do some extra checks to make sure we got them all
+ messages = []
+ for i in range(len(keys)):
+ messages.append(await wait_for_message(p))
+
+ unique_channels = set()
+ assert len(messages) == len(keys)
+ for i, message in enumerate(messages):
+ assert message["type"] == sub_type
+ assert message["data"] == i + 1
+ assert isinstance(message["channel"], bytes)
+ channel = message["channel"].decode("utf-8")
+ unique_channels.add(channel)
+
+ assert len(unique_channels) == len(keys)
+ for channel in unique_channels:
+ assert channel in keys
+
+ async def test_resubscribe_to_channels_on_reconnection(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "channel")
+ await self._test_resubscribe_on_reconnection(**kwargs)
+
+ async def test_resubscribe_to_patterns_on_reconnection(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "pattern")
+ await self._test_resubscribe_on_reconnection(**kwargs)
+
+ async def _test_subscribed_property(
+ self, p, sub_type, unsub_type, sub_func, unsub_func, keys
+ ):
+
+ assert p.subscribed is False
+ await sub_func(keys[0])
+ # we're now subscribed even though we haven't processed the
+ # reply from the server just yet
+ assert p.subscribed is True
+ assert await wait_for_message(p) == make_message(sub_type, keys[0], 1)
+ # we're still subscribed
+ assert p.subscribed is True
+
+ # unsubscribe from all channels
+ await unsub_func()
+ # we're still technically subscribed until we process the
+ # response messages from the server
+ assert p.subscribed is True
+ assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0)
+ # now we're no longer subscribed as no more messages can be delivered
+ # to any channels we were listening to
+ assert p.subscribed is False
+
+ # subscribing again flips the flag back
+ await sub_func(keys[0])
+ assert p.subscribed is True
+ assert await wait_for_message(p) == make_message(sub_type, keys[0], 1)
+
+ # unsubscribe again
+ await unsub_func()
+ assert p.subscribed is True
+ # subscribe to another channel before reading the unsubscribe response
+ await sub_func(keys[1])
+ assert p.subscribed is True
+ # read the unsubscribe for key1
+ assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0)
+ # we're still subscribed to key2, so subscribed should still be True
+ assert p.subscribed is True
+ # read the key2 subscribe message
+ assert await wait_for_message(p) == make_message(sub_type, keys[1], 1)
+ await unsub_func()
+ # haven't read the message yet, so we're still subscribed
+ assert p.subscribed is True
+ assert await wait_for_message(p) == make_message(unsub_type, keys[1], 0)
+ # now we're finally unsubscribed
+ assert p.subscribed is False
+
+ async def test_subscribe_property_with_channels(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "channel")
+ await self._test_subscribed_property(**kwargs)
+
+ @pytest.mark.onlynoncluster
+ async def test_subscribe_property_with_patterns(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "pattern")
+ await self._test_subscribed_property(**kwargs)
+
+ async def test_ignore_all_subscribe_messages(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+
+ checks = (
+ (p.subscribe, "foo"),
+ (p.unsubscribe, "foo"),
+ (p.psubscribe, "f*"),
+ (p.punsubscribe, "f*"),
+ )
+
+ assert p.subscribed is False
+ for func, channel in checks:
+ assert await func(channel) is None
+ assert p.subscribed is True
+ assert await wait_for_message(p) is None
+ assert p.subscribed is False
+
+ async def test_ignore_individual_subscribe_messages(self, r: redis.Redis):
+ p = r.pubsub()
+
+ checks = (
+ (p.subscribe, "foo"),
+ (p.unsubscribe, "foo"),
+ (p.psubscribe, "f*"),
+ (p.punsubscribe, "f*"),
+ )
+
+ assert p.subscribed is False
+ for func, channel in checks:
+ assert await func(channel) is None
+ assert p.subscribed is True
+ message = await wait_for_message(p, ignore_subscribe_messages=True)
+ assert message is None
+ assert p.subscribed is False
+
+ async def test_sub_unsub_resub_channels(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "channel")
+ await self._test_sub_unsub_resub(**kwargs)
+
+ @pytest.mark.onlynoncluster
+ async def test_sub_unsub_resub_patterns(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "pattern")
+ await self._test_sub_unsub_resub(**kwargs)
+
+ async def _test_sub_unsub_resub(
+ self, p, sub_type, unsub_type, sub_func, unsub_func, keys
+ ):
+ # https://github.com/andymccurdy/redis-py/issues/764
+ key = keys[0]
+ await sub_func(key)
+ await unsub_func(key)
+ await sub_func(key)
+ assert p.subscribed is True
+ assert await wait_for_message(p) == make_message(sub_type, key, 1)
+ assert await wait_for_message(p) == make_message(unsub_type, key, 0)
+ assert await wait_for_message(p) == make_message(sub_type, key, 1)
+ assert p.subscribed is True
+
+ async def test_sub_unsub_all_resub_channels(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "channel")
+ await self._test_sub_unsub_all_resub(**kwargs)
+
+ async def test_sub_unsub_all_resub_patterns(self, r: redis.Redis):
+ kwargs = make_subscribe_test_data(r.pubsub(), "pattern")
+ await self._test_sub_unsub_all_resub(**kwargs)
+
+ async def _test_sub_unsub_all_resub(
+ self, p, sub_type, unsub_type, sub_func, unsub_func, keys
+ ):
+ # https://github.com/andymccurdy/redis-py/issues/764
+ key = keys[0]
+ await sub_func(key)
+ await unsub_func()
+ await sub_func(key)
+ assert p.subscribed is True
+ assert await wait_for_message(p) == make_message(sub_type, key, 1)
+ assert await wait_for_message(p) == make_message(unsub_type, key, 0)
+ assert await wait_for_message(p) == make_message(sub_type, key, 1)
+ assert p.subscribed is True
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubMessages:
+ def setup_method(self, method):
+ self.message = None
+
+ def message_handler(self, message):
+ self.message = message
+
+ async def async_message_handler(self, message):
+ self.async_message = message
+
+ async def test_published_message_to_channel(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe("foo")
+ assert await wait_for_message(p) == make_message("subscribe", "foo", 1)
+ assert await r.publish("foo", "test message") == 1
+
+ message = await wait_for_message(p)
+ assert isinstance(message, dict)
+ assert message == make_message("message", "foo", "test message")
+
+ async def test_published_message_to_pattern(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe("foo")
+ await p.psubscribe("f*")
+ assert await wait_for_message(p) == make_message("subscribe", "foo", 1)
+ assert await wait_for_message(p) == make_message("psubscribe", "f*", 2)
+ # 1 to pattern, 1 to channel
+ assert await r.publish("foo", "test message") == 2
+
+ message1 = await wait_for_message(p)
+ message2 = await wait_for_message(p)
+ assert isinstance(message1, dict)
+ assert isinstance(message2, dict)
+
+ expected = [
+ make_message("message", "foo", "test message"),
+ make_message("pmessage", "foo", "test message", pattern="f*"),
+ ]
+
+ assert message1 in expected
+ assert message2 in expected
+ assert message1 != message2
+
+ async def test_channel_message_handler(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.subscribe(foo=self.message_handler)
+ assert await wait_for_message(p) is None
+ assert await r.publish("foo", "test message") == 1
+ assert await wait_for_message(p) is None
+ assert self.message == make_message("message", "foo", "test message")
+
+ async def test_channel_async_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.subscribe(foo=self.async_message_handler)
+ assert await wait_for_message(p) is None
+ assert await r.publish("foo", "test message") == 1
+ assert await wait_for_message(p) is None
+ assert self.async_message == make_message("message", "foo", "test message")
+
+ async def test_channel_sync_async_message_handler(self, r):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.subscribe(foo=self.message_handler)
+ await p.subscribe(bar=self.async_message_handler)
+ assert await wait_for_message(p) is None
+ assert await r.publish("foo", "test message") == 1
+ assert await r.publish("bar", "test message 2") == 1
+ assert await wait_for_message(p) is None
+ assert self.message == make_message("message", "foo", "test message")
+ assert self.async_message == make_message("message", "bar", "test message 2")
+
+ @pytest.mark.onlynoncluster
+ async def test_pattern_message_handler(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.psubscribe(**{"f*": self.message_handler})
+ assert await wait_for_message(p) is None
+ assert await r.publish("foo", "test message") == 1
+ assert await wait_for_message(p) is None
+ assert self.message == make_message(
+ "pmessage", "foo", "test message", pattern="f*"
+ )
+
+ async def test_unicode_channel_message_handler(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ channel = "uni" + chr(4456) + "code"
+ channels = {channel: self.message_handler}
+ await p.subscribe(**channels)
+ assert await wait_for_message(p) is None
+ assert await r.publish(channel, "test message") == 1
+ assert await wait_for_message(p) is None
+ assert self.message == make_message("message", channel, "test message")
+
+ @pytest.mark.onlynoncluster
+ # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html
+ # #known-limitations-with-pubsub
+ async def test_unicode_pattern_message_handler(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ pattern = "uni" + chr(4456) + "*"
+ channel = "uni" + chr(4456) + "code"
+ await p.psubscribe(**{pattern: self.message_handler})
+ assert await wait_for_message(p) is None
+ assert await r.publish(channel, "test message") == 1
+ assert await wait_for_message(p) is None
+ assert self.message == make_message(
+ "pmessage", channel, "test message", pattern=pattern
+ )
+
+ async def test_get_message_without_subscribe(self, r: redis.Redis):
+ p = r.pubsub()
+ with pytest.raises(RuntimeError) as info:
+ await p.get_message()
+ expect = (
+ "connection not set: " "did you forget to call subscribe() or psubscribe()?"
+ )
+ assert expect in info.exconly()
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubAutoDecoding:
+ """These tests only validate that we get unicode values back"""
+
+ channel = "uni" + chr(4456) + "code"
+ pattern = "uni" + chr(4456) + "*"
+ data = "abc" + chr(4458) + "123"
+
+ def make_message(self, type, channel, data, pattern=None):
+ return {"type": type, "channel": channel, "pattern": pattern, "data": data}
+
+ def setup_method(self, method):
+ self.message = None
+
+ def message_handler(self, message):
+ self.message = message
+
+ @pytest_asyncio.fixture()
+ async def r(self, create_redis):
+ return await create_redis(
+ decode_responses=True,
+ )
+
+ async def test_channel_subscribe_unsubscribe(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe(self.channel)
+ assert await wait_for_message(p) == self.make_message(
+ "subscribe", self.channel, 1
+ )
+
+ await p.unsubscribe(self.channel)
+ assert await wait_for_message(p) == self.make_message(
+ "unsubscribe", self.channel, 0
+ )
+
+ async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.psubscribe(self.pattern)
+ assert await wait_for_message(p) == self.make_message(
+ "psubscribe", self.pattern, 1
+ )
+
+ await p.punsubscribe(self.pattern)
+ assert await wait_for_message(p) == self.make_message(
+ "punsubscribe", self.pattern, 0
+ )
+
+ async def test_channel_publish(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe(self.channel)
+ assert await wait_for_message(p) == self.make_message(
+ "subscribe", self.channel, 1
+ )
+ await r.publish(self.channel, self.data)
+ assert await wait_for_message(p) == self.make_message(
+ "message", self.channel, self.data
+ )
+
+ @pytest.mark.onlynoncluster
+ async def test_pattern_publish(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.psubscribe(self.pattern)
+ assert await wait_for_message(p) == self.make_message(
+ "psubscribe", self.pattern, 1
+ )
+ await r.publish(self.channel, self.data)
+ assert await wait_for_message(p) == self.make_message(
+ "pmessage", self.channel, self.data, pattern=self.pattern
+ )
+
+ async def test_channel_message_handler(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.subscribe(**{self.channel: self.message_handler})
+ assert await wait_for_message(p) is None
+ await r.publish(self.channel, self.data)
+ assert await wait_for_message(p) is None
+ assert self.message == self.make_message("message", self.channel, self.data)
+
+ # test that we reconnected to the correct channel
+ self.message = None
+ await p.connection.disconnect()
+ assert await wait_for_message(p) is None # should reconnect
+ new_data = self.data + "new data"
+ await r.publish(self.channel, new_data)
+ assert await wait_for_message(p) is None
+ assert self.message == self.make_message("message", self.channel, new_data)
+
+ async def test_pattern_message_handler(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.psubscribe(**{self.pattern: self.message_handler})
+ assert await wait_for_message(p) is None
+ await r.publish(self.channel, self.data)
+ assert await wait_for_message(p) is None
+ assert self.message == self.make_message(
+ "pmessage", self.channel, self.data, pattern=self.pattern
+ )
+
+ # test that we reconnected to the correct pattern
+ self.message = None
+ await p.connection.disconnect()
+ assert await wait_for_message(p) is None # should reconnect
+ new_data = self.data + "new data"
+ await r.publish(self.channel, new_data)
+ assert await wait_for_message(p) is None
+ assert self.message == self.make_message(
+ "pmessage", self.channel, new_data, pattern=self.pattern
+ )
+
+ async def test_context_manager(self, r: redis.Redis):
+ async with r.pubsub() as pubsub:
+ await pubsub.subscribe("foo")
+ assert pubsub.connection is not None
+
+ assert pubsub.connection is None
+ assert pubsub.channels == {}
+ assert pubsub.patterns == {}
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubRedisDown:
+ async def test_channel_subscribe(self, r: redis.Redis):
+ r = redis.Redis(host="localhost", port=6390)
+ p = r.pubsub()
+ with pytest.raises(ConnectionError):
+ await p.subscribe("foo")
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubSubcommands:
+ @pytest.mark.onlynoncluster
+ @skip_if_server_version_lt("2.8.0")
+ async def test_pubsub_channels(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe("foo", "bar", "baz", "quux")
+ for i in range(4):
+ assert (await wait_for_message(p))["type"] == "subscribe"
+ expected = [b"bar", b"baz", b"foo", b"quux"]
+ assert all([channel in await r.pubsub_channels() for channel in expected])
+
+ @pytest.mark.onlynoncluster
+ @skip_if_server_version_lt("2.8.0")
+ async def test_pubsub_numsub(self, r: redis.Redis):
+ p1 = r.pubsub()
+ await p1.subscribe("foo", "bar", "baz")
+ for i in range(3):
+ assert (await wait_for_message(p1))["type"] == "subscribe"
+ p2 = r.pubsub()
+ await p2.subscribe("bar", "baz")
+ for i in range(2):
+ assert (await wait_for_message(p2))["type"] == "subscribe"
+ p3 = r.pubsub()
+ await p3.subscribe("baz")
+ assert (await wait_for_message(p3))["type"] == "subscribe"
+
+ channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)]
+ assert await r.pubsub_numsub("foo", "bar", "baz") == channels
+
+ @skip_if_server_version_lt("2.8.0")
+ async def test_pubsub_numpat(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.psubscribe("*oo", "*ar", "b*z")
+ for i in range(3):
+ assert (await wait_for_message(p))["type"] == "psubscribe"
+ assert await r.pubsub_numpat() == 3
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubPings:
+ @skip_if_server_version_lt("3.0.0")
+ async def test_send_pubsub_ping(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.subscribe("foo")
+ await p.ping()
+ assert await wait_for_message(p) == make_message(
+ type="pong", channel=None, data="", pattern=None
+ )
+
+ @skip_if_server_version_lt("3.0.0")
+ async def test_send_pubsub_ping_message(self, r: redis.Redis):
+ p = r.pubsub(ignore_subscribe_messages=True)
+ await p.subscribe("foo")
+ await p.ping(message="hello world")
+ assert await wait_for_message(p) == make_message(
+ type="pong", channel=None, data="hello world", pattern=None
+ )
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubConnectionKilled:
+ @skip_if_server_version_lt("3.0.0")
+ async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe("foo")
+ assert await wait_for_message(p) == make_message("subscribe", "foo", 1)
+ for client in await r.client_list():
+ if client["cmd"] == "subscribe":
+ await r.client_kill_filter(_id=client["id"])
+ with pytest.raises(ConnectionError):
+ await wait_for_message(p)
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubTimeouts:
+ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis):
+ p = r.pubsub()
+ await p.subscribe("foo")
+ assert await wait_for_message(p) == make_message("subscribe", "foo", 1)
+ assert await p.get_message(timeout=0.01) is None
+
+
+@pytest.mark.onlynoncluster
+class TestPubSubRun:
+ async def _subscribe(self, p, *args, **kwargs):
+ await p.subscribe(*args, **kwargs)
+ # Wait for the server to act on the subscription, to be sure that
+ # a subsequent publish on another connection will reach the pubsub.
+ while True:
+ message = await p.get_message(timeout=1)
+ if (
+ message is not None
+ and message["type"] == "subscribe"
+ and message["channel"] == b"foo"
+ ):
+ return
+
+ async def test_callbacks(self, r: redis.Redis):
+ def callback(message):
+ messages.put_nowait(message)
+
+ messages = asyncio.Queue()
+ p = r.pubsub()
+ await self._subscribe(p, foo=callback)
+ task = asyncio.get_event_loop().create_task(p.run())
+ await r.publish("foo", "bar")
+ message = await messages.get()
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ assert message == {
+ "channel": b"foo",
+ "data": b"bar",
+ "pattern": None,
+ "type": "message",
+ }
+
+ async def test_exception_handler(self, r: redis.Redis):
+ def exception_handler_callback(e, pubsub) -> None:
+ assert pubsub == p
+ exceptions.put_nowait(e)
+
+ exceptions = asyncio.Queue()
+ p = r.pubsub()
+ await self._subscribe(p, foo=lambda x: None)
+ with mock.patch.object(p, "get_message", side_effect=Exception("error")):
+ task = asyncio.get_event_loop().create_task(
+ p.run(exception_handler=exception_handler_callback)
+ )
+ e = await exceptions.get()
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ assert str(e) == "error"
diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py
new file mode 100644
index 0000000..e83e001
--- /dev/null
+++ b/tests/test_asyncio/test_retry.py
@@ -0,0 +1,70 @@
+import pytest
+
+from redis.asyncio.connection import Connection, UnixDomainSocketConnection
+from redis.asyncio.retry import Retry
+from redis.backoff import AbstractBackoff, NoBackoff
+from redis.exceptions import ConnectionError
+
+
+class BackoffMock(AbstractBackoff):
+ def __init__(self):
+ self.reset_calls = 0
+ self.calls = 0
+
+ def reset(self):
+ self.reset_calls += 1
+
+ def compute(self, failures):
+ self.calls += 1
+ return 0
+
+
+@pytest.mark.onlynoncluster
+class TestConnectionConstructorWithRetry:
+ "Test that the Connection constructors properly handles Retry objects"
+
+ @pytest.mark.parametrize("retry_on_timeout", [False, True])
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_timeout_boolean(self, Class, retry_on_timeout):
+ c = Class(retry_on_timeout=retry_on_timeout)
+ assert c.retry_on_timeout == retry_on_timeout
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == (1 if retry_on_timeout else 0)
+
+ @pytest.mark.parametrize("retries", range(10))
+ @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
+ def test_retry_on_timeout_retry(self, Class, retries: int):
+ retry_on_timeout = retries > 0
+ c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries))
+ assert c.retry_on_timeout == retry_on_timeout
+ assert isinstance(c.retry, Retry)
+ assert c.retry._retries == retries
+
+
+@pytest.mark.onlynoncluster
+class TestRetry:
+ "Test that Retry calls backoff and retries the expected number of times"
+
+ def setup_method(self, test_method):
+ self.actual_attempts = 0
+ self.actual_failures = 0
+
+ async def _do(self):
+ self.actual_attempts += 1
+ raise ConnectionError()
+
+ async def _fail(self, error):
+ self.actual_failures += 1
+
+ @pytest.mark.parametrize("retries", range(10))
+ @pytest.mark.asyncio
+ async def test_retry(self, retries: int):
+ backoff = BackoffMock()
+ retry = Retry(backoff, retries)
+ with pytest.raises(ConnectionError):
+ await retry.call_with_retry(self._do, self._fail)
+
+ assert self.actual_attempts == 1 + retries
+ assert self.actual_failures == 1 + retries
+ assert backoff.reset_calls == 1
+ assert backoff.calls == retries
diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py
new file mode 100644
index 0000000..764525f
--- /dev/null
+++ b/tests/test_asyncio/test_scripting.py
@@ -0,0 +1,159 @@
+import sys
+
+import pytest
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+from redis import exceptions
+from tests.conftest import skip_if_server_version_lt
+
+multiply_script = """
+local value = redis.call('GET', KEYS[1])
+value = tonumber(value)
+return value * ARGV[1]"""
+
+msgpack_hello_script = """
+local message = cmsgpack.unpack(ARGV[1])
+local name = message['name']
+return "hello " .. name
+"""
+msgpack_hello_script_broken = """
+local message = cmsgpack.unpack(ARGV[1])
+local names = message['name']
+return "hello " .. name
+"""
+
+
+@pytest.mark.onlynoncluster
+class TestScripting:
+ @pytest_asyncio.fixture
+ async def r(self, create_redis):
+ redis = await create_redis()
+ yield redis
+ await redis.script_flush()
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_eval(self, r):
+ await r.flushdb()
+ await r.set("a", 2)
+ # 2 * 3 == 6
+ assert await r.eval(multiply_script, 1, "a", 3) == 6
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ @skip_if_server_version_lt("6.2.0")
+ async def test_script_flush(self, r):
+ await r.set("a", 2)
+ await r.script_load(multiply_script)
+ await r.script_flush("ASYNC")
+
+ await r.set("a", 2)
+ await r.script_load(multiply_script)
+ await r.script_flush("SYNC")
+
+ await r.set("a", 2)
+ await r.script_load(multiply_script)
+ await r.script_flush()
+
+ with pytest.raises(exceptions.DataError):
+ await r.set("a", 2)
+ await r.script_load(multiply_script)
+ await r.script_flush("NOTREAL")
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_evalsha(self, r):
+ await r.set("a", 2)
+ sha = await r.script_load(multiply_script)
+ # 2 * 3 == 6
+ assert await r.evalsha(sha, 1, "a", 3) == 6
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_evalsha_script_not_loaded(self, r):
+ await r.set("a", 2)
+ sha = await r.script_load(multiply_script)
+ # remove the script from Redis's cache
+ await r.script_flush()
+ with pytest.raises(exceptions.NoScriptError):
+ await r.evalsha(sha, 1, "a", 3)
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_script_loading(self, r):
+ # get the sha, then clear the cache
+ sha = await r.script_load(multiply_script)
+ await r.script_flush()
+ assert await r.script_exists(sha) == [False]
+ await r.script_load(multiply_script)
+ assert await r.script_exists(sha) == [True]
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_script_object(self, r):
+ await r.script_flush()
+ await r.set("a", 2)
+ multiply = r.register_script(multiply_script)
+ precalculated_sha = multiply.sha
+ assert precalculated_sha
+ assert await r.script_exists(multiply.sha) == [False]
+ # Test second evalsha block (after NoScriptError)
+ assert await multiply(keys=["a"], args=[3]) == 6
+ # At this point, the script should be loaded
+ assert await r.script_exists(multiply.sha) == [True]
+ # Test that the precalculated sha matches the one from redis
+ assert multiply.sha == precalculated_sha
+ # Test first evalsha block
+ assert await multiply(keys=["a"], args=[3]) == 6
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_script_object_in_pipeline(self, r):
+ await r.script_flush()
+ multiply = r.register_script(multiply_script)
+ precalculated_sha = multiply.sha
+ assert precalculated_sha
+ pipe = r.pipeline()
+ pipe.set("a", 2)
+ pipe.get("a")
+ await multiply(keys=["a"], args=[3], client=pipe)
+ assert await r.script_exists(multiply.sha) == [False]
+ # [SET worked, GET 'a', result of multiple script]
+ assert await pipe.execute() == [True, b"2", 6]
+ # The script should have been loaded by pipe.execute()
+ assert await r.script_exists(multiply.sha) == [True]
+ # The precalculated sha should have been the correct one
+ assert multiply.sha == precalculated_sha
+
+ # purge the script from redis's cache and re-run the pipeline
+ # the multiply script should be reloaded by pipe.execute()
+ await r.script_flush()
+ pipe = r.pipeline()
+ pipe.set("a", 2)
+ pipe.get("a")
+ await multiply(keys=["a"], args=[3], client=pipe)
+ assert await r.script_exists(multiply.sha) == [False]
+ # [SET worked, GET 'a', result of multiple script]
+ assert await pipe.execute() == [True, b"2", 6]
+ assert await r.script_exists(multiply.sha) == [True]
+
+ @pytest.mark.asyncio(forbid_global_loop=True)
+ async def test_eval_msgpack_pipeline_error_in_lua(self, r):
+ msgpack_hello = r.register_script(msgpack_hello_script)
+ assert msgpack_hello.sha
+
+ pipe = r.pipeline()
+
+ # avoiding a dependency to msgpack, this is the output of
+ # msgpack.dumps({"name": "joe"})
+ msgpack_message_1 = b"\x81\xa4name\xa3Joe"
+
+ await msgpack_hello(args=[msgpack_message_1], client=pipe)
+
+ assert await r.script_exists(msgpack_hello.sha) == [False]
+ assert (await pipe.execute())[0] == b"hello Joe"
+ assert await r.script_exists(msgpack_hello.sha) == [True]
+
+ msgpack_hello_broken = r.register_script(msgpack_hello_script_broken)
+
+ await msgpack_hello_broken(args=[msgpack_message_1], client=pipe)
+ with pytest.raises(exceptions.ResponseError) as excinfo:
+ await pipe.execute()
+ assert excinfo.type == exceptions.ResponseError
diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py
new file mode 100644
index 0000000..cd6810c
--- /dev/null
+++ b/tests/test_asyncio/test_sentinel.py
@@ -0,0 +1,249 @@
+import socket
+import sys
+
+import pytest
+
+if sys.version_info[0:2] == (3, 6):
+ import pytest as pytest_asyncio
+else:
+ import pytest_asyncio
+
+import redis.asyncio.sentinel
+from redis import exceptions
+from redis.asyncio.sentinel import (
+ MasterNotFoundError,
+ Sentinel,
+ SentinelConnectionPool,
+ SlaveNotFoundError,
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest_asyncio.fixture(scope="module")
+def master_ip(master_host):
+ yield socket.gethostbyname(master_host)
+
+
+class SentinelTestClient:
+ def __init__(self, cluster, id):
+ self.cluster = cluster
+ self.id = id
+
+ async def sentinel_masters(self):
+ self.cluster.connection_error_if_down(self)
+ self.cluster.timeout_if_down(self)
+ return {self.cluster.service_name: self.cluster.master}
+
+ async def sentinel_slaves(self, master_name):
+ self.cluster.connection_error_if_down(self)
+ self.cluster.timeout_if_down(self)
+ if master_name != self.cluster.service_name:
+ return []
+ return self.cluster.slaves
+
+ async def execute_command(self, *args, **kwargs):
+ # wrapper purely to validate the calls don't explode
+ from redis.asyncio.client import bool_ok
+
+ return bool_ok
+
+
+class SentinelTestCluster:
+ def __init__(self, service_name="mymaster", ip="127.0.0.1", port=6379):
+ self.clients = {}
+ self.master = {
+ "ip": ip,
+ "port": port,
+ "is_master": True,
+ "is_sdown": False,
+ "is_odown": False,
+ "num-other-sentinels": 0,
+ }
+ self.service_name = service_name
+ self.slaves = []
+ self.nodes_down = set()
+ self.nodes_timeout = set()
+
+ def connection_error_if_down(self, node):
+ if node.id in self.nodes_down:
+ raise exceptions.ConnectionError
+
+ def timeout_if_down(self, node):
+ if node.id in self.nodes_timeout:
+ raise exceptions.TimeoutError
+
+ def client(self, host, port, **kwargs):
+ return SentinelTestClient(self, (host, port))
+
+
+@pytest_asyncio.fixture()
+async def cluster(master_ip):
+
+ cluster = SentinelTestCluster(ip=master_ip)
+ saved_Redis = redis.asyncio.sentinel.Redis
+ redis.asyncio.sentinel.Redis = cluster.client
+ yield cluster
+ redis.asyncio.sentinel.Redis = saved_Redis
+
+
+@pytest_asyncio.fixture()
+def sentinel(request, cluster):
+ return Sentinel([("foo", 26379), ("bar", 26379)])
+
+
+@pytest.mark.onlynoncluster
+async def test_discover_master(sentinel, master_ip):
+ address = await sentinel.discover_master("mymaster")
+ assert address == (master_ip, 6379)
+
+
+@pytest.mark.onlynoncluster
+async def test_discover_master_error(sentinel):
+ with pytest.raises(MasterNotFoundError):
+ await sentinel.discover_master("xxx")
+
+
+@pytest.mark.onlynoncluster
+async def test_discover_master_sentinel_down(cluster, sentinel, master_ip):
+ # Put first sentinel 'foo' down
+ cluster.nodes_down.add(("foo", 26379))
+ address = await sentinel.discover_master("mymaster")
+ assert address == (master_ip, 6379)
+ # 'bar' is now first sentinel
+ assert sentinel.sentinels[0].id == ("bar", 26379)
+
+
+@pytest.mark.onlynoncluster
+async def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip):
+ # Put first sentinel 'foo' down
+ cluster.nodes_timeout.add(("foo", 26379))
+ address = await sentinel.discover_master("mymaster")
+ assert address == (master_ip, 6379)
+ # 'bar' is now first sentinel
+ assert sentinel.sentinels[0].id == ("bar", 26379)
+
+
+@pytest.mark.onlynoncluster
+async def test_master_min_other_sentinels(cluster, master_ip):
+ sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1)
+ # min_other_sentinels
+ with pytest.raises(MasterNotFoundError):
+ await sentinel.discover_master("mymaster")
+ cluster.master["num-other-sentinels"] = 2
+ address = await sentinel.discover_master("mymaster")
+ assert address == (master_ip, 6379)
+
+
+@pytest.mark.onlynoncluster
+async def test_master_odown(cluster, sentinel):
+ cluster.master["is_odown"] = True
+ with pytest.raises(MasterNotFoundError):
+ await sentinel.discover_master("mymaster")
+
+
+@pytest.mark.onlynoncluster
+async def test_master_sdown(cluster, sentinel):
+ cluster.master["is_sdown"] = True
+ with pytest.raises(MasterNotFoundError):
+ await sentinel.discover_master("mymaster")
+
+
+@pytest.mark.onlynoncluster
+async def test_discover_slaves(cluster, sentinel):
+ assert await sentinel.discover_slaves("mymaster") == []
+
+ cluster.slaves = [
+ {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False},
+ {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False},
+ ]
+ assert await sentinel.discover_slaves("mymaster") == [
+ ("slave0", 1234),
+ ("slave1", 1234),
+ ]
+
+ # slave0 -> ODOWN
+ cluster.slaves[0]["is_odown"] = True
+ assert await sentinel.discover_slaves("mymaster") == [("slave1", 1234)]
+
+ # slave1 -> SDOWN
+ cluster.slaves[1]["is_sdown"] = True
+ assert await sentinel.discover_slaves("mymaster") == []
+
+ cluster.slaves[0]["is_odown"] = False
+ cluster.slaves[1]["is_sdown"] = False
+
+ # node0 -> DOWN
+ cluster.nodes_down.add(("foo", 26379))
+ assert await sentinel.discover_slaves("mymaster") == [
+ ("slave0", 1234),
+ ("slave1", 1234),
+ ]
+ cluster.nodes_down.clear()
+
+ # node0 -> TIMEOUT
+ cluster.nodes_timeout.add(("foo", 26379))
+ assert await sentinel.discover_slaves("mymaster") == [
+ ("slave0", 1234),
+ ("slave1", 1234),
+ ]
+
+
+@pytest.mark.onlynoncluster
+async def test_master_for(cluster, sentinel, master_ip):
+ master = sentinel.master_for("mymaster", db=9)
+ assert await master.ping()
+ assert master.connection_pool.master_address == (master_ip, 6379)
+
+ # Use internal connection check
+ master = sentinel.master_for("mymaster", db=9, check_connection=True)
+ assert await master.ping()
+
+
+@pytest.mark.onlynoncluster
+async def test_slave_for(cluster, sentinel):
+ cluster.slaves = [
+ {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False},
+ ]
+ slave = sentinel.slave_for("mymaster", db=9)
+ assert await slave.ping()
+
+
+@pytest.mark.onlynoncluster
+async def test_slave_for_slave_not_found_error(cluster, sentinel):
+ cluster.master["is_odown"] = True
+ slave = sentinel.slave_for("mymaster", db=9)
+ with pytest.raises(SlaveNotFoundError):
+ await slave.ping()
+
+
+@pytest.mark.onlynoncluster
+async def test_slave_round_robin(cluster, sentinel, master_ip):
+ cluster.slaves = [
+ {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False},
+ {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False},
+ ]
+ pool = SentinelConnectionPool("mymaster", sentinel)
+ rotator = pool.rotate_slaves()
+ assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379))
+ assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379))
+ # Fallback to master
+ assert await rotator.__anext__() == (master_ip, 6379)
+ with pytest.raises(SlaveNotFoundError):
+ await rotator.__anext__()
+
+
+@pytest.mark.onlynoncluster
+async def test_ckquorum(cluster, sentinel):
+ assert await sentinel.sentinel_ckquorum("mymaster")
+
+
+@pytest.mark.onlynoncluster
+async def test_flushconfig(cluster, sentinel):
+ assert await sentinel.sentinel_flushconfig()
+
+
+@pytest.mark.onlynoncluster
+async def test_reset(cluster, sentinel):
+ cluster.master["is_odown"] = True
+ assert await sentinel.sentinel_reset("mymaster")