diff options
author | Maksim Novikov <maksim.novikov@aiven.io> | 2021-12-02 15:02:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-02 16:02:43 +0200 |
commit | d4a9825a72e1b7715d79ce8134e678d9ef537dce (patch) | |
tree | 713d5c211f70013811f6e2dfe9b434723beab2be | |
parent | 1a59a7a45feaed2bd0e33ccdbcd92cd305fd7e44 (diff) | |
download | redis-py-d4a9825a72e1b7715d79ce8134e678d9ef537dce.tar.gz |
Allow overriding connection class via keyword arguments (#1752)
-rwxr-xr-x | redis/connection.py | 4 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 18 |
2 files changed, 22 insertions, 0 deletions
diff --git a/redis/connection.py b/redis/connection.py index d13fe65..2001c64 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1131,6 +1131,10 @@ class ConnectionPool: arguments always win. """ url_options = parse_url(url) + + if "connection_class" in kwargs: + url_options["connection_class"] = kwargs["connection_class"] + kwargs.update(url_options) return cls(**kwargs) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 2602af8..276e77c 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -445,6 +445,15 @@ class TestConnectionPoolUnixSocketURLParsing: assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} + def test_connection_class_override(self): + class MyConnection(redis.UnixDomainSocketConnection): + pass + + pool = redis.ConnectionPool.from_url( + 'unix:///socket', connection_class=MyConnection + ) + assert pool.connection_class == MyConnection + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") class TestSSLConnectionURLParsing: @@ -455,6 +464,15 @@ class TestSSLConnectionURLParsing: "host": "my.host", } + def test_connection_class_override(self): + class MyConnection(redis.SSLConnection): + pass + + pool = redis.ConnectionPool.from_url( + 'rediss://my.host', connection_class=MyConnection + ) + assert pool.connection_class == MyConnection + def test_cert_reqs_options(self): import ssl |