diff options
Diffstat (limited to 'redis/asyncio/client.py')
-rw-r--r-- | redis/asyncio/client.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index ffd68c1..5ef1f32 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -57,7 +57,7 @@ from redis.exceptions import ( WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT -from redis.utils import safe_str, str_if_bytes +from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) @@ -658,6 +658,7 @@ class PubSub: shard_hint: Optional[str] = None, ignore_subscribe_messages: bool = False, encoder=None, + push_handler_func: Optional[Callable] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -666,6 +667,7 @@ class PubSub: # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder + self.push_handler_func = push_handler_func if self.encoder is None: self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: @@ -678,6 +680,8 @@ class PubSub: b"pong", self.encoder.encode(self.HEALTH_CHECK_MESSAGE), ] + if self.push_handler_func is None: + _set_info_logger() self.channels = {} self.pending_unsubscribe_channels = set() self.patterns = {} @@ -757,6 +761,8 @@ class PubSub: self.connection.register_connect_callback(self.on_connect) else: await self.connection.connect() + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) async def _disconnect_raise_connect(self, conn, error): """ @@ -797,7 +803,9 @@ class PubSub: await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, conn.read_response, timeout=read_timeout, push_request=True + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it @@ -927,8 +935,8 @@ class PubSub: """ Ping the Redis server """ - message = "" if message is None else message - return self.execute_command("PING", message) + args = ["PING", message] if message is not None else ["PING"] + return self.execute_command(*args) async def handle_message(self, response, ignore_subscribe_messages=False): """ @@ -936,6 +944,10 @@ class PubSub: with a message handler, the handler is invoked instead of a parsed message being returned. """ + if response is None: + return None + if isinstance(response, bytes): + response = [b"pong", response] if response != b"PONG" else [b"pong", b""] message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { |