diff options
Diffstat (limited to 'redis/asyncio/cluster.py')
-rw-r--r-- | redis/asyncio/cluster.py | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index a4a9561..eb5f4db 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -5,12 +5,14 @@ import socket import warnings from typing import ( Any, + Callable, Deque, Dict, Generator, List, Mapping, Optional, + Tuple, Type, TypeVar, Union, @@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand maximum number of connections are already created, a :class:`~.MaxConnectionsError` is raised. This error may be retried as defined by :attr:`connection_error_retry_attempts` + :param address_remap: + | An optional callable which, when provided with an internal network + address of a node, e.g. a `(host, port)` tuple, will return the address + where the node is reachable. This can be used to map the addresses at + which the nodes _think_ they are, to addresses at which a client may + reach them, such as when they sit behind a proxy. | Rest of the arguments will be passed to the :class:`~redis.asyncio.connection.Connection` instances when created @@ -250,6 +258,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -337,7 +346,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs) + self.nodes_manager = NodesManager( + startup_nodes, + require_full_coverage, + kwargs, + address_remap=address_remap, + ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas self.reinitialize_steps = reinitialize_steps @@ -1059,6 +1073,7 @@ class NodesManager: "require_full_coverage", "slots_cache", "startup_nodes", + "address_remap", ) def __init__( @@ -1066,10 +1081,12 @@ class NodesManager: startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage self.connection_kwargs = connection_kwargs + self.address_remap = address_remap self.default_node: "ClusterNode" = None self.nodes_cache: Dict[str, "ClusterNode"] = {} @@ -1228,6 +1245,7 @@ class NodesManager: if host == "": host = startup_node.host port = int(primary_node[1]) + host, port = self.remap_host_port(host, port) target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: @@ -1246,6 +1264,7 @@ class NodesManager: for replica_node in replica_nodes: host = replica_node[0] port = replica_node[1] + host, port = self.remap_host_port(host, port) target_replica_node = tmp_nodes_cache.get( get_node_name(host, port) @@ -1319,6 +1338,16 @@ class NodesManager: ) ) + def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: + """ + Remap the host and port returned from the cluster to a different + internal value. Useful if the client is not connecting directly + to the cluster. + """ + if self.address_remap: + return self.address_remap((host, port)) + return host, port + class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ |