diff options
author | Abhimanyu Deora <abhikdeora@gmail.com> | 2020-10-26 11:16:23 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-26 09:16:23 -0700 |
commit | a404ad3dee93b3b0deb90b4075fd28734d78a282 (patch) | |
tree | a9d9e83a2fc9b12419a03246ccce3f8210f6f87e | |
parent | 15dafb1414f05ce24ef336fc539e06ad6a2b3d19 (diff) | |
download | redis-py-a404ad3dee93b3b0deb90b4075fd28734d78a282.tar.gz |
Add optional exception handler to PubSubWorkerThread (#1395)
Add optional exception handler to PubSubWorkerThread
Co-authored-by: Abhimanyu Deora <adeora@drwholdings.com>
-rw-r--r-- | README.rst | 14 | ||||
-rwxr-xr-x | redis/client.py | 23 | ||||
-rw-r--r-- | tests/test_pubsub.py | 24 |
3 files changed, 56 insertions, 5 deletions
@@ -732,6 +732,20 @@ subscribed to patterns or channels that don't have message handlers attached. # when it's time to shut it down... >>> thread.stop() +`run_in_thread` also supports an optional exception handler, which lets you +catch exceptions that occur within the worker thread and handle them +appropriately. The exception handler will take as arguments the exception +itself, the pubsub object, and the worker thread returned by `run_in_thread`. + +.. code-block:: pycon + >>> p.subscribe(**{'my-channel': my_handler}) + >>> def exception_handler(ex, pubsub, thread): + >>> print(ex) + >>> thread.stop() + >>> thread.join(timeout=1.0) + >>> pubsub.close() + >>> thread = p.run_in_thread(exception_handler=exception_handler) + A PubSub object adheres to the same encoding semantics as the client instance it was created from. Any channel or pattern that's unicode will be encoded using the `charset` specified on the client before being sent to Redis. If the diff --git a/redis/client.py b/redis/client.py index 42d1bfa..08b0314 100755 --- a/redis/client.py +++ b/redis/client.py @@ -3803,7 +3803,8 @@ class PubSub: return message - def run_in_thread(self, sleep_time=0, daemon=False): + def run_in_thread(self, sleep_time=0, daemon=False, + exception_handler=None): for channel, handler in self.channels.items(): if handler is None: raise PubSubError("Channel: '%s' has no handler registered" % @@ -3813,17 +3814,24 @@ class PubSub: raise PubSubError("Pattern: '%s' has no handler registered" % pattern) - thread = PubSubWorkerThread(self, sleep_time, daemon=daemon) + thread = PubSubWorkerThread( + self, + sleep_time, + daemon=daemon, + exception_handler=exception_handler + ) thread.start() return thread class PubSubWorkerThread(threading.Thread): - def __init__(self, pubsub, sleep_time, daemon=False): + def __init__(self, pubsub, sleep_time, daemon=False, + exception_handler=None): super().__init__() self.daemon = daemon self.pubsub = pubsub self.sleep_time = sleep_time + self.exception_handler = exception_handler self._running = threading.Event() def run(self): @@ -3833,8 +3841,13 @@ class PubSubWorkerThread(threading.Thread): pubsub = self.pubsub sleep_time = self.sleep_time while self._running.is_set(): - pubsub.get_message(ignore_subscribe_messages=True, - timeout=sleep_time) + try: + pubsub.get_message(ignore_subscribe_messages=True, + timeout=sleep_time) + except BaseException as e: + if self.exception_handler is None: + raise + self.exception_handler(e, pubsub, self) pubsub.close() def stop(self): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index ab9f09c..abeaecb 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,6 +1,9 @@ import pytest +import threading import time +from unittest import mock + import redis from redis.exceptions import ConnectionError @@ -543,3 +546,24 @@ class TestPubSubTimeouts: p.subscribe('foo') assert wait_for_message(p) == make_message('subscribe', 'foo', 1) assert p.get_message(timeout=0.01) is None + + +class TestPubSubWorkerThread: + def test_pubsub_worker_thread_exception_handler(self, r): + event = threading.Event() + + def exception_handler(ex, pubsub, thread): + thread.stop() + event.set() + + p = r.pubsub() + p.subscribe(**{'foo': lambda m: m}) + with mock.patch.object(p, 'get_message', + side_effect=Exception('error')): + pubsub_thread = p.run_in_thread( + exception_handler=exception_handler + ) + + assert event.wait(timeout=1.0) + pubsub_thread.join(timeout=1.0) + assert not pubsub_thread.is_alive() |