summaryrefslogtreecommitdiff
path: root/redis/asyncio/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/asyncio/client.py')
-rw-r--r--redis/asyncio/client.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index ffd68c1..5ef1f32 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -57,7 +57,7 @@ from redis.exceptions import (
WatchError,
)
from redis.typing import ChannelT, EncodableT, KeyT
-from redis.utils import safe_str, str_if_bytes
+from redis.utils import HIREDIS_AVAILABLE, _set_info_logger, safe_str, str_if_bytes
PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
_KeyT = TypeVar("_KeyT", bound=KeyT)
@@ -658,6 +658,7 @@ class PubSub:
shard_hint: Optional[str] = None,
ignore_subscribe_messages: bool = False,
encoder=None,
+ push_handler_func: Optional[Callable] = None,
):
self.connection_pool = connection_pool
self.shard_hint = shard_hint
@@ -666,6 +667,7 @@ class PubSub:
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
+ self.push_handler_func = push_handler_func
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
if self.encoder.decode_responses:
@@ -678,6 +680,8 @@ class PubSub:
b"pong",
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
]
+ if self.push_handler_func is None:
+ _set_info_logger()
self.channels = {}
self.pending_unsubscribe_channels = set()
self.patterns = {}
@@ -757,6 +761,8 @@ class PubSub:
self.connection.register_connect_callback(self.on_connect)
else:
await self.connection.connect()
+ if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
+ self.connection._parser.set_push_handler(self.push_handler_func)
async def _disconnect_raise_connect(self, conn, error):
"""
@@ -797,7 +803,9 @@ class PubSub:
await conn.connect()
read_timeout = None if block else timeout
- response = await self._execute(conn, conn.read_response, timeout=read_timeout)
+ response = await self._execute(
+ conn, conn.read_response, timeout=read_timeout, push_request=True
+ )
if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
@@ -927,8 +935,8 @@ class PubSub:
"""
Ping the Redis server
"""
- message = "" if message is None else message
- return self.execute_command("PING", message)
+ args = ["PING", message] if message is not None else ["PING"]
+ return self.execute_command(*args)
async def handle_message(self, response, ignore_subscribe_messages=False):
"""
@@ -936,6 +944,10 @@ class PubSub:
with a message handler, the handler is invoked instead of a parsed
message being returned.
"""
+ if response is None:
+ return None
+ if isinstance(response, bytes):
+ response = [b"pong", response] if response != b"PONG" else [b"pong", b""]
message_type = str_if_bytes(response[0])
if message_type == "pmessage":
message = {