diff options
author | Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> | 2022-02-22 05:29:55 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-22 12:29:55 +0200 |
commit | d56baeb683fc1935cfa343fa2eeb0fa9bd955283 (patch) | |
tree | 47357a74bf1d1428cfbcf0d8b2c781f1f971cf77 /tests | |
parent | e3c989d93e914e6502bd5a72f15ded49a135c5be (diff) | |
download | redis-py-d56baeb683fc1935cfa343fa2eeb0fa9bd955283.tar.gz |
Add Async Support (#1899)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/conftest.py | 79 | ||||
-rw-r--r-- | tests/test_asyncio/__init__.py | 0 | ||||
-rw-r--r-- | tests/test_asyncio/compat.py | 6 | ||||
-rw-r--r-- | tests/test_asyncio/conftest.py | 205 | ||||
-rw-r--r-- | tests/test_asyncio/test_commands.py | 3 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection.py | 64 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection_pool.py | 884 | ||||
-rw-r--r-- | tests/test_asyncio/test_encoding.py | 126 | ||||
-rw-r--r-- | tests/test_asyncio/test_lock.py | 242 | ||||
-rw-r--r-- | tests/test_asyncio/test_monitor.py | 67 | ||||
-rw-r--r-- | tests/test_asyncio/test_pipeline.py | 409 | ||||
-rw-r--r-- | tests/test_asyncio/test_pubsub.py | 660 | ||||
-rw-r--r-- | tests/test_asyncio/test_retry.py | 70 | ||||
-rw-r--r-- | tests/test_asyncio/test_scripting.py | 159 | ||||
-rw-r--r-- | tests/test_asyncio/test_sentinel.py | 249 |
15 files changed, 3216 insertions, 7 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index d9de876..2534ca0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +import argparse import random import time +from typing import Callable, TypeVar from unittest.mock import Mock from urllib.parse import urlparse @@ -21,6 +23,54 @@ default_redis_unstable_url = "redis://localhost:6378" default_redis_ssl_url = "rediss://localhost:6666" default_cluster_nodes = 6 +_DecoratedTest = TypeVar("_DecoratedTest", bound="Callable") +_TestDecorator = Callable[[_DecoratedTest], _DecoratedTest] + + +# Taken from python3.9 +class BooleanOptionalAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None, + ): + + _option_strings = [] + for option_string in option_strings: + _option_strings.append(option_string) + + if option_string.startswith("--"): + option_string = "--no-" + option_string[2:] + _option_strings.append(option_string) + + if help is not None and default is not None: + help += f" (default: {default})" + + super().__init__( + option_strings=_option_strings, + dest=dest, + nargs=0, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar, + ) + + def __call__(self, parser, namespace, values, option_string=None): + if option_string in self.option_strings: + setattr(namespace, self.dest, not option_string.startswith("--no-")) + + def format_usage(self): + return " | ".join(self.option_strings) + def pytest_addoption(parser): parser.addoption( @@ -62,6 +112,9 @@ def pytest_addoption(parser): help="Redis unstable (latest version) connection string " "defaults to %(default)s`", ) + parser.addoption( + "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" + ) def _get_info(redis_url): @@ -101,6 +154,18 @@ def pytest_sessionstart(session): cluster_nodes = session.config.getoption("--redis-cluster-nodes") wait_for_cluster_creation(redis_url, cluster_nodes) + use_uvloop = session.config.getoption("--uvloop") + + if use_uvloop: + try: + import uvloop + + uvloop.install() + except ImportError as e: + raise RuntimeError( + "Can not import uvloop, make sure it is installed" + ) from e + def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): """ @@ -133,19 +198,19 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=60): ) -def skip_if_server_version_lt(min_version): +def skip_if_server_version_lt(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = Version(redis_version) < Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") -def skip_if_server_version_gte(min_version): +def skip_if_server_version_gte(min_version: str) -> _TestDecorator: redis_version = REDIS_INFO["version"] check = Version(redis_version) >= Version(min_version) return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") -def skip_unless_arch_bits(arch_bits): +def skip_unless_arch_bits(arch_bits: int) -> _TestDecorator: return pytest.mark.skipif( REDIS_INFO["arch_bits"] != arch_bits, reason=f"server is not {arch_bits}-bit" ) @@ -169,17 +234,17 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): raise AttributeError(f"No redis module named {module_name}") -def skip_if_redis_enterprise(): +def skip_if_redis_enterprise() -> _TestDecorator: check = REDIS_INFO["enterprise"] is True return pytest.mark.skipif(check, reason="Redis enterprise") -def skip_ifnot_redis_enterprise(): +def skip_ifnot_redis_enterprise() -> _TestDecorator: check = REDIS_INFO["enterprise"] is False return pytest.mark.skipif(check, reason="Not running in redis enterprise") -def skip_if_nocryptography(): +def skip_if_nocryptography() -> _TestDecorator: try: import cryptography # noqa @@ -188,7 +253,7 @@ def skip_if_nocryptography(): return pytest.mark.skipif(True, reason="No cryptography dependency") -def skip_if_cryptography(): +def skip_if_cryptography() -> _TestDecorator: try: import cryptography # noqa diff --git a/tests/test_asyncio/__init__.py b/tests/test_asyncio/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/test_asyncio/__init__.py diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py new file mode 100644 index 0000000..ced4974 --- /dev/null +++ b/tests/test_asyncio/compat.py @@ -0,0 +1,6 @@ +from unittest import mock + +try: + mock.AsyncMock +except AttributeError: + import mock diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py new file mode 100644 index 0000000..0e9c73e --- /dev/null +++ b/tests/test_asyncio/conftest.py @@ -0,0 +1,205 @@ +import asyncio +import random +import sys +from typing import Union +from urllib.parse import urlparse + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import pytest +from packaging.version import Version + +import redis.asyncio as redis +from redis.asyncio.client import Monitor +from redis.asyncio.connection import ( + HIREDIS_AVAILABLE, + HiredisParser, + PythonParser, + parse_url, +) +from tests.conftest import REDIS_INFO + +from .compat import mock + + +async def _get_info(redis_url): + client = redis.Redis.from_url(redis_url) + info = await client.info() + await client.connection_pool.disconnect() + return info + + +@pytest_asyncio.fixture( + params=[ + (True, PythonParser), + (False, PythonParser), + pytest.param( + (True, HiredisParser), + marks=pytest.mark.skipif( + not HIREDIS_AVAILABLE, reason="hiredis is not installed" + ), + ), + pytest.param( + (False, HiredisParser), + marks=pytest.mark.skipif( + not HIREDIS_AVAILABLE, reason="hiredis is not installed" + ), + ), + ], + ids=[ + "single-python-parser", + "pool-python-parser", + "single-hiredis", + "pool-hiredis", + ], +) +def create_redis(request, event_loop: asyncio.BaseEventLoop): + """Wrapper around redis.create_redis.""" + single_connection, parser_cls = request.param + + async def f(url: str = request.config.getoption("--redis-url"), **kwargs): + single = kwargs.pop("single_connection_client", False) or single_connection + parser_class = kwargs.pop("parser_class", None) or parser_cls + url_options = parse_url(url) + url_options.update(kwargs) + pool = redis.ConnectionPool(parser_class=parser_class, **url_options) + client: redis.Redis = redis.Redis(connection_pool=pool) + if single: + client = client.client() + await client.initialize() + + def teardown(): + async def ateardown(): + if "username" in kwargs: + return + try: + await client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() + + if event_loop.is_running(): + event_loop.create_task(ateardown()) + else: + event_loop.run_until_complete(ateardown()) + + request.addfinalizer(teardown) + + return client + + return f + + +@pytest_asyncio.fixture() +async def r(create_redis): + yield await create_redis() + + +@pytest_asyncio.fixture() +async def r2(create_redis): + """A second client for tests that need multiple""" + yield await create_redis() + + +def _gen_cluster_mock_resp(r, response): + connection = mock.AsyncMock() + connection.read_response.return_value = response + r.connection = connection + return r + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_ok(create_redis, **kwargs): + r = await create_redis(**kwargs) + return _gen_cluster_mock_resp(r, "OK") + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_int(create_redis, **kwargs): + r = await create_redis(**kwargs) + return _gen_cluster_mock_resp(r, "2") + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_info(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "cluster_state:ok\r\ncluster_slots_assigned:16384\r\n" + "cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n" + "cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n" + "cluster_size:3\r\ncluster_current_epoch:7\r\n" + "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" + "cluster_stats_messages_received:105653\r\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_nodes(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest_asyncio.fixture() +async def mock_cluster_resp_slaves(create_redis, **kwargs): + r = await create_redis(**kwargs) + response = ( + "['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836789290 3 connected']" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest_asyncio.fixture(scope="session") +def master_host(request): + url = request.config.getoption("--redis-url") + parts = urlparse(url) + yield parts.hostname + + +async def wait_for_command( + client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None +): + # issue a command with a key name that's local to this process. + # if we find a command with our key before the command we're waiting + # for, something went wrong + if key is None: + # generate key + redis_version = REDIS_INFO["version"] + if Version(redis_version) >= Version("5.0.0"): + id_str = str(client.client_id()) + else: + id_str = f"{random.randrange(2 ** 32):08x}" + key = f"__REDIS-PY-{id_str}__" + await client.get(key) + while True: + monitor_response = await monitor.next_command() + if command in monitor_response["command"]: + return monitor_response + if key in monitor_response["command"]: + return None diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py new file mode 100644 index 0000000..12855c9 --- /dev/null +++ b/tests/test_asyncio/test_commands.py @@ -0,0 +1,3 @@ +""" +Tests async overrides of commands from their mixins +""" diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py new file mode 100644 index 0000000..46abec0 --- /dev/null +++ b/tests/test_asyncio/test_connection.py @@ -0,0 +1,64 @@ +import asyncio +import types + +import pytest + +from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection +from redis.exceptions import InvalidResponse +from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import skip_if_server_version_lt + +from .compat import mock + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +async def test_invalid_response(create_redis): + r = await create_redis(single_connection_client=True) + + raw = b"x" + readline_mock = mock.AsyncMock(return_value=raw) + + parser: "PythonParser" = r.connection._parser + with mock.patch.object(parser._buffer, "readline", readline_mock): + with pytest.raises(InvalidResponse) as cm: + await parser.read_response() + assert str(cm.value) == f"Protocol Error: {raw!r}" + + +@skip_if_server_version_lt("4.0.0") +@pytest.mark.redismod +@pytest.mark.onlynoncluster +async def test_loading_external_modules(modclient): + def inner(): + pass + + modclient.load_external_module("myfuncname", inner) + assert getattr(modclient, "myfuncname") == inner + assert isinstance(getattr(modclient, "myfuncname"), types.FunctionType) + + # and call it + from redis.commands import RedisModuleCommands + + j = RedisModuleCommands.json + modclient.load_external_module("sometestfuncname", j) + + # d = {'hello': 'world!'} + # mod = j(modclient) + # mod.set("fookey", ".", d) + # assert mod.get('fookey') == d + + +@pytest.mark.onlynoncluster +async def test_socket_param_regression(r): + """A regression test for issue #1060""" + conn = UnixDomainSocketConnection() + _ = await conn.disconnect() is True + + +@pytest.mark.onlynoncluster +async def test_can_run_concurrent_commands(r): + assert await r.ping() is True + assert all(await asyncio.gather(*(r.ping() for _ in range(10)))) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py new file mode 100644 index 0000000..f9dfefd --- /dev/null +++ b/tests/test_asyncio/test_connection_pool.py @@ -0,0 +1,884 @@ +import asyncio +import os +import re +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio as redis +from redis.asyncio.connection import Connection, to_bool +from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt + +from .compat import mock +from .test_pubsub import wait_for_message + +pytestmark = pytest.mark.asyncio + + +class TestRedisAutoReleaseConnectionPool: + @pytest_asyncio.fixture + async def r(self, create_redis) -> redis.Redis: + """This is necessary since r and r2 create ConnectionPools behind the scenes""" + r = await create_redis() + r.auto_close_connection_pool = True + yield r + + @staticmethod + def get_total_connected_connections(pool): + return len(pool._available_connections) + len(pool._in_use_connections) + + @staticmethod + async def create_two_conn(r: redis.Redis): + if not r.single_connection_client: # Single already initialized connection + r.connection = await r.connection_pool.get_connection("_") + return await r.connection_pool.get_connection("_") + + @staticmethod + def has_no_connected_connections(pool: redis.ConnectionPool): + return not any( + x.is_connected + for x in pool._available_connections + list(pool._in_use_connections) + ) + + async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis): + new_conn = await self.create_two_conn(r) + assert new_conn != r.connection + assert self.get_total_connected_connections(r.connection_pool) == 2 + await r.close() + assert self.has_no_connected_connections(r.connection_pool) + + async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): + assert r2.auto_close_connection_pool is False, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + new_conn = await self.create_two_conn(r2) + assert self.get_total_connected_connections(r2.connection_pool) == 2 + await r2.close() + assert r2.connection_pool._in_use_connections == {new_conn} + assert new_conn.is_connected + assert len(r2.connection_pool._available_connections) == 1 + assert r2.connection_pool._available_connections[0].is_connected + + async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis): + assert r.auto_close_connection_pool is True, "This is from the class fixture" + await self.create_two_conn(r) + await r.close() + assert self.get_total_connected_connections(r.connection_pool) == 2, ( + "The connection pool should not be disconnected as a manually created " + "connection pool was passed in in conftest.py" + ) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_close_override(self, r: redis.Redis, auto_close_conn_pool): + r.auto_close_connection_pool = auto_close_conn_pool + await self.create_two_conn(r) + await r.close(close_connection_pool=True) + assert self.has_no_connected_connections(r.connection_pool) + + @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) + async def test_negate_auto_close_client_pool( + self, r: redis.Redis, auto_close_conn_pool + ): + r.auto_close_connection_pool = auto_close_conn_pool + new_conn = await self.create_two_conn(r) + await r.close(close_connection_pool=False) + assert not self.has_no_connected_connections(r.connection_pool) + assert r.connection_pool._in_use_connections == {new_conn} + assert r.connection_pool._available_connections[0].is_connected + assert self.get_total_connected_connections(r.connection_pool) == 2 + + +class DummyConnection(Connection): + description_format = "DummyConnection<>" + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.pid = os.getpid() + + async def connect(self): + pass + + async def disconnect(self): + pass + + async def can_read(self, timeout: float = 0): + return False + + +@pytest.mark.onlynoncluster +class TestConnectionPool: + def get_pool( + self, + connection_kwargs=None, + max_connections=None, + connection_class=redis.Connection, + ): + connection_kwargs = connection_kwargs or {} + pool = redis.ConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + return pool + + async def test_connection_creation(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=DummyConnection + ) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_max_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.get_connection("_") + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = { + "host": "localhost", + "port": 6379, + "db": 1, + "client_name": "test-client", + } + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=redis.Connection + ) + expected = ( + "ConnectionPool<Connection<" + "host=localhost,port=6379,db=1,client_name=test-client>>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, + connection_class=redis.UnixDomainSocketConnection, + ) + expected = ( + "ConnectionPool<UnixDomainSocketConnection<" + "path=/abc,db=1,client_name=test-client>>" + ) + assert repr(pool) == expected + + +@pytest.mark.onlynoncluster +class TestBlockingConnectionPool: + def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + connection_kwargs = connection_kwargs or {} + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs, + ) + return pool + + async def test_connection_creation(self, master_host): + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_disconnect(self, master_host): + """A regression test for #1047""" + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "host": master_host[0], + "port": master_host[1], + } + pool = self.get_pool(connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.disconnect() + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_connection_pool_blocks_until_timeout(self, master_host): + """When out of connections, block for timeout seconds, then raise""" + connection_kwargs = {"host": master_host} + pool = self.get_pool( + max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs + ) + await pool.get_connection("_") + + start = asyncio.get_event_loop().time() + with pytest.raises(redis.ConnectionError): + await pool.get_connection("_") + # we should have waited at least 0.1 seconds + assert asyncio.get_event_loop().time() - start >= 0.1 + + async def test_connection_pool_blocks_until_conn_available(self, master_host): + """ + When out of connections, block until another connection is released + to the pool + """ + connection_kwargs = {"host": master_host[0], "port": master_host[1]} + pool = self.get_pool( + max_connections=1, timeout=2, connection_kwargs=connection_kwargs + ) + c1 = await pool.get_connection("_") + + async def target(): + await asyncio.sleep(0.1) + await pool.release(c1) + + start = asyncio.get_event_loop().time() + await asyncio.gather(target(), pool.get_connection("_")) + assert asyncio.get_event_loop().time() - start >= 0.1 + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + pool = redis.ConnectionPool( + host="localhost", port=6379, client_name="test-client" + ) + expected = ( + "ConnectionPool<Connection<" + "host=localhost,port=6379,db=0,client_name=test-client>>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + pool = redis.ConnectionPool( + connection_class=redis.UnixDomainSocketConnection, + path="abc", + client_name="test-client", + ) + expected = ( + "ConnectionPool<UnixDomainSocketConnection<" + "path=abc,db=0,client_name=test-client>>" + ) + assert repr(pool) == expected + + +@pytest.mark.onlynoncluster +class TestConnectionPoolURLParsing: + def test_hostname(self): + pool = redis.ConnectionPool.from_url("redis://my.host") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_quoted_hostname(self): + pool = redis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "my / host +=+", + } + + def test_port(self): + pool = redis.ConnectionPool.from_url("redis://localhost:6380") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "port": 6380, + } + + @skip_if_server_version_lt("6.0.0") + def test_username(self): + pool = redis.ConnectionPool.from_url("redis://myuser:@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + } + + @skip_if_server_version_lt("6.0.0") + def test_quoted_username(self): + pool = redis.ConnectionPool.from_url( + "redis://%2Fmyuser%2F%2B name%3D%24+:@localhost" + ) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = redis.ConnectionPool.from_url("redis://:mypassword@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + "redis://:%2Fmypass%2F%2B word%3D%24+@localhost" + ) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "/mypass/+ word=$+", + } + + @skip_if_server_version_lt("6.0.0") + def test_username_and_password(self): + pool = redis.ConnectionPool.from_url("redis://myuser:mypass@localhost") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + "password": "mypass", + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url("redis://localhost", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 1, + } + + def test_db_in_path(self): + pool = redis.ConnectionPool.from_url("redis://localhost/2", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 3, + } + + def test_extra_typed_querystring_options(self): + pool = redis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10" + ) + + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, + } + assert pool.max_connections == 10 + + def test_boolean_parsing(self): + for expected, value in ( + (None, None), + (None, ""), + (False, 0), + (False, "0"), + (False, "f"), + (False, "F"), + (False, "False"), + (False, "n"), + (False, "N"), + (False, "No"), + (True, 1), + (True, "1"), + (True, "y"), + (True, "Y"), + (True, "Yes"), + ): + assert expected is to_bool(value) + + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_invalid_extra_typed_querystring_options(self): + with pytest.raises(ValueError): + redis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=_&" "socket_connect_timeout=abc" + ) + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url("redis://localhost?a=1&b=2") + assert pool.connection_class == redis.Connection + assert pool.connection_kwargs == {"host": "localhost", "a": "1", "b": "2"} + + def test_calling_from_subclass_returns_correct_instance(self): + pool = redis.BlockingConnectionPool.from_url("redis://localhost") + assert isinstance(pool, redis.BlockingConnectionPool) + + def test_client_creates_connection_pool(self): + r = redis.Redis.from_url("redis://myhost") + assert r.connection_pool.connection_class == redis.Connection + assert r.connection_pool.connection_kwargs == { + "host": "myhost", + } + + def test_invalid_scheme_raises_error(self): + with pytest.raises(ValueError) as cm: + redis.ConnectionPool.from_url("localhost") + assert str(cm.value) == ( + "Redis URL must specify one of the following schemes " + "(redis://, rediss://, unix://)" + ) + + +@pytest.mark.onlynoncluster +class TestConnectionPoolUnixSocketURLParsing: + def test_defaults(self): + pool = redis.ConnectionPool.from_url("unix:///socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + } + + @skip_if_server_version_lt("6.0.0") + def test_username(self): + pool = redis.ConnectionPool.from_url("unix://myuser:@/socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "myuser", + } + + @skip_if_server_version_lt("6.0.0") + def test_quoted_username(self): + pool = redis.ConnectionPool.from_url( + "unix://%2Fmyuser%2F%2B name%3D%24+:@/socket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = redis.ConnectionPool.from_url("unix://:mypassword@/socket") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = redis.ConnectionPool.from_url( + "unix://:%2Fmypass%2F%2B word%3D%24+@/socket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "/mypass/+ word=$+", + } + + def test_quoted_path(self): + pool = redis.ConnectionPool.from_url( + "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket" + ) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/my/path/to/../+_+=$ocket", + "password": "mypassword", + } + + def test_db_as_argument(self): + pool = redis.ConnectionPool.from_url("unix:///socket", db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 1, + } + + def test_db_in_querystring(self): + pool = redis.ConnectionPool.from_url("unix:///socket?db=2", db=1) + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 2, + } + + def test_client_name_in_querystring(self): + pool = redis.ConnectionPool.from_url("redis://location?client_name=test-client") + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_extra_querystring_options(self): + pool = redis.ConnectionPool.from_url("unix:///socket?a=1&b=2") + assert pool.connection_class == redis.UnixDomainSocketConnection + assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} + + +@pytest.mark.onlynoncluster +class TestSSLConnectionURLParsing: + def test_host(self): + pool = redis.ConnectionPool.from_url("rediss://my.host") + assert pool.connection_class == redis.SSLConnection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_cert_reqs_options(self): + import ssl + + class DummyConnectionPool(redis.ConnectionPool): + def get_connection(self, *args, **kwargs): + return self.make_connection() + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") + assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") + assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") + assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") + assert pool.get_connection("_").check_hostname is False + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") + assert pool.get_connection("_").check_hostname is True + + +@pytest.mark.onlynoncluster +class TestConnection: + async def test_on_connect_error(self): + """ + An error in Connection.on_connect should disconnect from the server + see for details: https://github.com/andymccurdy/redis-py/issues/368 + """ + # this assumes the Redis server being tested against doesn't have + # 9999 databases ;) + bad_connection = redis.Redis(db=9999) + # an error should be raised on connect + with pytest.raises(redis.RedisError): + await bad_connection.info() + pool = bad_connection.connection_pool + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_disconnects_socket(self, r): + """ + If Redis raises a LOADING error, the connection should be + disconnected and a BusyLoadingError raised + """ + with pytest.raises(redis.BusyLoadingError): + await r.execute_command("DEBUG", "ERROR", "LOADING fake message") + if r.connection: + assert not r.connection._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_from_pipeline_immediate_command(self, r): + """ + BusyLoadingErrors should raise from Pipelines that execute a + command immediately, like WATCH does. + """ + pipe = r.pipeline() + with pytest.raises(redis.BusyLoadingError): + await pipe.immediate_execute_command( + "DEBUG", "ERROR", "LOADING fake message" + ) + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_busy_loading_from_pipeline(self, r): + """ + BusyLoadingErrors should be raised from a pipeline execution + regardless of the raise_on_error flag. + """ + pipe = r.pipeline() + pipe.execute_command("DEBUG", "ERROR", "LOADING fake message") + with pytest.raises(redis.BusyLoadingError): + await pipe.execute() + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @skip_if_server_version_lt("2.8.8") + @skip_if_redis_enterprise() + async def test_read_only_error(self, r): + """READONLY errors get turned in ReadOnlyError exceptions""" + with pytest.raises(redis.ReadOnlyError): + await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + + def test_connect_from_url_tcp(self): + connection = redis.Redis.from_url("redis://localhost") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "Connection", + "host=localhost,port=6379,db=0", + ) + + def test_connect_from_url_unix(self): + connection = redis.Redis.from_url("unix:///path/to/socket") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "UnixDomainSocketConnection", + "path=/path/to/socket,db=0", + ) + + @skip_if_redis_enterprise() + async def test_connect_no_auth_supplied_when_required(self, r): + """ + AuthenticationError should be raised when the server requires a + password but one isn't supplied. + """ + with pytest.raises(redis.AuthenticationError): + await r.execute_command( + "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set" + ) + + @skip_if_redis_enterprise() + async def test_connect_invalid_password_supplied(self, r): + """AuthenticationError should be raised when sending the wrong password""" + with pytest.raises(redis.AuthenticationError): + await r.execute_command("DEBUG", "ERROR", "ERR invalid password") + + +@pytest.mark.onlynoncluster +class TestMultiConnectionClient: + @pytest_asyncio.fixture() + async def r(self, create_redis, server): + redis = await create_redis(single_connection_client=False) + yield redis + await redis.flushall() + + +@pytest.mark.onlynoncluster +class TestHealthCheck: + interval = 60 + + @pytest_asyncio.fixture() + async def r(self, create_redis): + redis = await create_redis(health_check_interval=self.interval) + yield redis + await redis.flushall() + + def assert_interval_advanced(self, connection): + diff = connection.next_health_check - asyncio.get_event_loop().time() + assert self.interval > diff > (self.interval - 1) + + async def test_health_check_runs(self, r): + if r.connection: + r.connection.next_health_check = asyncio.get_event_loop().time() - 1 + await r.connection.check_health() + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_invokes_health_check(self, r): + # invoke a command to make sure the connection is entirely setup + if r.connection: + await r.get("foo") + r.connection.next_health_check = asyncio.get_event_loop().time() + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + m.assert_called_with("PING", check_health=False) + + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_advances_next_health_check(self, r): + if r.connection: + await r.get("foo") + next_health_check = r.connection.next_health_check + await r.get("foo") + assert next_health_check < r.connection.next_health_check + + async def test_health_check_not_invoked_within_interval(self, r): + if r.connection: + await r.get("foo") + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + ping_call_spec = (("PING",), {"check_health": False}) + assert ping_call_spec not in m.call_args_list + + async def test_health_check_in_pipeline(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_transaction(self, r): + async with r.pipeline(transaction=True) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_watched_pipeline(self, r): + await r.set("foo", "bar") + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + await pipe.watch("foo") + # the health check should be called when watching + m.assert_called_with("PING", check_health=False) + self.assert_interval_advanced(pipe.connection) + assert await pipe.get("foo") == b"bar" + + # reset the mock to clear the call list and schedule another + # health check + m.reset_mock() + pipe.connection.next_health_check = 0 + + pipe.multi() + responses = await pipe.set("foo", "not-bar").get("foo").execute() + assert responses == [True, b"not-bar"] + m.assert_any_call("PING", check_health=False) + + async def test_health_check_in_pubsub_before_subscribe(self, r): + """A health check happens before the first [p]subscribe""" + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + assert not p.subscribed + await p.subscribe("foo") + # the connection is not yet in pubsub mode, so the normal + # ping/pong within connection.send_command should check + # the health of the connection + m.assert_any_call("PING", check_health=False) + self.assert_interval_advanced(p.connection) + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + + async def test_health_check_in_pubsub_after_subscribed(self, r): + """ + Pubsub can handle a new subscribe when it's time to check the + connection health + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + # because we weren't subscribed when sending the subscribe + # message to 'foo', the connection's standard check_health ran + # prior to subscribing. + m.assert_any_call("PING", check_health=False) + + p.connection.next_health_check = 0 + m.reset_mock() + + await p.subscribe("bar") + # the second subscribe issues exactly only command (the subscribe) + # and the health check is not invoked + m.assert_called_once_with("SUBSCRIBE", "bar", check_health=False) + + # since no message has been read since the health check was + # reset, it should still be 0 + assert p.connection.next_health_check == 0 + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + assert await wait_for_message(p) is None + # now that the connection is subscribed, the pubsub health + # check should have taken over and include the HEALTH_CHECK_MESSAGE + m.assert_any_call("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) + + async def test_health_check_in_pubsub_poll(self, r): + """ + Polling a pubsub connection that's subscribed will regularly + check the connection's health. + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + + # polling the connection before the health check interval + # doesn't result in another health check + m.reset_mock() + next_health_check = p.connection.next_health_check + assert await wait_for_message(p) is None + assert p.connection.next_health_check == next_health_check + m.assert_not_called() + + # reset the health check and poll again + # we should not receive a pong message, but the next_health_check + # should be advanced + p.connection.next_health_check = 0 + assert await wait_for_message(p) is None + m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) diff --git a/tests/test_asyncio/test_encoding.py b/tests/test_asyncio/test_encoding.py new file mode 100644 index 0000000..da29837 --- /dev/null +++ b/tests/test_asyncio/test_encoding.py @@ -0,0 +1,126 @@ +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio as redis +from redis.exceptions import DataError + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestEncoding: + @pytest_asyncio.fixture() + async def r(self, create_redis): + redis = await create_redis(decode_responses=True) + yield redis + await redis.flushall() + + @pytest_asyncio.fixture() + async def r_no_decode(self, create_redis): + redis = await create_redis(decode_responses=False) + yield redis + await redis.flushall() + + async def test_simple_encoding(self, r_no_decode: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r_no_decode.set("unicode-string", unicode_string.encode("utf-8")) + cached_val = await r_no_decode.get("unicode-string") + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_simple_encoding_and_decoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r.set("unicode-string", unicode_string) + cached_val = await r.get("unicode-string") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_memoryview_encoding(self, r_no_decode: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r_no_decode.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r_no_decode.get("unicode-string-memoryview") + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_memoryview_encoding_and_decoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r.get("unicode-string-memoryview") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_list_encoding(self, r: redis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + result = [unicode_string, unicode_string, unicode_string] + await r.rpush("a", *result) + assert await r.lrange("a", 0, -1) == result + + +@pytest.mark.onlynoncluster +class TestEncodingErrors: + async def test_ignore(self, create_redis): + r = await create_redis( + decode_responses=True, + encoding_errors="ignore", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo" + + async def test_replace(self, create_redis): + r = await create_redis( + decode_responses=True, + encoding_errors="replace", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo\ufffd" + + +@pytest.mark.onlynoncluster +class TestMemoryviewsAreNotPacked: + async def test_memoryviews_are_not_packed(self, r): + arg = memoryview(b"some_arg") + arg_list = ["SOME_COMMAND", arg] + c = r.connection or await r.connection_pool.get_connection("_") + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + +class TestCommandsAreNotEncoded: + @pytest_asyncio.fixture() + async def r(self, create_redis): + redis = await create_redis(encoding="utf-16") + yield redis + await redis.flushall() + + async def test_basic_command(self, r: redis.Redis): + await r.set("hello", "world") + + +class TestInvalidUserInput: + async def test_boolean_fails(self, r: redis.Redis): + with pytest.raises(DataError): + await r.set("a", True) # type: ignore + + async def test_none_fails(self, r: redis.Redis): + with pytest.raises(DataError): + await r.set("a", None) # type: ignore + + async def test_user_type_fails(self, r: redis.Redis): + class Foo: + def __str__(self): + return "Foo" + + with pytest.raises(DataError): + await r.set("a", Foo()) # type: ignore diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py new file mode 100644 index 0000000..c496718 --- /dev/null +++ b/tests/test_asyncio/test_lock.py @@ -0,0 +1,242 @@ +import asyncio +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +from redis.asyncio.lock import Lock +from redis.exceptions import LockError, LockNotOwnedError + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestLock: + @pytest_asyncio.fixture() + async def r_decoded(self, create_redis): + redis = await create_redis(decode_responses=True) + yield redis + await redis.flushall() + + def get_lock(self, redis, *args, **kwargs): + kwargs["lock_class"] = Lock + return redis.lock(*args, **kwargs) + + async def test_lock(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + assert await r.get("foo") == lock.local.token + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + + async def test_lock_token(self, r): + lock = self.get_lock(r, "foo") + await self._test_lock_token(r, lock) + + async def test_lock_token_thread_local_false(self, r): + lock = self.get_lock(r, "foo", thread_local=False) + await self._test_lock_token(r, lock) + + async def _test_lock_token(self, r, lock): + assert await lock.acquire(blocking=False, token="test") + assert await r.get("foo") == b"test" + assert lock.local.token == b"test" + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + assert lock.local.token is None + + async def test_locked(self, r): + lock = self.get_lock(r, "foo") + assert await lock.locked() is False + await lock.acquire(blocking=False) + assert await lock.locked() is True + await lock.release() + assert await lock.locked() is False + + async def _test_owned(self, client): + lock = self.get_lock(client, "foo") + assert await lock.owned() is False + await lock.acquire(blocking=False) + assert await lock.owned() is True + await lock.release() + assert await lock.owned() is False + + lock2 = self.get_lock(client, "foo") + assert await lock.owned() is False + assert await lock2.owned() is False + await lock2.acquire(blocking=False) + assert await lock.owned() is False + assert await lock2.owned() is True + await lock2.release() + assert await lock.owned() is False + assert await lock2.owned() is False + + async def test_owned(self, r): + await self._test_owned(r) + + async def test_owned_with_decoded_responses(self, r_decoded): + await self._test_owned(r_decoded) + + async def test_competing_locks(self, r): + lock1 = self.get_lock(r, "foo") + lock2 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + assert not await lock2.acquire(blocking=False) + await lock1.release() + assert await lock2.acquire(blocking=False) + assert not await lock1.acquire(blocking=False) + await lock2.release() + + async def test_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8 < (await r.ttl("foo")) <= 10 + await lock.release() + + async def test_float_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=9.5) + assert await lock.acquire(blocking=False) + assert 8 < (await r.pttl("foo")) <= 9500 + await lock.release() + + async def test_blocking_timeout(self, r, event_loop): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + bt = 0.2 + sleep = 0.05 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = event_loop.time() + assert not await lock2.acquire() + # The elapsed duration should be less than the total blocking_timeout + assert bt > (event_loop.time() - start) > bt - sleep + await lock1.release() + + async def test_context_manager(self, r): + # blocking_timeout prevents a deadlock if the lock can't be acquired + # for some reason + async with self.get_lock(r, "foo", blocking_timeout=0.2) as lock: + assert await r.get("foo") == lock.local.token + assert await r.get("foo") is None + + async def test_context_manager_raises_when_locked_not_acquired(self, r): + await r.set("foo", "bar") + with pytest.raises(LockError): + async with self.get_lock(r, "foo", blocking_timeout=0.1): + pass + + async def test_high_sleep_small_blocking_timeout(self, r): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + sleep = 60 + bt = 1 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = asyncio.get_event_loop().time() + assert not await lock2.acquire() + # the elapsed timed is less than the blocking_timeout as the lock is + # unattainable given the sleep/blocking_timeout configuration + assert bt > (asyncio.get_event_loop().time() - start) + await lock1.release() + + async def test_releasing_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo") + with pytest.raises(LockError): + await lock.release() + + async def test_releasing_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo") + await lock.acquire(blocking=False) + # manually change the token + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.release() + # even though we errored, the token is still cleared + assert lock.local.token is None + + async def test_extend_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extend_lock_replace_ttl(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10, replace_ttl=True) + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_extend_lock_float(self, r): + lock = self.get_lock(r, "foo", timeout=10.0) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10.0) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extending_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.extend(10) + + async def test_extending_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.extend(10) + await lock.release() + + async def test_extending_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.extend(10) + + async def test_reacquire_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert await r.pexpire("foo", 5000) + assert await r.pttl("foo") <= 5000 + assert await lock.reacquire() + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_reacquiring_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.reacquire() + + async def test_reacquiring_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.reacquire() + await lock.release() + + async def test_reacquiring_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.reacquire() + + +@pytest.mark.onlynoncluster +class TestLockClassSelection: + def test_lock_class_argument(self, r): + class MyLock: + def __init__(self, *args, **kwargs): + + pass + + lock = r.lock("foo", lock_class=MyLock) + assert type(lock) == MyLock diff --git a/tests/test_asyncio/test_monitor.py b/tests/test_asyncio/test_monitor.py new file mode 100644 index 0000000..783ba26 --- /dev/null +++ b/tests/test_asyncio/test_monitor.py @@ -0,0 +1,67 @@ +import pytest + +from tests.conftest import skip_if_redis_enterprise, skip_ifnot_redis_enterprise + +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestMonitor: + async def test_wait_command_not_found(self, r): + """Make sure the wait_for_command func works when command is not found""" + async with r.monitor() as m: + response = await wait_for_command(r, m, "nothing") + assert response is None + + async def test_response_values(self, r): + db = r.connection_pool.connection_kwargs.get("db", 0) + async with r.monitor() as m: + await r.ping() + response = await wait_for_command(r, m, "PING") + assert isinstance(response["time"], float) + assert response["db"] == db + assert response["client_type"] in ("tcp", "unix") + assert isinstance(response["client_address"], str) + assert isinstance(response["client_port"], str) + assert response["command"] == "PING" + + async def test_command_with_quoted_key(self, r): + async with r.monitor() as m: + await r.get('foo"bar') + response = await wait_for_command(r, m, 'GET foo"bar') + assert response["command"] == 'GET foo"bar' + + async def test_command_with_binary_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\x92") + assert response["command"] == "GET foo\\x92" + + async def test_command_with_escaped_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\\\x92") + assert response["command"] == "GET foo\\\\x92" + + @skip_if_redis_enterprise() + async def test_lua_script(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response["command"] == "GET foo" + assert response["client_type"] == "lua" + assert response["client_address"] == "lua" + assert response["client_port"] == "" + + @skip_ifnot_redis_enterprise() + async def test_lua_script_in_enterprise(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response is None diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py new file mode 100644 index 0000000..5bb1a8a --- /dev/null +++ b/tests/test_asyncio/test_pipeline.py @@ -0,0 +1,409 @@ +import pytest + +import redis +from tests.conftest import skip_if_server_version_lt + +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.onlynoncluster +class TestPipeline: + async def test_pipeline_is_true(self, r): + """Ensure pipeline instances are not false-y""" + async with r.pipeline() as pipe: + assert pipe + + async def test_pipeline(self, r): + async with r.pipeline() as pipe: + ( + pipe.set("a", "a1") + .get("a") + .zadd("z", {"z1": 1}) + .zadd("z", {"z2": 4}) + .zincrby("z", 1, "z1") + .zrange("z", 0, 5, withscores=True) + ) + assert await pipe.execute() == [ + True, + b"a1", + True, + True, + 2.0, + [(b"z1", 2.0), (b"z2", 4)], + ] + + async def test_pipeline_memoryview(self, r): + async with r.pipeline() as pipe: + (pipe.set("a", memoryview(b"a1")).get("a")) + assert await pipe.execute() == [ + True, + b"a1", + ] + + async def test_pipeline_length(self, r): + async with r.pipeline() as pipe: + # Initially empty. + assert len(pipe) == 0 + + # Fill 'er up! + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert len(pipe) == 3 + + # Execute calls reset(), so empty once again. + await pipe.execute() + assert len(pipe) == 0 + + @pytest.mark.onlynoncluster + async def test_pipeline_no_transaction(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert await pipe.execute() == [True, True, True] + assert await r.get("a") == b"a1" + assert await r.get("b") == b"b1" + assert await r.get("c") == b"c1" + + async def test_pipeline_no_transaction_watch(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + pipe.multi() + pipe.set("a", int(a) + 1) + assert await pipe.execute() == [True] + + async def test_pipeline_no_transaction_watch_failure(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + await r.set("a", "bad") + + pipe.multi() + pipe.set("a", int(a) + 1) + + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert await r.get("a") == b"bad" + + async def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + result = await pipe.execute(raise_on_error=False) + + assert result[0] + assert await r.get("a") == b"1" + assert result[1] + assert await r.get("b") == b"2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], redis.ResponseError) + assert await r.get("c") == b"a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert await r.get("d") == b"4" + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_exec_error_raised(self, r): + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH c 3) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_transaction_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline() as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + @pytest.mark.onlynoncluster + async def test_pipeline_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + async def test_parse_error_raised(self, r): + async with r.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_parse_error_raised_transaction(self, r): + async with r.pipeline() as pipe: + pipe.multi() + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + @pytest.mark.onlynoncluster + async def test_watch_succeed(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_watch_failure(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + pipe.get("a") + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_watch_failure_in_empty_transaction(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + with pytest.raises(redis.WatchError): + await pipe.execute() + + assert not pipe.watching + + @pytest.mark.onlynoncluster + async def test_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + await pipe.unwatch() + assert not pipe.watching + pipe.get("a") + assert await pipe.execute() == [b"1"] + + @pytest.mark.onlynoncluster + async def test_watch_exec_no_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is None, "should not send UNWATCH" + + @pytest.mark.onlynoncluster + async def test_watch_reset_unwatch(self, r): + await r.set("a", 1) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a") + assert pipe.watching + await pipe.reset() + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is not None + assert unwatch_command["command"] == "UNWATCH" + + @pytest.mark.onlynoncluster + async def test_transaction_callable(self, r): + await r.set("a", 1) + await r.set("b", 2) + has_run = [] + + async def my_transaction(pipe): + a_value = await pipe.get("a") + assert a_value in (b"1", b"2") + b_value = await pipe.get("b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + await r.incr("a") + has_run.append("it has") + + pipe.multi() + pipe.set("c", int(a_value) + int(b_value)) + + result = await r.transaction(my_transaction, "a", "b") + assert result == [True] + assert await r.get("c") == b"4" + + @pytest.mark.onlynoncluster + async def test_transaction_callable_returns_value_from_callable(self, r): + async def callback(pipe): + # No need to do anything here since we only want the return value + return "a" + + res = await r.transaction(callback, "my-key", value_from_callable=True) + assert res == "a" + + async def test_exec_error_in_no_transaction_pipeline(self, r): + await r.set("a", 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of " "pipeline caused error: " + ) + + assert await r.get("a") == b"1" + + async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): + key = chr(3456) + "abcd" + chr(3421) + await r.set(key, 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen(key) + pipe.expire(key, 100) + + with pytest.raises(redis.ResponseError) as ex: + await pipe.execute() + + expected = f"Command # 1 (LLEN {key}) of pipeline caused error: " + assert str(ex.value).startswith(expected) + + assert await r.get(key) == b"1" + + async def test_pipeline_with_bitfield(self, r): + async with r.pipeline() as pipe: + pipe.set("a", "1") + bf = pipe.bitfield("b") + pipe2 = ( + bf.set("u8", 8, 255) + .get("u8", 0) + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 1110 + .execute() + ) + pipe.get("a") + response = await pipe.execute() + + assert pipe == pipe2 + assert response == [True, [0, 0, 15, 15, 14], b"1"] + + async def test_pipeline_get(self, r): + await r.set("a", "a1") + async with r.pipeline() as pipe: + await pipe.get("a") + assert await pipe.execute() == [b"a1"] + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.0.0") + async def test_pipeline_discard(self, r): + + # empty pipeline should raise an error + async with r.pipeline() as pipe: + pipe.set("key", "someval") + await pipe.discard() + with pytest.raises(redis.exceptions.ResponseError): + await pipe.execute() + + # setting a pipeline and discarding should do the same + async with r.pipeline() as pipe: + pipe.set("key", "someval") + pipe.set("someotherkey", "val") + response = await pipe.execute() + pipe.set("key", "another value!") + await pipe.discard() + pipe.set("key", "another vae!") + with pytest.raises(redis.exceptions.ResponseError): + await pipe.execute() + + pipe.set("foo", "bar") + response = await pipe.execute() + assert response[0] + assert await r.get("foo") == b"bar" diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py new file mode 100644 index 0000000..9efcd3c --- /dev/null +++ b/tests/test_asyncio/test_pubsub.py @@ -0,0 +1,660 @@ +import asyncio +import sys +from typing import Optional + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio as redis +from redis.exceptions import ConnectionError +from redis.typing import EncodableT +from tests.conftest import skip_if_server_version_lt + +from .compat import mock + +pytestmark = pytest.mark.asyncio(forbid_global_loop=True) + + +async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): + now = asyncio.get_event_loop().time() + timeout = now + timeout + while now < timeout: + message = await pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) + if message is not None: + return message + await asyncio.sleep(0.01) + now = asyncio.get_event_loop().time() + return None + + +def make_message( + type, channel: Optional[str], data: EncodableT, pattern: Optional[str] = None +): + return { + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + +def make_subscribe_test_data(pubsub, type): + if type == "channel": + return { + "p": pubsub, + "sub_type": "subscribe", + "unsub_type": "unsubscribe", + "sub_func": pubsub.subscribe, + "unsub_func": pubsub.unsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } + elif type == "pattern": + return { + "p": pubsub, + "sub_type": "psubscribe", + "unsub_type": "punsubscribe", + "sub_func": pubsub.psubscribe, + "unsub_func": pubsub.punsubscribe, + "keys": ["f*", "b*", "uni" + chr(4456) + "*"], + } + assert False, f"invalid subscribe type: {type}" + + +@pytest.mark.onlynoncluster +class TestPubSubSubscribeUnsubscribe: + async def _test_subscribe_unsubscribe( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + for key in keys: + assert await unsub_func(key) is None + + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys): + i = len(keys) - 1 - i + assert await wait_for_message(p) == make_message(unsub_type, key, i) + + async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribe_unsubscribe(**kwargs) + + async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribe_unsubscribe(**kwargs) + + @pytest.mark.onlynoncluster + async def _test_resubscribe_on_reconnection( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + # manually disconnect + await p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + # note, we may not re-subscribe to channels in exactly the same order + # so we have to do some extra checks to make sure we got them all + messages = [] + for i in range(len(keys)): + messages.append(await wait_for_message(p)) + + unique_channels = set() + assert len(messages) == len(keys) + for i, message in enumerate(messages): + assert message["type"] == sub_type + assert message["data"] == i + 1 + assert isinstance(message["channel"], bytes) + channel = message["channel"].decode("utf-8") + unique_channels.add(channel) + + assert len(unique_channels) == len(keys) + for channel in unique_channels: + assert channel in keys + + async def test_resubscribe_to_channels_on_reconnection(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def test_resubscribe_to_patterns_on_reconnection(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def _test_subscribed_property( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + assert p.subscribed is False + await sub_func(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all channels + await unsub_func() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + await sub_func(keys[0]) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + + # unsubscribe again + await unsub_func() + assert p.subscribed is True + # subscribe to another channel before reading the unsubscribe response + await sub_func(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert await wait_for_message(p) == make_message(sub_type, keys[1], 1) + await unsub_func() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[1], 0) + # now we're finally unsubscribed + assert p.subscribed is False + + async def test_subscribe_property_with_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribed_property(**kwargs) + + @pytest.mark.onlynoncluster + async def test_subscribe_property_with_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribed_property(**kwargs) + + async def test_ignore_all_subscribe_messages(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + assert await wait_for_message(p) is None + assert p.subscribed is False + + async def test_ignore_individual_subscribe_messages(self, r: redis.Redis): + p = r.pubsub() + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + message = await wait_for_message(p, ignore_subscribe_messages=True) + assert message is None + assert p.subscribed is False + + async def test_sub_unsub_resub_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_resub(**kwargs) + + @pytest.mark.onlynoncluster + async def test_sub_unsub_resub_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_resub(**kwargs) + + async def _test_sub_unsub_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func(key) + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + async def test_sub_unsub_all_resub_channels(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_all_resub(**kwargs) + + async def test_sub_unsub_all_resub_patterns(self, r: redis.Redis): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_all_resub(**kwargs) + + async def _test_sub_unsub_all_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func() + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + +@pytest.mark.onlynoncluster +class TestPubSubMessages: + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + async def async_message_handler(self, message): + self.async_message = message + + async def test_published_message_to_channel(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await r.publish("foo", "test message") == 1 + + message = await wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("message", "foo", "test message") + + async def test_published_message_to_pattern(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + await p.psubscribe("f*") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await wait_for_message(p) == make_message("psubscribe", "f*", 2) + # 1 to pattern, 1 to channel + assert await r.publish("foo", "test message") == 2 + + message1 = await wait_for_message(p) + message2 = await wait_for_message(p) + assert isinstance(message1, dict) + assert isinstance(message2, dict) + + expected = [ + make_message("message", "foo", "test message"), + make_message("pmessage", "foo", "test message", pattern="f*"), + ] + + assert message1 in expected + assert message2 in expected + assert message1 != message2 + + async def test_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + + async def test_channel_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.async_message == make_message("message", "foo", "test message") + + async def test_channel_sync_async_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + await p.subscribe(bar=self.async_message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await r.publish("bar", "test message 2") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + assert self.async_message == make_message("message", "bar", "test message 2") + + @pytest.mark.onlynoncluster + async def test_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{"f*": self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", "foo", "test message", pattern="f*" + ) + + async def test_unicode_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + await p.subscribe(**channels) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", channel, "test message") + + @pytest.mark.onlynoncluster + # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + # #known-limitations-with-pubsub + async def test_unicode_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + pattern = "uni" + chr(4456) + "*" + channel = "uni" + chr(4456) + "code" + await p.psubscribe(**{pattern: self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", channel, "test message", pattern=pattern + ) + + async def test_get_message_without_subscribe(self, r: redis.Redis): + p = r.pubsub() + with pytest.raises(RuntimeError) as info: + await p.get_message() + expect = ( + "connection not set: " "did you forget to call subscribe() or psubscribe()?" + ) + assert expect in info.exconly() + + +@pytest.mark.onlynoncluster +class TestPubSubAutoDecoding: + """These tests only validate that we get unicode values back""" + + channel = "uni" + chr(4456) + "code" + pattern = "uni" + chr(4456) + "*" + data = "abc" + chr(4458) + "123" + + def make_message(self, type, channel, data, pattern=None): + return {"type": type, "channel": channel, "pattern": pattern, "data": data} + + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + @pytest_asyncio.fixture() + async def r(self, create_redis): + return await create_redis( + decode_responses=True, + ) + + async def test_channel_subscribe_unsubscribe(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + + await p.unsubscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "unsubscribe", self.channel, 0 + ) + + async def test_pattern_subscribe_unsubscribe(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + + await p.punsubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "punsubscribe", self.pattern, 0 + ) + + async def test_channel_publish(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "message", self.channel, self.data + ) + + @pytest.mark.onlynoncluster + async def test_pattern_publish(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + async def test_channel_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(**{self.channel: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, new_data) + + async def test_pattern_message_handler(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{self.pattern: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + # test that we reconnected to the correct pattern + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, new_data, pattern=self.pattern + ) + + async def test_context_manager(self, r: redis.Redis): + async with r.pubsub() as pubsub: + await pubsub.subscribe("foo") + assert pubsub.connection is not None + + assert pubsub.connection is None + assert pubsub.channels == {} + assert pubsub.patterns == {} + + +@pytest.mark.onlynoncluster +class TestPubSubRedisDown: + async def test_channel_subscribe(self, r: redis.Redis): + r = redis.Redis(host="localhost", port=6390) + p = r.pubsub() + with pytest.raises(ConnectionError): + await p.subscribe("foo") + + +@pytest.mark.onlynoncluster +class TestPubSubSubcommands: + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_channels(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert (await wait_for_message(p))["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in await r.pubsub_channels() for channel in expected]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numsub(self, r: redis.Redis): + p1 = r.pubsub() + await p1.subscribe("foo", "bar", "baz") + for i in range(3): + assert (await wait_for_message(p1))["type"] == "subscribe" + p2 = r.pubsub() + await p2.subscribe("bar", "baz") + for i in range(2): + assert (await wait_for_message(p2))["type"] == "subscribe" + p3 = r.pubsub() + await p3.subscribe("baz") + assert (await wait_for_message(p3))["type"] == "subscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert await r.pubsub_numsub("foo", "bar", "baz") == channels + + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numpat(self, r: redis.Redis): + p = r.pubsub() + await p.psubscribe("*oo", "*ar", "b*z") + for i in range(3): + assert (await wait_for_message(p))["type"] == "psubscribe" + assert await r.pubsub_numpat() == 3 + + +@pytest.mark.onlynoncluster +class TestPubSubPings: + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping() + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="", pattern=None + ) + + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping_message(self, r: redis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping(message="hello world") + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="hello world", pattern=None + ) + + +@pytest.mark.onlynoncluster +class TestPubSubConnectionKilled: + @skip_if_server_version_lt("3.0.0") + async def test_connection_error_raised_when_connection_dies(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + for client in await r.client_list(): + if client["cmd"] == "subscribe": + await r.client_kill_filter(_id=client["id"]) + with pytest.raises(ConnectionError): + await wait_for_message(p) + + +@pytest.mark.onlynoncluster +class TestPubSubTimeouts: + async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await p.get_message(timeout=0.01) is None + + +@pytest.mark.onlynoncluster +class TestPubSubRun: + async def _subscribe(self, p, *args, **kwargs): + await p.subscribe(*args, **kwargs) + # Wait for the server to act on the subscription, to be sure that + # a subsequent publish on another connection will reach the pubsub. + while True: + message = await p.get_message(timeout=1) + if ( + message is not None + and message["type"] == "subscribe" + and message["channel"] == b"foo" + ): + return + + async def test_callbacks(self, r: redis.Redis): + def callback(message): + messages.put_nowait(message) + + messages = asyncio.Queue() + p = r.pubsub() + await self._subscribe(p, foo=callback) + task = asyncio.get_event_loop().create_task(p.run()) + await r.publish("foo", "bar") + message = await messages.get() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert message == { + "channel": b"foo", + "data": b"bar", + "pattern": None, + "type": "message", + } + + async def test_exception_handler(self, r: redis.Redis): + def exception_handler_callback(e, pubsub) -> None: + assert pubsub == p + exceptions.put_nowait(e) + + exceptions = asyncio.Queue() + p = r.pubsub() + await self._subscribe(p, foo=lambda x: None) + with mock.patch.object(p, "get_message", side_effect=Exception("error")): + task = asyncio.get_event_loop().create_task( + p.run(exception_handler=exception_handler_callback) + ) + e = await exceptions.get() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert str(e) == "error" diff --git a/tests/test_asyncio/test_retry.py b/tests/test_asyncio/test_retry.py new file mode 100644 index 0000000..e83e001 --- /dev/null +++ b/tests/test_asyncio/test_retry.py @@ -0,0 +1,70 @@ +import pytest + +from redis.asyncio.connection import Connection, UnixDomainSocketConnection +from redis.asyncio.retry import Retry +from redis.backoff import AbstractBackoff, NoBackoff +from redis.exceptions import ConnectionError + + +class BackoffMock(AbstractBackoff): + def __init__(self): + self.reset_calls = 0 + self.calls = 0 + + def reset(self): + self.reset_calls += 1 + + def compute(self, failures): + self.calls += 1 + return 0 + + +@pytest.mark.onlynoncluster +class TestConnectionConstructorWithRetry: + "Test that the Connection constructors properly handles Retry objects" + + @pytest.mark.parametrize("retry_on_timeout", [False, True]) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_boolean(self, Class, retry_on_timeout): + c = Class(retry_on_timeout=retry_on_timeout) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == (1 if retry_on_timeout else 0) + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection]) + def test_retry_on_timeout_retry(self, Class, retries: int): + retry_on_timeout = retries > 0 + c = Class(retry_on_timeout=retry_on_timeout, retry=Retry(NoBackoff(), retries)) + assert c.retry_on_timeout == retry_on_timeout + assert isinstance(c.retry, Retry) + assert c.retry._retries == retries + + +@pytest.mark.onlynoncluster +class TestRetry: + "Test that Retry calls backoff and retries the expected number of times" + + def setup_method(self, test_method): + self.actual_attempts = 0 + self.actual_failures = 0 + + async def _do(self): + self.actual_attempts += 1 + raise ConnectionError() + + async def _fail(self, error): + self.actual_failures += 1 + + @pytest.mark.parametrize("retries", range(10)) + @pytest.mark.asyncio + async def test_retry(self, retries: int): + backoff = BackoffMock() + retry = Retry(backoff, retries) + with pytest.raises(ConnectionError): + await retry.call_with_retry(self._do, self._fail) + + assert self.actual_attempts == 1 + retries + assert self.actual_failures == 1 + retries + assert backoff.reset_calls == 1 + assert backoff.calls == retries diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py new file mode 100644 index 0000000..764525f --- /dev/null +++ b/tests/test_asyncio/test_scripting.py @@ -0,0 +1,159 @@ +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +from redis import exceptions +from tests.conftest import skip_if_server_version_lt + +multiply_script = """ +local value = redis.call('GET', KEYS[1]) +value = tonumber(value) +return value * ARGV[1]""" + +msgpack_hello_script = """ +local message = cmsgpack.unpack(ARGV[1]) +local name = message['name'] +return "hello " .. name +""" +msgpack_hello_script_broken = """ +local message = cmsgpack.unpack(ARGV[1]) +local names = message['name'] +return "hello " .. name +""" + + +@pytest.mark.onlynoncluster +class TestScripting: + @pytest_asyncio.fixture + async def r(self, create_redis): + redis = await create_redis() + yield redis + await redis.script_flush() + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval(self, r): + await r.flushdb() + await r.set("a", 2) + # 2 * 3 == 6 + assert await r.eval(multiply_script, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + @skip_if_server_version_lt("6.2.0") + async def test_script_flush(self, r): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("ASYNC") + + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("SYNC") + + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush() + + with pytest.raises(exceptions.DataError): + await r.set("a", 2) + await r.script_load(multiply_script) + await r.script_flush("NOTREAL") + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # 2 * 3 == 6 + assert await r.evalsha(sha, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha_script_not_loaded(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # remove the script from Redis's cache + await r.script_flush() + with pytest.raises(exceptions.NoScriptError): + await r.evalsha(sha, 1, "a", 3) + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_loading(self, r): + # get the sha, then clear the cache + sha = await r.script_load(multiply_script) + await r.script_flush() + assert await r.script_exists(sha) == [False] + await r.script_load(multiply_script) + assert await r.script_exists(sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object(self, r): + await r.script_flush() + await r.set("a", 2) + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + assert await r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) + assert await multiply(keys=["a"], args=[3]) == 6 + # At this point, the script should be loaded + assert await r.script_exists(multiply.sha) == [True] + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block + assert await multiply(keys=["a"], args=[3]) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object_in_pipeline(self, r): + await r.script_flush() + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + # The script should have been loaded by pipe.execute() + assert await r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha + + # purge the script from redis's cache and re-run the pipeline + # the multiply script should be reloaded by pipe.execute() + await r.script_flush() + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + assert await r.script_exists(multiply.sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval_msgpack_pipeline_error_in_lua(self, r): + msgpack_hello = r.register_script(msgpack_hello_script) + assert msgpack_hello.sha + + pipe = r.pipeline() + + # avoiding a dependency to msgpack, this is the output of + # msgpack.dumps({"name": "joe"}) + msgpack_message_1 = b"\x81\xa4name\xa3Joe" + + await msgpack_hello(args=[msgpack_message_1], client=pipe) + + assert await r.script_exists(msgpack_hello.sha) == [False] + assert (await pipe.execute())[0] == b"hello Joe" + assert await r.script_exists(msgpack_hello.sha) == [True] + + msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) + + await msgpack_hello_broken(args=[msgpack_message_1], client=pipe) + with pytest.raises(exceptions.ResponseError) as excinfo: + await pipe.execute() + assert excinfo.type == exceptions.ResponseError diff --git a/tests/test_asyncio/test_sentinel.py b/tests/test_asyncio/test_sentinel.py new file mode 100644 index 0000000..cd6810c --- /dev/null +++ b/tests/test_asyncio/test_sentinel.py @@ -0,0 +1,249 @@ +import socket +import sys + +import pytest + +if sys.version_info[0:2] == (3, 6): + import pytest as pytest_asyncio +else: + import pytest_asyncio + +import redis.asyncio.sentinel +from redis import exceptions +from redis.asyncio.sentinel import ( + MasterNotFoundError, + Sentinel, + SentinelConnectionPool, + SlaveNotFoundError, +) + +pytestmark = pytest.mark.asyncio + + +@pytest_asyncio.fixture(scope="module") +def master_ip(master_host): + yield socket.gethostbyname(master_host) + + +class SentinelTestClient: + def __init__(self, cluster, id): + self.cluster = cluster + self.id = id + + async def sentinel_masters(self): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + return {self.cluster.service_name: self.cluster.master} + + async def sentinel_slaves(self, master_name): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + if master_name != self.cluster.service_name: + return [] + return self.cluster.slaves + + async def execute_command(self, *args, **kwargs): + # wrapper purely to validate the calls don't explode + from redis.asyncio.client import bool_ok + + return bool_ok + + +class SentinelTestCluster: + def __init__(self, service_name="mymaster", ip="127.0.0.1", port=6379): + self.clients = {} + self.master = { + "ip": ip, + "port": port, + "is_master": True, + "is_sdown": False, + "is_odown": False, + "num-other-sentinels": 0, + } + self.service_name = service_name + self.slaves = [] + self.nodes_down = set() + self.nodes_timeout = set() + + def connection_error_if_down(self, node): + if node.id in self.nodes_down: + raise exceptions.ConnectionError + + def timeout_if_down(self, node): + if node.id in self.nodes_timeout: + raise exceptions.TimeoutError + + def client(self, host, port, **kwargs): + return SentinelTestClient(self, (host, port)) + + +@pytest_asyncio.fixture() +async def cluster(master_ip): + + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = redis.asyncio.sentinel.Redis + redis.asyncio.sentinel.Redis = cluster.client + yield cluster + redis.asyncio.sentinel.Redis = saved_Redis + + +@pytest_asyncio.fixture() +def sentinel(request, cluster): + return Sentinel([("foo", 26379), ("bar", 26379)]) + + +@pytest.mark.onlynoncluster +async def test_discover_master(sentinel, master_ip): + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +async def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("xxx") + + +@pytest.mark.onlynoncluster +async def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +@pytest.mark.onlynoncluster +async def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +@pytest.mark.onlynoncluster +async def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + cluster.master["num-other-sentinels"] = 2 + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +async def test_master_odown(cluster, sentinel): + cluster.master["is_odown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +@pytest.mark.onlynoncluster +async def test_master_sdown(cluster, sentinel): + cluster.master["is_sdown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +@pytest.mark.onlynoncluster +async def test_discover_slaves(cluster, sentinel): + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves = [ + {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False}, + ] + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + # slave0 -> ODOWN + cluster.slaves[0]["is_odown"] = True + assert await sentinel.discover_slaves("mymaster") == [("slave1", 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]["is_sdown"] = True + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves[0]["is_odown"] = False + cluster.slaves[1]["is_sdown"] = False + + # node0 -> DOWN + cluster.nodes_down.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + +@pytest.mark.onlynoncluster +async def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for("mymaster", db=9) + assert await master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for("mymaster", db=9, check_connection=True) + assert await master.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + slave = sentinel.slave_for("mymaster", db=9) + assert await slave.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master["is_odown"] = True + slave = sentinel.slave_for("mymaster", db=9) + with pytest.raises(SlaveNotFoundError): + await slave.ping() + + +@pytest.mark.onlynoncluster +async def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + pool = SentinelConnectionPool("mymaster", sentinel) + rotator = pool.rotate_slaves() + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + # Fallback to master + assert await rotator.__anext__() == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + await rotator.__anext__() + + +@pytest.mark.onlynoncluster +async def test_ckquorum(cluster, sentinel): + assert await sentinel.sentinel_ckquorum("mymaster") + + +@pytest.mark.onlynoncluster +async def test_flushconfig(cluster, sentinel): + assert await sentinel.sentinel_flushconfig() + + +@pytest.mark.onlynoncluster +async def test_reset(cluster, sentinel): + cluster.master["is_odown"] = True + assert await sentinel.sentinel_reset("mymaster") |