summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCyril Chapellier <tchapi@users.noreply.github.com>2023-05-05 09:25:20 +0200
committerGitHub <noreply@github.com>2023-05-05 14:25:20 +0700
commita228b4838c2ac84c7fca31a1908800edee4b24ed (patch)
tree962bbce56dc7f2140cdf9997d0da48d3007c026d
parent07fef85dd2c648c5b731625e269a18677755c4a9 (diff)
downloadrq-a228b4838c2ac84c7fca31a1908800edee4b24ed.tar.gz
[Hotfix] Fix SSL connection for scheduler (#1894)
* fix: ssl * fix: reinstate a test for parse_connection
-rw-r--r--rq/connections.py19
-rw-r--r--rq/scheduler.py4
-rw-r--r--rq/worker_pool.py13
-rw-r--r--tests/test_connection.py16
-rw-r--r--tests/test_worker_pool.py5
5 files changed, 27 insertions, 30 deletions
diff --git a/rq/connections.py b/rq/connections.py
index dfb590a..36d771d 100644
--- a/rq/connections.py
+++ b/rq/connections.py
@@ -118,23 +118,10 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
- connection_kwargs = connection.connection_pool.connection_kwargs.copy()
- # Redis does not accept parser_class argument which is sometimes present
- # on connection_pool kwargs, for example when hiredis is used
- connection_kwargs.pop('parser_class', None)
+ connection_pool_kwargs = connection.connection_pool.connection_kwargs.copy()
connection_pool_class = connection.connection_pool.connection_class
- if issubclass(connection_pool_class, SSLConnection):
- connection_kwargs['ssl'] = True
- if issubclass(connection_pool_class, UnixDomainSocketConnection):
- # The connection keyword arguments are obtained from
- # `UnixDomainSocketConnection`, which expects `path`, but passed to
- # `redis.client.Redis`, which expects `unix_socket_path`, renaming
- # the key is necessary.
- # `path` is not left in the dictionary as that keyword argument is
- # not expected by `redis.client.Redis` and would raise an exception.
- connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
-
- return connection.__class__, connection_pool_class, connection_kwargs
+
+ return connection.__class__, connection_pool_class, connection_pool_kwargs
_connection_stack = LocalStack()
diff --git a/rq/scheduler.py b/rq/scheduler.py
index 069181d..a64b400 100644
--- a/rq/scheduler.py
+++ b/rq/scheduler.py
@@ -50,7 +50,7 @@ class RQScheduler:
self._acquired_locks: Set[str] = set()
self._scheduled_job_registries: List[ScheduledJobRegistry] = []
self.lock_acquisition_time = None
- self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection)
+ self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection)
self.serializer = resolve_serializer(serializer)
self._connection = None
@@ -71,7 +71,7 @@ class RQScheduler:
if self._connection:
return self._connection
self._connection = self._connection_class(
- connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs)
+ connection_pool=ConnectionPool(connection_class=self._pool_class, **self._pool_kwargs)
)
return self._connection
diff --git a/rq/worker_pool.py b/rq/worker_pool.py
index 4bd21bb..005c3b9 100644
--- a/rq/worker_pool.py
+++ b/rq/worker_pool.py
@@ -11,7 +11,7 @@ from typing import Dict, List, NamedTuple, Optional, Set, Type, Union
from uuid import uuid4
from redis import Redis
-from redis import SSLConnection, UnixDomainSocketConnection
+from redis import ConnectionPool
from rq.serializers import DefaultSerializer
from rq.timeouts import HorseMonitorTimeoutException, UnixSignalDeathPenalty
@@ -65,7 +65,7 @@ class WorkerPool:
# A dictionary of WorkerData keyed by worker name
self.worker_dict: Dict[str, WorkerData] = {}
- self._connection_class, _, self._connection_kwargs = parse_connection(connection)
+ self._connection_class, self._pool_class, self._pool_kwargs = parse_connection(connection)
@property
def queues(self) -> List[Queue]:
@@ -158,7 +158,7 @@ class WorkerPool:
name = uuid4().hex
process = Process(
target=run_worker,
- args=(name, self._queue_names, self._connection_class, self._connection_kwargs),
+ args=(name, self._queue_names, self._connection_class, self._pool_class, self._pool_kwargs),
kwargs={
'_sleep': _sleep,
'burst': burst,
@@ -234,7 +234,8 @@ def run_worker(
worker_name: str,
queue_names: List[str],
connection_class,
- connection_kwargs: dict,
+ connection_pool_class,
+ connection_pool_kwargs: dict,
worker_class: Type[BaseWorker] = Worker,
serializer: Type[DefaultSerializer] = DefaultSerializer,
job_class: Type[Job] = Job,
@@ -242,7 +243,9 @@ def run_worker(
logging_level: str = "INFO",
_sleep: int = 0,
):
- connection = connection_class(**connection_kwargs)
+ connection = connection_class(
+ connection_pool=ConnectionPool(connection_class=connection_pool_class, **connection_pool_kwargs)
+ )
queues = [Queue(name, connection=connection) for name in queue_names]
worker = worker_class(queues, name=worker_name, connection=connection, serializer=serializer, job_class=job_class)
worker.log.info("Starting worker started with PID %s", os.getpid())
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 0b64d2b..5ac76d6 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -1,4 +1,4 @@
-from redis import ConnectionPool, Redis, UnixDomainSocketConnection
+from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
from rq import Connection, Queue
from rq.connections import parse_connection
@@ -38,10 +38,14 @@ class TestConnectionInheritance(RQTestCase):
self.assertEqual(q2.connection, job2.connection)
def test_parse_connection(self):
- """Test parsing `ssl` and UnixDomainSocketConnection"""
- _, _, kwargs = parse_connection(Redis(ssl=True))
- self.assertTrue(kwargs['ssl'])
+ """Test parsing the connection"""
+ conn_class, pool_class, pool_kwargs = parse_connection(Redis(ssl=True))
+ self.assertEqual(conn_class, Redis)
+ self.assertEqual(pool_class, SSLConnection)
+
path = '/tmp/redis.sock'
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
- _, _, kwargs = parse_connection(Redis(connection_pool=pool))
- self.assertTrue(kwargs['unix_socket_path'], path)
+ conn_class, pool_class, pool_kwargs = parse_connection(Redis(connection_pool=pool))
+ self.assertEqual(conn_class, Redis)
+ self.assertEqual(pool_class, UnixDomainSocketConnection)
+ self.assertEqual(pool_kwargs, {"path": path})
diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py
index c836309..ab2e677 100644
--- a/tests/test_worker_pool.py
+++ b/tests/test_worker_pool.py
@@ -8,6 +8,7 @@ from rq.job import JobStatus
from tests import TestCase
from tests.fixtures import CustomJob, _send_shutdown_command, long_running_job, say_hello
+from rq.connections import parse_connection
from rq.queue import Queue
from rq.serializers import JSONSerializer
from rq.worker import SimpleWorker
@@ -108,8 +109,10 @@ class TestWorkerPool(TestCase):
"""Ensure run_worker() properly spawns a Worker"""
queue = Queue('foo', connection=self.connection)
queue.enqueue(say_hello)
+
+ connection_class, pool_class, pool_kwargs = parse_connection(self.connection)
run_worker(
- 'test-worker', ['foo'], self.connection.__class__, self.connection.connection_pool.connection_kwargs.copy()
+ 'test-worker', ['foo'], connection_class, pool_class, pool_kwargs
)
# Worker should have processed the job
self.assertEqual(len(queue), 0)