summaryrefslogtreecommitdiff
path: root/redis/asyncio
diff options
context:
space:
mode:
Diffstat (limited to 'redis/asyncio')
-rw-r--r--redis/asyncio/__init__.py58
-rw-r--r--redis/asyncio/client.py1344
-rw-r--r--redis/asyncio/connection.py1707
-rw-r--r--redis/asyncio/lock.py311
-rw-r--r--redis/asyncio/retry.py58
-rw-r--r--redis/asyncio/sentinel.py353
-rw-r--r--redis/asyncio/utils.py28
7 files changed, 3859 insertions, 0 deletions
diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py
new file mode 100644
index 0000000..c655c7d
--- /dev/null
+++ b/redis/asyncio/__init__.py
@@ -0,0 +1,58 @@
+from redis.asyncio.client import Redis, StrictRedis
+from redis.asyncio.connection import (
+ BlockingConnectionPool,
+ Connection,
+ ConnectionPool,
+ SSLConnection,
+ UnixDomainSocketConnection,
+)
+from redis.asyncio.sentinel import (
+ Sentinel,
+ SentinelConnectionPool,
+ SentinelManagedConnection,
+ SentinelManagedSSLConnection,
+)
+from redis.asyncio.utils import from_url
+from redis.exceptions import (
+ AuthenticationError,
+ AuthenticationWrongNumberOfArgsError,
+ BusyLoadingError,
+ ChildDeadlockedError,
+ ConnectionError,
+ DataError,
+ InvalidResponse,
+ PubSubError,
+ ReadOnlyError,
+ RedisError,
+ ResponseError,
+ TimeoutError,
+ WatchError,
+)
+
+__all__ = [
+ "AuthenticationError",
+ "AuthenticationWrongNumberOfArgsError",
+ "BlockingConnectionPool",
+ "BusyLoadingError",
+ "ChildDeadlockedError",
+ "Connection",
+ "ConnectionError",
+ "ConnectionPool",
+ "DataError",
+ "from_url",
+ "InvalidResponse",
+ "PubSubError",
+ "ReadOnlyError",
+ "Redis",
+ "RedisError",
+ "ResponseError",
+ "Sentinel",
+ "SentinelConnectionPool",
+ "SentinelManagedConnection",
+ "SentinelManagedSSLConnection",
+ "SSLConnection",
+ "StrictRedis",
+ "TimeoutError",
+ "UnixDomainSocketConnection",
+ "WatchError",
+]
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
new file mode 100644
index 0000000..619592e
--- /dev/null
+++ b/redis/asyncio/client.py
@@ -0,0 +1,1344 @@
+import asyncio
+import copy
+import inspect
+import re
+import warnings
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ MutableMapping,
+ NoReturn,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from redis.asyncio.connection import (
+ Connection,
+ ConnectionPool,
+ SSLConnection,
+ UnixDomainSocketConnection,
+)
+from redis.client import (
+ EMPTY_RESPONSE,
+ NEVER_DECODE,
+ AbstractRedis,
+ CaseInsensitiveDict,
+ bool_ok,
+)
+from redis.commands import (
+ AsyncCoreCommands,
+ AsyncSentinelCommands,
+ RedisModuleCommands,
+ list_or_args,
+)
+from redis.compat import Protocol, TypedDict
+from redis.exceptions import (
+ ConnectionError,
+ ExecAbortError,
+ PubSubError,
+ RedisError,
+ ResponseError,
+ TimeoutError,
+ WatchError,
+)
+from redis.lock import Lock
+from redis.retry import Retry
+from redis.typing import ChannelT, EncodableT, KeyT
+from redis.utils import safe_str, str_if_bytes
+
+PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]]
+_KeyT = TypeVar("_KeyT", bound=KeyT)
+_ArgT = TypeVar("_ArgT", KeyT, EncodableT)
+_RedisT = TypeVar("_RedisT", bound="Redis")
+_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object])
+if TYPE_CHECKING:
+ from redis.commands.core import Script
+
+
+class ResponseCallbackProtocol(Protocol):
+ def __call__(self, response: Any, **kwargs):
+ ...
+
+
+class AsyncResponseCallbackProtocol(Protocol):
+ async def __call__(self, response: Any, **kwargs):
+ ...
+
+
+ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol]
+
+
+class Redis(
+ AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands
+):
+ """
+ Implementation of the Redis protocol.
+
+ This abstract class provides a Python interface to all Redis commands
+ and an implementation of the Redis protocol.
+
+ Pipelines derive from this, implementing how
+ the commands are sent and received to the Redis server. Based on
+ configuration, an instance will either use a ConnectionPool, or
+ Connection object to talk to redis.
+ """
+
+ response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
+
+ @classmethod
+ def from_url(cls, url: str, **kwargs):
+ """
+ Return a Redis client object configured from the given URL
+
+ For example::
+
+ redis://[[username]:[password]]@localhost:6379/0
+ rediss://[[username]:[password]]@localhost:6379/0
+ unix://[[username]:[password]]@/path/to/socket.sock?db=0
+
+ Three URL schemes are supported:
+
+ - `redis://` creates a TCP socket connection. See more at:
+ <https://www.iana.org/assignments/uri-schemes/prov/redis>
+ - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
+ <https://www.iana.org/assignments/uri-schemes/prov/rediss>
+ - ``unix://``: creates a Unix Domain Socket connection.
+
+ The username, password, hostname, path and all querystring values
+ are passed through urllib.parse.unquote in order to replace any
+ percent-encoded values with their corresponding characters.
+
+ There are several ways to specify a database number. The first value
+ found will be used:
+ 1. A ``db`` querystring option, e.g. redis://localhost?db=0
+ 2. If using the redis:// or rediss:// schemes, the path argument
+ of the url, e.g. redis://localhost/0
+ 3. A ``db`` keyword argument to this function.
+
+ If none of these options are specified, the default db=0 is used.
+
+ All querystring options are cast to their appropriate Python types.
+ Boolean arguments can be specified with string values "True"/"False"
+ or "Yes"/"No". Values that cannot be properly cast cause a
+ ``ValueError`` to be raised. Once parsed, the querystring arguments
+ and keyword arguments are passed to the ``ConnectionPool``'s
+ class initializer. In the case of conflicting arguments, querystring
+ arguments always win.
+
+ """
+ connection_pool = ConnectionPool.from_url(url, **kwargs)
+ return cls(connection_pool=connection_pool)
+
+ def __init__(
+ self,
+ *,
+ host: str = "localhost",
+ port: int = 6379,
+ db: Union[str, int] = 0,
+ password: Optional[str] = None,
+ socket_timeout: Optional[float] = None,
+ socket_connect_timeout: Optional[float] = None,
+ socket_keepalive: Optional[bool] = None,
+ socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
+ connection_pool: Optional[ConnectionPool] = None,
+ unix_socket_path: Optional[str] = None,
+ encoding: str = "utf-8",
+ encoding_errors: str = "strict",
+ decode_responses: bool = False,
+ retry_on_timeout: bool = False,
+ ssl: bool = False,
+ ssl_keyfile: Optional[str] = None,
+ ssl_certfile: Optional[str] = None,
+ ssl_cert_reqs: str = "required",
+ ssl_ca_certs: Optional[str] = None,
+ ssl_check_hostname: bool = False,
+ max_connections: Optional[int] = None,
+ single_connection_client: bool = False,
+ health_check_interval: int = 0,
+ client_name: Optional[str] = None,
+ username: Optional[str] = None,
+ retry: Optional[Retry] = None,
+ auto_close_connection_pool: bool = True,
+ ):
+ """
+ Initialize a new Redis client.
+ To specify a retry policy, first set `retry_on_timeout` to `True`
+ then set `retry` to a valid `Retry` object
+ """
+ kwargs: Dict[str, Any]
+ # auto_close_connection_pool only has an effect if connection_pool is
+ # None. This is a similar feature to the missing __del__ to resolve #1103,
+ # but it accounts for whether a user wants to manually close the connection
+ # pool, as a similar feature to ConnectionPool's __del__.
+ self.auto_close_connection_pool = (
+ auto_close_connection_pool if connection_pool is None else False
+ )
+ if not connection_pool:
+ kwargs = {
+ "db": db,
+ "username": username,
+ "password": password,
+ "socket_timeout": socket_timeout,
+ "encoding": encoding,
+ "encoding_errors": encoding_errors,
+ "decode_responses": decode_responses,
+ "retry_on_timeout": retry_on_timeout,
+ "retry": copy.deepcopy(retry),
+ "max_connections": max_connections,
+ "health_check_interval": health_check_interval,
+ "client_name": client_name,
+ }
+ # based on input, setup appropriate connection args
+ if unix_socket_path is not None:
+ kwargs.update(
+ {
+ "path": unix_socket_path,
+ "connection_class": UnixDomainSocketConnection,
+ }
+ )
+ else:
+ # TCP specific options
+ kwargs.update(
+ {
+ "host": host,
+ "port": port,
+ "socket_connect_timeout": socket_connect_timeout,
+ "socket_keepalive": socket_keepalive,
+ "socket_keepalive_options": socket_keepalive_options,
+ }
+ )
+
+ if ssl:
+ kwargs.update(
+ {
+ "connection_class": SSLConnection,
+ "ssl_keyfile": ssl_keyfile,
+ "ssl_certfile": ssl_certfile,
+ "ssl_cert_reqs": ssl_cert_reqs,
+ "ssl_ca_certs": ssl_ca_certs,
+ "ssl_check_hostname": ssl_check_hostname,
+ }
+ )
+ connection_pool = ConnectionPool(**kwargs)
+ self.connection_pool = connection_pool
+ self.single_connection_client = single_connection_client
+ self.connection: Optional[Connection] = None
+
+ self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}<{self.connection_pool!r}>"
+
+ def __await__(self):
+ return self.initialize().__await__()
+
+ async def initialize(self: _RedisT) -> _RedisT:
+ if self.single_connection_client and self.connection is None:
+ self.connection = await self.connection_pool.get_connection("_")
+ return self
+
+ def set_response_callback(self, command: str, callback: ResponseCallbackT):
+ """Set a custom Response Callback"""
+ self.response_callbacks[command] = callback
+
+ def load_external_module(
+ self,
+ funcname,
+ func,
+ ):
+ """
+ This function can be used to add externally defined redis modules,
+ and their namespaces to the redis client.
+
+ funcname - A string containing the name of the function to create
+ func - The function, being added to this class.
+
+ ex: Assume that one has a custom redis module named foomod that
+ creates command named 'foo.dothing' and 'foo.anotherthing' in redis.
+ To load function functions into this namespace:
+
+ from redis import Redis
+ from foomodule import F
+ r = Redis()
+ r.load_external_module("foo", F)
+ r.foo().dothing('your', 'arguments')
+
+ For a concrete example see the reimport of the redisjson module in
+ tests/test_connection.py::test_loading_external_modules
+ """
+ setattr(self, funcname, func)
+
+ def pipeline(
+ self, transaction: bool = True, shard_hint: Optional[str] = None
+ ) -> "Pipeline":
+ """
+ Return a new pipeline object that can queue multiple commands for
+ later execution. ``transaction`` indicates whether all commands
+ should be executed atomically. Apart from making a group of operations
+ atomic, pipelines are useful for reducing the back-and-forth overhead
+ between the client and server.
+ """
+ return Pipeline(
+ self.connection_pool, self.response_callbacks, transaction, shard_hint
+ )
+
+ async def transaction(
+ self,
+ func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]],
+ *watches: KeyT,
+ shard_hint: Optional[str] = None,
+ value_from_callable: bool = False,
+ watch_delay: Optional[float] = None,
+ ):
+ """
+ Convenience method for executing the callable `func` as a transaction
+ while watching all keys specified in `watches`. The 'func' callable
+ should expect a single argument which is a Pipeline object.
+ """
+ pipe: Pipeline
+ async with self.pipeline(True, shard_hint) as pipe:
+ while True:
+ try:
+ if watches:
+ await pipe.watch(*watches)
+ func_value = func(pipe)
+ if inspect.isawaitable(func_value):
+ func_value = await func_value
+ exec_value = await pipe.execute()
+ return func_value if value_from_callable else exec_value
+ except WatchError:
+ if watch_delay is not None and watch_delay > 0:
+ await asyncio.sleep(watch_delay)
+ continue
+
+ def lock(
+ self,
+ name: KeyT,
+ timeout: Optional[float] = None,
+ sleep: float = 0.1,
+ blocking_timeout: Optional[float] = None,
+ lock_class: Optional[Type[Lock]] = None,
+ thread_local: bool = True,
+ ) -> Lock:
+ """
+ Return a new Lock object using key ``name`` that mimics
+ the behavior of threading.Lock.
+
+ If specified, ``timeout`` indicates a maximum life for the lock.
+ By default, it will remain locked until release() is called.
+
+ ``sleep`` indicates the amount of time to sleep per loop iteration
+ when the lock is in blocking mode and another client is currently
+ holding the lock.
+
+ ``blocking_timeout`` indicates the maximum amount of time in seconds to
+ spend trying to acquire the lock. A value of ``None`` indicates
+ continue trying forever. ``blocking_timeout`` can be specified as a
+ float or integer, both representing the number of seconds to wait.
+
+ ``lock_class`` forces the specified lock implementation.
+
+ ``thread_local`` indicates whether the lock token is placed in
+ thread-local storage. By default, the token is placed in thread local
+ storage so that a thread only sees its token, not a token set by
+ another thread. Consider the following timeline:
+
+ time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
+ thread-1 sets the token to "abc"
+ time: 1, thread-2 blocks trying to acquire `my-lock` using the
+ Lock instance.
+ time: 5, thread-1 has not yet completed. redis expires the lock
+ key.
+ time: 5, thread-2 acquired `my-lock` now that it's available.
+ thread-2 sets the token to "xyz"
+ time: 6, thread-1 finishes its work and calls release(). if the
+ token is *not* stored in thread local storage, then
+ thread-1 would see the token value as "xyz" and would be
+ able to successfully release the thread-2's lock.
+
+ In some use cases it's necessary to disable thread local storage. For
+ example, if you have code where one thread acquires a lock and passes
+ that lock instance to a worker thread to release later. If thread
+ local storage isn't disabled in this case, the worker thread won't see
+ the token set by the thread that acquired the lock. Our assumption
+ is that these cases aren't common and as such default to using
+ thread local storage."""
+ if lock_class is None:
+ lock_class = Lock
+ return lock_class(
+ self,
+ name,
+ timeout=timeout,
+ sleep=sleep,
+ blocking_timeout=blocking_timeout,
+ thread_local=thread_local,
+ )
+
+ def pubsub(self, **kwargs) -> "PubSub":
+ """
+ Return a Publish/Subscribe object. With this object, you can
+ subscribe to channels and listen for messages that get published to
+ them.
+ """
+ return PubSub(self.connection_pool, **kwargs)
+
+ def monitor(self) -> "Monitor":
+ return Monitor(self.connection_pool)
+
+ def client(self) -> "Redis":
+ return self.__class__(
+ connection_pool=self.connection_pool, single_connection_client=True
+ )
+
+ async def __aenter__(self: _RedisT) -> _RedisT:
+ return await self.initialize()
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+ _DEL_MESSAGE = "Unclosed Redis client"
+
+ def __del__(self, _warnings: Any = warnings) -> None:
+ if self.connection is not None:
+ _warnings.warn(
+ f"Unclosed client session {self!r}",
+ ResourceWarning,
+ source=self,
+ )
+ context = {"client": self, "message": self._DEL_MESSAGE}
+ asyncio.get_event_loop().call_exception_handler(context)
+
+ async def close(self, close_connection_pool: Optional[bool] = None) -> None:
+ """
+ Closes Redis client connection
+
+ :param close_connection_pool: decides whether to close the connection pool used
+ by this Redis client, overriding Redis.auto_close_connection_pool. By default,
+ let Redis.auto_close_connection_pool decide whether to close the connection
+ pool.
+ """
+ conn = self.connection
+ if conn:
+ self.connection = None
+ await self.connection_pool.release(conn)
+ if close_connection_pool or (
+ close_connection_pool is None and self.auto_close_connection_pool
+ ):
+ await self.connection_pool.disconnect()
+
+ async def _send_command_parse_response(self, conn, command_name, *args, **options):
+ """
+ Send a command and parse the response
+ """
+ await conn.send_command(*args)
+ return await self.parse_response(conn, command_name, **options)
+
+ async def _disconnect_raise(self, conn: Connection, error: Exception):
+ """
+ Close the connection and raise an exception
+ if retry_on_timeout is not set or the error
+ is not a TimeoutError
+ """
+ await conn.disconnect()
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ raise error
+
+ # COMMAND EXECUTION AND PROTOCOL PARSING
+ async def execute_command(self, *args, **options):
+ """Execute a command and return a parsed response"""
+ await self.initialize()
+ pool = self.connection_pool
+ command_name = args[0]
+ conn = self.connection or await pool.get_connection(command_name, **options)
+
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, command_name, *args, **options
+ ),
+ lambda error: self._disconnect_raise(conn, error),
+ )
+ finally:
+ if not self.connection:
+ await pool.release(conn)
+
+ async def parse_response(
+ self, connection: Connection, command_name: Union[str, bytes], **options
+ ):
+ """Parses a response from the Redis server"""
+ try:
+ if NEVER_DECODE in options:
+ response = await connection.read_response(disable_encoding=True)
+ else:
+ response = await connection.read_response()
+ except ResponseError:
+ if EMPTY_RESPONSE in options:
+ return options[EMPTY_RESPONSE]
+ raise
+ if command_name in self.response_callbacks:
+ # Mypy bug: https://github.com/python/mypy/issues/10977
+ command_name = cast(str, command_name)
+ retval = self.response_callbacks[command_name](response, **options)
+ return await retval if inspect.isawaitable(retval) else retval
+ return response
+
+
+StrictRedis = Redis
+
+
+class MonitorCommandInfo(TypedDict):
+ time: float
+ db: int
+ client_address: str
+ client_port: str
+ client_type: str
+ command: str
+
+
+class Monitor:
+ """
+ Monitor is useful for handling the MONITOR command to the redis server.
+ next_command() method returns one command from monitor
+ listen() method yields commands from monitor.
+ """
+
+ monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)")
+ command_re = re.compile(r'"(.*?)(?<!\\)"')
+
+ def __init__(self, connection_pool: ConnectionPool):
+ self.connection_pool = connection_pool
+ self.connection: Optional[Connection] = None
+
+ async def connect(self):
+ if self.connection is None:
+ self.connection = await self.connection_pool.get_connection("MONITOR")
+
+ async def __aenter__(self):
+ await self.connect()
+ await self.connection.send_command("MONITOR")
+ # check that monitor returns 'OK', but don't return it to user
+ response = await self.connection.read_response()
+ if not bool_ok(response):
+ raise RedisError(f"MONITOR failed: {response}")
+ return self
+
+ async def __aexit__(self, *args):
+ await self.connection.disconnect()
+ await self.connection_pool.release(self.connection)
+
+ async def next_command(self) -> MonitorCommandInfo:
+ """Parse the response from a monitor command"""
+ await self.connect()
+ response = await self.connection.read_response()
+ if isinstance(response, bytes):
+ response = self.connection.encoder.decode(response, force=True)
+ command_time, command_data = response.split(" ", 1)
+ m = self.monitor_re.match(command_data)
+ db_id, client_info, command = m.groups()
+ command = " ".join(self.command_re.findall(command))
+ # Redis escapes double quotes because each piece of the command
+ # string is surrounded by double quotes. We don't have that
+ # requirement so remove the escaping and leave the quote.
+ command = command.replace('\\"', '"')
+
+ if client_info == "lua":
+ client_address = "lua"
+ client_port = ""
+ client_type = "lua"
+ elif client_info.startswith("unix"):
+ client_address = "unix"
+ client_port = client_info[5:]
+ client_type = "unix"
+ else:
+ # use rsplit as ipv6 addresses contain colons
+ client_address, client_port = client_info.rsplit(":", 1)
+ client_type = "tcp"
+ return {
+ "time": float(command_time),
+ "db": int(db_id),
+ "client_address": client_address,
+ "client_port": client_port,
+ "client_type": client_type,
+ "command": command,
+ }
+
+ async def listen(self) -> AsyncIterator[MonitorCommandInfo]:
+ """Listen for commands coming to the server."""
+ while True:
+ yield await self.next_command()
+
+
+class PubSub:
+ """
+ PubSub provides publish, subscribe and listen support to Redis channels.
+
+ After subscribing to one or more channels, the listen() method will block
+ until a message arrives on one of the subscribed channels. That message
+ will be returned and it's safe to start listening again.
+ """
+
+ PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
+ UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
+ HEALTH_CHECK_MESSAGE = "redis-py-health-check"
+
+ def __init__(
+ self,
+ connection_pool: ConnectionPool,
+ shard_hint: Optional[str] = None,
+ ignore_subscribe_messages: bool = False,
+ encoder=None,
+ ):
+ self.connection_pool = connection_pool
+ self.shard_hint = shard_hint
+ self.ignore_subscribe_messages = ignore_subscribe_messages
+ self.connection = None
+ # we need to know the encoding options for this connection in order
+ # to lookup channel and pattern names for callback handlers.
+ self.encoder = encoder
+ if self.encoder is None:
+ self.encoder = self.connection_pool.get_encoder()
+ if self.encoder.decode_responses:
+ self.health_check_response: Iterable[Union[str, bytes]] = [
+ "pong",
+ self.HEALTH_CHECK_MESSAGE,
+ ]
+ else:
+ self.health_check_response = [
+ b"pong",
+ self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
+ ]
+ self.channels = {}
+ self.pending_unsubscribe_channels = set()
+ self.patterns = {}
+ self.pending_unsubscribe_patterns = set()
+ self._lock = asyncio.Lock()
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.reset()
+
+ def __del__(self):
+ if self.connection:
+ self.connection.clear_connect_callbacks()
+
+ async def reset(self):
+ async with self._lock:
+ if self.connection:
+ await self.connection.disconnect()
+ self.connection.clear_connect_callbacks()
+ await self.connection_pool.release(self.connection)
+ self.connection = None
+ self.channels = {}
+ self.pending_unsubscribe_channels = set()
+ self.patterns = {}
+ self.pending_unsubscribe_patterns = set()
+
+ def close(self) -> Awaitable[NoReturn]:
+ return self.reset()
+
+ async def on_connect(self, connection: Connection):
+ """Re-subscribe to any channels and patterns previously subscribed to"""
+ # NOTE: for python3, we can't pass bytestrings as keyword arguments
+ # so we need to decode channel/pattern names back to unicode strings
+ # before passing them to [p]subscribe.
+ self.pending_unsubscribe_channels.clear()
+ self.pending_unsubscribe_patterns.clear()
+ if self.channels:
+ channels = {}
+ for k, v in self.channels.items():
+ channels[self.encoder.decode(k, force=True)] = v
+ await self.subscribe(**channels)
+ if self.patterns:
+ patterns = {}
+ for k, v in self.patterns.items():
+ patterns[self.encoder.decode(k, force=True)] = v
+ await self.psubscribe(**patterns)
+
+ @property
+ def subscribed(self):
+ """Indicates if there are subscriptions to any channels or patterns"""
+ return bool(self.channels or self.patterns)
+
+ async def execute_command(self, *args: EncodableT):
+ """Execute a publish/subscribe command"""
+
+ # NOTE: don't parse the response in this function -- it could pull a
+ # legitimate message off the stack if the connection is already
+ # subscribed to one or more channels
+
+ if self.connection is None:
+ self.connection = await self.connection_pool.get_connection(
+ "pubsub", self.shard_hint
+ )
+ # register a callback that re-subscribes to any channels we
+ # were listening to when we were disconnected
+ self.connection.register_connect_callback(self.on_connect)
+ connection = self.connection
+ kwargs = {"check_health": not self.subscribed}
+ await self._execute(connection, connection.send_command, *args, **kwargs)
+
+ async def _disconnect_raise_connect(self, conn, error):
+ """
+ Close the connection and raise an exception
+ if retry_on_timeout is not set or the error
+ is not a TimeoutError. Otherwise, try to reconnect
+ """
+ await conn.disconnect()
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ raise error
+ await conn.connect()
+
+ async def _execute(self, conn, command, *args, **kwargs):
+ """
+ Connect manually upon disconnection. If the Redis server is down,
+ this will fail and raise a ConnectionError as desired.
+ After reconnection, the ``on_connect`` callback should have been
+ called by the # connection to resubscribe us to any channels and
+ patterns we were previously listening to
+ """
+ return await conn.retry.call_with_retry(
+ lambda: command(*args, **kwargs),
+ lambda error: self._disconnect_raise_connect(conn, error),
+ )
+
+ async def parse_response(self, block: bool = True, timeout: float = 0):
+ """Parse the response from a publish/subscribe command"""
+ conn = self.connection
+ if conn is None:
+ raise RuntimeError(
+ "pubsub connection not set: "
+ "did you forget to call subscribe() or psubscribe()?"
+ )
+
+ await self.check_health()
+
+ if not block and not await self._execute(conn, conn.can_read, timeout=timeout):
+ return None
+ response = await self._execute(conn, conn.read_response)
+
+ if conn.health_check_interval and response == self.health_check_response:
+ # ignore the health check message as user might not expect it
+ return None
+ return response
+
+ async def check_health(self):
+ conn = self.connection
+ if conn is None:
+ raise RuntimeError(
+ "pubsub connection not set: "
+ "did you forget to call subscribe() or psubscribe()?"
+ )
+
+ if (
+ conn.health_check_interval
+ and asyncio.get_event_loop().time() > conn.next_health_check
+ ):
+ await conn.send_command(
+ "PING", self.HEALTH_CHECK_MESSAGE, check_health=False
+ )
+
+ def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT:
+ """
+ normalize channel/pattern names to be either bytes or strings
+ based on whether responses are automatically decoded. this saves us
+ from coercing the value for each message coming in.
+ """
+ encode = self.encoder.encode
+ decode = self.encoder.decode
+ return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501
+
+ async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
+ """
+ Subscribe to channel patterns. Patterns supplied as keyword arguments
+ expect a pattern name as the key and a callable as the value. A
+ pattern's callable will be invoked automatically when a message is
+ received on that pattern rather than producing a message via
+ ``listen()``.
+ """
+ parsed_args = list_or_args((args[0],), args[1:]) if args else args
+ new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args)
+ # Mypy bug: https://github.com/python/mypy/issues/10970
+ new_patterns.update(kwargs) # type: ignore[arg-type]
+ ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys())
+ # update the patterns dict AFTER we send the command. we don't want to
+ # subscribe twice to these patterns, once for the command and again
+ # for the reconnection.
+ new_patterns = self._normalize_keys(new_patterns)
+ self.patterns.update(new_patterns)
+ self.pending_unsubscribe_patterns.difference_update(new_patterns)
+ return ret_val
+
+ def punsubscribe(self, *args: ChannelT) -> Awaitable:
+ """
+ Unsubscribe from the supplied patterns. If empty, unsubscribe from
+ all patterns.
+ """
+ patterns: Iterable[ChannelT]
+ if args:
+ parsed_args = list_or_args((args[0],), args[1:])
+ patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys()
+ else:
+ parsed_args = []
+ patterns = self.patterns
+ self.pending_unsubscribe_patterns.update(patterns)
+ return self.execute_command("PUNSUBSCRIBE", *parsed_args)
+
+ async def subscribe(self, *args: ChannelT, **kwargs: Callable):
+ """
+ Subscribe to channels. Channels supplied as keyword arguments expect
+ a channel name as the key and a callable as the value. A channel's
+ callable will be invoked automatically when a message is received on
+ that channel rather than producing a message via ``listen()`` or
+ ``get_message()``.
+ """
+ parsed_args = list_or_args((args[0],), args[1:]) if args else ()
+ new_channels = dict.fromkeys(parsed_args)
+ # Mypy bug: https://github.com/python/mypy/issues/10970
+ new_channels.update(kwargs) # type: ignore[arg-type]
+ ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys())
+ # update the channels dict AFTER we send the command. we don't want to
+ # subscribe twice to these channels, once for the command and again
+ # for the reconnection.
+ new_channels = self._normalize_keys(new_channels)
+ self.channels.update(new_channels)
+ self.pending_unsubscribe_channels.difference_update(new_channels)
+ return ret_val
+
+ def unsubscribe(self, *args) -> Awaitable:
+ """
+ Unsubscribe from the supplied channels. If empty, unsubscribe from
+ all channels
+ """
+ if args:
+ parsed_args = list_or_args(args[0], args[1:])
+ channels = self._normalize_keys(dict.fromkeys(parsed_args))
+ else:
+ parsed_args = []
+ channels = self.channels
+ self.pending_unsubscribe_channels.update(channels)
+ return self.execute_command("UNSUBSCRIBE", *parsed_args)
+
+ async def listen(self) -> AsyncIterator:
+ """Listen for messages on channels this client has been subscribed to"""
+ while self.subscribed:
+ response = await self.handle_message(await self.parse_response(block=True))
+ if response is not None:
+ yield response
+
+ async def get_message(
+ self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
+ ):
+ """
+ Get the next message if one is available, otherwise None.
+
+ If timeout is specified, the system will wait for `timeout` seconds
+ before returning. Timeout should be specified as a floating point
+ number.
+ """
+ response = await self.parse_response(block=False, timeout=timeout)
+ if response:
+ return await self.handle_message(response, ignore_subscribe_messages)
+ return None
+
+ def ping(self, message=None) -> Awaitable:
+ """
+ Ping the Redis server
+ """
+ message = "" if message is None else message
+ return self.execute_command("PING", message)
+
+ async def handle_message(self, response, ignore_subscribe_messages=False):
+ """
+ Parses a pub/sub message. If the channel or pattern was subscribed to
+ with a message handler, the handler is invoked instead of a parsed
+ message being returned.
+ """
+ message_type = str_if_bytes(response[0])
+ if message_type == "pmessage":
+ message = {
+ "type": message_type,
+ "pattern": response[1],
+ "channel": response[2],
+ "data": response[3],
+ }
+ elif message_type == "pong":
+ message = {
+ "type": message_type,
+ "pattern": None,
+ "channel": None,
+ "data": response[1],
+ }
+ else:
+ message = {
+ "type": message_type,
+ "pattern": None,
+ "channel": response[1],
+ "data": response[2],
+ }
+
+ # if this is an unsubscribe message, remove it from memory
+ if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
+ if message_type == "punsubscribe":
+ pattern = response[1]
+ if pattern in self.pending_unsubscribe_patterns:
+ self.pending_unsubscribe_patterns.remove(pattern)
+ self.patterns.pop(pattern, None)
+ else:
+ channel = response[1]
+ if channel in self.pending_unsubscribe_channels:
+ self.pending_unsubscribe_channels.remove(channel)
+ self.channels.pop(channel, None)
+
+ if message_type in self.PUBLISH_MESSAGE_TYPES:
+ # if there's a message handler, invoke it
+ if message_type == "pmessage":
+ handler = self.patterns.get(message["pattern"], None)
+ else:
+ handler = self.channels.get(message["channel"], None)
+ if handler:
+ if inspect.iscoroutinefunction(handler):
+ await handler(message)
+ else:
+ handler(message)
+ return None
+ elif message_type != "pong":
+ # this is a subscribe/unsubscribe message. ignore if we don't
+ # want them
+ if ignore_subscribe_messages or self.ignore_subscribe_messages:
+ return None
+
+ return message
+
+ async def run(
+ self,
+ *,
+ exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
+ poll_timeout: float = 1.0,
+ ) -> None:
+ """Process pub/sub messages using registered callbacks.
+
+ This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
+ redis-py, but it is a coroutine. To launch it as a separate task, use
+ ``asyncio.create_task``:
+
+ >>> task = asyncio.create_task(pubsub.run())
+
+ To shut it down, use asyncio cancellation:
+
+ >>> task.cancel()
+ >>> await task
+ """
+ for channel, handler in self.channels.items():
+ if handler is None:
+ raise PubSubError(f"Channel: '{channel}' has no handler registered")
+ for pattern, handler in self.patterns.items():
+ if handler is None:
+ raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
+
+ while True:
+ try:
+ await self.get_message(
+ ignore_subscribe_messages=True, timeout=poll_timeout
+ )
+ except asyncio.CancelledError:
+ raise
+ except BaseException as e:
+ if exception_handler is None:
+ raise
+ res = exception_handler(e, self)
+ if inspect.isawaitable(res):
+ await res
+ # Ensure that other tasks on the event loop get a chance to run
+ # if we didn't have to block for I/O anywhere.
+ await asyncio.sleep(0)
+
+
+class PubsubWorkerExceptionHandler(Protocol):
+ def __call__(self, e: BaseException, pubsub: PubSub):
+ ...
+
+
+class AsyncPubsubWorkerExceptionHandler(Protocol):
+ async def __call__(self, e: BaseException, pubsub: PubSub):
+ ...
+
+
+PSWorkerThreadExcHandlerT = Union[
+ PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler
+]
+
+
+CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]]
+CommandStackT = List[CommandT]
+
+
+class Pipeline(Redis): # lgtm [py/init-calls-subclass]
+ """
+ Pipelines provide a way to transmit multiple commands to the Redis server
+ in one transmission. This is convenient for batch processing, such as
+ saving all the values in a list to Redis.
+
+ All commands executed within a pipeline are wrapped with MULTI and EXEC
+ calls. This guarantees all commands executed in the pipeline will be
+ executed atomically.
+
+ Any command raising an exception does *not* halt the execution of
+ subsequent commands in the pipeline. Instead, the exception is caught
+ and its instance is placed into the response list returned by execute().
+ Code iterating over the response list should be able to deal with an
+ instance of an exception as a potential value. In general, these will be
+ ResponseError exceptions, such as those raised when issuing a command
+ on a key of a different datatype.
+ """
+
+ UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"}
+
+ def __init__(
+ self,
+ connection_pool: ConnectionPool,
+ response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT],
+ transaction: bool,
+ shard_hint: Optional[str],
+ ):
+ self.connection_pool = connection_pool
+ self.connection = None
+ self.response_callbacks = response_callbacks
+ self.is_transaction = transaction
+ self.shard_hint = shard_hint
+ self.watching = False
+ self.command_stack: CommandStackT = []
+ self.scripts: Set["Script"] = set()
+ self.explicit_transaction = False
+
+ async def __aenter__(self: _RedisT) -> _RedisT:
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.reset()
+
+ def __await__(self):
+ return self._async_self().__await__()
+
+ _DEL_MESSAGE = "Unclosed Pipeline client"
+
+ def __len__(self):
+ return len(self.command_stack)
+
+ def __bool__(self):
+ """Pipeline instances should always evaluate to True"""
+ return True
+
+ async def _async_self(self):
+ return self
+
+ async def reset(self):
+ self.command_stack = []
+ self.scripts = set()
+ # make sure to reset the connection state in the event that we were
+ # watching something
+ if self.watching and self.connection:
+ try:
+ # call this manually since our unwatch or
+ # immediate_execute_command methods can call reset()
+ await self.connection.send_command("UNWATCH")
+ await self.connection.read_response()
+ except ConnectionError:
+ # disconnect will also remove any previous WATCHes
+ if self.connection:
+ await self.connection.disconnect()
+ # clean up the other instance attributes
+ self.watching = False
+ self.explicit_transaction = False
+ # we can safely return the connection to the pool here since we're
+ # sure we're no longer WATCHing anything
+ if self.connection:
+ await self.connection_pool.release(self.connection)
+ self.connection = None
+
+ def multi(self):
+ """
+ Start a transactional block of the pipeline after WATCH commands
+ are issued. End the transactional block with `execute`.
+ """
+ if self.explicit_transaction:
+ raise RedisError("Cannot issue nested calls to MULTI")
+ if self.command_stack:
+ raise RedisError(
+ "Commands without an initial WATCH have already " "been issued"
+ )
+ self.explicit_transaction = True
+
+ def execute_command(
+ self, *args, **kwargs
+ ) -> Union["Pipeline", Awaitable["Pipeline"]]:
+ if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
+ return self.immediate_execute_command(*args, **kwargs)
+ return self.pipeline_execute_command(*args, **kwargs)
+
+ async def _disconnect_reset_raise(self, conn, error):
+ """
+ Close the connection, reset watching state and
+ raise an exception if we were watching,
+ retry_on_timeout is not set,
+ or the error is not a TimeoutError
+ """
+ await conn.disconnect()
+ # if we were already watching a variable, the watch is no longer
+ # valid since this connection has died. raise a WatchError, which
+ # indicates the user should retry this transaction.
+ if self.watching:
+ await self.reset()
+ raise WatchError(
+ "A ConnectionError occurred on while " "watching one or more keys"
+ )
+ # if retry_on_timeout is not set, or the error is not
+ # a TimeoutError, raise it
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ await self.reset()
+ raise
+
+ async def immediate_execute_command(self, *args, **options):
+ """
+ Execute a command immediately, but don't auto-retry on a
+ ConnectionError if we're already WATCHing a variable. Used when
+ issuing WATCH or subsequent commands retrieving their values but before
+ MULTI is called.
+ """
+ command_name = args[0]
+ conn = self.connection
+ # if this is the first call, we need a connection
+ if not conn:
+ conn = await self.connection_pool.get_connection(
+ command_name, self.shard_hint
+ )
+ self.connection = conn
+
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, command_name, *args, **options
+ ),
+ lambda error: self._disconnect_reset_raise(conn, error),
+ )
+
+ def pipeline_execute_command(self, *args, **options):
+ """
+ Stage a command to be executed when execute() is next called
+
+ Returns the current Pipeline object back so commands can be
+ chained together, such as:
+
+ pipe = pipe.set('foo', 'bar').incr('baz').decr('bang')
+
+ At some other point, you can then run: pipe.execute(),
+ which will execute all commands queued in the pipe.
+ """
+ self.command_stack.append((args, options))
+ return self
+
+ async def _execute_transaction( # noqa: C901
+ self, connection: Connection, commands: CommandStackT, raise_on_error
+ ):
+ pre: CommandT = (("MULTI",), {})
+ post: CommandT = (("EXEC",), {})
+ cmds = (pre, *commands, post)
+ all_cmds = connection.pack_commands(
+ args for args, options in cmds if EMPTY_RESPONSE not in options
+ )
+ await connection.send_packed_command(all_cmds)
+ errors = []
+
+ # parse off the response for MULTI
+ # NOTE: we need to handle ResponseErrors here and continue
+ # so that we read all the additional command messages from
+ # the socket
+ try:
+ await self.parse_response(connection, "_")
+ except ResponseError as err:
+ errors.append((0, err))
+
+ # and all the other commands
+ for i, command in enumerate(commands):
+ if EMPTY_RESPONSE in command[1]:
+ errors.append((i, command[1][EMPTY_RESPONSE]))
+ else:
+ try:
+ await self.parse_response(connection, "_")
+ except ResponseError as err:
+ self.annotate_exception(err, i + 1, command[0])
+ errors.append((i, err))
+
+ # parse the EXEC.
+ try:
+ response = await self.parse_response(connection, "_")
+ except ExecAbortError as err:
+ if errors:
+ raise errors[0][1] from err
+ raise
+
+ # EXEC clears any watched keys
+ self.watching = False
+
+ if response is None:
+ raise WatchError("Watched variable changed.") from None
+
+ # put any parse errors into the response
+ for i, e in errors:
+ response.insert(i, e)
+
+ if len(response) != len(commands):
+ if self.connection:
+ await self.connection.disconnect()
+ raise ResponseError(
+ "Wrong number of response items from pipeline execution"
+ ) from None
+
+ # find any errors in the response and raise if necessary
+ if raise_on_error:
+ self.raise_first_error(commands, response)
+
+ # We have to run response callbacks manually
+ data = []
+ for r, cmd in zip(response, commands):
+ if not isinstance(r, Exception):
+ args, options = cmd
+ command_name = args[0]
+ if command_name in self.response_callbacks:
+ r = self.response_callbacks[command_name](r, **options)
+ if inspect.isawaitable(r):
+ r = await r
+ data.append(r)
+ return data
+
+ async def _execute_pipeline(
+ self, connection: Connection, commands: CommandStackT, raise_on_error: bool
+ ):
+ # build up all commands into a single request to increase network perf
+ all_cmds = connection.pack_commands([args for args, _ in commands])
+ await connection.send_packed_command(all_cmds)
+
+ response = []
+ for args, options in commands:
+ try:
+ response.append(
+ await self.parse_response(connection, args[0], **options)
+ )
+ except ResponseError as e:
+ response.append(e)
+
+ if raise_on_error:
+ self.raise_first_error(commands, response)
+ return response
+
+ def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]):
+ for i, r in enumerate(response):
+ if isinstance(r, ResponseError):
+ self.annotate_exception(r, i + 1, commands[i][0])
+ raise r
+
+ def annotate_exception(
+ self, exception: Exception, number: int, command: Iterable[object]
+ ) -> None:
+ cmd = " ".join(map(safe_str, command))
+ msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}"
+ exception.args = (msg,) + exception.args[1:]
+
+ async def parse_response(
+ self, connection: Connection, command_name: Union[str, bytes], **options
+ ):
+ result = await super().parse_response(connection, command_name, **options)
+ if command_name in self.UNWATCH_COMMANDS:
+ self.watching = False
+ elif command_name == "WATCH":
+ self.watching = True
+ return result
+
+ async def load_scripts(self):
+ # make sure all scripts that are about to be run on this pipeline exist
+ scripts = list(self.scripts)
+ immediate = self.immediate_execute_command
+ shas = [s.sha for s in scripts]
+ # we can't use the normal script_* methods because they would just
+ # get buffered in the pipeline.
+ exists = await immediate("SCRIPT EXISTS", *shas)
+ if not all(exists):
+ for s, exist in zip(scripts, exists):
+ if not exist:
+ s.sha = await immediate("SCRIPT LOAD", s.script)
+
+ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
+ """
+ Close the connection, raise an exception if we were watching,
+ and raise an exception if retry_on_timeout is not set,
+ or the error is not a TimeoutError
+ """
+ await conn.disconnect()
+ # if we were watching a variable, the watch is no longer valid
+ # since this connection has died. raise a WatchError, which
+ # indicates the user should retry this transaction.
+ if self.watching:
+ raise WatchError(
+ "A ConnectionError occurred on while " "watching one or more keys"
+ )
+ # if retry_on_timeout is not set, or the error is not
+ # a TimeoutError, raise it
+ if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
+ await self.reset()
+ raise
+
+ async def execute(self, raise_on_error: bool = True):
+ """Execute all the commands in the current pipeline"""
+ stack = self.command_stack
+ if not stack and not self.watching:
+ return []
+ if self.scripts:
+ await self.load_scripts()
+ if self.is_transaction or self.explicit_transaction:
+ execute = self._execute_transaction
+ else:
+ execute = self._execute_pipeline
+
+ conn = self.connection
+ if not conn:
+ conn = await self.connection_pool.get_connection("MULTI", self.shard_hint)
+ # assign to self.connection so reset() releases the connection
+ # back to the pool after we're done
+ self.connection = conn
+ conn = cast(Connection, conn)
+
+ try:
+ return await conn.retry.call_with_retry(
+ lambda: execute(conn, stack, raise_on_error),
+ lambda error: self._disconnect_raise_reset(conn, error),
+ )
+ finally:
+ await self.reset()
+
+ async def discard(self):
+ """Flushes all previously queued commands
+ See: https://redis.io/commands/DISCARD
+ """
+ await self.execute_command("DISCARD")
+
+ async def watch(self, *names: KeyT):
+ """Watches the values at keys ``names``"""
+ if self.explicit_transaction:
+ raise RedisError("Cannot issue a WATCH after a MULTI")
+ return await self.execute_command("WATCH", *names)
+
+ async def unwatch(self):
+ """Unwatches all previously specified keys"""
+ return self.watching and await self.execute_command("UNWATCH") or True
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
new file mode 100644
index 0000000..ae54d81
--- /dev/null
+++ b/redis/asyncio/connection.py
@@ -0,0 +1,1707 @@
+import asyncio
+import copy
+import enum
+import errno
+import inspect
+import io
+import os
+import socket
+import ssl
+import sys
+import threading
+import weakref
+from itertools import chain
+from types import MappingProxyType
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+from urllib.parse import ParseResult, parse_qs, unquote, urlparse
+
+import async_timeout
+
+from redis.asyncio.retry import Retry
+from redis.backoff import NoBackoff
+from redis.compat import Protocol, TypedDict
+from redis.exceptions import (
+ AuthenticationError,
+ AuthenticationWrongNumberOfArgsError,
+ BusyLoadingError,
+ ChildDeadlockedError,
+ ConnectionError,
+ DataError,
+ ExecAbortError,
+ InvalidResponse,
+ ModuleError,
+ NoPermissionError,
+ NoScriptError,
+ ReadOnlyError,
+ RedisError,
+ ResponseError,
+ TimeoutError,
+)
+from redis.typing import EncodableT, EncodedT
+from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
+
+hiredis = None
+if HIREDIS_AVAILABLE:
+ import hiredis
+
+NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
+ BlockingIOError: errno.EWOULDBLOCK,
+ ssl.SSLWantReadError: 2,
+ ssl.SSLWantWriteError: 2,
+ ssl.SSLError: 2,
+}
+
+NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
+
+
+SYM_STAR = b"*"
+SYM_DOLLAR = b"$"
+SYM_CRLF = b"\r\n"
+SYM_LF = b"\n"
+SYM_EMPTY = b""
+
+SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
+
+
+class _Sentinel(enum.Enum):
+ sentinel = object()
+
+
+SENTINEL = _Sentinel.sentinel
+MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
+NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
+MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
+MODULE_EXPORTS_DATA_TYPES_ERROR = (
+ "Error unloading module: the module "
+ "exports one or more module-side data "
+ "types, can't unload"
+)
+
+
+class _HiredisReaderArgs(TypedDict, total=False):
+ protocolError: Callable[[str], Exception]
+ replyError: Callable[[str], Exception]
+ encoding: Optional[str]
+ errors: Optional[str]
+
+
+class Encoder:
+ """Encode strings to bytes-like and decode bytes-like to strings"""
+
+ __slots__ = "encoding", "encoding_errors", "decode_responses"
+
+ def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool):
+ self.encoding = encoding
+ self.encoding_errors = encoding_errors
+ self.decode_responses = decode_responses
+
+ def encode(self, value: EncodableT) -> EncodedT:
+ """Return a bytestring or bytes-like representation of the value"""
+ if isinstance(value, (bytes, memoryview)):
+ return value
+ if isinstance(value, bool):
+ # special case bool since it is a subclass of int
+ raise DataError(
+ "Invalid input of type: 'bool'. "
+ "Convert to a bytes, string, int or float first."
+ )
+ if isinstance(value, (int, float)):
+ return repr(value).encode()
+ if not isinstance(value, str):
+ # a value we don't know how to deal with. throw an error
+ typename = value.__class__.__name__ # type: ignore[unreachable]
+ raise DataError(
+ f"Invalid input of type: {typename!r}. "
+ "Convert to a bytes, string, int or float first."
+ )
+ return value.encode(self.encoding, self.encoding_errors)
+
+ def decode(self, value: EncodableT, force=False) -> EncodableT:
+ """Return a unicode string from the bytes-like representation"""
+ if self.decode_responses or force:
+ if isinstance(value, memoryview):
+ return value.tobytes().decode(self.encoding, self.encoding_errors)
+ if isinstance(value, bytes):
+ return value.decode(self.encoding, self.encoding_errors)
+ return value
+
+
+ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]]
+
+
+class BaseParser:
+ """Plain Python parsing class"""
+
+ __slots__ = "_stream", "_buffer", "_read_size"
+
+ EXCEPTION_CLASSES: ExceptionMappingT = {
+ "ERR": {
+ "max number of clients reached": ConnectionError,
+ "Client sent AUTH, but no password is set": AuthenticationError,
+ "invalid password": AuthenticationError,
+ # some Redis server versions report invalid command syntax
+ # in lowercase
+ "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501
+ # some Redis server versions report invalid command syntax
+ # in uppercase
+ "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501
+ MODULE_LOAD_ERROR: ModuleError,
+ MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
+ NO_SUCH_MODULE_ERROR: ModuleError,
+ MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
+ },
+ "EXECABORT": ExecAbortError,
+ "LOADING": BusyLoadingError,
+ "NOSCRIPT": NoScriptError,
+ "READONLY": ReadOnlyError,
+ "NOAUTH": AuthenticationError,
+ "NOPERM": NoPermissionError,
+ }
+
+ def __init__(self, socket_read_size: int):
+ self._stream: Optional[asyncio.StreamReader] = None
+ self._buffer: Optional[SocketBuffer] = None
+ self._read_size = socket_read_size
+
+ def __del__(self):
+ try:
+ self.on_disconnect()
+ except Exception:
+ pass
+
+ def parse_error(self, response: str) -> ResponseError:
+ """Parse an error response"""
+ error_code = response.split(" ")[0]
+ if error_code in self.EXCEPTION_CLASSES:
+ response = response[len(error_code) + 1 :]
+ exception_class = self.EXCEPTION_CLASSES[error_code]
+ if isinstance(exception_class, dict):
+ exception_class = exception_class.get(response, ResponseError)
+ return exception_class(response)
+ return ResponseError(response)
+
+ def on_disconnect(self):
+ raise NotImplementedError()
+
+ def on_connect(self, connection: "Connection"):
+ raise NotImplementedError()
+
+ async def can_read(self, timeout: float) -> bool:
+ raise NotImplementedError()
+
+ async def read_response(
+ self, disable_decoding: bool = False
+ ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
+ raise NotImplementedError()
+
+
+class SocketBuffer:
+ """Async-friendly re-impl of redis-py's SocketBuffer.
+
+ TODO: We're currently passing through two buffers,
+ the asyncio.StreamReader and this. I imagine we can reduce the layers here
+ while maintaining compliance with prior art.
+ """
+
+ def __init__(
+ self,
+ stream_reader: asyncio.StreamReader,
+ socket_read_size: int,
+ socket_timeout: Optional[float],
+ ):
+ self._stream: Optional[asyncio.StreamReader] = stream_reader
+ self.socket_read_size = socket_read_size
+ self.socket_timeout = socket_timeout
+ self._buffer: Optional[io.BytesIO] = io.BytesIO()
+ # number of bytes written to the buffer from the socket
+ self.bytes_written = 0
+ # number of bytes read from the buffer
+ self.bytes_read = 0
+
+ @property
+ def length(self):
+ return self.bytes_written - self.bytes_read
+
+ async def _read_from_socket(
+ self,
+ length: Optional[int] = None,
+ timeout: Union[float, None, _Sentinel] = SENTINEL,
+ raise_on_timeout: bool = True,
+ ) -> bool:
+ buf = self._buffer
+ if buf is None or self._stream is None:
+ raise RedisError("Buffer is closed.")
+ buf.seek(self.bytes_written)
+ marker = 0
+ timeout = timeout if timeout is not SENTINEL else self.socket_timeout
+
+ try:
+ while True:
+ async with async_timeout.timeout(timeout):
+ data = await self._stream.read(self.socket_read_size)
+ # an empty string indicates the server shutdown the socket
+ if isinstance(data, bytes) and len(data) == 0:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ buf.write(data)
+ data_length = len(data)
+ self.bytes_written += data_length
+ marker += data_length
+
+ if length is not None and length > marker:
+ continue
+ return True
+ except (socket.timeout, asyncio.TimeoutError):
+ if raise_on_timeout:
+ raise TimeoutError("Timeout reading from socket")
+ return False
+ except NONBLOCKING_EXCEPTIONS as ex:
+ # if we're in nonblocking mode and the recv raises a
+ # blocking error, simply return False indicating that
+ # there's no data to be read. otherwise raise the
+ # original exception.
+ allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
+ if not raise_on_timeout and ex.errno == allowed:
+ return False
+ raise ConnectionError(f"Error while reading from socket: {ex.args}")
+
+ async def can_read(self, timeout: float) -> bool:
+ return bool(self.length) or await self._read_from_socket(
+ timeout=timeout, raise_on_timeout=False
+ )
+
+ async def read(self, length: int) -> bytes:
+ length = length + 2 # make sure to read the \r\n terminator
+ # make sure we've read enough data from the socket
+ if length > self.length:
+ await self._read_from_socket(length - self.length)
+
+ if self._buffer is None:
+ raise RedisError("Buffer is closed.")
+
+ self._buffer.seek(self.bytes_read)
+ data = self._buffer.read(length)
+ self.bytes_read += len(data)
+
+ # purge the buffer when we've consumed it all so it doesn't
+ # grow forever
+ if self.bytes_read == self.bytes_written:
+ self.purge()
+
+ return data[:-2]
+
+ async def readline(self) -> bytes:
+ buf = self._buffer
+ if buf is None:
+ raise RedisError("Buffer is closed.")
+
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+ while not data.endswith(SYM_CRLF):
+ # there's more data in the socket that we need
+ await self._read_from_socket()
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+
+ self.bytes_read += len(data)
+
+ # purge the buffer when we've consumed it all so it doesn't
+ # grow forever
+ if self.bytes_read == self.bytes_written:
+ self.purge()
+
+ return data[:-2]
+
+ def purge(self):
+ if self._buffer is None:
+ raise RedisError("Buffer is closed.")
+
+ self._buffer.seek(0)
+ self._buffer.truncate()
+ self.bytes_written = 0
+ self.bytes_read = 0
+
+ def close(self):
+ try:
+ self.purge()
+ self._buffer.close() # type: ignore[union-attr]
+ except Exception:
+ # issue #633 suggests the purge/close somehow raised a
+ # BadFileDescriptor error. Perhaps the client ran out of
+ # memory or something else? It's probably OK to ignore
+ # any error being raised from purge/close since we're
+ # removing the reference to the instance below.
+ pass
+ self._buffer = None
+ self._stream = None
+
+
+class PythonParser(BaseParser):
+ """Plain Python parsing class"""
+
+ __slots__ = BaseParser.__slots__ + ("encoder",)
+
+ def __init__(self, socket_read_size: int):
+ super().__init__(socket_read_size)
+ self.encoder: Optional[Encoder] = None
+
+ def on_connect(self, connection: "Connection"):
+ """Called when the stream connects"""
+ self._stream = connection._reader
+ if self._stream is None:
+ raise RedisError("Buffer is closed.")
+
+ self._buffer = SocketBuffer(
+ self._stream, self._read_size, connection.socket_timeout
+ )
+ self.encoder = connection.encoder
+
+ def on_disconnect(self):
+ """Called when the stream disconnects"""
+ if self._stream is not None:
+ self._stream = None
+ if self._buffer is not None:
+ self._buffer.close()
+ self._buffer = None
+ self.encoder = None
+
+ async def can_read(self, timeout: float):
+ return self._buffer and bool(await self._buffer.can_read(timeout))
+
+ async def read_response(
+ self, disable_decoding: bool = False
+ ) -> Union[EncodableT, ResponseError, None]:
+ if not self._buffer or not self.encoder:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ raw = await self._buffer.readline()
+ if not raw:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ response: Any
+ byte, response = raw[:1], raw[1:]
+
+ if byte not in (b"-", b"+", b":", b"$", b"*"):
+ raise InvalidResponse(f"Protocol Error: {raw!r}")
+
+ # server returned an error
+ if byte == b"-":
+ response = response.decode("utf-8", errors="replace")
+ error = self.parse_error(response)
+ # if the error is a ConnectionError, raise immediately so the user
+ # is notified
+ if isinstance(error, ConnectionError):
+ raise error
+ # otherwise, we're dealing with a ResponseError that might belong
+ # inside a pipeline response. the connection's read_response()
+ # and/or the pipeline's execute() will raise this error if
+ # necessary, so just return the exception instance here.
+ return error
+ # single value
+ elif byte == b"+":
+ pass
+ # int value
+ elif byte == b":":
+ response = int(response)
+ # bulk response
+ elif byte == b"$":
+ length = int(response)
+ if length == -1:
+ return None
+ response = await self._buffer.read(length)
+ # multi-bulk response
+ elif byte == b"*":
+ length = int(response)
+ if length == -1:
+ return None
+ response = [
+ (await self.read_response(disable_decoding)) for _ in range(length)
+ ]
+ if isinstance(response, bytes) and disable_decoding is False:
+ response = self.encoder.decode(response)
+ return response
+
+
+class HiredisParser(BaseParser):
+ """Parser class for connections using Hiredis"""
+
+ __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")
+
+ _next_response: bool
+
+ def __init__(self, socket_read_size: int):
+ if not HIREDIS_AVAILABLE:
+ raise RedisError("Hiredis is not available.")
+ super().__init__(socket_read_size=socket_read_size)
+ self._reader: Optional[hiredis.Reader] = None
+ self._socket_timeout: Optional[float] = None
+
+ def on_connect(self, connection: "Connection"):
+ self._stream = connection._reader
+ kwargs: _HiredisReaderArgs = {
+ "protocolError": InvalidResponse,
+ "replyError": self.parse_error,
+ }
+ if connection.encoder.decode_responses:
+ kwargs["encoding"] = connection.encoder.encoding
+ kwargs["errors"] = connection.encoder.encoding_errors
+
+ self._reader = hiredis.Reader(**kwargs)
+ self._next_response = False
+ self._socket_timeout = connection.socket_timeout
+
+ def on_disconnect(self):
+ self._stream = None
+ self._reader = None
+ self._next_response = False
+
+ async def can_read(self, timeout: float):
+ if not self._reader:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+
+ if self._next_response is False:
+ self._next_response = self._reader.gets()
+ if self._next_response is False:
+ return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
+ return True
+
+ async def read_from_socket(
+ self,
+ timeout: Union[float, None, _Sentinel] = SENTINEL,
+ raise_on_timeout: bool = True,
+ ):
+ if self._stream is None or self._reader is None:
+ raise RedisError("Parser already closed.")
+
+ timeout = self._socket_timeout if timeout is SENTINEL else timeout
+ try:
+ async with async_timeout.timeout(timeout):
+ buffer = await self._stream.read(self._read_size)
+ if not isinstance(buffer, bytes) or len(buffer) == 0:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
+ self._reader.feed(buffer)
+ # data was read from the socket and added to the buffer.
+ # return True to indicate that data was read.
+ return True
+ except asyncio.CancelledError:
+ raise
+ except (socket.timeout, asyncio.TimeoutError):
+ if raise_on_timeout:
+ raise TimeoutError("Timeout reading from socket") from None
+ return False
+ except NONBLOCKING_EXCEPTIONS as ex:
+ # if we're in nonblocking mode and the recv raises a
+ # blocking error, simply return False indicating that
+ # there's no data to be read. otherwise raise the
+ # original exception.
+ allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
+ if not raise_on_timeout and ex.errno == allowed:
+ return False
+ raise ConnectionError(f"Error while reading from socket: {ex.args}")
+
+ async def read_response(
+ self, disable_decoding: bool = False
+ ) -> Union[EncodableT, List[EncodableT]]:
+ if not self._stream or not self._reader:
+ self.on_disconnect()
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
+
+ response: Union[
+ EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]
+ ]
+ # _next_response might be cached from a can_read() call
+ if self._next_response is not False:
+ response = self._next_response
+ self._next_response = False
+ return response
+
+ response = self._reader.gets()
+ while response is False:
+ await self.read_from_socket()
+ response = self._reader.gets()
+
+ # if the response is a ConnectionError or the response is a list and
+ # the first item is a ConnectionError, raise it as something bad
+ # happened
+ if isinstance(response, ConnectionError):
+ raise response
+ elif (
+ isinstance(response, list)
+ and response
+ and isinstance(response[0], ConnectionError)
+ ):
+ raise response[0]
+ # cast as there won't be a ConnectionError here.
+ return cast(Union[EncodableT, List[EncodableT]], response)
+
+
+DefaultParser: Type[Union[PythonParser, HiredisParser]]
+if HIREDIS_AVAILABLE:
+ DefaultParser = HiredisParser
+else:
+ DefaultParser = PythonParser
+
+
+class ConnectCallbackProtocol(Protocol):
+ def __call__(self, connection: "Connection"):
+ ...
+
+
+class AsyncConnectCallbackProtocol(Protocol):
+ async def __call__(self, connection: "Connection"):
+ ...
+
+
+ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
+
+
+class Connection:
+ """Manages TCP communication to and from a Redis server"""
+
+ __slots__ = (
+ "pid",
+ "host",
+ "port",
+ "db",
+ "username",
+ "client_name",
+ "password",
+ "socket_timeout",
+ "socket_connect_timeout",
+ "socket_keepalive",
+ "socket_keepalive_options",
+ "socket_type",
+ "redis_connect_func",
+ "retry_on_timeout",
+ "health_check_interval",
+ "next_health_check",
+ "last_active_at",
+ "encoder",
+ "ssl_context",
+ "_reader",
+ "_writer",
+ "_parser",
+ "_connect_callbacks",
+ "_buffer_cutoff",
+ "_lock",
+ "_socket_read_size",
+ "__dict__",
+ )
+
+ def __init__(
+ self,
+ *,
+ host: str = "localhost",
+ port: Union[str, int] = 6379,
+ db: Union[str, int] = 0,
+ password: Optional[str] = None,
+ socket_timeout: Optional[float] = None,
+ socket_connect_timeout: Optional[float] = None,
+ socket_keepalive: bool = False,
+ socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
+ socket_type: int = 0,
+ retry_on_timeout: bool = False,
+ encoding: str = "utf-8",
+ encoding_errors: str = "strict",
+ decode_responses: bool = False,
+ parser_class: Type[BaseParser] = DefaultParser,
+ socket_read_size: int = 65536,
+ health_check_interval: float = 0,
+ client_name: Optional[str] = None,
+ username: Optional[str] = None,
+ retry: Optional[Retry] = None,
+ redis_connect_func: Optional[ConnectCallbackT] = None,
+ encoder_class: Type[Encoder] = Encoder,
+ ):
+ self.pid = os.getpid()
+ self.host = host
+ self.port = int(port)
+ self.db = db
+ self.username = username
+ self.client_name = client_name
+ self.password = password
+ self.socket_timeout = socket_timeout
+ self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
+ self.socket_keepalive = socket_keepalive
+ self.socket_keepalive_options = socket_keepalive_options or {}
+ self.socket_type = socket_type
+ self.retry_on_timeout = retry_on_timeout
+ if retry_on_timeout:
+ if retry is None:
+ self.retry = Retry(NoBackoff(), 1)
+ else:
+ # deep-copy the Retry object as it is mutable
+ self.retry = copy.deepcopy(retry)
+ else:
+ self.retry = Retry(NoBackoff(), 0)
+ self.health_check_interval = health_check_interval
+ self.next_health_check: float = -1
+ self.ssl_context: Optional[RedisSSLContext] = None
+ self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
+ self.redis_connect_func = redis_connect_func
+ self._reader: Optional[asyncio.StreamReader] = None
+ self._writer: Optional[asyncio.StreamWriter] = None
+ self._socket_read_size = socket_read_size
+ self.set_parser(parser_class)
+ self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
+ self._buffer_cutoff = 6000
+ self._lock = asyncio.Lock()
+
+ def __repr__(self):
+ repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
+ return f"{self.__class__.__name__}<{repr_args}>"
+
+ def repr_pieces(self):
+ pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
+ if self.client_name:
+ pieces.append(("client_name", self.client_name))
+ return pieces
+
+ def __del__(self):
+ try:
+ if self.is_connected:
+ loop = asyncio.get_event_loop()
+ coro = self.disconnect()
+ if loop.is_running():
+ loop.create_task(coro)
+ else:
+ loop.run_until_complete(coro)
+ except Exception:
+ pass
+
+ @property
+ def is_connected(self):
+ return bool(self._reader and self._writer)
+
+ def register_connect_callback(self, callback):
+ self._connect_callbacks.append(weakref.WeakMethod(callback))
+
+ def clear_connect_callbacks(self):
+ self._connect_callbacks = []
+
+ def set_parser(self, parser_class):
+ """
+ Creates a new instance of parser_class with socket size:
+ _socket_read_size and assigns it to the parser for the connection
+ :param parser_class: The required parser class
+ """
+ self._parser = parser_class(socket_read_size=self._socket_read_size)
+
+ async def connect(self):
+ """Connects to the Redis server if not already connected"""
+ if self.is_connected:
+ return
+ try:
+ await self._connect()
+ except asyncio.CancelledError:
+ raise
+ except (socket.timeout, asyncio.TimeoutError):
+ raise TimeoutError("Timeout connecting to server")
+ except OSError as e:
+ raise ConnectionError(self._error_message(e))
+ except Exception as exc:
+ raise ConnectionError(exc) from exc
+
+ try:
+ if self.redis_connect_func is None:
+ # Use the default on_connect function
+ await self.on_connect()
+ else:
+ # Use the passed function redis_connect_func
+ await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
+ self.redis_connect_func
+ ) else self.redis_connect_func(self)
+ except RedisError:
+ # clean up after any error in on_connect
+ await self.disconnect()
+ raise
+
+ # run any user callbacks. right now the only internal callback
+ # is for pubsub channel/pattern resubscription
+ for ref in self._connect_callbacks:
+ callback = ref()
+ task = callback(self)
+ if task and inspect.isawaitable(task):
+ await task
+
+ async def _connect(self):
+ """Create a TCP socket connection"""
+ async with async_timeout.timeout(self.socket_connect_timeout):
+ reader, writer = await asyncio.open_connection(
+ host=self.host,
+ port=self.port,
+ ssl=self.ssl_context.get() if self.ssl_context else None,
+ )
+ self._reader = reader
+ self._writer = writer
+ sock = writer.transport.get_extra_info("socket")
+ if sock is not None:
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ try:
+ # TCP_KEEPALIVE
+ if self.socket_keepalive:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ for k, v in self.socket_keepalive_options.items():
+ sock.setsockopt(socket.SOL_TCP, k, v)
+
+ except (OSError, TypeError):
+ # `socket_keepalive_options` might contain invalid options
+ # causing an error. Do not leave the connection open.
+ writer.close()
+ raise
+
+ def _error_message(self, exception):
+ # args for socket.error can either be (errno, "message")
+ # or just "message"
+ if len(exception.args) == 1:
+ return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}."
+ else:
+ return (
+ f"Error {exception.args[0]} connecting to {self.host}:{self.port}. "
+ f"{exception.args[0]}."
+ )
+
+ async def on_connect(self):
+ """Initialize the connection, authenticate and select a database"""
+ self._parser.on_connect(self)
+
+ # if username and/or password are set, authenticate
+ if self.username or self.password:
+ auth_args: Union[Tuple[str], Tuple[str, str]]
+ if self.username:
+ auth_args = (self.username, self.password or "")
+ else:
+ # Mypy bug: https://github.com/python/mypy/issues/10944
+ auth_args = (self.password or "",)
+ # avoid checking health here -- PING will fail if we try
+ # to check the health prior to the AUTH
+ await self.send_command("AUTH", *auth_args, check_health=False)
+
+ try:
+ auth_response = await self.read_response()
+ except AuthenticationWrongNumberOfArgsError:
+ # a username and password were specified but the Redis
+ # server seems to be < 6.0.0 which expects a single password
+ # arg. retry auth with just the password.
+ # https://github.com/andymccurdy/redis-py/issues/1274
+ await self.send_command("AUTH", self.password, check_health=False)
+ auth_response = await self.read_response()
+
+ if str_if_bytes(auth_response) != "OK":
+ raise AuthenticationError("Invalid Username or Password")
+
+ # if a client_name is given, set it
+ if self.client_name:
+ await self.send_command("CLIENT", "SETNAME", self.client_name)
+ if str_if_bytes(await self.read_response()) != "OK":
+ raise ConnectionError("Error setting client name")
+
+ # if a database is specified, switch to it
+ if self.db:
+ await self.send_command("SELECT", self.db)
+ if str_if_bytes(await self.read_response()) != "OK":
+ raise ConnectionError("Invalid Database")
+
+ async def disconnect(self):
+ """Disconnects from the Redis server"""
+ try:
+ async with async_timeout.timeout(self.socket_connect_timeout):
+ self._parser.on_disconnect()
+ if not self.is_connected:
+ return
+ try:
+ if os.getpid() == self.pid:
+ self._writer.close() # type: ignore[union-attr]
+ # py3.6 doesn't have this method
+ if hasattr(self._writer, "wait_closed"):
+ await self._writer.wait_closed() # type: ignore[union-attr]
+ except OSError:
+ pass
+ self._reader = None
+ self._writer = None
+ except asyncio.TimeoutError:
+ raise TimeoutError(
+ f"Timed out closing connection after {self.socket_connect_timeout}"
+ ) from None
+
+ async def _send_ping(self):
+ """Send PING, expect PONG in return"""
+ await self.send_command("PING", check_health=False)
+ if str_if_bytes(await self.read_response()) != "PONG":
+ raise ConnectionError("Bad response from PING health check")
+
+ async def _ping_failed(self, error):
+ """Function to call when PING fails"""
+ await self.disconnect()
+
+ async def check_health(self):
+ """Check the health of the connection with a PING/PONG"""
+ if sys.version_info[0:2] == (3, 6):
+ func = asyncio.get_event_loop
+ else:
+ func = asyncio.get_running_loop
+
+ if self.health_check_interval and func().time() > self.next_health_check:
+ await self.retry.call_with_retry(self._send_ping, self._ping_failed)
+
+ async def _send_packed_command(self, command: Iterable[bytes]) -> None:
+ if self._writer is None:
+ raise RedisError("Connection already closed.")
+
+ self._writer.writelines(command)
+ await self._writer.drain()
+
+ async def send_packed_command(
+ self,
+ command: Union[bytes, str, Iterable[bytes]],
+ check_health: bool = True,
+ ):
+ """Send an already packed command to the Redis server"""
+ if not self._writer:
+ await self.connect()
+ # guard against health check recursion
+ if check_health:
+ await self.check_health()
+ try:
+ if isinstance(command, str):
+ command = command.encode()
+ if isinstance(command, bytes):
+ command = [command]
+ await asyncio.wait_for(
+ self._send_packed_command(command),
+ self.socket_timeout,
+ )
+ except asyncio.TimeoutError:
+ await self.disconnect()
+ raise TimeoutError("Timeout writing to socket") from None
+ except OSError as e:
+ await self.disconnect()
+ if len(e.args) == 1:
+ err_no, errmsg = "UNKNOWN", e.args[0]
+ else:
+ err_no = e.args[0]
+ errmsg = e.args[1]
+ raise ConnectionError(
+ f"Error {err_no} while writing to socket. {errmsg}."
+ ) from e
+ except BaseException:
+ await self.disconnect()
+ raise
+
+ async def send_command(self, *args, **kwargs):
+ """Pack and send a command to the Redis server"""
+ if not self.is_connected:
+ await self.connect()
+ await self.send_packed_command(
+ self.pack_command(*args), check_health=kwargs.get("check_health", True)
+ )
+
+ async def can_read(self, timeout: float = 0):
+ """Poll the socket to see if there's data that can be read."""
+ if not self.is_connected:
+ await self.connect()
+ return await self._parser.can_read(timeout)
+
+ async def read_response(self, disable_decoding: bool = False):
+ """Read the response from a previously sent command"""
+ try:
+ async with self._lock:
+ async with async_timeout.timeout(self.socket_timeout):
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding
+ )
+ except asyncio.TimeoutError:
+ await self.disconnect()
+ raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
+ except OSError as e:
+ await self.disconnect()
+ raise ConnectionError(
+ f"Error while reading from {self.host}:{self.port} : {e.args}"
+ )
+ except BaseException:
+ await self.disconnect()
+ raise
+
+ if self.health_check_interval:
+ if sys.version_info[0:2] == (3, 6):
+ func = asyncio.get_event_loop
+ else:
+ func = asyncio.get_running_loop
+ self.next_health_check = func().time() + self.health_check_interval
+
+ if isinstance(response, ResponseError):
+ raise response from None
+ return response
+
+ def pack_command(self, *args: EncodableT) -> List[bytes]:
+ """Pack a series of arguments into the Redis protocol"""
+ output = []
+ # the client might have included 1 or more literal arguments in
+ # the command name, e.g., 'CONFIG GET'. The Redis server expects these
+ # arguments to be sent separately, so split the first argument
+ # manually. These arguments should be bytestrings so that they are
+ # not encoded.
+ assert not isinstance(args[0], float)
+ if isinstance(args[0], str):
+ args = tuple(args[0].encode().split()) + args[1:]
+ elif b" " in args[0]:
+ args = tuple(args[0].split()) + args[1:]
+
+ buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
+
+ buffer_cutoff = self._buffer_cutoff
+ for arg in map(self.encoder.encode, args):
+ # to avoid large string mallocs, chunk the command into the
+ # output list if we're sending large values or memoryviews
+ arg_length = len(arg)
+ if (
+ len(buff) > buffer_cutoff
+ or arg_length > buffer_cutoff
+ or isinstance(arg, memoryview)
+ ):
+ buff = SYM_EMPTY.join(
+ (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
+ )
+ output.append(buff)
+ output.append(arg)
+ buff = SYM_CRLF
+ else:
+ buff = SYM_EMPTY.join(
+ (
+ buff,
+ SYM_DOLLAR,
+ str(arg_length).encode(),
+ SYM_CRLF,
+ arg,
+ SYM_CRLF,
+ )
+ )
+ output.append(buff)
+ return output
+
+ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
+ """Pack multiple commands into the Redis protocol"""
+ output: List[bytes] = []
+ pieces: List[bytes] = []
+ buffer_length = 0
+ buffer_cutoff = self._buffer_cutoff
+
+ for cmd in commands:
+ for chunk in self.pack_command(*cmd):
+ chunklen = len(chunk)
+ if (
+ buffer_length > buffer_cutoff
+ or chunklen > buffer_cutoff
+ or isinstance(chunk, memoryview)
+ ):
+ output.append(SYM_EMPTY.join(pieces))
+ buffer_length = 0
+ pieces = []
+
+ if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
+ output.append(chunk)
+ else:
+ pieces.append(chunk)
+ buffer_length += chunklen
+
+ if pieces:
+ output.append(SYM_EMPTY.join(pieces))
+ return output
+
+
+class SSLConnection(Connection):
+ def __init__(
+ self,
+ ssl_keyfile: Optional[str] = None,
+ ssl_certfile: Optional[str] = None,
+ ssl_cert_reqs: str = "required",
+ ssl_ca_certs: Optional[str] = None,
+ ssl_check_hostname: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.ssl_context: RedisSSLContext = RedisSSLContext(
+ keyfile=ssl_keyfile,
+ certfile=ssl_certfile,
+ cert_reqs=ssl_cert_reqs,
+ ca_certs=ssl_ca_certs,
+ check_hostname=ssl_check_hostname,
+ )
+
+ @property
+ def keyfile(self):
+ return self.ssl_context.keyfile
+
+ @property
+ def certfile(self):
+ return self.ssl_context.certfile
+
+ @property
+ def cert_reqs(self):
+ return self.ssl_context.cert_reqs
+
+ @property
+ def ca_certs(self):
+ return self.ssl_context.ca_certs
+
+ @property
+ def check_hostname(self):
+ return self.ssl_context.check_hostname
+
+
+class RedisSSLContext:
+ __slots__ = (
+ "keyfile",
+ "certfile",
+ "cert_reqs",
+ "ca_certs",
+ "context",
+ "check_hostname",
+ )
+
+ def __init__(
+ self,
+ keyfile: Optional[str] = None,
+ certfile: Optional[str] = None,
+ cert_reqs: Optional[str] = None,
+ ca_certs: Optional[str] = None,
+ check_hostname: bool = False,
+ ):
+ self.keyfile = keyfile
+ self.certfile = certfile
+ if cert_reqs is None:
+ self.cert_reqs = ssl.CERT_NONE
+ elif isinstance(cert_reqs, str):
+ CERT_REQS = {
+ "none": ssl.CERT_NONE,
+ "optional": ssl.CERT_OPTIONAL,
+ "required": ssl.CERT_REQUIRED,
+ }
+ if cert_reqs not in CERT_REQS:
+ raise RedisError(
+ f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
+ )
+ self.cert_reqs = CERT_REQS[cert_reqs]
+ self.ca_certs = ca_certs
+ self.check_hostname = check_hostname
+ self.context: Optional[ssl.SSLContext] = None
+
+ def get(self) -> ssl.SSLContext:
+ if not self.context:
+ context = ssl.create_default_context()
+ context.check_hostname = self.check_hostname
+ context.verify_mode = self.cert_reqs
+ if self.certfile and self.keyfile:
+ context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
+ if self.ca_certs:
+ context.load_verify_locations(self.ca_certs)
+ self.context = context
+ return self.context
+
+
+class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init]
+ def __init__(
+ self,
+ *,
+ path: str = "",
+ db: Union[str, int] = 0,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ socket_timeout: Optional[float] = None,
+ socket_connect_timeout: Optional[float] = None,
+ encoding: str = "utf-8",
+ encoding_errors: str = "strict",
+ decode_responses: bool = False,
+ retry_on_timeout: bool = False,
+ parser_class: Type[BaseParser] = DefaultParser,
+ socket_read_size: int = 65536,
+ health_check_interval: float = 0.0,
+ client_name: str = None,
+ retry: Optional[Retry] = None,
+ ):
+ """
+ Initialize a new UnixDomainSocketConnection.
+ To specify a retry policy, first set `retry_on_timeout` to `True`
+ then set `retry` to a valid `Retry` object
+ """
+ self.pid = os.getpid()
+ self.path = path
+ self.db = db
+ self.username = username
+ self.client_name = client_name
+ self.password = password
+ self.socket_timeout = socket_timeout
+ self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
+ self.retry_on_timeout = retry_on_timeout
+ if retry_on_timeout:
+ if retry is None:
+ self.retry = Retry(NoBackoff(), 1)
+ else:
+ # deep-copy the Retry object as it is mutable
+ self.retry = copy.deepcopy(retry)
+ else:
+ self.retry = Retry(NoBackoff(), 0)
+ self.health_check_interval = health_check_interval
+ self.next_health_check = -1
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
+ self._sock = None
+ self._reader = None
+ self._writer = None
+ self._socket_read_size = socket_read_size
+ self.set_parser(parser_class)
+ self._connect_callbacks = []
+ self._buffer_cutoff = 6000
+ self._lock = asyncio.Lock()
+
+ def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
+ pieces = [
+ ("path", self.path),
+ ("db", self.db),
+ ]
+ if self.client_name:
+ pieces.append(("client_name", self.client_name))
+ return pieces
+
+ async def _connect(self):
+ async with async_timeout.timeout(self.socket_connect_timeout):
+ reader, writer = await asyncio.open_unix_connection(path=self.path)
+ self._reader = reader
+ self._writer = writer
+ await self.on_connect()
+
+ def _error_message(self, exception):
+ # args for socket.error can either be (errno, "message")
+ # or just "message"
+ if len(exception.args) == 1:
+ return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
+ else:
+ return (
+ f"Error {exception.args[0]} connecting to unix socket: "
+ f"{self.path}. {exception.args[1]}."
+ )
+
+
+FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
+
+
+def to_bool(value) -> Optional[bool]:
+ if value is None or value == "":
+ return None
+ if isinstance(value, str) and value.upper() in FALSE_STRINGS:
+ return False
+ return bool(value)
+
+
+URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
+ {
+ "db": int,
+ "socket_timeout": float,
+ "socket_connect_timeout": float,
+ "socket_keepalive": to_bool,
+ "retry_on_timeout": to_bool,
+ "max_connections": int,
+ "health_check_interval": int,
+ "ssl_check_hostname": to_bool,
+ }
+)
+
+
+class ConnectKwargs(TypedDict, total=False):
+ username: str
+ password: str
+ connection_class: Type[Connection]
+ host: str
+ port: int
+ db: int
+ path: str
+
+
+def parse_url(url: str) -> ConnectKwargs:
+ parsed: ParseResult = urlparse(url)
+ kwargs: ConnectKwargs = {}
+
+ for name, value_list in parse_qs(parsed.query).items():
+ if value_list and len(value_list) > 0:
+ value = unquote(value_list[0])
+ parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
+ if parser:
+ try:
+ # We can't type this.
+ kwargs[name] = parser(value) # type: ignore[misc]
+ except (TypeError, ValueError):
+ raise ValueError(f"Invalid value for `{name}` in connection URL.")
+ else:
+ kwargs[name] = value # type: ignore[misc]
+
+ if parsed.username:
+ kwargs["username"] = unquote(parsed.username)
+ if parsed.password:
+ kwargs["password"] = unquote(parsed.password)
+
+ # We only support redis://, rediss:// and unix:// schemes.
+ if parsed.scheme == "unix":
+ if parsed.path:
+ kwargs["path"] = unquote(parsed.path)
+ kwargs["connection_class"] = UnixDomainSocketConnection
+
+ elif parsed.scheme in ("redis", "rediss"):
+ if parsed.hostname:
+ kwargs["host"] = unquote(parsed.hostname)
+ if parsed.port:
+ kwargs["port"] = int(parsed.port)
+
+ # If there's a path argument, use it as the db argument if a
+ # querystring value wasn't specified
+ if parsed.path and "db" not in kwargs:
+ try:
+ kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
+ except (AttributeError, ValueError):
+ pass
+
+ if parsed.scheme == "rediss":
+ kwargs["connection_class"] = SSLConnection
+ else:
+ valid_schemes = "redis://, rediss://, unix://"
+ raise ValueError(
+ f"Redis URL must specify one of the following schemes ({valid_schemes})"
+ )
+
+ return kwargs
+
+
+_CP = TypeVar("_CP", bound="ConnectionPool")
+
+
+class ConnectionPool:
+ """
+ Create a connection pool. ``If max_connections`` is set, then this
+ object raises :py:class:`~redis.ConnectionError` when the pool's
+ limit is reached.
+
+ By default, TCP connections are created unless ``connection_class``
+ is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
+ unix sockets.
+
+ Any additional keyword arguments are passed to the constructor of
+ ``connection_class``.
+ """
+
+ @classmethod
+ def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
+ """
+ Return a connection pool configured from the given URL.
+
+ For example::
+
+ redis://[[username]:[password]]@localhost:6379/0
+ rediss://[[username]:[password]]@localhost:6379/0
+ unix://[[username]:[password]]@/path/to/socket.sock?db=0
+
+ Three URL schemes are supported:
+
+ - `redis://` creates a TCP socket connection. See more at:
+ <https://www.iana.org/assignments/uri-schemes/prov/redis>
+ - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
+ <https://www.iana.org/assignments/uri-schemes/prov/rediss>
+ - ``unix://``: creates a Unix Domain Socket connection.
+
+ The username, password, hostname, path and all querystring values
+ are passed through urllib.parse.unquote in order to replace any
+ percent-encoded values with their corresponding characters.
+
+ There are several ways to specify a database number. The first value
+ found will be used:
+ 1. A ``db`` querystring option, e.g. redis://localhost?db=0
+ 2. If using the redis:// or rediss:// schemes, the path argument
+ of the url, e.g. redis://localhost/0
+ 3. A ``db`` keyword argument to this function.
+
+ If none of these options are specified, the default db=0 is used.
+
+ All querystring options are cast to their appropriate Python types.
+ Boolean arguments can be specified with string values "True"/"False"
+ or "Yes"/"No". Values that cannot be properly cast cause a
+ ``ValueError`` to be raised. Once parsed, the querystring arguments
+ and keyword arguments are passed to the ``ConnectionPool``'s
+ class initializer. In the case of conflicting arguments, querystring
+ arguments always win.
+ """
+ url_options = parse_url(url)
+ kwargs.update(url_options)
+ return cls(**kwargs)
+
+ def __init__(
+ self,
+ connection_class: Type[Connection] = Connection,
+ max_connections: Optional[int] = None,
+ **connection_kwargs,
+ ):
+ max_connections = max_connections or 2 ** 31
+ if not isinstance(max_connections, int) or max_connections < 0:
+ raise ValueError('"max_connections" must be a positive integer')
+
+ self.connection_class = connection_class
+ self.connection_kwargs = connection_kwargs
+ self.max_connections = max_connections
+
+ # a lock to protect the critical section in _checkpid().
+ # this lock is acquired when the process id changes, such as
+ # after a fork. during this time, multiple threads in the child
+ # process could attempt to acquire this lock. the first thread
+ # to acquire the lock will reset the data structures and lock
+ # object of this pool. subsequent threads acquiring this lock
+ # will notice the first thread already did the work and simply
+ # release the lock.
+ self._fork_lock = threading.Lock()
+ self._lock = asyncio.Lock()
+ self._created_connections: int
+ self._available_connections: List[Connection]
+ self._in_use_connections: Set[Connection]
+ self.reset() # lgtm [py/init-calls-subclass]
+ self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}"
+ f"<{self.connection_class(**self.connection_kwargs)!r}>"
+ )
+
+ def reset(self):
+ self._lock = asyncio.Lock()
+ self._created_connections = 0
+ self._available_connections = []
+ self._in_use_connections = set()
+
+ # this must be the last operation in this method. while reset() is
+ # called when holding _fork_lock, other threads in this process
+ # can call _checkpid() which compares self.pid and os.getpid() without
+ # holding any lock (for performance reasons). keeping this assignment
+ # as the last operation ensures that those other threads will also
+ # notice a pid difference and block waiting for the first thread to
+ # release _fork_lock. when each of these threads eventually acquire
+ # _fork_lock, they will notice that another thread already called
+ # reset() and they will immediately release _fork_lock and continue on.
+ self.pid = os.getpid()
+
+ def _checkpid(self):
+ # _checkpid() attempts to keep ConnectionPool fork-safe on modern
+ # systems. this is called by all ConnectionPool methods that
+ # manipulate the pool's state such as get_connection() and release().
+ #
+ # _checkpid() determines whether the process has forked by comparing
+ # the current process id to the process id saved on the ConnectionPool
+ # instance. if these values are the same, _checkpid() simply returns.
+ #
+ # when the process ids differ, _checkpid() assumes that the process
+ # has forked and that we're now running in the child process. the child
+ # process cannot use the parent's file descriptors (e.g., sockets).
+ # therefore, when _checkpid() sees the process id change, it calls
+ # reset() in order to reinitialize the child's ConnectionPool. this
+ # will cause the child to make all new connection objects.
+ #
+ # _checkpid() is protected by self._fork_lock to ensure that multiple
+ # threads in the child process do not call reset() multiple times.
+ #
+ # there is an extremely small chance this could fail in the following
+ # scenario:
+ # 1. process A calls _checkpid() for the first time and acquires
+ # self._fork_lock.
+ # 2. while holding self._fork_lock, process A forks (the fork()
+ # could happen in a different thread owned by process A)
+ # 3. process B (the forked child process) inherits the
+ # ConnectionPool's state from the parent. that state includes
+ # a locked _fork_lock. process B will not be notified when
+ # process A releases the _fork_lock and will thus never be
+ # able to acquire the _fork_lock.
+ #
+ # to mitigate this possible deadlock, _checkpid() will only wait 5
+ # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
+ # that time it is assumed that the child is deadlocked and a
+ # redis.ChildDeadlockedError error is raised.
+ if self.pid != os.getpid():
+ acquired = self._fork_lock.acquire(timeout=5)
+ if not acquired:
+ raise ChildDeadlockedError
+ # reset() the instance for the new process if another thread
+ # hasn't already done so
+ try:
+ if self.pid != os.getpid():
+ self.reset()
+ finally:
+ self._fork_lock.release()
+
+ async def get_connection(self, command_name, *keys, **options):
+ """Get a connection from the pool"""
+ self._checkpid()
+ async with self._lock:
+ try:
+ connection = self._available_connections.pop()
+ except IndexError:
+ connection = self.make_connection()
+ self._in_use_connections.add(connection)
+
+ try:
+ # ensure this connection is connected to Redis
+ await connection.connect()
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
+ # pool before all data has been read or the socket has been
+ # closed. either way, reconnect and verify everything is good.
+ try:
+ if await connection.can_read():
+ raise ConnectionError("Connection has data") from None
+ except ConnectionError:
+ await connection.disconnect()
+ await connection.connect()
+ if await connection.can_read():
+ raise ConnectionError("Connection not ready") from None
+ except BaseException:
+ # release the connection back to the pool so that we don't
+ # leak it
+ await self.release(connection)
+ raise
+
+ return connection
+
+ def get_encoder(self):
+ """Return an encoder based on encoding settings"""
+ kwargs = self.connection_kwargs
+ return self.encoder_class(
+ encoding=kwargs.get("encoding", "utf-8"),
+ encoding_errors=kwargs.get("encoding_errors", "strict"),
+ decode_responses=kwargs.get("decode_responses", False),
+ )
+
+ def make_connection(self):
+ """Create a new connection"""
+ if self._created_connections >= self.max_connections:
+ raise ConnectionError("Too many connections")
+ self._created_connections += 1
+ return self.connection_class(**self.connection_kwargs)
+
+ async def release(self, connection: Connection):
+ """Releases the connection back to the pool"""
+ self._checkpid()
+ async with self._lock:
+ try:
+ self._in_use_connections.remove(connection)
+ except KeyError:
+ # Gracefully fail when a connection is returned to this pool
+ # that the pool doesn't actually own
+ pass
+
+ if self.owns_connection(connection):
+ self._available_connections.append(connection)
+ else:
+ # pool doesn't own this connection. do not add it back
+ # to the pool and decrement the count so that another
+ # connection can take its place if needed
+ self._created_connections -= 1
+ await connection.disconnect()
+ return
+
+ def owns_connection(self, connection: Connection):
+ return connection.pid == self.pid
+
+ async def disconnect(self, inuse_connections: bool = True):
+ """
+ Disconnects connections in the pool
+
+ If ``inuse_connections`` is True, disconnect connections that are
+ current in use, potentially by other tasks. Otherwise only disconnect
+ connections that are idle in the pool.
+ """
+ self._checkpid()
+ async with self._lock:
+ if inuse_connections:
+ connections: Iterable[Connection] = chain(
+ self._available_connections, self._in_use_connections
+ )
+ else:
+ connections = self._available_connections
+ resp = await asyncio.gather(
+ *(connection.disconnect() for connection in connections),
+ return_exceptions=True,
+ )
+ exc = next((r for r in resp if isinstance(r, BaseException)), None)
+ if exc:
+ raise exc
+
+
+class BlockingConnectionPool(ConnectionPool):
+ """
+ Thread-safe blocking connection pool::
+
+ >>> from redis.client import Redis
+ >>> client = Redis(connection_pool=BlockingConnectionPool())
+
+ It performs the same function as the default
+ :py:class:`~redis.ConnectionPool` implementation, in that,
+ it maintains a pool of reusable connections that can be shared by
+ multiple redis clients (safely across threads if required).
+
+ The difference is that, in the event that a client tries to get a
+ connection from the pool when all of connections are in use, rather than
+ raising a :py:class:`~redis.ConnectionError` (as the default
+ :py:class:`~redis.ConnectionPool` implementation does), it
+ makes the client wait ("blocks") for a specified number of seconds until
+ a connection becomes available.
+
+ Use ``max_connections`` to increase / decrease the pool size::
+
+ >>> pool = BlockingConnectionPool(max_connections=10)
+
+ Use ``timeout`` to tell it either how many seconds to wait for a connection
+ to become available, or to block forever:
+
+ >>> # Block forever.
+ >>> pool = BlockingConnectionPool(timeout=None)
+
+ >>> # Raise a ``ConnectionError`` after five seconds if a connection is
+ >>> # not available.
+ >>> pool = BlockingConnectionPool(timeout=5)
+ """
+
+ def __init__(
+ self,
+ max_connections: int = 50,
+ timeout: Optional[int] = 20,
+ connection_class: Type[Connection] = Connection,
+ queue_class: Type[asyncio.Queue] = asyncio.LifoQueue,
+ **connection_kwargs,
+ ):
+
+ self.queue_class = queue_class
+ self.timeout = timeout
+ self._connections: List[Connection]
+ super().__init__(
+ connection_class=connection_class,
+ max_connections=max_connections,
+ **connection_kwargs,
+ )
+
+ def reset(self):
+ # Create and fill up a thread safe queue with ``None`` values.
+ self.pool = self.queue_class(self.max_connections)
+ while True:
+ try:
+ self.pool.put_nowait(None)
+ except asyncio.QueueFull:
+ break
+
+ # Keep a list of actual connection instances so that we can
+ # disconnect them later.
+ self._connections = []
+
+ # this must be the last operation in this method. while reset() is
+ # called when holding _fork_lock, other threads in this process
+ # can call _checkpid() which compares self.pid and os.getpid() without
+ # holding any lock (for performance reasons). keeping this assignment
+ # as the last operation ensures that those other threads will also
+ # notice a pid difference and block waiting for the first thread to
+ # release _fork_lock. when each of these threads eventually acquire
+ # _fork_lock, they will notice that another thread already called
+ # reset() and they will immediately release _fork_lock and continue on.
+ self.pid = os.getpid()
+
+ def make_connection(self):
+ """Make a fresh connection."""
+ connection = self.connection_class(**self.connection_kwargs)
+ self._connections.append(connection)
+ return connection
+
+ async def get_connection(self, command_name, *keys, **options):
+ """
+ Get a connection, blocking for ``self.timeout`` until a connection
+ is available from the pool.
+
+ If the connection returned is ``None`` then creates a new connection.
+ Because we use a last-in first-out queue, the existing connections
+ (having been returned to the pool after the initial ``None`` values
+ were added) will be returned before ``None`` values. This means we only
+ create new connections when we need to, i.e.: the actual number of
+ connections will only increase in response to demand.
+ """
+ # Make sure we haven't changed process.
+ self._checkpid()
+
+ # Try and get a connection from the pool. If one isn't available within
+ # self.timeout then raise a ``ConnectionError``.
+ connection = None
+ try:
+ async with async_timeout.timeout(self.timeout):
+ connection = await self.pool.get()
+ except (asyncio.QueueEmpty, asyncio.TimeoutError):
+ # Note that this is not caught by the redis client and will be
+ # raised unless handled by application code. If you want never to
+ raise ConnectionError("No connection available.")
+
+ # If the ``connection`` is actually ``None`` then that's a cue to make
+ # a new connection to add to the pool.
+ if connection is None:
+ connection = self.make_connection()
+
+ try:
+ # ensure this connection is connected to Redis
+ await connection.connect()
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
+ # pool before all data has been read or the socket has been
+ # closed. either way, reconnect and verify everything is good.
+ try:
+ if await connection.can_read():
+ raise ConnectionError("Connection has data") from None
+ except ConnectionError:
+ await connection.disconnect()
+ await connection.connect()
+ if await connection.can_read():
+ raise ConnectionError("Connection not ready") from None
+ except BaseException:
+ # release the connection back to the pool so that we don't leak it
+ await self.release(connection)
+ raise
+
+ return connection
+
+ async def release(self, connection: Connection):
+ """Releases the connection back to the pool."""
+ # Make sure we haven't changed process.
+ self._checkpid()
+ if not self.owns_connection(connection):
+ # pool doesn't own this connection. do not add it back
+ # to the pool. instead add a None value which is a placeholder
+ # that will cause the pool to recreate the connection if
+ # its needed.
+ await connection.disconnect()
+ self.pool.put_nowait(None)
+ return
+
+ # Put the connection back into the pool.
+ try:
+ self.pool.put_nowait(connection)
+ except asyncio.QueueFull:
+ # perhaps the pool has been reset() after a fork? regardless,
+ # we don't want this connection
+ pass
+
+ async def disconnect(self, inuse_connections: bool = True):
+ """Disconnects all connections in the pool."""
+ self._checkpid()
+ async with self._lock:
+ resp = await asyncio.gather(
+ *(connection.disconnect() for connection in self._connections),
+ return_exceptions=True,
+ )
+ exc = next((r for r in resp if isinstance(r, BaseException)), None)
+ if exc:
+ raise exc
diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py
new file mode 100644
index 0000000..d486132
--- /dev/null
+++ b/redis/asyncio/lock.py
@@ -0,0 +1,311 @@
+import asyncio
+import sys
+import threading
+import uuid
+from types import SimpleNamespace
+from typing import TYPE_CHECKING, Awaitable, NoReturn, Optional, Union
+
+from redis.exceptions import LockError, LockNotOwnedError
+
+if TYPE_CHECKING:
+ from redis.asyncio import Redis
+
+
+class Lock:
+ """
+ A shared, distributed Lock. Using Redis for locking allows the Lock
+ to be shared across processes and/or machines.
+
+ It's left to the user to resolve deadlock issues and make sure
+ multiple clients play nicely together.
+ """
+
+ lua_release = None
+ lua_extend = None
+ lua_reacquire = None
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # return 1 if the lock was released, otherwise 0
+ LUA_RELEASE_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ redis.call('del', KEYS[1])
+ return 1
+ """
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # ARGV[2] - additional milliseconds
+ # ARGV[3] - "0" if the additional time should be added to the lock's
+ # existing ttl or "1" if the existing ttl should be replaced
+ # return 1 if the locks time was extended, otherwise 0
+ LUA_EXTEND_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ local expiration = redis.call('pttl', KEYS[1])
+ if not expiration then
+ expiration = 0
+ end
+ if expiration < 0 then
+ return 0
+ end
+
+ local newttl = ARGV[2]
+ if ARGV[3] == "0" then
+ newttl = ARGV[2] + expiration
+ end
+ redis.call('pexpire', KEYS[1], newttl)
+ return 1
+ """
+
+ # KEYS[1] - lock name
+ # ARGV[1] - token
+ # ARGV[2] - milliseconds
+ # return 1 if the locks time was reacquired, otherwise 0
+ LUA_REACQUIRE_SCRIPT = """
+ local token = redis.call('get', KEYS[1])
+ if not token or token ~= ARGV[1] then
+ return 0
+ end
+ redis.call('pexpire', KEYS[1], ARGV[2])
+ return 1
+ """
+
+ def __init__(
+ self,
+ redis: "Redis",
+ name: Union[str, bytes, memoryview],
+ timeout: Optional[float] = None,
+ sleep: float = 0.1,
+ blocking: bool = True,
+ blocking_timeout: Optional[float] = None,
+ thread_local: bool = True,
+ ):
+ """
+ Create a new Lock instance named ``name`` using the Redis client
+ supplied by ``redis``.
+
+ ``timeout`` indicates a maximum life for the lock in seconds.
+ By default, it will remain locked until release() is called.
+ ``timeout`` can be specified as a float or integer, both representing
+ the number of seconds to wait.
+
+ ``sleep`` indicates the amount of time to sleep in seconds per loop
+ iteration when the lock is in blocking mode and another client is
+ currently holding the lock.
+
+ ``blocking`` indicates whether calling ``acquire`` should block until
+ the lock has been acquired or to fail immediately, causing ``acquire``
+ to return False and the lock not being acquired. Defaults to True.
+ Note this value can be overridden by passing a ``blocking``
+ argument to ``acquire``.
+
+ ``blocking_timeout`` indicates the maximum amount of time in seconds to
+ spend trying to acquire the lock. A value of ``None`` indicates
+ continue trying forever. ``blocking_timeout`` can be specified as a
+ float or integer, both representing the number of seconds to wait.
+
+ ``thread_local`` indicates whether the lock token is placed in
+ thread-local storage. By default, the token is placed in thread local
+ storage so that a thread only sees its token, not a token set by
+ another thread. Consider the following timeline:
+
+ time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
+ thread-1 sets the token to "abc"
+ time: 1, thread-2 blocks trying to acquire `my-lock` using the
+ Lock instance.
+ time: 5, thread-1 has not yet completed. redis expires the lock
+ key.
+ time: 5, thread-2 acquired `my-lock` now that it's available.
+ thread-2 sets the token to "xyz"
+ time: 6, thread-1 finishes its work and calls release(). if the
+ token is *not* stored in thread local storage, then
+ thread-1 would see the token value as "xyz" and would be
+ able to successfully release the thread-2's lock.
+
+ In some use cases it's necessary to disable thread local storage. For
+ example, if you have code where one thread acquires a lock and passes
+ that lock instance to a worker thread to release later. If thread
+ local storage isn't disabled in this case, the worker thread won't see
+ the token set by the thread that acquired the lock. Our assumption
+ is that these cases aren't common and as such default to using
+ thread local storage.
+ """
+ self.redis = redis
+ self.name = name
+ self.timeout = timeout
+ self.sleep = sleep
+ self.blocking = blocking
+ self.blocking_timeout = blocking_timeout
+ self.thread_local = bool(thread_local)
+ self.local = threading.local() if self.thread_local else SimpleNamespace()
+ self.local.token = None
+ self.register_scripts()
+
+ def register_scripts(self):
+ cls = self.__class__
+ client = self.redis
+ if cls.lua_release is None:
+ cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
+ if cls.lua_extend is None:
+ cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
+ if cls.lua_reacquire is None:
+ cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
+
+ async def __aenter__(self):
+ if await self.acquire():
+ return self
+ raise LockError("Unable to acquire lock within the time specified")
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.release()
+
+ async def acquire(
+ self,
+ blocking: Optional[bool] = None,
+ blocking_timeout: Optional[float] = None,
+ token: Optional[Union[str, bytes]] = None,
+ ):
+ """
+ Use Redis to hold a shared, distributed lock named ``name``.
+ Returns True once the lock is acquired.
+
+ If ``blocking`` is False, always return immediately. If the lock
+ was acquired, return True, otherwise return False.
+
+ ``blocking_timeout`` specifies the maximum number of seconds to
+ wait trying to acquire the lock.
+
+ ``token`` specifies the token value to be used. If provided, token
+ must be a bytes object or a string that can be encoded to a bytes
+ object with the default encoding. If a token isn't specified, a UUID
+ will be generated.
+ """
+ if sys.version_info[0:2] != (3, 6):
+ loop = asyncio.get_running_loop()
+ else:
+ loop = asyncio.get_event_loop()
+
+ sleep = self.sleep
+ if token is None:
+ token = uuid.uuid1().hex.encode()
+ else:
+ encoder = self.redis.connection_pool.get_encoder()
+ token = encoder.encode(token)
+ if blocking is None:
+ blocking = self.blocking
+ if blocking_timeout is None:
+ blocking_timeout = self.blocking_timeout
+ stop_trying_at = None
+ if blocking_timeout is not None:
+ stop_trying_at = loop.time() + blocking_timeout
+ while True:
+ if await self.do_acquire(token):
+ self.local.token = token
+ return True
+ if not blocking:
+ return False
+ next_try_at = loop.time() + sleep
+ if stop_trying_at is not None and next_try_at > stop_trying_at:
+ return False
+ await asyncio.sleep(sleep)
+
+ async def do_acquire(self, token: Union[str, bytes]) -> bool:
+ if self.timeout:
+ # convert to milliseconds
+ timeout = int(self.timeout * 1000)
+ else:
+ timeout = None
+ if await self.redis.set(self.name, token, nx=True, px=timeout):
+ return True
+ return False
+
+ async def locked(self) -> bool:
+ """
+ Returns True if this key is locked by any process, otherwise False.
+ """
+ return await self.redis.get(self.name) is not None
+
+ async def owned(self) -> bool:
+ """
+ Returns True if this key is locked by this lock, otherwise False.
+ """
+ stored_token = await self.redis.get(self.name)
+ # need to always compare bytes to bytes
+ # TODO: this can be simplified when the context manager is finished
+ if stored_token and not isinstance(stored_token, bytes):
+ encoder = self.redis.connection_pool.get_encoder()
+ stored_token = encoder.encode(stored_token)
+ return self.local.token is not None and stored_token == self.local.token
+
+ def release(self) -> Awaitable[NoReturn]:
+ """Releases the already acquired lock"""
+ expected_token = self.local.token
+ if expected_token is None:
+ raise LockError("Cannot release an unlocked lock")
+ self.local.token = None
+ return self.do_release(expected_token)
+
+ async def do_release(self, expected_token: bytes):
+ if not bool(
+ await self.lua_release(
+ keys=[self.name], args=[expected_token], client=self.redis
+ )
+ ):
+ raise LockNotOwnedError("Cannot release a lock" " that's no longer owned")
+
+ def extend(
+ self, additional_time: float, replace_ttl: bool = False
+ ) -> Awaitable[bool]:
+ """
+ Adds more time to an already acquired lock.
+
+ ``additional_time`` can be specified as an integer or a float, both
+ representing the number of seconds to add.
+
+ ``replace_ttl`` if False (the default), add `additional_time` to
+ the lock's existing ttl. If True, replace the lock's ttl with
+ `additional_time`.
+ """
+ if self.local.token is None:
+ raise LockError("Cannot extend an unlocked lock")
+ if self.timeout is None:
+ raise LockError("Cannot extend a lock with no timeout")
+ return self.do_extend(additional_time, replace_ttl)
+
+ async def do_extend(self, additional_time, replace_ttl) -> bool:
+ additional_time = int(additional_time * 1000)
+ if not bool(
+ await self.lua_extend(
+ keys=[self.name],
+ args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
+ client=self.redis,
+ )
+ ):
+ raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned")
+ return True
+
+ def reacquire(self) -> Awaitable[bool]:
+ """
+ Resets a TTL of an already acquired lock back to a timeout value.
+ """
+ if self.local.token is None:
+ raise LockError("Cannot reacquire an unlocked lock")
+ if self.timeout is None:
+ raise LockError("Cannot reacquire a lock with no timeout")
+ return self.do_reacquire()
+
+ async def do_reacquire(self) -> bool:
+ timeout = int(self.timeout * 1000)
+ if not bool(
+ await self.lua_reacquire(
+ keys=[self.name], args=[self.local.token, timeout], client=self.redis
+ )
+ ):
+ raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
+ return True
diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py
new file mode 100644
index 0000000..9b53494
--- /dev/null
+++ b/redis/asyncio/retry.py
@@ -0,0 +1,58 @@
+from asyncio import sleep
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, TypeVar
+
+from redis.exceptions import ConnectionError, RedisError, TimeoutError
+
+if TYPE_CHECKING:
+ from redis.backoff import AbstractBackoff
+
+
+T = TypeVar("T")
+
+
+class Retry:
+ """Retry a specific number of times after a failure"""
+
+ __slots__ = "_backoff", "_retries", "_supported_errors"
+
+ def __init__(
+ self,
+ backoff: "AbstractBackoff",
+ retries: int,
+ supported_errors: Tuple[RedisError, ...] = (
+ ConnectionError,
+ TimeoutError,
+ ),
+ ):
+ """
+ Initialize a `Retry` object with a `Backoff` object
+ that retries a maximum of `retries` times.
+ You can specify the types of supported errors which trigger
+ a retry with the `supported_errors` parameter.
+ """
+ self._backoff = backoff
+ self._retries = retries
+ self._supported_errors = supported_errors
+
+ async def call_with_retry(
+ self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]
+ ) -> T:
+ """
+ Execute an operation that might fail and returns its result, or
+ raise the exception that was thrown depending on the `Backoff` object.
+ `do`: the operation to call. Expects no argument.
+ `fail`: the failure handler, expects the last error that was thrown
+ """
+ self._backoff.reset()
+ failures = 0
+ while True:
+ try:
+ return await do()
+ except self._supported_errors as error:
+ failures += 1
+ await fail(error)
+ if failures > self._retries:
+ raise error
+ backoff = self._backoff.compute(failures)
+ if backoff > 0:
+ await sleep(backoff)
diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py
new file mode 100644
index 0000000..5aefd09
--- /dev/null
+++ b/redis/asyncio/sentinel.py
@@ -0,0 +1,353 @@
+import asyncio
+import random
+import weakref
+from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type
+
+from redis.asyncio.client import Redis
+from redis.asyncio.connection import (
+ Connection,
+ ConnectionPool,
+ EncodableT,
+ SSLConnection,
+)
+from redis.commands import AsyncSentinelCommands
+from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
+from redis.utils import str_if_bytes
+
+
+class MasterNotFoundError(ConnectionError):
+ pass
+
+
+class SlaveNotFoundError(ConnectionError):
+ pass
+
+
+class SentinelManagedConnection(Connection):
+ def __init__(self, **kwargs):
+ self.connection_pool = kwargs.pop("connection_pool")
+ super().__init__(**kwargs)
+
+ def __repr__(self):
+ pool = self.connection_pool
+ s = f"{self.__class__.__name__}<service={pool.service_name}"
+ if self.host:
+ host_info = f",host={self.host},port={self.port}"
+ s += host_info
+ return s + ">"
+
+ async def connect_to(self, address):
+ self.host, self.port = address
+ await super().connect()
+ if self.connection_pool.check_connection:
+ await self.send_command("PING")
+ if str_if_bytes(await self.read_response()) != "PONG":
+ raise ConnectionError("PING failed")
+
+ async def connect(self):
+ if self._reader:
+ return # already connected
+ if self.connection_pool.is_master:
+ await self.connect_to(await self.connection_pool.get_master_address())
+ else:
+ async for slave in self.connection_pool.rotate_slaves():
+ try:
+ return await self.connect_to(slave)
+ except ConnectionError:
+ continue
+ raise SlaveNotFoundError # Never be here
+
+ async def read_response(self, disable_decoding: bool = False):
+ try:
+ return await super().read_response(disable_decoding=disable_decoding)
+ except ReadOnlyError:
+ if self.connection_pool.is_master:
+ # When talking to a master, a ReadOnlyError when likely
+ # indicates that the previous master that we're still connected
+ # to has been demoted to a slave and there's a new master.
+ # calling disconnect will force the connection to re-query
+ # sentinel during the next connect() attempt.
+ await self.disconnect()
+ raise ConnectionError("The previous master is now a slave")
+ raise
+
+
+class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
+ pass
+
+
+class SentinelConnectionPool(ConnectionPool):
+ """
+ Sentinel backed connection pool.
+
+ If ``check_connection`` flag is set to True, SentinelManagedConnection
+ sends a PING command right after establishing the connection.
+ """
+
+ def __init__(self, service_name, sentinel_manager, **kwargs):
+ kwargs["connection_class"] = kwargs.get(
+ "connection_class",
+ SentinelManagedSSLConnection
+ if kwargs.pop("ssl", False)
+ else SentinelManagedConnection,
+ )
+ self.is_master = kwargs.pop("is_master", True)
+ self.check_connection = kwargs.pop("check_connection", False)
+ super().__init__(**kwargs)
+ self.connection_kwargs["connection_pool"] = weakref.proxy(self)
+ self.service_name = service_name
+ self.sentinel_manager = sentinel_manager
+ self.master_address = None
+ self.slave_rr_counter = None
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}"
+ f"<service={self.service_name}({self.is_master and 'master' or 'slave'})>"
+ )
+
+ def reset(self):
+ super().reset()
+ self.master_address = None
+ self.slave_rr_counter = None
+
+ def owns_connection(self, connection: Connection):
+ check = not self.is_master or (
+ self.is_master and self.master_address == (connection.host, connection.port)
+ )
+ return check and super().owns_connection(connection)
+
+ async def get_master_address(self):
+ master_address = await self.sentinel_manager.discover_master(self.service_name)
+ if self.is_master:
+ if self.master_address != master_address:
+ self.master_address = master_address
+ # disconnect any idle connections so that they reconnect
+ # to the new master the next time that they are used.
+ await self.disconnect(inuse_connections=False)
+ return master_address
+
+ async def rotate_slaves(self) -> AsyncIterator:
+ """Round-robin slave balancer"""
+ slaves = await self.sentinel_manager.discover_slaves(self.service_name)
+ if slaves:
+ if self.slave_rr_counter is None:
+ self.slave_rr_counter = random.randint(0, len(slaves) - 1)
+ for _ in range(len(slaves)):
+ self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves)
+ slave = slaves[self.slave_rr_counter]
+ yield slave
+ # Fallback to the master connection
+ try:
+ yield await self.get_master_address()
+ except MasterNotFoundError:
+ pass
+ raise SlaveNotFoundError(f"No slave found for {self.service_name!r}")
+
+
+class Sentinel(AsyncSentinelCommands):
+ """
+ Redis Sentinel cluster client
+
+ >>> from redis.sentinel import Sentinel
+ >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1)
+ >>> master = sentinel.master_for('mymaster', socket_timeout=0.1)
+ >>> await master.set('foo', 'bar')
+ >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1)
+ >>> await slave.get('foo')
+ b'bar'
+
+ ``sentinels`` is a list of sentinel nodes. Each node is represented by
+ a pair (hostname, port).
+
+ ``min_other_sentinels`` defined a minimum number of peers for a sentinel.
+ When querying a sentinel, if it doesn't meet this threshold, responses
+ from that sentinel won't be considered valid.
+
+ ``sentinel_kwargs`` is a dictionary of connection arguments used when
+ connecting to sentinel instances. Any argument that can be passed to
+ a normal Redis connection can be specified here. If ``sentinel_kwargs`` is
+ not specified, any socket_timeout and socket_keepalive options specified
+ in ``connection_kwargs`` will be used.
+
+ ``connection_kwargs`` are keyword arguments that will be used when
+ establishing a connection to a Redis server.
+ """
+
+ def __init__(
+ self,
+ sentinels,
+ min_other_sentinels=0,
+ sentinel_kwargs=None,
+ **connection_kwargs,
+ ):
+ # if sentinel_kwargs isn't defined, use the socket_* options from
+ # connection_kwargs
+ if sentinel_kwargs is None:
+ sentinel_kwargs = {
+ k: v for k, v in connection_kwargs.items() if k.startswith("socket_")
+ }
+ self.sentinel_kwargs = sentinel_kwargs
+
+ self.sentinels = [
+ Redis(host=hostname, port=port, **self.sentinel_kwargs)
+ for hostname, port in sentinels
+ ]
+ self.min_other_sentinels = min_other_sentinels
+ self.connection_kwargs = connection_kwargs
+
+ async def execute_command(self, *args, **kwargs):
+ """
+ Execute Sentinel command in sentinel nodes.
+ once - If set to True, then execute the resulting command on a single
+ node at random, rather than across the entire sentinel cluster.
+ """
+ once = bool(kwargs.get("once", False))
+ if "once" in kwargs.keys():
+ kwargs.pop("once")
+
+ if once:
+ tasks = [
+ asyncio.Task(sentinel.execute_command(*args, **kwargs))
+ for sentinel in self.sentinels
+ ]
+ await asyncio.gather(*tasks)
+ else:
+ await random.choice(self.sentinels).execute_command(*args, **kwargs)
+ return True
+
+ def __repr__(self):
+ sentinel_addresses = []
+ for sentinel in self.sentinels:
+ sentinel_addresses.append(
+ f"{sentinel.connection_pool.connection_kwargs['host']}:"
+ f"{sentinel.connection_pool.connection_kwargs['port']}"
+ )
+ return f"{self.__class__.__name__}<sentinels=[{','.join(sentinel_addresses)}]>"
+
+ def check_master_state(self, state: dict, service_name: str) -> bool:
+ if not state["is_master"] or state["is_sdown"] or state["is_odown"]:
+ return False
+ # Check if our sentinel doesn't see other nodes
+ if state["num-other-sentinels"] < self.min_other_sentinels:
+ return False
+ return True
+
+ async def discover_master(self, service_name: str):
+ """
+ Asks sentinel servers for the Redis master's address corresponding
+ to the service labeled ``service_name``.
+
+ Returns a pair (address, port) or raises MasterNotFoundError if no
+ master is found.
+ """
+ for sentinel_no, sentinel in enumerate(self.sentinels):
+ try:
+ masters = await sentinel.sentinel_masters()
+ except (ConnectionError, TimeoutError):
+ continue
+ state = masters.get(service_name)
+ if state and self.check_master_state(state, service_name):
+ # Put this sentinel at the top of the list
+ self.sentinels[0], self.sentinels[sentinel_no] = (
+ sentinel,
+ self.sentinels[0],
+ )
+ return state["ip"], state["port"]
+ raise MasterNotFoundError(f"No master found for {service_name!r}")
+
+ def filter_slaves(
+ self, slaves: Iterable[Mapping]
+ ) -> Sequence[Tuple[EncodableT, EncodableT]]:
+ """Remove slaves that are in an ODOWN or SDOWN state"""
+ slaves_alive = []
+ for slave in slaves:
+ if slave["is_odown"] or slave["is_sdown"]:
+ continue
+ slaves_alive.append((slave["ip"], slave["port"]))
+ return slaves_alive
+
+ async def discover_slaves(
+ self, service_name: str
+ ) -> Sequence[Tuple[EncodableT, EncodableT]]:
+ """Returns a list of alive slaves for service ``service_name``"""
+ for sentinel in self.sentinels:
+ try:
+ slaves = await sentinel.sentinel_slaves(service_name)
+ except (ConnectionError, ResponseError, TimeoutError):
+ continue
+ slaves = self.filter_slaves(slaves)
+ if slaves:
+ return slaves
+ return []
+
+ def master_for(
+ self,
+ service_name: str,
+ redis_class: Type[Redis] = Redis,
+ connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
+ **kwargs,
+ ):
+ """
+ Returns a redis client instance for the ``service_name`` master.
+
+ A :py:class:`~redis.sentinel.SentinelConnectionPool` class is
+ used to retrieve the master's address before establishing a new
+ connection.
+
+ NOTE: If the master's address has changed, any cached connections to
+ the old master are closed.
+
+ By default clients will be a :py:class:`~redis.Redis` instance.
+ Specify a different class to the ``redis_class`` argument if you
+ desire something different.
+
+ The ``connection_pool_class`` specifies the connection pool to
+ use. The :py:class:`~redis.sentinel.SentinelConnectionPool`
+ will be used by default.
+
+ All other keyword arguments are merged with any connection_kwargs
+ passed to this class and passed to the connection pool as keyword
+ arguments to be used to initialize Redis connections.
+ """
+ kwargs["is_master"] = True
+ connection_kwargs = dict(self.connection_kwargs)
+ connection_kwargs.update(kwargs)
+ return redis_class(
+ connection_pool=connection_pool_class(
+ service_name, self, **connection_kwargs
+ )
+ )
+
+ def slave_for(
+ self,
+ service_name: str,
+ redis_class: Type[Redis] = Redis,
+ connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool,
+ **kwargs,
+ ):
+ """
+ Returns redis client instance for the ``service_name`` slave(s).
+
+ A SentinelConnectionPool class is used to retrieve the slave's
+ address before establishing a new connection.
+
+ By default clients will be a :py:class:`~redis.Redis` instance.
+ Specify a different class to the ``redis_class`` argument if you
+ desire something different.
+
+ The ``connection_pool_class`` specifies the connection pool to use.
+ The SentinelConnectionPool will be used by default.
+
+ All other keyword arguments are merged with any connection_kwargs
+ passed to this class and passed to the connection pool as keyword
+ arguments to be used to initialize Redis connections.
+ """
+ kwargs["is_master"] = False
+ connection_kwargs = dict(self.connection_kwargs)
+ connection_kwargs.update(kwargs)
+ return redis_class(
+ connection_pool=connection_pool_class(
+ service_name, self, **connection_kwargs
+ )
+ )
diff --git a/redis/asyncio/utils.py b/redis/asyncio/utils.py
new file mode 100644
index 0000000..5a55b36
--- /dev/null
+++ b/redis/asyncio/utils.py
@@ -0,0 +1,28 @@
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from redis.asyncio.client import Pipeline, Redis
+
+
+def from_url(url, **kwargs):
+ """
+ Returns an active Redis client generated from the given database URL.
+
+ Will attempt to extract the database id from the path url fragment, if
+ none is provided.
+ """
+ from redis.asyncio.client import Redis
+
+ return Redis.from_url(url, **kwargs)
+
+
+class pipeline:
+ def __init__(self, redis_obj: "Redis"):
+ self.p: "Pipeline" = redis_obj.pipeline()
+
+ async def __aenter__(self) -> "Pipeline":
+ return self.p
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.p.execute()
+ del self.p