diff options
-rw-r--r-- | redis/cluster.py | 5 | ||||
-rw-r--r-- | tests/test_cluster.py | 17 |
2 files changed, 20 insertions, 2 deletions
diff --git a/redis/cluster.py b/redis/cluster.py index 235d8f2..5f730a8 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -121,6 +121,7 @@ REDIS_ALLOWED_KEYS = ( "charset", "connection_class", "connection_pool", + "connection_pool_class", "client_name", "credential_provider", "db", @@ -1267,6 +1268,7 @@ class NodesManager: require_full_coverage=False, lock=None, dynamic_startup_nodes=True, + connection_pool_class=ConnectionPool, **kwargs, ): self.nodes_cache = {} @@ -1277,6 +1279,7 @@ class NodesManager: self.from_url = from_url self._require_full_coverage = require_full_coverage self._dynamic_startup_nodes = dynamic_startup_nodes + self.connection_pool_class = connection_pool_class self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1420,7 +1423,7 @@ class NodesManager: # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - r = Redis(connection_pool=ConnectionPool(**kwargs)) + r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: r = Redis(host=host, port=port, **kwargs) return r diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 55b660c..da6a8e4 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -18,7 +18,7 @@ from redis.cluster import ( get_node_name, ) from redis.commands import CommandsParser -from redis.connection import Connection +from redis.connection import BlockingConnectionPool, Connection, ConnectionPool from redis.crc import key_slot from redis.exceptions import ( AskError, @@ -2496,6 +2496,21 @@ class TestNodesManager: else: assert startup_nodes == ["my@DNS.com:7000"] + @pytest.mark.parametrize( + "connection_pool_class", [ConnectionPool, BlockingConnectionPool] + ) + def test_connection_pool_class(self, connection_pool_class): + rc = get_mocked_redis_client( + url="redis://my@DNS.com:7000", + cluster_slots=default_cluster_slots, + connection_pool_class=connection_pool_class, + ) + + for node in rc.nodes_manager.nodes_cache.values(): + assert isinstance( + node.redis_connection.connection_pool, connection_pool_class + ) + @pytest.mark.onlycluster class TestClusterPubSubObject: |