diff options
author | Selwin Ong <selwin.ong@gmail.com> | 2017-11-24 21:14:53 +0700 |
---|---|---|
committer | Selwin Ong <selwin.ong@gmail.com> | 2017-11-24 21:14:53 +0700 |
commit | 7d23d752cc59e0700b0d941b62009ef9319cd759 (patch) | |
tree | a94b29744d4c06af7f69febd97ff202c17999a1d | |
parent | e25c5dbc16333c2c976539bad48a583cd007bfb7 (diff) | |
download | rq-7d23d752cc59e0700b0d941b62009ef9319cd759.tar.gz |
Added worker_registration.unregister.
-rw-r--r-- | rq/worker.py | 24 | ||||
-rw-r--r-- | rq/worker_registration.py | 34 | ||||
-rw-r--r-- | tests/test_worker.py | 9 | ||||
-rw-r--r-- | tests/test_worker_registry.py | 33 |
4 files changed, 90 insertions, 10 deletions
diff --git a/rq/worker.py b/rq/worker.py index ec25119..6912e8c 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -21,6 +21,7 @@ except ImportError: from redis import WatchError +from . import worker_registration from .compat import PY2, as_text, string_types, text_type from .connections import get_current_connection, push_connection, pop_connection from .defaults import DEFAULT_RESULT_TTL, DEFAULT_WORKER_TTL @@ -96,10 +97,12 @@ class Worker(object): job_class = Job @classmethod - def all(cls, connection=None, job_class=None, queue_class=None): + def all(cls, connection=None, job_class=None, queue_class=None, queue=None): """Returns an iterable of all Workers. """ - if connection is None: + if queue: + connection = queue.connection + elif connection is None: connection = get_current_connection() reported_working = connection.smembers(cls.redis_workers_keys) workers = [cls.find_by_key(as_text(key), @@ -110,6 +113,11 @@ class Worker(object): return compact(workers) @classmethod + def all_keys(cls, connection): + return [as_text(key) + for key in connection.smembers(cls.redis_workers_keys)] + + @classmethod def find_by_key(cls, worker_key, connection=None, job_class=None, queue_class=None): """Returns a Worker instance, based on the naming conventions for @@ -132,7 +140,7 @@ class Worker(object): connection=connection, job_class=job_class, queue_class=queue_class) - + worker.refresh() return worker @@ -185,7 +193,7 @@ class Worker(object): if exc_handler is not None: self.push_exc_handler(exc_handler) warnings.warn( - "use of exc_handler is deprecated, pass a list to exception_handlers instead.", + "exc_handler is deprecated, pass a list to exception_handlers instead.", DeprecationWarning ) elif isinstance(exception_handlers, list): @@ -268,7 +276,7 @@ class Worker(object): p.hset(key, 'birth', now_in_string) p.hset(key, 'last_heartbeat', now_in_string) p.hset(key, 'queues', queues) - p.sadd(self.redis_workers_keys, key) + worker_registration.register(self, p) p.expire(key, self.default_worker_ttl) p.execute() @@ -278,7 +286,7 @@ class Worker(object): with self.connection._pipeline() as p: # We cannot use self.state = 'dead' here, because that would # rollback the pipeline - p.srem(self.redis_workers_keys, self.key) + worker_registration.unregister(self, p) p.hset(self.key, 'death', utcformat(utcnow())) p.expire(self.key, 60) p.execute() @@ -560,7 +568,7 @@ class Worker(object): connection=self.connection, job_class=self.job_class) for queue in queues.split(',')] - + def increment_failed_job_count(self, pipeline=None): connection = pipeline if pipeline is not None else self.connection connection.hincrby(self.key, 'failed_job_count', 1) @@ -765,7 +773,7 @@ class Worker(object): self.connection, job_class=self.job_class) - try: + try: job.started_at = utcnow() with self.death_penalty_class(job.timeout or self.queue_class.DEFAULT_TIMEOUT): rv = job.perform() diff --git a/rq/worker_registration.py b/rq/worker_registration.py new file mode 100644 index 0000000..a5c240b --- /dev/null +++ b/rq/worker_registration.py @@ -0,0 +1,34 @@ +# from .worker import Worker + + +workers_by_queue_key = 'rq:workers:%s' + + +def register(worker, pipeline=None): + """ + Store worker key in Redis data structures so we can easily discover + all active workers. + """ + connection = pipeline if pipeline is not None else worker.connection + connection.sadd(worker.redis_workers_keys, worker.key) + for name in worker.queue_names(): + redis_key = workers_by_queue_key % name + connection.sadd(redis_key, worker.key) + + +def unregister(worker, pipeline=None): + """ + Remove worker key from Redis. + """ + if pipeline is None: + connection = worker.connection._pipeline() + else: + connection = pipeline + + connection.srem(worker.redis_workers_keys, worker.key) + for name in worker.queue_names(): + redis_key = workers_by_queue_key % name + connection.srem(redis_key, worker.key) + + if pipeline is None: + connection.execute() diff --git a/tests/test_worker.py b/tests/test_worker.py index 8e44d29..a196756 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -119,7 +119,12 @@ class TestWorker(RQTestCase): self.assertEqual(worker.queues, queues) self.assertEqual(worker.get_state(), WorkerStatus.STARTED) self.assertEqual(worker._job_id, None) - w.register_death() + self.assertTrue(worker.key in Worker.all_keys(worker.connection)) + + # If worker is gone, its keys should also be removed + worker.connection.delete(worker.key) + Worker.find_by_key(worker.key) + self.assertFalse(worker.key in Worker.all_keys(worker.connection)) def test_worker_ttl(self): """Worker ttl.""" @@ -183,7 +188,7 @@ class TestWorker(RQTestCase): # importable from the worker process. job = Job.create(func=div_by_zero, args=(3,)) job.save() - + job_data = job.data invalid_data = job_data.replace(b'div_by_zero', b'nonexisting') assert job_data != invalid_data diff --git a/tests/test_worker_registry.py b/tests/test_worker_registry.py new file mode 100644 index 0000000..8393348 --- /dev/null +++ b/tests/test_worker_registry.py @@ -0,0 +1,33 @@ +from tests import RQTestCase + +from rq import Queue, Worker +from rq.worker_registration import register, unregister, workers_by_queue_key + + +class TestWorkerRegistry(RQTestCase): + + def test_worker_registration(self): + """Ensure worker.key is correctly set in Redis.""" + foo_queue = Queue(name='foo') + bar_queue = Queue(name='bar') + worker = Worker([foo_queue, bar_queue]) + + register(worker) + redis = worker.connection + + self.assertTrue(redis.sismember(worker.redis_workers_keys, worker.key)) + self.assertTrue( + redis.sismember(workers_by_queue_key % foo_queue.name, worker.key) + ) + self.assertTrue( + redis.sismember(workers_by_queue_key % bar_queue.name, worker.key) + ) + + unregister(worker) + self.assertFalse(redis.sismember(worker.redis_workers_keys, worker.key)) + self.assertFalse( + redis.sismember(workers_by_queue_key % foo_queue.name, worker.key) + ) + self.assertFalse( + redis.sismember(workers_by_queue_key % bar_queue.name, worker.key) + ) |