summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMaksim Novikov <maksim.novikov@aiven.io>2021-12-02 15:02:43 +0100
committerGitHub <noreply@github.com>2021-12-02 16:02:43 +0200
commitd4a9825a72e1b7715d79ce8134e678d9ef537dce (patch)
tree713d5c211f70013811f6e2dfe9b434723beab2be
parent1a59a7a45feaed2bd0e33ccdbcd92cd305fd7e44 (diff)
downloadredis-py-d4a9825a72e1b7715d79ce8134e678d9ef537dce.tar.gz
Allow overriding connection class via keyword arguments (#1752)
-rwxr-xr-xredis/connection.py4
-rw-r--r--tests/test_connection_pool.py18
2 files changed, 22 insertions, 0 deletions
diff --git a/redis/connection.py b/redis/connection.py
index d13fe65..2001c64 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -1131,6 +1131,10 @@ class ConnectionPool:
arguments always win.
"""
url_options = parse_url(url)
+
+ if "connection_class" in kwargs:
+ url_options["connection_class"] = kwargs["connection_class"]
+
kwargs.update(url_options)
return cls(**kwargs)
diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py
index 2602af8..276e77c 100644
--- a/tests/test_connection_pool.py
+++ b/tests/test_connection_pool.py
@@ -445,6 +445,15 @@ class TestConnectionPoolUnixSocketURLParsing:
assert pool.connection_class == redis.UnixDomainSocketConnection
assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"}
+ def test_connection_class_override(self):
+ class MyConnection(redis.UnixDomainSocketConnection):
+ pass
+
+ pool = redis.ConnectionPool.from_url(
+ 'unix:///socket', connection_class=MyConnection
+ )
+ assert pool.connection_class == MyConnection
+
@pytest.mark.skipif(not ssl_available, reason="SSL not installed")
class TestSSLConnectionURLParsing:
@@ -455,6 +464,15 @@ class TestSSLConnectionURLParsing:
"host": "my.host",
}
+ def test_connection_class_override(self):
+ class MyConnection(redis.SSLConnection):
+ pass
+
+ pool = redis.ConnectionPool.from_url(
+ 'rediss://my.host', connection_class=MyConnection
+ )
+ assert pool.connection_class == MyConnection
+
def test_cert_reqs_options(self):
import ssl