From fdb9075745060e7a3633248fa6f419e895f010b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 2 May 2022 09:25:34 +0000 Subject: Async Connection: Allow `PubSub.run()` without previous `subscribe()` (#2148) --- redis/asyncio/client.py | 15 ++++++++++++--- tests/test_asyncio/test_pubsub.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 7689e12..8dde96e 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -693,6 +693,15 @@ class PubSub: # legitimate message off the stack if the connection is already # subscribed to one or more channels + await self.connect() + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def connect(self): + """ + Ensure that the PubSub is connected + """ if self.connection is None: self.connection = await self.connection_pool.get_connection( "pubsub", self.shard_hint @@ -700,9 +709,8 @@ class PubSub: # register a callback that re-subscribes to any channels we # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) - connection = self.connection - kwargs = {"check_health": not self.subscribed} - await self._execute(connection, connection.send_command, *args, **kwargs) + else: + await self.connection.connect() async def _disconnect_raise_connect(self, conn, error): """ @@ -962,6 +970,7 @@ class PubSub: if handler is None: raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + await self.connect() while True: try: await self.get_message( diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 9efcd3c..f71ec7e 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -2,6 +2,7 @@ import asyncio import sys from typing import Optional +import async_timeout import pytest if sys.version_info[0:2] == (3, 6): @@ -658,3 +659,35 @@ class TestPubSubRun: except asyncio.CancelledError: pass assert str(e) == "error" + + async def test_late_subscribe(self, r: redis.Redis): + def callback(message): + messages.put_nowait(message) + + messages = asyncio.Queue() + p = r.pubsub() + task = asyncio.get_event_loop().create_task(p.run()) + # wait until loop gets settled. Add a subscription + await asyncio.sleep(0.1) + await p.subscribe(foo=callback) + # wait tof the subscribe to finish. Cannot use _subscribe() because + # p.run() is already accepting messages + await asyncio.sleep(0.1) + await r.publish("foo", "bar") + message = None + try: + async with async_timeout.timeout(0.1): + message = await messages.get() + except asyncio.TimeoutError: + pass + task.cancel() + # we expect a cancelled error, not the Runtime error + # ("did you forget to call subscribe()"") + with pytest.raises(asyncio.CancelledError): + await task + assert message == { + "channel": b"foo", + "data": b"bar", + "pattern": None, + "type": "message", + } -- cgit v1.2.1