summaryrefslogtreecommitdiff
path: root/redis/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/client.py')
-rwxr-xr-xredis/client.py74
1 files changed, 68 insertions, 6 deletions
diff --git a/redis/client.py b/redis/client.py
index 0236f20..5116482 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -1288,18 +1288,17 @@ class PubSub:
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection = None
+ self.subscribed_event = threading.Event()
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
+ self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
- self.health_check_response = [
- b"pong",
- self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
- ]
+ self.health_check_response = [b"pong", self.health_check_response_b]
self.reset()
def __enter__(self):
@@ -1324,9 +1323,11 @@ class PubSub:
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
+ self.health_check_response_counter = 0
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
+ self.subscribed_event.clear()
def close(self):
self.reset()
@@ -1352,7 +1353,7 @@ class PubSub:
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
- return bool(self.channels or self.patterns)
+ return self.subscribed_event.is_set()
def execute_command(self, *args):
"Execute a publish/subscribe command"
@@ -1370,8 +1371,28 @@ class PubSub:
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
+ if not self.subscribed:
+ self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)
+ def clean_health_check_responses(self):
+ """
+ If any health check responses are present, clean them
+ """
+ ttl = 10
+ conn = self.connection
+ while self.health_check_response_counter > 0 and ttl > 0:
+ if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
+ response = self._execute(conn, conn.read_response)
+ if self.is_health_check_response(response):
+ self.health_check_response_counter -= 1
+ else:
+ raise PubSubError(
+ "A non health check response was cleaned by "
+ "execute_command: {0}".format(response)
+ )
+ ttl -= 1
+
def _disconnect_raise_connect(self, conn, error):
"""
Close the connection and raise an exception
@@ -1411,11 +1432,23 @@ class PubSub:
return None
response = self._execute(conn, conn.read_response)
- if conn.health_check_interval and response == self.health_check_response:
+ if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
+ self.health_check_response_counter -= 1
return None
return response
+ def is_health_check_response(self, response):
+ """
+ Check if the response is a health check response.
+ If there are no subscriptions redis responds to PING command with a
+ bulk response, instead of a multi-bulk with "pong" and the response.
+ """
+ return response in [
+ self.health_check_response, # If there was a subscription
+ self.health_check_response_b, # If there wasn't
+ ]
+
def check_health(self):
conn = self.connection
if conn is None:
@@ -1426,6 +1459,7 @@ class PubSub:
if conn.health_check_interval and time.time() > conn.next_health_check:
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
+ self.health_check_response_counter += 1
def _normalize_keys(self, data):
"""
@@ -1455,6 +1489,11 @@ class PubSub:
# for the reconnection.
new_patterns = self._normalize_keys(new_patterns)
self.patterns.update(new_patterns)
+ if not self.subscribed:
+ # Set the subscribed_event flag to True
+ self.subscribed_event.set()
+ # Clear the health check counter
+ self.health_check_response_counter = 0
self.pending_unsubscribe_patterns.difference_update(new_patterns)
return ret_val
@@ -1489,6 +1528,11 @@ class PubSub:
# for the reconnection.
new_channels = self._normalize_keys(new_channels)
self.channels.update(new_channels)
+ if not self.subscribed:
+ # Set the subscribed_event flag to True
+ self.subscribed_event.set()
+ # Clear the health check counter
+ self.health_check_response_counter = 0
self.pending_unsubscribe_channels.difference_update(new_channels)
return ret_val
@@ -1520,6 +1564,20 @@ class PubSub:
before returning. Timeout should be specified as a floating point
number.
"""
+ if not self.subscribed:
+ # Wait for subscription
+ start_time = time.time()
+ if self.subscribed_event.wait(timeout) is True:
+ # The connection was subscribed during the timeout time frame.
+ # The timeout should be adjusted based on the time spent
+ # waiting for the subscription
+ time_spent = time.time() - start_time
+ timeout = max(0.0, timeout - time_spent)
+ else:
+ # The connection isn't subscribed to any channels or patterns,
+ # so no messages are available
+ return None
+
response = self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
@@ -1575,6 +1633,10 @@ class PubSub:
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
+ if not self.channels and not self.patterns:
+ # There are no subscriptions anymore, set subscribed_event flag
+ # to false
+ self.subscribed_event.clear()
if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it