diff options
author | Vivanov98 <66319645+Vivanov98@users.noreply.github.com> | 2023-01-29 13:48:50 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-29 15:48:50 +0200 |
commit | 428d60940f386d3680a413aa327889308f82c5de (patch) | |
tree | 9c49ab0693e0af183e084cf8b1930521fc315c41 | |
parent | 9e6a9b52e5aab021d239ca56e27f06bca871cbf0 (diff) | |
download | redis-py-428d60940f386d3680a413aa327889308f82c5de.tar.gz |
Fix issue 2540: Synchronise concurrent command calls to single-client mode. (#2568)
Co-authored-by: Viktor Ivanov <viktor@infogrid.io>
-rw-r--r-- | redis/asyncio/client.py | 15 | ||||
-rw-r--r-- | tests/test_asyncio/test_connection.py | 45 |
2 files changed, 58 insertions, 2 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index e56fd02..3fc7fad 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -253,6 +253,11 @@ class Redis( self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + # If using a single connection client, we need to lock creation-of and use-of + # the client in order to avoid race conditions such as using asyncio.gather + # on a set of redis commands + self._single_conn_lock = asyncio.Lock() + def __repr__(self): return f"{self.__class__.__name__}<{self.connection_pool!r}>" @@ -260,8 +265,10 @@ class Redis( return self.initialize().__await__() async def initialize(self: _RedisT) -> _RedisT: - if self.single_connection_client and self.connection is None: - self.connection = await self.connection_pool.get_connection("_") + if self.single_connection_client: + async with self._single_conn_lock: + if self.connection is None: + self.connection = await self.connection_pool.get_connection("_") return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -501,6 +508,8 @@ class Redis( command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) + if self.single_connection_client: + await self._single_conn_lock.acquire() try: return await conn.retry.call_with_retry( lambda: self._send_command_parse_response( @@ -509,6 +518,8 @@ class Redis( lambda error: self._disconnect_raise(conn, error), ) finally: + if self.single_connection_client: + self._single_conn_lock.release() if not self.connection: await pool.release(conn) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index bf59dbe..8e4fdac 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -6,6 +6,7 @@ from unittest.mock import patch import pytest import redis +from redis.asyncio import Redis from redis.asyncio.connection import ( BaseParser, Connection, @@ -41,6 +42,50 @@ async def test_invalid_response(create_redis): await r.connection.disconnect() +@pytest.mark.onlynoncluster +async def test_single_connection(): + """Test that concurrent requests on a single client are synchronised.""" + r = Redis(single_connection_client=True) + + init_call_count = 0 + command_call_count = 0 + in_use = False + + class Retry_: + async def call_with_retry(self, _, __): + # If we remove the single-client lock, this error gets raised as two + # coroutines will be vying for the `in_use` flag due to the two + # asymmetric sleep calls + nonlocal command_call_count + nonlocal in_use + if in_use is True: + raise ValueError("Commands should be executed one at a time.") + in_use = True + await asyncio.sleep(0.01) + command_call_count += 1 + await asyncio.sleep(0.03) + in_use = False + return "foo" + + mock_conn = mock.MagicMock() + mock_conn.retry = Retry_() + + async def get_conn(_): + # Validate only one client is created in single-client mode when + # concurrent requests are made + nonlocal init_call_count + await asyncio.sleep(0.01) + init_call_count += 1 + return mock_conn + + with mock.patch.object(r.connection_pool, "get_connection", get_conn): + with mock.patch.object(r.connection_pool, "release"): + await asyncio.gather(r.set("a", "b"), r.set("c", "d")) + + assert init_call_count == 1 + assert command_call_count == 2 + + @skip_if_server_version_lt("4.0.0") @pytest.mark.redismod @pytest.mark.onlynoncluster |