diff options
-rwxr-xr-x | redis/connection.py | 4 | ||||
-rw-r--r-- | tests/test_connection_pool.py | 18 |
2 files changed, 22 insertions, 0 deletions
diff --git a/redis/connection.py b/redis/connection.py index d13fe65..2001c64 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1131,6 +1131,10 @@ class ConnectionPool: arguments always win. """ url_options = parse_url(url) + + if "connection_class" in kwargs: + url_options["connection_class"] = kwargs["connection_class"] + kwargs.update(url_options) return cls(**kwargs) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 2602af8..276e77c 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -445,6 +445,15 @@ class TestConnectionPoolUnixSocketURLParsing: assert pool.connection_class == redis.UnixDomainSocketConnection assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} + def test_connection_class_override(self): + class MyConnection(redis.UnixDomainSocketConnection): + pass + + pool = redis.ConnectionPool.from_url( + 'unix:///socket', connection_class=MyConnection + ) + assert pool.connection_class == MyConnection + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") class TestSSLConnectionURLParsing: @@ -455,6 +464,15 @@ class TestSSLConnectionURLParsing: "host": "my.host", } + def test_connection_class_override(self): + class MyConnection(redis.SSLConnection): + pass + + pool = redis.ConnectionPool.from_url( + 'rediss://my.host', connection_class=MyConnection + ) + assert pool.connection_class == MyConnection + def test_cert_reqs_options(self): import ssl |