summaryrefslogtreecommitdiff
path: root/redis
diff options
context:
space:
mode:
Diffstat (limited to 'redis')
-rw-r--r--redis/asyncio/cluster.py31
-rw-r--r--redis/cluster.py22
2 files changed, 52 insertions, 1 deletions
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index a4a9561..eb5f4db 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -5,12 +5,14 @@ import socket
import warnings
from typing import (
Any,
+ Callable,
Deque,
Dict,
Generator,
List,
Mapping,
Optional,
+ Tuple,
Type,
TypeVar,
Union,
@@ -147,6 +149,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
maximum number of connections are already created, a
:class:`~.MaxConnectionsError` is raised. This error may be retried as defined
by :attr:`connection_error_retry_attempts`
+ :param address_remap:
+ | An optional callable which, when provided with an internal network
+ address of a node, e.g. a `(host, port)` tuple, will return the address
+ where the node is reachable. This can be used to map the addresses at
+ which the nodes _think_ they are, to addresses at which a client may
+ reach them, such as when they sit behind a proxy.
| Rest of the arguments will be passed to the
:class:`~redis.asyncio.connection.Connection` instances when created
@@ -250,6 +258,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
ssl_certfile: Optional[str] = None,
ssl_check_hostname: bool = False,
ssl_keyfile: Optional[str] = None,
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
if db:
raise RedisClusterException(
@@ -337,7 +346,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
- self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
+ self.nodes_manager = NodesManager(
+ startup_nodes,
+ require_full_coverage,
+ kwargs,
+ address_remap=address_remap,
+ )
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
self.reinitialize_steps = reinitialize_steps
@@ -1059,6 +1073,7 @@ class NodesManager:
"require_full_coverage",
"slots_cache",
"startup_nodes",
+ "address_remap",
)
def __init__(
@@ -1066,10 +1081,12 @@ class NodesManager:
startup_nodes: List["ClusterNode"],
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
self.connection_kwargs = connection_kwargs
+ self.address_remap = address_remap
self.default_node: "ClusterNode" = None
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1228,6 +1245,7 @@ class NodesManager:
if host == "":
host = startup_node.host
port = int(primary_node[1])
+ host, port = self.remap_host_port(host, port)
target_node = tmp_nodes_cache.get(get_node_name(host, port))
if not target_node:
@@ -1246,6 +1264,7 @@ class NodesManager:
for replica_node in replica_nodes:
host = replica_node[0]
port = replica_node[1]
+ host, port = self.remap_host_port(host, port)
target_replica_node = tmp_nodes_cache.get(
get_node_name(host, port)
@@ -1319,6 +1338,16 @@ class NodesManager:
)
)
+ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
+ """
+ Remap the host and port returned from the cluster to a different
+ internal value. Useful if the client is not connecting directly
+ to the cluster.
+ """
+ if self.address_remap:
+ return self.address_remap((host, port))
+ return host, port
+
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
diff --git a/redis/cluster.py b/redis/cluster.py
index 5e6e7da..3ecc2da 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -466,6 +466,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
read_from_replicas: bool = False,
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
"""
@@ -514,6 +515,12 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
reinitialize_steps to 1.
To avoid reinitializing the cluster on moved errors, set
reinitialize_steps to 0.
+ :param address_remap:
+ An optional callable which, when provided with an internal network
+ address of a node, e.g. a `(host, port)` tuple, will return the address
+ where the node is reachable. This can be used to map the addresses at
+ which the nodes _think_ they are, to addresses at which a client may
+ reach them, such as when they sit behind a proxy.
:**kwargs:
Extra arguments that will be sent into Redis instance when created
@@ -594,6 +601,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
from_url=from_url,
require_full_coverage=require_full_coverage,
dynamic_startup_nodes=dynamic_startup_nodes,
+ address_remap=address_remap,
**kwargs,
)
@@ -1269,6 +1277,7 @@ class NodesManager:
lock=None,
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
+ address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
**kwargs,
):
self.nodes_cache = {}
@@ -1280,6 +1289,7 @@ class NodesManager:
self._require_full_coverage = require_full_coverage
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
+ self.address_remap = address_remap
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
@@ -1502,6 +1512,7 @@ class NodesManager:
if host == "":
host = startup_node.host
port = int(primary_node[1])
+ host, port = self.remap_host_port(host, port)
target_node = self._get_or_create_cluster_node(
host, port, PRIMARY, tmp_nodes_cache
@@ -1518,6 +1529,7 @@ class NodesManager:
for replica_node in replica_nodes:
host = str_if_bytes(replica_node[0])
port = replica_node[1]
+ host, port = self.remap_host_port(host, port)
target_replica_node = self._get_or_create_cluster_node(
host, port, REPLICA, tmp_nodes_cache
@@ -1591,6 +1603,16 @@ class NodesManager:
# The read_load_balancer is None, do nothing
pass
+ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
+ """
+ Remap the host and port returned from the cluster to a different
+ internal value. Useful if the client is not connecting directly
+ to the cluster.
+ """
+ if self.address_remap:
+ return self.address_remap((host, port))
+ return host, port
+
class ClusterPubSub(PubSub):
"""