diff options
author | Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> | 2022-02-22 05:29:55 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-22 12:29:55 +0200 |
commit | d56baeb683fc1935cfa343fa2eeb0fa9bd955283 (patch) | |
tree | 47357a74bf1d1428cfbcf0d8b2c781f1f971cf77 /redis/asyncio/sentinel.py | |
parent | e3c989d93e914e6502bd5a72f15ded49a135c5be (diff) | |
download | redis-py-d56baeb683fc1935cfa343fa2eeb0fa9bd955283.tar.gz |
Add Async Support (#1899)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com>
Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'redis/asyncio/sentinel.py')
-rw-r--r-- | redis/asyncio/sentinel.py | 353 |
1 files changed, 353 insertions, 0 deletions
diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py new file mode 100644 index 0000000..5aefd09 --- /dev/null +++ b/redis/asyncio/sentinel.py @@ -0,0 +1,353 @@ +import asyncio +import random +import weakref +from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type + +from redis.asyncio.client import Redis +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + EncodableT, + SSLConnection, +) +from redis.commands import AsyncSentinelCommands +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{self.__class__.__name__}<service={pool.service_name}" + if self.host: + host_info = f",host={self.host},port={self.port}" + s += host_info + return s + ">" + + async def connect_to(self, address): + self.host, self.port = address + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + async def connect(self): + if self._reader: + return # already connected + if self.connection_pool.is_master: + await self.connect_to(await self.connection_pool.get_master_address()) + else: + async for slave in self.connection_pool.rotate_slaves(): + try: + return await self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + async def read_response(self, disable_decoding: bool = False): + try: + return await super().read_response(disable_decoding=disable_decoding) + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + await self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): + pass + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.master_address = None + self.slave_rr_counter = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<service={self.service_name}({self.is_master and 'master' or 'slave'})>" + ) + + def reset(self): + super().reset() + self.master_address = None + self.slave_rr_counter = None + + def owns_connection(self, connection: Connection): + check = not self.is_master or ( + self.is_master and self.master_address == (connection.host, connection.port) + ) + return check and super().owns_connection(connection) + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master: + if self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + await self.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + """Round-robin slave balancer""" + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class Sentinel(AsyncSentinelCommands): + """ + Redis Sentinel cluster client + + >>> from redis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> await master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> await slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(host=hostname, port=port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + async def execute_command(self, *args, **kwargs): + """ + Execute Sentinel command in sentinel nodes. + once - If set to True, then execute the resulting command on a single + node at random, rather than across the entire sentinel cluster. + """ + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") + + if once: + tasks = [ + asyncio.Task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + await asyncio.gather(*tasks) + else: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + return True + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + f"{sentinel.connection_pool.connection_kwargs['host']}:" + f"{sentinel.connection_pool.connection_kwargs['port']}" + ) + return f"{self.__class__.__name__}<sentinels=[{','.join(sentinel_addresses)}]>" + + def check_master_state(self, state: dict, service_name: str) -> bool: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + async def discover_master(self, service_name: str): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = await sentinel.sentinel_masters() + except (ConnectionError, TimeoutError): + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + raise MasterNotFoundError(f"No master found for {service_name!r}") + + def filter_slaves( + self, slaves: Iterable[Mapping] + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Remove slaves that are in an ODOWN or SDOWN state""" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + async def discover_slaves( + self, service_name: str + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Returns a list of alive slaves for service ``service_name``""" + for sentinel in self.sentinels: + try: + slaves = await sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrieve the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) + + def slave_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrieve the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) |