From d6cb997bc7b96f4e646a497a89f9466c97aeffef Mon Sep 17 00:00:00 2001 From: Bar Shaul <88437685+barshaul@users.noreply.github.com> Date: Thu, 23 Dec 2021 11:42:44 +0200 Subject: Fixing read race condition during pubsub (#1737) --- redis/client.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++----- tests/test_pubsub.py | 43 +++++++++++++++++++++++------- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/redis/client.py b/redis/client.py index 0236f20..5116482 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1288,18 +1288,17 @@ class PubSub: self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages self.connection = None + self.subscribed_event = threading.Event() # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. self.encoder = encoder if self.encoder is None: self.encoder = self.connection_pool.get_encoder() + self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) if self.encoder.decode_responses: self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: - self.health_check_response = [ - b"pong", - self.encoder.encode(self.HEALTH_CHECK_MESSAGE), - ] + self.health_check_response = [b"pong", self.health_check_response_b] self.reset() def __enter__(self): @@ -1324,9 +1323,11 @@ class PubSub: self.connection_pool.release(self.connection) self.connection = None self.channels = {} + self.health_check_response_counter = 0 self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.subscribed_event.clear() def close(self): self.reset() @@ -1352,7 +1353,7 @@ class PubSub: @property def subscribed(self): "Indicates if there are subscriptions to any channels or patterns" - return bool(self.channels or self.patterns) + return self.subscribed_event.is_set() def execute_command(self, *args): "Execute a publish/subscribe command" @@ -1370,8 +1371,28 @@ class PubSub: self.connection.register_connect_callback(self.on_connect) connection = self.connection kwargs = {"check_health": not self.subscribed} + if not self.subscribed: + self.clean_health_check_responses() self._execute(connection, connection.send_command, *args, **kwargs) + def clean_health_check_responses(self): + """ + If any health check responses are present, clean them + """ + ttl = 10 + conn = self.connection + while self.health_check_response_counter > 0 and ttl > 0: + if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): + response = self._execute(conn, conn.read_response) + if self.is_health_check_response(response): + self.health_check_response_counter -= 1 + else: + raise PubSubError( + "A non health check response was cleaned by " + "execute_command: {0}".format(response) + ) + ttl -= 1 + def _disconnect_raise_connect(self, conn, error): """ Close the connection and raise an exception @@ -1411,11 +1432,23 @@ class PubSub: return None response = self._execute(conn, conn.read_response) - if conn.health_check_interval and response == self.health_check_response: + if self.is_health_check_response(response): # ignore the health check message as user might not expect it + self.health_check_response_counter -= 1 return None return response + def is_health_check_response(self, response): + """ + Check if the response is a health check response. + If there are no subscriptions redis responds to PING command with a + bulk response, instead of a multi-bulk with "pong" and the response. + """ + return response in [ + self.health_check_response, # If there was a subscription + self.health_check_response_b, # If there wasn't + ] + def check_health(self): conn = self.connection if conn is None: @@ -1426,6 +1459,7 @@ class PubSub: if conn.health_check_interval and time.time() > conn.next_health_check: conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) + self.health_check_response_counter += 1 def _normalize_keys(self, data): """ @@ -1455,6 +1489,11 @@ class PubSub: # for the reconnection. new_patterns = self._normalize_keys(new_patterns) self.patterns.update(new_patterns) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 self.pending_unsubscribe_patterns.difference_update(new_patterns) return ret_val @@ -1489,6 +1528,11 @@ class PubSub: # for the reconnection. new_channels = self._normalize_keys(new_channels) self.channels.update(new_channels) + if not self.subscribed: + # Set the subscribed_event flag to True + self.subscribed_event.set() + # Clear the health check counter + self.health_check_response_counter = 0 self.pending_unsubscribe_channels.difference_update(new_channels) return ret_val @@ -1520,6 +1564,20 @@ class PubSub: before returning. Timeout should be specified as a floating point number. """ + if not self.subscribed: + # Wait for subscription + start_time = time.time() + if self.subscribed_event.wait(timeout) is True: + # The connection was subscribed during the timeout time frame. + # The timeout should be adjusted based on the time spent + # waiting for the subscription + time_spent = time.time() - start_time + timeout = max(0.0, timeout - time_spent) + else: + # The connection isn't subscribed to any channels or patterns, + # so no messages are available + return None + response = self.parse_response(block=False, timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) @@ -1575,6 +1633,10 @@ class PubSub: if channel in self.pending_unsubscribe_channels: self.pending_unsubscribe_channels.remove(channel) self.channels.pop(channel, None) + if not self.channels and not self.patterns: + # There are no subscriptions anymore, set subscribed_event flag + # to false + self.subscribed_event.clear() if message_type in self.PUBLISH_MESSAGE_TYPES: # if there's a message handler, invoke it diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 20ae0a0..23af461 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -2,6 +2,7 @@ import platform import threading import time from unittest import mock +from unittest.mock import patch import pytest @@ -348,15 +349,6 @@ class TestPubSubMessages: "pmessage", channel, "test message", pattern=pattern ) - def test_get_message_without_subscribe(self, r): - p = r.pubsub() - with pytest.raises(RuntimeError) as info: - p.get_message() - expect = ( - "connection not set: " "did you forget to call subscribe() or psubscribe()?" - ) - assert expect in info.exconly() - class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" @@ -549,6 +541,39 @@ class TestPubSubTimeouts: assert wait_for_message(p) == make_message("subscribe", "foo", 1) assert p.get_message(timeout=0.01) is None + def test_get_message_not_subscribed_return_none(self, r): + p = r.pubsub() + assert p.subscribed is False + assert p.get_message() is None + assert p.get_message(timeout=0.1) is None + with patch.object(threading.Event, "wait") as mock: + mock.return_value = False + assert p.get_message(timeout=0.01) is None + assert mock.called + + def test_get_message_subscribe_during_waiting(self, r): + p = r.pubsub() + + def poll(ps, expected_res): + assert ps.get_message() is None + message = ps.get_message(timeout=1) + assert message == expected_res + + subscribe_response = make_message("subscribe", "foo", 1) + poller = threading.Thread(target=poll, args=(p, subscribe_response)) + poller.start() + time.sleep(0.2) + p.subscribe("foo") + poller.join() + + def test_get_message_wait_for_subscription_not_being_called(self, r): + p = r.pubsub() + p.subscribe("foo") + with patch.object(threading.Event, "wait") as mock: + assert p.subscribed is True + assert wait_for_message(p) == make_message("subscribe", "foo", 1) + assert mock.called is False + class TestPubSubWorkerThread: @pytest.mark.skipif( -- cgit v1.2.1