diff options
Diffstat (limited to 'redis/asyncio')
-rw-r--r-- | redis/asyncio/__init__.py | 58 | ||||
-rw-r--r-- | redis/asyncio/client.py | 1344 | ||||
-rw-r--r-- | redis/asyncio/connection.py | 1707 | ||||
-rw-r--r-- | redis/asyncio/lock.py | 311 | ||||
-rw-r--r-- | redis/asyncio/retry.py | 58 | ||||
-rw-r--r-- | redis/asyncio/sentinel.py | 353 | ||||
-rw-r--r-- | redis/asyncio/utils.py | 28 |
7 files changed, 3859 insertions, 0 deletions
diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py new file mode 100644 index 0000000..c655c7d --- /dev/null +++ b/redis/asyncio/__init__.py @@ -0,0 +1,58 @@ +from redis.asyncio.client import Redis, StrictRedis +from redis.asyncio.connection import ( + BlockingConnectionPool, + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.asyncio.sentinel import ( + Sentinel, + SentinelConnectionPool, + SentinelManagedConnection, + SentinelManagedSSLConnection, +) +from redis.asyncio.utils import from_url +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + InvalidResponse, + PubSubError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) + +__all__ = [ + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", + "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", + "Connection", + "ConnectionError", + "ConnectionPool", + "DataError", + "from_url", + "InvalidResponse", + "PubSubError", + "ReadOnlyError", + "Redis", + "RedisError", + "ResponseError", + "Sentinel", + "SentinelConnectionPool", + "SentinelManagedConnection", + "SentinelManagedSSLConnection", + "SSLConnection", + "StrictRedis", + "TimeoutError", + "UnixDomainSocketConnection", + "WatchError", +] diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py new file mode 100644 index 0000000..619592e --- /dev/null +++ b/redis/asyncio/client.py @@ -0,0 +1,1344 @@ +import asyncio +import copy +import inspect +import re +import warnings +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from redis.client import ( + EMPTY_RESPONSE, + NEVER_DECODE, + AbstractRedis, + CaseInsensitiveDict, + bool_ok, +) +from redis.commands import ( + AsyncCoreCommands, + AsyncSentinelCommands, + RedisModuleCommands, + list_or_args, +) +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + ConnectionError, + ExecAbortError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from redis.lock import Lock +from redis.retry import Retry +from redis.typing import ChannelT, EncodableT, KeyT +from redis.utils import safe_str, str_if_bytes + +PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] +_KeyT = TypeVar("_KeyT", bound=KeyT) +_ArgT = TypeVar("_ArgT", KeyT, EncodableT) +_RedisT = TypeVar("_RedisT", bound="Redis") +_NormalizeKeysT = TypeVar("_NormalizeKeysT", bound=Mapping[ChannelT, object]) +if TYPE_CHECKING: + from redis.commands.core import Script + + +class ResponseCallbackProtocol(Protocol): + def __call__(self, response: Any, **kwargs): + ... + + +class AsyncResponseCallbackProtocol(Protocol): + async def __call__(self, response: Any, **kwargs): + ... + + +ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] + + +class Redis( + AbstractRedis, RedisModuleCommands, AsyncCoreCommands, AsyncSentinelCommands +): + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Pipelines derive from this, implementing how + the commands are sent and received to the Redis server. Based on + configuration, an instance will either use a ConnectionPool, or + Connection object to talk to redis. + """ + + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT] + + @classmethod + def from_url(cls, url: str, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/redis> + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/rediss> + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + connection_pool = ConnectionPool.from_url(url, **kwargs) + return cls(connection_pool=connection_pool) + + def __init__( + self, + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: Optional[bool] = None, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + connection_pool: Optional[ConnectionPool] = None, + unix_socket_path: Optional[str] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + ssl: bool = False, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + max_connections: Optional[int] = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + auto_close_connection_pool: bool = True, + ): + """ + Initialize a new Redis client. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + kwargs: Dict[str, Any] + # auto_close_connection_pool only has an effect if connection_pool is + # None. This is a similar feature to the missing __del__ to resolve #1103, + # but it accounts for whether a user wants to manually close the connection + # pool, as a similar feature to ConnectionPool's __del__. + self.auto_close_connection_pool = ( + auto_close_connection_pool if connection_pool is None else False + ) + if not connection_pool: + kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "retry": copy.deepcopy(retry), + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_check_hostname": ssl_check_hostname, + } + ) + connection_pool = ConnectionPool(**kwargs) + self.connection_pool = connection_pool + self.single_connection_client = single_connection_client + self.connection: Optional[Connection] = None + + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + + def __repr__(self): + return f"{self.__class__.__name__}<{self.connection_pool!r}>" + + def __await__(self): + return self.initialize().__await__() + + async def initialize(self: _RedisT) -> _RedisT: + if self.single_connection_client and self.connection is None: + self.connection = await self.connection_pool.get_connection("_") + return self + + def set_response_callback(self, command: str, callback: ResponseCallbackT): + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def load_external_module( + self, + funcname, + func, + ): + """ + This function can be used to add externally defined redis modules, + and their namespaces to the redis client. + + funcname - A string containing the name of the function to create + func - The function, being added to this class. + + ex: Assume that one has a custom redis module named foomod that + creates command named 'foo.dothing' and 'foo.anotherthing' in redis. + To load function functions into this namespace: + + from redis import Redis + from foomodule import F + r = Redis() + r.load_external_module("foo", F) + r.foo().dothing('your', 'arguments') + + For a concrete example see the reimport of the redisjson module in + tests/test_connection.py::test_loading_external_modules + """ + setattr(self, funcname, func) + + def pipeline( + self, transaction: bool = True, shard_hint: Optional[str] = None + ) -> "Pipeline": + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + pipe: Pipeline + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = func(pipe) + if inspect.isawaitable(func_value): + func_value = await func_value + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + await asyncio.sleep(watch_delay) + continue + + def lock( + self, + name: KeyT, + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking_timeout: Optional[float] = None, + lock_class: Optional[Type[Lock]] = None, + thread_local: bool = True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs) -> "PubSub": + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self) -> "Monitor": + return Monitor(self.connection_pool) + + def client(self) -> "Redis": + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + async def __aenter__(self: _RedisT) -> _RedisT: + return await self.initialize() + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + _DEL_MESSAGE = "Unclosed Redis client" + + def __del__(self, _warnings: Any = warnings) -> None: + if self.connection is not None: + _warnings.warn( + f"Unclosed client session {self!r}", + ResourceWarning, + source=self, + ) + context = {"client": self, "message": self._DEL_MESSAGE} + asyncio.get_event_loop().call_exception_handler(context) + + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Closes Redis client connection + + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection + pool. + """ + conn = self.connection + if conn: + self.connection = None + await self.connection_pool.release(conn) + if close_connection_pool or ( + close_connection_pool is None and self.auto_close_connection_pool + ): + await self.connection_pool.disconnect() + + async def _send_command_parse_response(self, conn, command_name, *args, **options): + """ + Send a command and parse the response + """ + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + + async def _disconnect_raise(self, conn: Connection, error: Exception): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + + try: + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + finally: + if not self.connection: + await pool.release(conn) + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + """Parses a response from the Redis server""" + try: + if NEVER_DECODE in options: + response = await connection.read_response(disable_encoding=True) + else: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + if command_name in self.response_callbacks: + # Mypy bug: https://github.com/python/mypy/issues/10977 + command_name = cast(str, command_name) + retval = self.response_callbacks[command_name](response, **options) + return await retval if inspect.isawaitable(retval) else retval + return response + + +StrictRedis = Redis + + +class MonitorCommandInfo(TypedDict): + time: float + db: int + client_address: str + client_port: str + client_type: str + command: str + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + command_re = re.compile(r'"(.*?)(?<!\\)"') + + def __init__(self, connection_pool: ConnectionPool): + self.connection_pool = connection_pool + self.connection: Optional[Connection] = None + + async def connect(self): + if self.connection is None: + self.connection = await self.connection_pool.get_connection("MONITOR") + + async def __aenter__(self): + await self.connect() + await self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = await self.connection.read_response() + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + return self + + async def __aexit__(self, *args): + await self.connection.disconnect() + await self.connection_pool.release(self.connection) + + async def next_command(self) -> MonitorCommandInfo: + """Parse the response from a monitor command""" + await self.connect() + response = await self.connection.read_response() + if isinstance(response, bytes): + response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) + m = self.monitor_re.match(command_data) + db_id, client_info, command = m.groups() + command = " ".join(self.command_re.findall(command)) + # Redis escapes double quotes because each piece of the command + # string is surrounded by double quotes. We don't have that + # requirement so remove the escaping and leave the quote. + command = command.replace('\\"', '"') + + if client_info == "lua": + client_address = "lua" + client_port = "" + client_type = "lua" + elif client_info.startswith("unix"): + client_address = "unix" + client_port = client_info[5:] + client_type = "unix" + else: + # use rsplit as ipv6 addresses contain colons + client_address, client_port = client_info.rsplit(":", 1) + client_type = "tcp" + return { + "time": float(command_time), + "db": int(db_id), + "client_address": client_address, + "client_port": client_port, + "client_type": client_type, + "command": command, + } + + async def listen(self) -> AsyncIterator[MonitorCommandInfo]: + """Listen for commands coming to the server.""" + while True: + yield await self.next_command() + + +class PubSub: + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + + PUBLISH_MESSAGE_TYPES = ("message", "pmessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + HEALTH_CHECK_MESSAGE = "redis-py-health-check" + + def __init__( + self, + connection_pool: ConnectionPool, + shard_hint: Optional[str] = None, + ignore_subscribe_messages: bool = False, + encoder=None, + ): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.ignore_subscribe_messages = ignore_subscribe_messages + self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + self.encoder = encoder + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response: Iterable[Union[str, bytes]] = [ + "pong", + self.HEALTH_CHECK_MESSAGE, + ] + else: + self.health_check_response = [ + b"pong", + self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + ] + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self._lock = asyncio.Lock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __del__(self): + if self.connection: + self.connection.clear_connect_callbacks() + + async def reset(self): + async with self._lock: + if self.connection: + await self.connection.disconnect() + self.connection.clear_connect_callbacks() + await self.connection_pool.release(self.connection) + self.connection = None + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + + def close(self) -> Awaitable[NoReturn]: + return self.reset() + + async def on_connect(self, connection: Connection): + """Re-subscribe to any channels and patterns previously subscribed to""" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + if self.channels: + channels = {} + for k, v in self.channels.items(): + channels[self.encoder.decode(k, force=True)] = v + await self.subscribe(**channels) + if self.patterns: + patterns = {} + for k, v in self.patterns.items(): + patterns[self.encoder.decode(k, force=True)] = v + await self.psubscribe(**patterns) + + @property + def subscribed(self): + """Indicates if there are subscriptions to any channels or patterns""" + return bool(self.channels or self.patterns) + + async def execute_command(self, *args: EncodableT): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def _disconnect_raise_connect(self, conn, error): + """ + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect + """ + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error + await conn.connect() + + async def _execute(self, conn, command, *args, **kwargs): + """ + Connect manually upon disconnection. If the Redis server is down, + this will fail and raise a ConnectionError as desired. + After reconnection, the ``on_connect`` callback should have been + called by the # connection to resubscribe us to any channels and + patterns we were previously listening to + """ + return await conn.retry.call_with_retry( + lambda: command(*args, **kwargs), + lambda error: self._disconnect_raise_connect(conn, error), + ) + + async def parse_response(self, block: bool = True, timeout: float = 0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + await self.check_health() + + if not block and not await self._execute(conn, conn.can_read, timeout=timeout): + return None + response = await self._execute(conn, conn.read_response) + + if conn.health_check_interval and response == self.health_check_response: + # ignore the health check message as user might not expect it + return None + return response + + async def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if ( + conn.health_check_interval + and asyncio.get_event_loop().time() > conn.next_health_check + ): + await conn.send_command( + "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + ) + + def _normalize_keys(self, data: _NormalizeKeysT) -> _NormalizeKeysT: + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} # type: ignore[return-value] # noqa: E501 + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else args + new_patterns: Dict[ChannelT, PubSubHandler] = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_patterns.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args: ChannelT) -> Awaitable: + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + patterns: Iterable[ChannelT] + if args: + parsed_args = list_or_args((args[0],), args[1:]) + patterns = self._normalize_keys(dict.fromkeys(parsed_args)).keys() + else: + parsed_args = [] + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *parsed_args) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + parsed_args = list_or_args((args[0],), args[1:]) if args else () + new_channels = dict.fromkeys(parsed_args) + # Mypy bug: https://github.com/python/mypy/issues/10970 + new_channels.update(kwargs) # type: ignore[arg-type] + ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args) -> Awaitable: + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + parsed_args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(parsed_args)) + else: + parsed_args = [] + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *parsed_args) + + async def listen(self) -> AsyncIterator: + """Listen for messages on channels this client has been subscribed to""" + while self.subscribed: + response = await self.handle_message(await self.parse_response(block=True)) + if response is not None: + yield response + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number. + """ + response = await self.parse_response(block=False, timeout=timeout) + if response: + return await self.handle_message(response, ignore_subscribe_messages) + return None + + def ping(self, message=None) -> Awaitable: + """ + Ping the Redis server + """ + message = "" if message is None else message + return self.execute_command("PING", message) + + async def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + if inspect.iscoroutinefunction(handler): + await handler(message) + else: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError(f"Channel: '{channel}' has no handler registered") + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError(f"Pattern: '{pattern}' has no handler registered") + + while True: + try: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + except asyncio.CancelledError: + raise + except BaseException as e: + if exception_handler is None: + raise + res = exception_handler(e, self) + if inspect.isawaitable(res): + await res + # Ensure that other tasks on the event loop get a chance to run + # if we didn't have to block for I/O anywhere. + await asyncio.sleep(0) + + +class PubsubWorkerExceptionHandler(Protocol): + def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +class AsyncPubsubWorkerExceptionHandler(Protocol): + async def __call__(self, e: BaseException, pubsub: PubSub): + ... + + +PSWorkerThreadExcHandlerT = Union[ + PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler +] + + +CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] +CommandStackT = List[CommandT] + + +class Pipeline(Redis): # lgtm [py/init-calls-subclass] + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT], + transaction: bool, + shard_hint: Optional[str], + ): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.is_transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.command_stack: CommandStackT = [] + self.scripts: Set["Script"] = set() + self.explicit_transaction = False + + async def __aenter__(self: _RedisT) -> _RedisT: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __await__(self): + return self._async_self().__await__() + + _DEL_MESSAGE = "Unclosed Pipeline client" + + def __len__(self): + return len(self.command_stack) + + def __bool__(self): + """Pipeline instances should always evaluate to True""" + return True + + async def _async_self(self): + return self + + async def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self.connection.send_command("UNWATCH") + await self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + if self.connection: + await self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + await self.connection_pool.release(self.connection) + self.connection = None + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already " "been issued" + ) + self.explicit_transaction = True + + def execute_command( + self, *args, **kwargs + ) -> Union["Pipeline", Awaitable["Pipeline"]]: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + async def _disconnect_reset_raise(self, conn, error): + """ + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + await self.reset() + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) + self.connection = conn + + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise(conn, error), + ) + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + async def _execute_transaction( # noqa: C901 + self, connection: Connection, commands: CommandStackT, raise_on_error + ): + pre: CommandT = (("MULTI",), {}) + post: CommandT = (("EXEC",), {}) + cmds = (pre, *commands, post) + all_cmds = connection.pack_commands( + args for args, options in cmds if EMPTY_RESPONSE not in options + ) + await connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await self.parse_response(connection, "_") + except ResponseError as err: + errors.append((0, err)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + await self.parse_response(connection, "_") + except ResponseError as err: + self.annotate_exception(err, i + 1, command[0]) + errors.append((i, err)) + + # parse the EXEC. + try: + response = await self.parse_response(connection, "_") + except ExecAbortError as err: + if errors: + raise errors[0][1] from err + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") from None + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + if self.connection: + await self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) from None + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + if inspect.isawaitable(r): + r = await r + data.append(r) + return data + + async def _execute_pipeline( + self, connection: Connection, commands: CommandStackT, raise_on_error: bool + ): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + await connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + await self.parse_response(connection, args[0], **options) + ) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception( + self, exception: Exception, number: int, command: Iterable[object] + ) -> None: + cmd = " ".join(map(safe_str, command)) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + exception.args = (msg,) + exception.args[1:] + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + result = await super().parse_response(connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + async def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = await immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = await immediate("SCRIPT LOAD", s.script) + + async def _disconnect_raise_reset(self, conn: Connection, error: Exception): + """ + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError + """ + await conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise + + async def execute(self, raise_on_error: bool = True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + await self.load_scripts() + if self.is_transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + conn = cast(Connection, conn) + + try: + return await conn.retry.call_with_retry( + lambda: execute(conn, stack, raise_on_error), + lambda error: self._disconnect_raise_reset(conn, error), + ) + finally: + await self.reset() + + async def discard(self): + """Flushes all previously queued commands + See: https://redis.io/commands/DISCARD + """ + await self.execute_command("DISCARD") + + async def watch(self, *names: KeyT): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + return self.watching and await self.execute_command("UNWATCH") or True diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py new file mode 100644 index 0000000..ae54d81 --- /dev/null +++ b/redis/asyncio/connection.py @@ -0,0 +1,1707 @@ +import asyncio +import copy +import enum +import errno +import inspect +import io +import os +import socket +import ssl +import sys +import threading +import weakref +from itertools import chain +from types import MappingProxyType +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from urllib.parse import ParseResult, parse_qs, unquote, urlparse + +import async_timeout + +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.compat import Protocol, TypedDict +from redis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + ExecAbortError, + InvalidResponse, + ModuleError, + NoPermissionError, + NoScriptError, + ReadOnlyError, + RedisError, + ResponseError, + TimeoutError, +) +from redis.typing import EncodableT, EncodedT +from redis.utils import HIREDIS_AVAILABLE, str_if_bytes + +hiredis = None +if HIREDIS_AVAILABLE: + import hiredis + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { + BlockingIOError: errno.EWOULDBLOCK, + ssl.SSLWantReadError: 2, + ssl.SSLWantWriteError: 2, + ssl.SSLError: 2, +} + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + + +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_LF = b"\n" +SYM_EMPTY = b"" + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." + + +class _Sentinel(enum.Enum): + sentinel = object() + + +SENTINEL = _Sentinel.sentinel +MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class Encoder: + """Encode strings to bytes-like and decode bytes-like to strings""" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value: EncodableT) -> EncodedT: + """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, (bytes, memoryview)): + return value + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) + if isinstance(value, (int, float)): + return repr(value).encode() + if not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ # type: ignore[unreachable] + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) + return value.encode(self.encoding, self.encoding_errors) + + def decode(self, value: EncodableT, force=False) -> EncodableT: + """Return a unicode string from the bytes-like representation""" + if self.decode_responses or force: + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) + if isinstance(value, bytes): + return value.decode(self.encoding, self.encoding_errors) + return value + + +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + + +class BaseParser: + """Plain Python parsing class""" + + __slots__ = "_stream", "_buffer", "_read_size" + + EXCEPTION_CLASSES: ExceptionMappingT = { + "ERR": { + "max number of clients reached": ConnectionError, + "Client sent AUTH, but no password is set": AuthenticationError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + }, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + def __init__(self, socket_read_size: int): + self._stream: Optional[asyncio.StreamReader] = None + self._buffer: Optional[SocketBuffer] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def parse_error(self, response: str) -> ResponseError: + """Parse an error response""" + error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection: "Connection"): + raise NotImplementedError() + + async def can_read(self, timeout: float) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class SocketBuffer: + """Async-friendly re-impl of redis-py's SocketBuffer. + + TODO: We're currently passing through two buffers, + the asyncio.StreamReader and this. I imagine we can reduce the layers here + while maintaining compliance with prior art. + """ + + def __init__( + self, + stream_reader: asyncio.StreamReader, + socket_read_size: int, + socket_timeout: Optional[float], + ): + self._stream: Optional[asyncio.StreamReader] = stream_reader + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer: Optional[io.BytesIO] = io.BytesIO() + # number of bytes written to the buffer from the socket + self.bytes_written = 0 + # number of bytes read from the buffer + self.bytes_read = 0 + + @property + def length(self): + return self.bytes_written - self.bytes_read + + async def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, + ) -> bool: + buf = self._buffer + if buf is None or self._stream is None: + raise RedisError("Buffer is closed.") + buf.seek(self.bytes_written) + marker = 0 + timeout = timeout if timeout is not SENTINEL else self.socket_timeout + + try: + while True: + async with async_timeout.timeout(timeout): + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def can_read(self, timeout: float) -> bool: + return bool(self.length) or await self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + async def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # make sure we've read enough data from the socket + if length > self.length: + await self._read_from_socket(length - self.length) + + if self._buffer is None: + raise RedisError("Buffer is closed.") + + self._buffer.seek(self.bytes_read) + data = self._buffer.read(length) + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + async def readline(self) -> bytes: + buf = self._buffer + if buf is None: + raise RedisError("Buffer is closed.") + + buf.seek(self.bytes_read) + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + await self._read_from_socket() + buf.seek(self.bytes_read) + data = buf.readline() + + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + def purge(self): + if self._buffer is None: + raise RedisError("Buffer is closed.") + + self._buffer.seek(0) + self._buffer.truncate() + self.bytes_written = 0 + self.bytes_read = 0 + + def close(self): + try: + self.purge() + self._buffer.close() # type: ignore[union-attr] + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._stream = None + + +class PythonParser(BaseParser): + """Plain Python parsing class""" + + __slots__ = BaseParser.__slots__ + ("encoder",) + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + + def on_connect(self, connection: "Connection"): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + + self._buffer = SocketBuffer( + self._stream, self._read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + """Called when the stream disconnects""" + if self._stream is not None: + self._stream = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + async def can_read(self, timeout: float): + return self._buffer and bool(await self._buffer.can_read(timeout)) + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._buffer or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + response: Any + byte, response = raw[:1], raw[1:] + + if byte not in (b"-", b"+", b":", b"$", b"*"): + raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + response = int(response) + # bulk response + elif byte == b"$": + length = int(response) + if length == -1: + return None + response = await self._buffer.read(length) + # multi-bulk response + elif byte == b"*": + length = int(response) + if length == -1: + return None + response = [ + (await self.read_response(disable_decoding)) for _ in range(length) + ] + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class HiredisParser(BaseParser): + """Parser class for connections using Hiredis""" + + __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") + + _next_response: bool + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader: Optional[hiredis.Reader] = None + self._socket_timeout: Optional[float] = None + + def on_connect(self, connection: "Connection"): + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + self._socket_timeout = connection.socket_timeout + + def on_disconnect(self): + self._stream = None + self._reader = None + self._next_response = False + + async def can_read(self, timeout: float): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return await self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + async def read_from_socket( + self, + timeout: Union[float, None, _Sentinel] = SENTINEL, + raise_on_timeout: bool = True, + ): + if self._stream is None or self._reader is None: + raise RedisError("Parser already closed.") + + timeout = self._socket_timeout if timeout is SENTINEL else timeout + try: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not isinstance(buffer, bytes) or len(buffer) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") from None + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + if not self._stream or not self._reader: + self.on_disconnect() + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response: Union[ + EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] + ] + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + # cast as there won't be a ConnectionError here. + return cast(Union[EncodableT, List[EncodableT]], response) + + +DefaultParser: Type[Union[PythonParser, HiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = HiredisParser +else: + DefaultParser = PythonParser + + +class ConnectCallbackProtocol(Protocol): + def __call__(self, connection: "Connection"): + ... + + +class AsyncConnectCallbackProtocol(Protocol): + async def __call__(self, connection: "Connection"): + ... + + +ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] + + +class Connection: + """Manages TCP communication to and from a Redis server""" + + __slots__ = ( + "pid", + "host", + "port", + "db", + "username", + "client_name", + "password", + "socket_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_type", + "redis_connect_func", + "retry_on_timeout", + "health_check_interval", + "next_health_check", + "last_active_at", + "encoder", + "ssl_context", + "_reader", + "_writer", + "_parser", + "_connect_callbacks", + "_buffer_cutoff", + "_lock", + "_socket_read_size", + "__dict__", + ) + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + db: Union[str, int] = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + retry_on_timeout: bool = False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0, + client_name: Optional[str] = None, + username: Optional[str] = None, + retry: Optional[Retry] = None, + redis_connect_func: Optional[ConnectCallbackT] = None, + encoder_class: Type[Encoder] = Encoder, + ): + self.pid = os.getpid() + self.host = host + self.port = int(port) + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check: float = -1 + self.ssl_context: Optional[RedisSSLContext] = None + self.encoder = encoder_class(encoding, encoding_errors, decode_responses) + self.redis_connect_func = redis_connect_func + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] + self._buffer_cutoff = 6000 + self._lock = asyncio.Lock() + + def __repr__(self): + repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) + return f"{self.__class__.__name__}<{repr_args}>" + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def __del__(self): + try: + if self.is_connected: + loop = asyncio.get_event_loop() + coro = self.disconnect() + if loop.is_running(): + loop.create_task(coro) + else: + loop.run_until_complete(coro) + except Exception: + pass + + @property + def is_connected(self): + return bool(self._reader and self._writer) + + def register_connect_callback(self, callback): + self._connect_callbacks.append(weakref.WeakMethod(callback)) + + def clear_connect_callbacks(self): + self._connect_callbacks = [] + + def set_parser(self, parser_class): + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ + self._parser = parser_class(socket_read_size=self._socket_read_size) + + async def connect(self): + """Connects to the Redis server if not already connected""" + if self.is_connected: + return + try: + await self._connect() + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + except Exception as exc: + raise ConnectionError(exc) from exc + + try: + if self.redis_connect_func is None: + # Use the default on_connect function + await self.on_connect() + else: + # Use the passed function redis_connect_func + await self.redis_connect_func(self) if asyncio.iscoroutinefunction( + self.redis_connect_func + ) else self.redis_connect_func(self) + except RedisError: + # clean up after any error in on_connect + await self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for ref in self._connect_callbacks: + callback = ref() + task = callback(self) + if task and inspect.isawaitable(task): + await task + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context.get() if self.ssl_context else None, + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " + f"{exception.args[0]}." + ) + + async def on_connect(self): + """Initialize the connection, authenticate and select a database""" + self._parser.on_connect(self) + + # if username and/or password are set, authenticate + if self.username or self.password: + auth_args: Union[Tuple[str], Tuple[str, str]] + if self.username: + auth_args = (self.username, self.password or "") + else: + # Mypy bug: https://github.com/python/mypy/issues/10944 + auth_args = (self.password or "",) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + await self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = await self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + await self.send_command("AUTH", self.password, check_health=False) + auth_response = await self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if a client_name is given, set it + if self.client_name: + await self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + # if a database is specified, switch to it + if self.db: + await self.send_command("SELECT", self.db) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + async def disconnect(self): + """Disconnects from the Redis server""" + try: + async with async_timeout.timeout(self.socket_connect_timeout): + self._parser.on_disconnect() + if not self.is_connected: + return + try: + if os.getpid() == self.pid: + self._writer.close() # type: ignore[union-attr] + # py3.6 doesn't have this method + if hasattr(self._writer, "wait_closed"): + await self._writer.wait_closed() # type: ignore[union-attr] + except OSError: + pass + self._reader = None + self._writer = None + except asyncio.TimeoutError: + raise TimeoutError( + f"Timed out closing connection after {self.socket_connect_timeout}" + ) from None + + async def _send_ping(self): + """Send PING, expect PONG in return""" + await self.send_command("PING", check_health=False) + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + + async def _ping_failed(self, error): + """Function to call when PING fails""" + await self.disconnect() + + async def check_health(self): + """Check the health of the connection with a PING/PONG""" + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + + if self.health_check_interval and func().time() > self.next_health_check: + await self.retry.call_with_retry(self._send_ping, self._ping_failed) + + async def _send_packed_command(self, command: Iterable[bytes]) -> None: + if self._writer is None: + raise RedisError("Connection already closed.") + + self._writer.writelines(command) + await self._writer.drain() + + async def send_packed_command( + self, + command: Union[bytes, str, Iterable[bytes]], + check_health: bool = True, + ): + """Send an already packed command to the Redis server""" + if not self._writer: + await self.connect() + # guard against health check recursion + if check_health: + await self.check_health() + try: + if isinstance(command, str): + command = command.encode() + if isinstance(command, bytes): + command = [command] + await asyncio.wait_for( + self._send_packed_command(command), + self.socket_timeout, + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError("Timeout writing to socket") from None + except OSError as e: + await self.disconnect() + if len(e.args) == 1: + err_no, errmsg = "UNKNOWN", e.args[0] + else: + err_no = e.args[0] + errmsg = e.args[1] + raise ConnectionError( + f"Error {err_no} while writing to socket. {errmsg}." + ) from e + except BaseException: + await self.disconnect() + raise + + async def send_command(self, *args, **kwargs): + """Pack and send a command to the Redis server""" + if not self.is_connected: + await self.connect() + await self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) + + async def can_read(self, timeout: float = 0): + """Poll the socket to see if there's data that can be read.""" + if not self.is_connected: + await self.connect() + return await self._parser.can_read(timeout) + + async def read_response(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + async with self._lock: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + + def pack_command(self, *args: EncodableT) -> List[bytes]: + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + assert not isinstance(args[0], float) + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encoder.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: + """Pack multiple commands into the Redis protocol""" + output: List[bytes] = [] + pieces: List[bytes] = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self.pack_command(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + +class SSLConnection(Connection): + def __init__( + self, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + ssl_check_hostname: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.ssl_context: RedisSSLContext = RedisSSLContext( + keyfile=ssl_keyfile, + certfile=ssl_certfile, + cert_reqs=ssl_cert_reqs, + ca_certs=ssl_ca_certs, + check_hostname=ssl_check_hostname, + ) + + @property + def keyfile(self): + return self.ssl_context.keyfile + + @property + def certfile(self): + return self.ssl_context.certfile + + @property + def cert_reqs(self): + return self.ssl_context.cert_reqs + + @property + def ca_certs(self): + return self.ssl_context.ca_certs + + @property + def check_hostname(self): + return self.ssl_context.check_hostname + + +class RedisSSLContext: + __slots__ = ( + "keyfile", + "certfile", + "cert_reqs", + "ca_certs", + "context", + "check_hostname", + ) + + def __init__( + self, + keyfile: Optional[str] = None, + certfile: Optional[str] = None, + cert_reqs: Optional[str] = None, + ca_certs: Optional[str] = None, + check_hostname: bool = False, + ): + self.keyfile = keyfile + self.certfile = certfile + if cert_reqs is None: + self.cert_reqs = ssl.CERT_NONE + elif isinstance(cert_reqs, str): + CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, + } + if cert_reqs not in CERT_REQS: + raise RedisError( + f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" + ) + self.cert_reqs = CERT_REQS[cert_reqs] + self.ca_certs = ca_certs + self.check_hostname = check_hostname + self.context: Optional[ssl.SSLContext] = None + + def get(self) -> ssl.SSLContext: + if not self.context: + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.certfile and self.keyfile: + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) + if self.ca_certs: + context.load_verify_locations(self.ca_certs) + self.context = context + return self.context + + +class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] + def __init__( + self, + *, + path: str = "", + db: Union[str, int] = 0, + username: Optional[str] = None, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0.0, + client_name: str = None, + retry: Optional[Retry] = None, + ): + """ + Initialize a new UnixDomainSocketConnection. + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + """ + self.pid = os.getpid() + self.path = path + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.retry_on_timeout = retry_on_timeout + if retry_on_timeout: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + else: + self.retry = Retry(NoBackoff(), 0) + self.health_check_interval = health_check_interval + self.next_health_check = -1 + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self._sock = None + self._reader = None + self._writer = None + self._socket_read_size = socket_read_size + self.set_parser(parser_class) + self._connect_callbacks = [] + self._buffer_cutoff = 6000 + self._lock = asyncio.Lock() + + def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: + pieces = [ + ("path", self.path), + ("db", self.db), + ] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + async def _connect(self): + async with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_unix_connection(path=self.path) + self._reader = reader + self._writer = writer + await self.on_connect() + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{self.path}. {exception.args[1]}." + ) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value) -> Optional[bool]: + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + } +) + + +class ConnectKwargs(TypedDict, total=False): + username: str + password: str + connection_class: Type[Connection] + host: str + port: int + db: int + path: str + + +def parse_url(url: str) -> ConnectKwargs: + parsed: ParseResult = urlparse(url) + kwargs: ConnectKwargs = {} + + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + # We can't type this. + kwargs[name] = parser(value) # type: ignore[misc] + except (TypeError, ValueError): + raise ValueError(f"Invalid value for `{name}` in connection URL.") + else: + kwargs[name] = value # type: ignore[misc] + + if parsed.username: + kwargs["username"] = unquote(parsed.username) + if parsed.password: + kwargs["password"] = unquote(parsed.password) + + # We only support redis://, rediss:// and unix:// schemes. + if parsed.scheme == "unix": + if parsed.path: + kwargs["path"] = unquote(parsed.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + elif parsed.scheme in ("redis", "rediss"): + if parsed.hostname: + kwargs["host"] = unquote(parsed.hostname) + if parsed.port: + kwargs["port"] = int(parsed.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if parsed.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(parsed.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if parsed.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + else: + valid_schemes = "redis://, rediss://, unix://" + raise ValueError( + f"Redis URL must specify one of the following schemes ({valid_schemes})" + ) + + return kwargs + + +_CP = TypeVar("_CP", bound="ConnectionPool") + + +class ConnectionPool: + """ + Create a connection pool. ``If max_connections`` is set, then this + object raises :py:class:`~redis.ConnectionError` when the pool's + limit is reached. + + By default, TCP connections are created unless ``connection_class`` + is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for + unix sockets. + + Any additional keyword arguments are passed to the constructor of + ``connection_class``. + """ + + @classmethod + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: + """ + Return a connection pool configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/redis> + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + <https://www.iana.org/assignments/uri-schemes/prov/rediss> + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + """ + url_options = parse_url(url) + kwargs.update(url_options) + return cls(**kwargs) + + def __init__( + self, + connection_class: Type[Connection] = Connection, + max_connections: Optional[int] = None, + **connection_kwargs, + ): + max_connections = max_connections or 2 ** 31 + if not isinstance(max_connections, int) or max_connections < 0: + raise ValueError('"max_connections" must be a positive integer') + + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections + + # a lock to protect the critical section in _checkpid(). + # this lock is acquired when the process id changes, such as + # after a fork. during this time, multiple threads in the child + # process could attempt to acquire this lock. the first thread + # to acquire the lock will reset the data structures and lock + # object of this pool. subsequent threads acquiring this lock + # will notice the first thread already did the work and simply + # release the lock. + self._fork_lock = threading.Lock() + self._lock = asyncio.Lock() + self._created_connections: int + self._available_connections: List[Connection] + self._in_use_connections: Set[Connection] + self.reset() # lgtm [py/init-calls-subclass] + self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<{self.connection_class(**self.connection_kwargs)!r}>" + ) + + def reset(self): + self._lock = asyncio.Lock() + self._created_connections = 0 + self._available_connections = [] + self._in_use_connections = set() + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def _checkpid(self): + # _checkpid() attempts to keep ConnectionPool fork-safe on modern + # systems. this is called by all ConnectionPool methods that + # manipulate the pool's state such as get_connection() and release(). + # + # _checkpid() determines whether the process has forked by comparing + # the current process id to the process id saved on the ConnectionPool + # instance. if these values are the same, _checkpid() simply returns. + # + # when the process ids differ, _checkpid() assumes that the process + # has forked and that we're now running in the child process. the child + # process cannot use the parent's file descriptors (e.g., sockets). + # therefore, when _checkpid() sees the process id change, it calls + # reset() in order to reinitialize the child's ConnectionPool. this + # will cause the child to make all new connection objects. + # + # _checkpid() is protected by self._fork_lock to ensure that multiple + # threads in the child process do not call reset() multiple times. + # + # there is an extremely small chance this could fail in the following + # scenario: + # 1. process A calls _checkpid() for the first time and acquires + # self._fork_lock. + # 2. while holding self._fork_lock, process A forks (the fork() + # could happen in a different thread owned by process A) + # 3. process B (the forked child process) inherits the + # ConnectionPool's state from the parent. that state includes + # a locked _fork_lock. process B will not be notified when + # process A releases the _fork_lock and will thus never be + # able to acquire the _fork_lock. + # + # to mitigate this possible deadlock, _checkpid() will only wait 5 + # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in + # that time it is assumed that the child is deadlocked and a + # redis.ChildDeadlockedError error is raised. + if self.pid != os.getpid(): + acquired = self._fork_lock.acquire(timeout=5) + if not acquired: + raise ChildDeadlockedError + # reset() the instance for the new process if another thread + # hasn't already done so + try: + if self.pid != os.getpid(): + self.reset() + finally: + self._fork_lock.release() + + async def get_connection(self, command_name, *keys, **options): + """Get a connection from the pool""" + self._checkpid() + async with self._lock: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't + # leak it + await self.release(connection) + raise + + return connection + + def get_encoder(self): + """Return an encoder based on encoding settings""" + kwargs = self.connection_kwargs + return self.encoder_class( + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), + ) + + def make_connection(self): + """Create a new connection""" + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) + + async def release(self, connection: Connection): + """Releases the connection back to the pool""" + self._checkpid() + async with self._lock: + try: + self._in_use_connections.remove(connection) + except KeyError: + # Gracefully fail when a connection is returned to this pool + # that the pool doesn't actually own + pass + + if self.owns_connection(connection): + self._available_connections.append(connection) + else: + # pool doesn't own this connection. do not add it back + # to the pool and decrement the count so that another + # connection can take its place if needed + self._created_connections -= 1 + await connection.disconnect() + return + + def owns_connection(self, connection: Connection): + return connection.pid == self.pid + + async def disconnect(self, inuse_connections: bool = True): + """ + Disconnects connections in the pool + + If ``inuse_connections`` is True, disconnect connections that are + current in use, potentially by other tasks. Otherwise only disconnect + connections that are idle in the pool. + """ + self._checkpid() + async with self._lock: + if inuse_connections: + connections: Iterable[Connection] = chain( + self._available_connections, self._in_use_connections + ) + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc + + +class BlockingConnectionPool(ConnectionPool): + """ + Thread-safe blocking connection pool:: + + >>> from redis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~redis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~redis.ConnectionError` (as the default + :py:class:`~redis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + max_connections: int = 50, + timeout: Optional[int] = 20, + connection_class: Type[Connection] = Connection, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + **connection_kwargs, + ): + + self.queue_class = queue_class + self.timeout = timeout + self._connections: List[Connection] + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, + ) + + def reset(self): + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except asyncio.QueueFull: + break + + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def make_connection(self): + """Make a fresh connection.""" + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + async def get_connection(self, command_name, *keys, **options): + """ + Get a connection, blocking for ``self.timeout`` until a connection + is available from the pool. + + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections + (having been returned to the pool after the initial ``None`` values + were added) will be returned before ``None`` values. This means we only + create new connections when we need to, i.e.: the actual number of + connections will only increase in response to demand. + """ + # Make sure we haven't changed process. + self._checkpid() + + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + async with async_timeout.timeout(self.timeout): + connection = await self.pool.get() + except (asyncio.QueueEmpty, asyncio.TimeoutError): + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't leak it + await self.release(connection) + raise + + return connection + + async def release(self, connection: Connection): + """Releases the connection back to the pool.""" + # Make sure we haven't changed process. + self._checkpid() + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + await connection.disconnect() + self.pool.put_nowait(None) + return + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except asyncio.QueueFull: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + + async def disconnect(self, inuse_connections: bool = True): + """Disconnects all connections in the pool.""" + self._checkpid() + async with self._lock: + resp = await asyncio.gather( + *(connection.disconnect() for connection in self._connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc diff --git a/redis/asyncio/lock.py b/redis/asyncio/lock.py new file mode 100644 index 0000000..d486132 --- /dev/null +++ b/redis/asyncio/lock.py @@ -0,0 +1,311 @@ +import asyncio +import sys +import threading +import uuid +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, NoReturn, Optional, Union + +from redis.exceptions import LockError, LockNotOwnedError + +if TYPE_CHECKING: + from redis.asyncio import Redis + + +class Lock: + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + lua_release = None + lua_extend = None + lua_reacquire = None + + # KEYS[1] - lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - additional milliseconds + # ARGV[3] - "0" if the additional time should be added to the lock's + # existing ttl or "1" if the existing ttl should be replaced + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + + local newttl = ARGV[2] + if ARGV[3] == "0" then + newttl = ARGV[2] + expiration + end + redis.call('pexpire', KEYS[1], newttl) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + + def __init__( + self, + redis: "Redis", + name: Union[str, bytes, memoryview], + timeout: Optional[float] = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: Optional[float] = None, + thread_local: bool = True, + ): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock in seconds. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep in seconds per loop + iteration when the lock is in blocking mode and another client is + currently holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.thread_local = bool(thread_local) + self.local = threading.local() if self.thread_local else SimpleNamespace() + self.local.token = None + self.register_scripts() + + def register_scripts(self): + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) + + async def __aenter__(self): + if await self.acquire(): + return self + raise LockError("Unable to acquire lock within the time specified") + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + + async def acquire( + self, + blocking: Optional[bool] = None, + blocking_timeout: Optional[float] = None, + token: Optional[Union[str, bytes]] = None, + ): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. + """ + if sys.version_info[0:2] != (3, 6): + loop = asyncio.get_running_loop() + else: + loop = asyncio.get_event_loop() + + sleep = self.sleep + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.connection_pool.get_encoder() + token = encoder.encode(token) + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = loop.time() + blocking_timeout + while True: + if await self.do_acquire(token): + self.local.token = token + return True + if not blocking: + return False + next_try_at = loop.time() + sleep + if stop_trying_at is not None and next_try_at > stop_trying_at: + return False + await asyncio.sleep(sleep) + + async def do_acquire(self, token: Union[str, bytes]) -> bool: + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if await self.redis.set(self.name, token, nx=True, px=timeout): + return True + return False + + async def locked(self) -> bool: + """ + Returns True if this key is locked by any process, otherwise False. + """ + return await self.redis.get(self.name) is not None + + async def owned(self) -> bool: + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = await self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.connection_pool.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and stored_token == self.local.token + + def release(self) -> Awaitable[NoReturn]: + """Releases the already acquired lock""" + expected_token = self.local.token + if expected_token is None: + raise LockError("Cannot release an unlocked lock") + self.local.token = None + return self.do_release(expected_token) + + async def do_release(self, expected_token: bytes): + if not bool( + await self.lua_release( + keys=[self.name], args=[expected_token], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") + + def extend( + self, additional_time: float, replace_ttl: bool = False + ) -> Awaitable[bool]: + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + + ``replace_ttl`` if False (the default), add `additional_time` to + the lock's existing ttl. If True, replace the lock's ttl with + `additional_time`. + """ + if self.local.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time, replace_ttl) + + async def do_extend(self, additional_time, replace_ttl) -> bool: + additional_time = int(additional_time * 1000) + if not bool( + await self.lua_extend( + keys=[self.name], + args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + client=self.redis, + ) + ): + raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") + return True + + def reacquire(self) -> Awaitable[bool]: + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + async def do_reacquire(self) -> bool: + timeout = int(self.timeout * 1000) + if not bool( + await self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") + return True diff --git a/redis/asyncio/retry.py b/redis/asyncio/retry.py new file mode 100644 index 0000000..9b53494 --- /dev/null +++ b/redis/asyncio/retry.py @@ -0,0 +1,58 @@ +from asyncio import sleep +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, TypeVar + +from redis.exceptions import ConnectionError, RedisError, TimeoutError + +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff + + +T = TypeVar("T") + + +class Retry: + """Retry a specific number of times after a failure""" + + __slots__ = "_backoff", "_retries", "_supported_errors" + + def __init__( + self, + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[RedisError, ...] = ( + ConnectionError, + TimeoutError, + ), + ): + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors + + async def call_with_retry( + self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any] + ) -> T: + """ + Execute an operation that might fail and returns its result, or + raise the exception that was thrown depending on the `Backoff` object. + `do`: the operation to call. Expects no argument. + `fail`: the failure handler, expects the last error that was thrown + """ + self._backoff.reset() + failures = 0 + while True: + try: + return await do() + except self._supported_errors as error: + failures += 1 + await fail(error) + if failures > self._retries: + raise error + backoff = self._backoff.compute(failures) + if backoff > 0: + await sleep(backoff) diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py new file mode 100644 index 0000000..5aefd09 --- /dev/null +++ b/redis/asyncio/sentinel.py @@ -0,0 +1,353 @@ +import asyncio +import random +import weakref +from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type + +from redis.asyncio.client import Redis +from redis.asyncio.connection import ( + Connection, + ConnectionPool, + EncodableT, + SSLConnection, +) +from redis.commands import AsyncSentinelCommands +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{self.__class__.__name__}<service={pool.service_name}" + if self.host: + host_info = f",host={self.host},port={self.port}" + s += host_info + return s + ">" + + async def connect_to(self, address): + self.host, self.port = address + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + async def connect(self): + if self._reader: + return # already connected + if self.connection_pool.is_master: + await self.connect_to(await self.connection_pool.get_master_address()) + else: + async for slave in self.connection_pool.rotate_slaves(): + try: + return await self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + async def read_response(self, disable_decoding: bool = False): + try: + return await super().read_response(disable_decoding=disable_decoding) + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + await self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection): + pass + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.master_address = None + self.slave_rr_counter = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<service={self.service_name}({self.is_master and 'master' or 'slave'})>" + ) + + def reset(self): + super().reset() + self.master_address = None + self.slave_rr_counter = None + + def owns_connection(self, connection: Connection): + check = not self.is_master or ( + self.is_master and self.master_address == (connection.host, connection.port) + ) + return check and super().owns_connection(connection) + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master: + if self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + await self.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + """Round-robin slave balancer""" + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class Sentinel(AsyncSentinelCommands): + """ + Redis Sentinel cluster client + + >>> from redis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> await master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> await slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(host=hostname, port=port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + async def execute_command(self, *args, **kwargs): + """ + Execute Sentinel command in sentinel nodes. + once - If set to True, then execute the resulting command on a single + node at random, rather than across the entire sentinel cluster. + """ + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") + + if once: + tasks = [ + asyncio.Task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + await asyncio.gather(*tasks) + else: + await random.choice(self.sentinels).execute_command(*args, **kwargs) + return True + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + f"{sentinel.connection_pool.connection_kwargs['host']}:" + f"{sentinel.connection_pool.connection_kwargs['port']}" + ) + return f"{self.__class__.__name__}<sentinels=[{','.join(sentinel_addresses)}]>" + + def check_master_state(self, state: dict, service_name: str) -> bool: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + async def discover_master(self, service_name: str): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = await sentinel.sentinel_masters() + except (ConnectionError, TimeoutError): + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + raise MasterNotFoundError(f"No master found for {service_name!r}") + + def filter_slaves( + self, slaves: Iterable[Mapping] + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Remove slaves that are in an ODOWN or SDOWN state""" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + async def discover_slaves( + self, service_name: str + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Returns a list of alive slaves for service ``service_name``""" + for sentinel in self.sentinels: + try: + slaves = await sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrieve the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) + + def slave_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrieve the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) diff --git a/redis/asyncio/utils.py b/redis/asyncio/utils.py new file mode 100644 index 0000000..5a55b36 --- /dev/null +++ b/redis/asyncio/utils.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from redis.asyncio.client import Pipeline, Redis + + +def from_url(url, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from redis.asyncio.client import Redis + + return Redis.from_url(url, **kwargs) + + +class pipeline: + def __init__(self, redis_obj: "Redis"): + self.p: "Pipeline" = redis_obj.pipeline() + + async def __aenter__(self) -> "Pipeline": + return self.p + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.p.execute() + del self.p |