summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBar Shaul <88437685+barshaul@users.noreply.github.com>2021-12-23 11:42:44 +0200
committerGitHub <noreply@github.com>2021-12-23 11:42:44 +0200
commitd6cb997bc7b96f4e646a497a89f9466c97aeffef (patch)
treedf29aca1fe7c72ca90cb8b9b63aba3d28c171d6b
parentddc51c4ace0caa0787715801b9df42e65c790d46 (diff)
downloadredis-py-d6cb997bc7b96f4e646a497a89f9466c97aeffef.tar.gz
Fixing read race condition during pubsub (#1737)
-rwxr-xr-xredis/client.py74
-rw-r--r--tests/test_pubsub.py43
2 files changed, 102 insertions, 15 deletions
diff --git a/redis/client.py b/redis/client.py
index 0236f20..5116482 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -1288,18 +1288,17 @@ class PubSub:
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
self.connection = None
+ self.subscribed_event = threading.Event()
# we need to know the encoding options for this connection in order
# to lookup channel and pattern names for callback handlers.
self.encoder = encoder
if self.encoder is None:
self.encoder = self.connection_pool.get_encoder()
+ self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
if self.encoder.decode_responses:
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
- self.health_check_response = [
- b"pong",
- self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
- ]
+ self.health_check_response = [b"pong", self.health_check_response_b]
self.reset()
def __enter__(self):
@@ -1324,9 +1323,11 @@ class PubSub:
self.connection_pool.release(self.connection)
self.connection = None
self.channels = {}
+ self.health_check_response_counter = 0
self.pending_unsubscribe_channels = set()
self.patterns = {}
self.pending_unsubscribe_patterns = set()
+ self.subscribed_event.clear()
def close(self):
self.reset()
@@ -1352,7 +1353,7 @@ class PubSub:
@property
def subscribed(self):
"Indicates if there are subscriptions to any channels or patterns"
- return bool(self.channels or self.patterns)
+ return self.subscribed_event.is_set()
def execute_command(self, *args):
"Execute a publish/subscribe command"
@@ -1370,8 +1371,28 @@ class PubSub:
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
+ if not self.subscribed:
+ self.clean_health_check_responses()
self._execute(connection, connection.send_command, *args, **kwargs)
+ def clean_health_check_responses(self):
+ """
+ If any health check responses are present, clean them
+ """
+ ttl = 10
+ conn = self.connection
+ while self.health_check_response_counter > 0 and ttl > 0:
+ if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
+ response = self._execute(conn, conn.read_response)
+ if self.is_health_check_response(response):
+ self.health_check_response_counter -= 1
+ else:
+ raise PubSubError(
+ "A non health check response was cleaned by "
+ "execute_command: {0}".format(response)
+ )
+ ttl -= 1
+
def _disconnect_raise_connect(self, conn, error):
"""
Close the connection and raise an exception
@@ -1411,11 +1432,23 @@ class PubSub:
return None
response = self._execute(conn, conn.read_response)
- if conn.health_check_interval and response == self.health_check_response:
+ if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
+ self.health_check_response_counter -= 1
return None
return response
+ def is_health_check_response(self, response):
+ """
+ Check if the response is a health check response.
+ If there are no subscriptions redis responds to PING command with a
+ bulk response, instead of a multi-bulk with "pong" and the response.
+ """
+ return response in [
+ self.health_check_response, # If there was a subscription
+ self.health_check_response_b, # If there wasn't
+ ]
+
def check_health(self):
conn = self.connection
if conn is None:
@@ -1426,6 +1459,7 @@ class PubSub:
if conn.health_check_interval and time.time() > conn.next_health_check:
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
+ self.health_check_response_counter += 1
def _normalize_keys(self, data):
"""
@@ -1455,6 +1489,11 @@ class PubSub:
# for the reconnection.
new_patterns = self._normalize_keys(new_patterns)
self.patterns.update(new_patterns)
+ if not self.subscribed:
+ # Set the subscribed_event flag to True
+ self.subscribed_event.set()
+ # Clear the health check counter
+ self.health_check_response_counter = 0
self.pending_unsubscribe_patterns.difference_update(new_patterns)
return ret_val
@@ -1489,6 +1528,11 @@ class PubSub:
# for the reconnection.
new_channels = self._normalize_keys(new_channels)
self.channels.update(new_channels)
+ if not self.subscribed:
+ # Set the subscribed_event flag to True
+ self.subscribed_event.set()
+ # Clear the health check counter
+ self.health_check_response_counter = 0
self.pending_unsubscribe_channels.difference_update(new_channels)
return ret_val
@@ -1520,6 +1564,20 @@ class PubSub:
before returning. Timeout should be specified as a floating point
number.
"""
+ if not self.subscribed:
+ # Wait for subscription
+ start_time = time.time()
+ if self.subscribed_event.wait(timeout) is True:
+ # The connection was subscribed during the timeout time frame.
+ # The timeout should be adjusted based on the time spent
+ # waiting for the subscription
+ time_spent = time.time() - start_time
+ timeout = max(0.0, timeout - time_spent)
+ else:
+ # The connection isn't subscribed to any channels or patterns,
+ # so no messages are available
+ return None
+
response = self.parse_response(block=False, timeout=timeout)
if response:
return self.handle_message(response, ignore_subscribe_messages)
@@ -1575,6 +1633,10 @@ class PubSub:
if channel in self.pending_unsubscribe_channels:
self.pending_unsubscribe_channels.remove(channel)
self.channels.pop(channel, None)
+ if not self.channels and not self.patterns:
+ # There are no subscriptions anymore, set subscribed_event flag
+ # to false
+ self.subscribed_event.clear()
if message_type in self.PUBLISH_MESSAGE_TYPES:
# if there's a message handler, invoke it
diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py
index 20ae0a0..23af461 100644
--- a/tests/test_pubsub.py
+++ b/tests/test_pubsub.py
@@ -2,6 +2,7 @@ import platform
import threading
import time
from unittest import mock
+from unittest.mock import patch
import pytest
@@ -348,15 +349,6 @@ class TestPubSubMessages:
"pmessage", channel, "test message", pattern=pattern
)
- def test_get_message_without_subscribe(self, r):
- p = r.pubsub()
- with pytest.raises(RuntimeError) as info:
- p.get_message()
- expect = (
- "connection not set: " "did you forget to call subscribe() or psubscribe()?"
- )
- assert expect in info.exconly()
-
class TestPubSubAutoDecoding:
"These tests only validate that we get unicode values back"
@@ -549,6 +541,39 @@ class TestPubSubTimeouts:
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
assert p.get_message(timeout=0.01) is None
+ def test_get_message_not_subscribed_return_none(self, r):
+ p = r.pubsub()
+ assert p.subscribed is False
+ assert p.get_message() is None
+ assert p.get_message(timeout=0.1) is None
+ with patch.object(threading.Event, "wait") as mock:
+ mock.return_value = False
+ assert p.get_message(timeout=0.01) is None
+ assert mock.called
+
+ def test_get_message_subscribe_during_waiting(self, r):
+ p = r.pubsub()
+
+ def poll(ps, expected_res):
+ assert ps.get_message() is None
+ message = ps.get_message(timeout=1)
+ assert message == expected_res
+
+ subscribe_response = make_message("subscribe", "foo", 1)
+ poller = threading.Thread(target=poll, args=(p, subscribe_response))
+ poller.start()
+ time.sleep(0.2)
+ p.subscribe("foo")
+ poller.join()
+
+ def test_get_message_wait_for_subscription_not_being_called(self, r):
+ p = r.pubsub()
+ p.subscribe("foo")
+ with patch.object(threading.Event, "wait") as mock:
+ assert p.subscribed is True
+ assert wait_for_message(p) == make_message("subscribe", "foo", 1)
+ assert mock.called is False
+
class TestPubSubWorkerThread:
@pytest.mark.skipif(