summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2022-05-02 09:25:34 +0000
committerGitHub <noreply@github.com>2022-05-02 12:25:34 +0300
commitfdb9075745060e7a3633248fa6f419e895f010b7 (patch)
treed229abb20e27df2a4e2952b499ae056b595a749e
parent696d984a74ef6cd3e3968df8d11cf9af80057424 (diff)
downloadredis-py-fdb9075745060e7a3633248fa6f419e895f010b7.tar.gz
Async Connection: Allow `PubSub.run()` without previous `subscribe()` (#2148)
-rw-r--r--redis/asyncio/client.py15
-rw-r--r--tests/test_asyncio/test_pubsub.py33
2 files changed, 45 insertions, 3 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 7689e12..8dde96e 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -693,6 +693,15 @@ class PubSub:
# legitimate message off the stack if the connection is already
# subscribed to one or more channels
+ await self.connect()
+ connection = self.connection
+ kwargs = {"check_health": not self.subscribed}
+ await self._execute(connection, connection.send_command, *args, **kwargs)
+
+ async def connect(self):
+ """
+ Ensure that the PubSub is connected
+ """
if self.connection is None:
self.connection = await self.connection_pool.get_connection(
"pubsub", self.shard_hint
@@ -700,9 +709,8 @@ class PubSub:
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
- connection = self.connection
- kwargs = {"check_health": not self.subscribed}
- await self._execute(connection, connection.send_command, *args, **kwargs)
+ else:
+ await self.connection.connect()
async def _disconnect_raise_connect(self, conn, error):
"""
@@ -962,6 +970,7 @@ class PubSub:
if handler is None:
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
+ await self.connect()
while True:
try:
await self.get_message(
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 9efcd3c..f71ec7e 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -2,6 +2,7 @@ import asyncio
import sys
from typing import Optional
+import async_timeout
import pytest
if sys.version_info[0:2] == (3, 6):
@@ -658,3 +659,35 @@ class TestPubSubRun:
except asyncio.CancelledError:
pass
assert str(e) == "error"
+
+ async def test_late_subscribe(self, r: redis.Redis):
+ def callback(message):
+ messages.put_nowait(message)
+
+ messages = asyncio.Queue()
+ p = r.pubsub()
+ task = asyncio.get_event_loop().create_task(p.run())
+ # wait until loop gets settled. Add a subscription
+ await asyncio.sleep(0.1)
+ await p.subscribe(foo=callback)
+ # wait tof the subscribe to finish. Cannot use _subscribe() because
+ # p.run() is already accepting messages
+ await asyncio.sleep(0.1)
+ await r.publish("foo", "bar")
+ message = None
+ try:
+ async with async_timeout.timeout(0.1):
+ message = await messages.get()
+ except asyncio.TimeoutError:
+ pass
+ task.cancel()
+ # we expect a cancelled error, not the Runtime error
+ # ("did you forget to call subscribe()"")
+ with pytest.raises(asyncio.CancelledError):
+ await task
+ assert message == {
+ "channel": b"foo",
+ "data": b"bar",
+ "pattern": None,
+ "type": "message",
+ }