summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbhimanyu Deora <abhikdeora@gmail.com>2020-10-26 11:16:23 -0500
committerGitHub <noreply@github.com>2020-10-26 09:16:23 -0700
commita404ad3dee93b3b0deb90b4075fd28734d78a282 (patch)
treea9d9e83a2fc9b12419a03246ccce3f8210f6f87e
parent15dafb1414f05ce24ef336fc539e06ad6a2b3d19 (diff)
downloadredis-py-a404ad3dee93b3b0deb90b4075fd28734d78a282.tar.gz
Add optional exception handler to PubSubWorkerThread (#1395)
Add optional exception handler to PubSubWorkerThread Co-authored-by: Abhimanyu Deora <adeora@drwholdings.com>
-rw-r--r--README.rst14
-rwxr-xr-xredis/client.py23
-rw-r--r--tests/test_pubsub.py24
3 files changed, 56 insertions, 5 deletions
diff --git a/README.rst b/README.rst
index fccea62..389d5b4 100644
--- a/README.rst
+++ b/README.rst
@@ -732,6 +732,20 @@ subscribed to patterns or channels that don't have message handlers attached.
# when it's time to shut it down...
>>> thread.stop()
+`run_in_thread` also supports an optional exception handler, which lets you
+catch exceptions that occur within the worker thread and handle them
+appropriately. The exception handler will take as arguments the exception
+itself, the pubsub object, and the worker thread returned by `run_in_thread`.
+
+.. code-block:: pycon
+ >>> p.subscribe(**{'my-channel': my_handler})
+ >>> def exception_handler(ex, pubsub, thread):
+ >>> print(ex)
+ >>> thread.stop()
+ >>> thread.join(timeout=1.0)
+ >>> pubsub.close()
+ >>> thread = p.run_in_thread(exception_handler=exception_handler)
+
A PubSub object adheres to the same encoding semantics as the client instance
it was created from. Any channel or pattern that's unicode will be encoded
using the `charset` specified on the client before being sent to Redis. If the
diff --git a/redis/client.py b/redis/client.py
index 42d1bfa..08b0314 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -3803,7 +3803,8 @@ class PubSub:
return message
- def run_in_thread(self, sleep_time=0, daemon=False):
+ def run_in_thread(self, sleep_time=0, daemon=False,
+ exception_handler=None):
for channel, handler in self.channels.items():
if handler is None:
raise PubSubError("Channel: '%s' has no handler registered" %
@@ -3813,17 +3814,24 @@ class PubSub:
raise PubSubError("Pattern: '%s' has no handler registered" %
pattern)
- thread = PubSubWorkerThread(self, sleep_time, daemon=daemon)
+ thread = PubSubWorkerThread(
+ self,
+ sleep_time,
+ daemon=daemon,
+ exception_handler=exception_handler
+ )
thread.start()
return thread
class PubSubWorkerThread(threading.Thread):
- def __init__(self, pubsub, sleep_time, daemon=False):
+ def __init__(self, pubsub, sleep_time, daemon=False,
+ exception_handler=None):
super().__init__()
self.daemon = daemon
self.pubsub = pubsub
self.sleep_time = sleep_time
+ self.exception_handler = exception_handler
self._running = threading.Event()
def run(self):
@@ -3833,8 +3841,13 @@ class PubSubWorkerThread(threading.Thread):
pubsub = self.pubsub
sleep_time = self.sleep_time
while self._running.is_set():
- pubsub.get_message(ignore_subscribe_messages=True,
- timeout=sleep_time)
+ try:
+ pubsub.get_message(ignore_subscribe_messages=True,
+ timeout=sleep_time)
+ except BaseException as e:
+ if self.exception_handler is None:
+ raise
+ self.exception_handler(e, pubsub, self)
pubsub.close()
def stop(self):
diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py
index ab9f09c..abeaecb 100644
--- a/tests/test_pubsub.py
+++ b/tests/test_pubsub.py
@@ -1,6 +1,9 @@
import pytest
+import threading
import time
+from unittest import mock
+
import redis
from redis.exceptions import ConnectionError
@@ -543,3 +546,24 @@ class TestPubSubTimeouts:
p.subscribe('foo')
assert wait_for_message(p) == make_message('subscribe', 'foo', 1)
assert p.get_message(timeout=0.01) is None
+
+
+class TestPubSubWorkerThread:
+ def test_pubsub_worker_thread_exception_handler(self, r):
+ event = threading.Event()
+
+ def exception_handler(ex, pubsub, thread):
+ thread.stop()
+ event.set()
+
+ p = r.pubsub()
+ p.subscribe(**{'foo': lambda m: m})
+ with mock.patch.object(p, 'get_message',
+ side_effect=Exception('error')):
+ pubsub_thread = p.run_in_thread(
+ exception_handler=exception_handler
+ )
+
+ assert event.wait(timeout=1.0)
+ pubsub_thread.join(timeout=1.0)
+ assert not pubsub_thread.is_alive()