diff options
author | MyrikLD <myrik260138@tut.by> | 2021-02-03 03:50:47 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-03 07:50:47 +0700 |
commit | 14ca7881e4f0f2c8f70f41395168fcbebb2f2c71 (patch) | |
tree | f464258b27176394c6924c98d33dba5ae6abf987 | |
parent | 59d1b40d142ac0d70e07d306f3aea3d7ebae9417 (diff) | |
download | rq-14ca7881e4f0f2c8f70f41395168fcbebb2f2c71.tar.gz |
Add runner for asyncio tasks (#1405)
* add asyncio runner
* add asyncio runner
* fix for old version
* fix tests
-rw-r--r-- | rq/job.py | 8 | ||||
-rw-r--r-- | tests/fixtures.py | 5 | ||||
-rw-r--r-- | tests/test_job.py | 15 |
3 files changed, 27 insertions, 1 deletions
@@ -8,6 +8,7 @@ import pickle import warnings import zlib +import asyncio from collections.abc import Iterable from distutils.version import StrictVersion from functools import partial @@ -720,7 +721,12 @@ class Job(object): pipeline.hmset(self.key, mapping) def _execute(self): - return self.func(*self.args, **self.kwargs) + result = self.func(*self.args, **self.kwargs) + if asyncio.iscoroutine(result): + loop = asyncio.get_event_loop() + coro_result = loop.run_until_complete(result) + return coro_result + return result def get_ttl(self, default_ttl=None): """Returns ttl for a job that determines how long a job will be diff --git a/tests/fixtures.py b/tests/fixtures.py index b2d4af1..82e98bc 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -32,6 +32,11 @@ def say_hello(name=None): return 'Hi there, %s!' % (name,) +async def say_hello_async(name=None): + """A async job with a single argument and a return value.""" + return say_hello(name) + + def say_hello_unicode(name=None): """A job with a single argument and a return value.""" return text_type(say_hello(name)) # noqa diff --git a/tests/test_job.py b/tests/test_job.py index 8b921d7..fa7eff1 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -748,6 +748,21 @@ class TestJob(RQTestCase): self.assertRaises(TypeError, queue.enqueue, fixtures.say_hello, job_id=1234) + def test_create_job_with_async(self): + """test creating jobs with async function""" + queue = Queue(connection=self.testconn) + + async_job = queue.enqueue(fixtures.say_hello_async, job_id="async_job") + sync_job = queue.enqueue(fixtures.say_hello, job_id="sync_job") + + self.assertEqual(async_job.id, "async_job") + self.assertEqual(sync_job.id, "sync_job") + + async_task_result = async_job.perform() + sync_task_result = sync_job.perform() + + self.assertEqual(sync_task_result, async_task_result) + def test_get_call_string_unicode(self): """test call string with unicode keyword arguments""" queue = Queue(connection=self.testconn) |