summaryrefslogtreecommitdiff
path: root/redis/asyncio/connection.py
diff options
context:
space:
mode:
authorAndrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com>2022-02-22 05:29:55 -0500
committerGitHub <noreply@github.com>2022-02-22 12:29:55 +0200
commitd56baeb683fc1935cfa343fa2eeb0fa9bd955283 (patch)
tree47357a74bf1d1428cfbcf0d8b2c781f1f971cf77 /redis/asyncio/connection.py
parente3c989d93e914e6502bd5a72f15ded49a135c5be (diff)
downloadredis-py-d56baeb683fc1935cfa343fa2eeb0fa9bd955283.tar.gz
Add Async Support (#1899)
Co-authored-by: Chayim I. Kirshen <c@kirshen.com> Co-authored-by: dvora-h <dvora.heller@redis.com>
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r--redis/asyncio/connection.py1707
1 files changed, 1707 insertions, 0 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
new file mode 100644
index 0000000..ae54d81
--- /dev/null
+++ b/redis/asyncio/connection.py
@@ -0,0 +1,1707 @@
+import asyncio
+import copy
+import enum
+import errno
+import inspect
+import io
+import os
+import socket
+import ssl
+import sys
+import threading
+import weakref
+from itertools import chain
+from types import MappingProxyType
+from typing import (
+ Any,
+ Callable,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+from urllib.parse import ParseResult, parse_qs, unquote, urlparse
+
+import async_timeout
+
+from redis.asyncio.retry import Retry
+from redis.backoff import NoBackoff
+from redis.compat import Protocol, TypedDict
+from redis.exceptions import (
+ AuthenticationError,
+ AuthenticationWrongNumberOfArgsError,
+ BusyLoadingError,
+ ChildDeadlockedError,
+ ConnectionError,
+ DataError,
+ ExecAbortError,
+ InvalidResponse,
+ ModuleError,
+ NoPermissionError,
+ NoScriptError,
+ ReadOnlyError,
+ RedisError,
+ ResponseError,
+ TimeoutError,
+)
+from redis.typing import EncodableT, EncodedT
+from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
+
+hiredis = None
+if HIREDIS_AVAILABLE:
+ import hiredis
+
+NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {
+ BlockingIOError: errno.EWOULDBLOCK,
+ ssl.SSLWantReadError: 2,
+ ssl.SSLWantWriteError: 2,
+ ssl.SSLError: 2,
+}
+
+NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys())
+
+
+SYM_STAR = b"*"
+SYM_DOLLAR = b"$"
+SYM_CRLF = b"\r\n"
+SYM_LF = b"\n"
+SYM_EMPTY = b""
+
+SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
+
+
+class _Sentinel(enum.Enum):
+ sentinel = object()
+
+
+SENTINEL = _Sentinel.sentinel
+MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
+NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
+MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
+MODULE_EXPORTS_DATA_TYPES_ERROR = (
+ "Error unloading module: the module "
+ "exports one or more module-side data "
+ "types, can't unload"
+)
+
+
+class _HiredisReaderArgs(TypedDict, total=False):
+ protocolError: Callable[[str], Exception]
+ replyError: Callable[[str], Exception]
+ encoding: Optional[str]
+ errors: Optional[str]
+
+
+class Encoder:
+ """Encode strings to bytes-like and decode bytes-like to strings"""
+
+ __slots__ = "encoding", "encoding_errors", "decode_responses"
+
+ def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool):
+ self.encoding = encoding
+ self.encoding_errors = encoding_errors
+ self.decode_responses = decode_responses
+
+ def encode(self, value: EncodableT) -> EncodedT:
+ """Return a bytestring or bytes-like representation of the value"""
+ if isinstance(value, (bytes, memoryview)):
+ return value
+ if isinstance(value, bool):
+ # special case bool since it is a subclass of int
+ raise DataError(
+ "Invalid input of type: 'bool'. "
+ "Convert to a bytes, string, int or float first."
+ )
+ if isinstance(value, (int, float)):
+ return repr(value).encode()
+ if not isinstance(value, str):
+ # a value we don't know how to deal with. throw an error
+ typename = value.__class__.__name__ # type: ignore[unreachable]
+ raise DataError(
+ f"Invalid input of type: {typename!r}. "
+ "Convert to a bytes, string, int or float first."
+ )
+ return value.encode(self.encoding, self.encoding_errors)
+
+ def decode(self, value: EncodableT, force=False) -> EncodableT:
+ """Return a unicode string from the bytes-like representation"""
+ if self.decode_responses or force:
+ if isinstance(value, memoryview):
+ return value.tobytes().decode(self.encoding, self.encoding_errors)
+ if isinstance(value, bytes):
+ return value.decode(self.encoding, self.encoding_errors)
+ return value
+
+
+ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]]
+
+
+class BaseParser:
+ """Plain Python parsing class"""
+
+ __slots__ = "_stream", "_buffer", "_read_size"
+
+ EXCEPTION_CLASSES: ExceptionMappingT = {
+ "ERR": {
+ "max number of clients reached": ConnectionError,
+ "Client sent AUTH, but no password is set": AuthenticationError,
+ "invalid password": AuthenticationError,
+ # some Redis server versions report invalid command syntax
+ # in lowercase
+ "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501
+ # some Redis server versions report invalid command syntax
+ # in uppercase
+ "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501
+ MODULE_LOAD_ERROR: ModuleError,
+ MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
+ NO_SUCH_MODULE_ERROR: ModuleError,
+ MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
+ },
+ "EXECABORT": ExecAbortError,
+ "LOADING": BusyLoadingError,
+ "NOSCRIPT": NoScriptError,
+ "READONLY": ReadOnlyError,
+ "NOAUTH": AuthenticationError,
+ "NOPERM": NoPermissionError,
+ }
+
+ def __init__(self, socket_read_size: int):
+ self._stream: Optional[asyncio.StreamReader] = None
+ self._buffer: Optional[SocketBuffer] = None
+ self._read_size = socket_read_size
+
+ def __del__(self):
+ try:
+ self.on_disconnect()
+ except Exception:
+ pass
+
+ def parse_error(self, response: str) -> ResponseError:
+ """Parse an error response"""
+ error_code = response.split(" ")[0]
+ if error_code in self.EXCEPTION_CLASSES:
+ response = response[len(error_code) + 1 :]
+ exception_class = self.EXCEPTION_CLASSES[error_code]
+ if isinstance(exception_class, dict):
+ exception_class = exception_class.get(response, ResponseError)
+ return exception_class(response)
+ return ResponseError(response)
+
+ def on_disconnect(self):
+ raise NotImplementedError()
+
+ def on_connect(self, connection: "Connection"):
+ raise NotImplementedError()
+
+ async def can_read(self, timeout: float) -> bool:
+ raise NotImplementedError()
+
+ async def read_response(
+ self, disable_decoding: bool = False
+ ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
+ raise NotImplementedError()
+
+
+class SocketBuffer:
+ """Async-friendly re-impl of redis-py's SocketBuffer.
+
+ TODO: We're currently passing through two buffers,
+ the asyncio.StreamReader and this. I imagine we can reduce the layers here
+ while maintaining compliance with prior art.
+ """
+
+ def __init__(
+ self,
+ stream_reader: asyncio.StreamReader,
+ socket_read_size: int,
+ socket_timeout: Optional[float],
+ ):
+ self._stream: Optional[asyncio.StreamReader] = stream_reader
+ self.socket_read_size = socket_read_size
+ self.socket_timeout = socket_timeout
+ self._buffer: Optional[io.BytesIO] = io.BytesIO()
+ # number of bytes written to the buffer from the socket
+ self.bytes_written = 0
+ # number of bytes read from the buffer
+ self.bytes_read = 0
+
+ @property
+ def length(self):
+ return self.bytes_written - self.bytes_read
+
+ async def _read_from_socket(
+ self,
+ length: Optional[int] = None,
+ timeout: Union[float, None, _Sentinel] = SENTINEL,
+ raise_on_timeout: bool = True,
+ ) -> bool:
+ buf = self._buffer
+ if buf is None or self._stream is None:
+ raise RedisError("Buffer is closed.")
+ buf.seek(self.bytes_written)
+ marker = 0
+ timeout = timeout if timeout is not SENTINEL else self.socket_timeout
+
+ try:
+ while True:
+ async with async_timeout.timeout(timeout):
+ data = await self._stream.read(self.socket_read_size)
+ # an empty string indicates the server shutdown the socket
+ if isinstance(data, bytes) and len(data) == 0:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ buf.write(data)
+ data_length = len(data)
+ self.bytes_written += data_length
+ marker += data_length
+
+ if length is not None and length > marker:
+ continue
+ return True
+ except (socket.timeout, asyncio.TimeoutError):
+ if raise_on_timeout:
+ raise TimeoutError("Timeout reading from socket")
+ return False
+ except NONBLOCKING_EXCEPTIONS as ex:
+ # if we're in nonblocking mode and the recv raises a
+ # blocking error, simply return False indicating that
+ # there's no data to be read. otherwise raise the
+ # original exception.
+ allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
+ if not raise_on_timeout and ex.errno == allowed:
+ return False
+ raise ConnectionError(f"Error while reading from socket: {ex.args}")
+
+ async def can_read(self, timeout: float) -> bool:
+ return bool(self.length) or await self._read_from_socket(
+ timeout=timeout, raise_on_timeout=False
+ )
+
+ async def read(self, length: int) -> bytes:
+ length = length + 2 # make sure to read the \r\n terminator
+ # make sure we've read enough data from the socket
+ if length > self.length:
+ await self._read_from_socket(length - self.length)
+
+ if self._buffer is None:
+ raise RedisError("Buffer is closed.")
+
+ self._buffer.seek(self.bytes_read)
+ data = self._buffer.read(length)
+ self.bytes_read += len(data)
+
+ # purge the buffer when we've consumed it all so it doesn't
+ # grow forever
+ if self.bytes_read == self.bytes_written:
+ self.purge()
+
+ return data[:-2]
+
+ async def readline(self) -> bytes:
+ buf = self._buffer
+ if buf is None:
+ raise RedisError("Buffer is closed.")
+
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+ while not data.endswith(SYM_CRLF):
+ # there's more data in the socket that we need
+ await self._read_from_socket()
+ buf.seek(self.bytes_read)
+ data = buf.readline()
+
+ self.bytes_read += len(data)
+
+ # purge the buffer when we've consumed it all so it doesn't
+ # grow forever
+ if self.bytes_read == self.bytes_written:
+ self.purge()
+
+ return data[:-2]
+
+ def purge(self):
+ if self._buffer is None:
+ raise RedisError("Buffer is closed.")
+
+ self._buffer.seek(0)
+ self._buffer.truncate()
+ self.bytes_written = 0
+ self.bytes_read = 0
+
+ def close(self):
+ try:
+ self.purge()
+ self._buffer.close() # type: ignore[union-attr]
+ except Exception:
+ # issue #633 suggests the purge/close somehow raised a
+ # BadFileDescriptor error. Perhaps the client ran out of
+ # memory or something else? It's probably OK to ignore
+ # any error being raised from purge/close since we're
+ # removing the reference to the instance below.
+ pass
+ self._buffer = None
+ self._stream = None
+
+
+class PythonParser(BaseParser):
+ """Plain Python parsing class"""
+
+ __slots__ = BaseParser.__slots__ + ("encoder",)
+
+ def __init__(self, socket_read_size: int):
+ super().__init__(socket_read_size)
+ self.encoder: Optional[Encoder] = None
+
+ def on_connect(self, connection: "Connection"):
+ """Called when the stream connects"""
+ self._stream = connection._reader
+ if self._stream is None:
+ raise RedisError("Buffer is closed.")
+
+ self._buffer = SocketBuffer(
+ self._stream, self._read_size, connection.socket_timeout
+ )
+ self.encoder = connection.encoder
+
+ def on_disconnect(self):
+ """Called when the stream disconnects"""
+ if self._stream is not None:
+ self._stream = None
+ if self._buffer is not None:
+ self._buffer.close()
+ self._buffer = None
+ self.encoder = None
+
+ async def can_read(self, timeout: float):
+ return self._buffer and bool(await self._buffer.can_read(timeout))
+
+ async def read_response(
+ self, disable_decoding: bool = False
+ ) -> Union[EncodableT, ResponseError, None]:
+ if not self._buffer or not self.encoder:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ raw = await self._buffer.readline()
+ if not raw:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+ response: Any
+ byte, response = raw[:1], raw[1:]
+
+ if byte not in (b"-", b"+", b":", b"$", b"*"):
+ raise InvalidResponse(f"Protocol Error: {raw!r}")
+
+ # server returned an error
+ if byte == b"-":
+ response = response.decode("utf-8", errors="replace")
+ error = self.parse_error(response)
+ # if the error is a ConnectionError, raise immediately so the user
+ # is notified
+ if isinstance(error, ConnectionError):
+ raise error
+ # otherwise, we're dealing with a ResponseError that might belong
+ # inside a pipeline response. the connection's read_response()
+ # and/or the pipeline's execute() will raise this error if
+ # necessary, so just return the exception instance here.
+ return error
+ # single value
+ elif byte == b"+":
+ pass
+ # int value
+ elif byte == b":":
+ response = int(response)
+ # bulk response
+ elif byte == b"$":
+ length = int(response)
+ if length == -1:
+ return None
+ response = await self._buffer.read(length)
+ # multi-bulk response
+ elif byte == b"*":
+ length = int(response)
+ if length == -1:
+ return None
+ response = [
+ (await self.read_response(disable_decoding)) for _ in range(length)
+ ]
+ if isinstance(response, bytes) and disable_decoding is False:
+ response = self.encoder.decode(response)
+ return response
+
+
+class HiredisParser(BaseParser):
+ """Parser class for connections using Hiredis"""
+
+ __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")
+
+ _next_response: bool
+
+ def __init__(self, socket_read_size: int):
+ if not HIREDIS_AVAILABLE:
+ raise RedisError("Hiredis is not available.")
+ super().__init__(socket_read_size=socket_read_size)
+ self._reader: Optional[hiredis.Reader] = None
+ self._socket_timeout: Optional[float] = None
+
+ def on_connect(self, connection: "Connection"):
+ self._stream = connection._reader
+ kwargs: _HiredisReaderArgs = {
+ "protocolError": InvalidResponse,
+ "replyError": self.parse_error,
+ }
+ if connection.encoder.decode_responses:
+ kwargs["encoding"] = connection.encoder.encoding
+ kwargs["errors"] = connection.encoder.encoding_errors
+
+ self._reader = hiredis.Reader(**kwargs)
+ self._next_response = False
+ self._socket_timeout = connection.socket_timeout
+
+ def on_disconnect(self):
+ self._stream = None
+ self._reader = None
+ self._next_response = False
+
+ async def can_read(self, timeout: float):
+ if not self._reader:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
+
+ if self._next_response is False:
+ self._next_response = self._reader.gets()
+ if self._next_response is False:
+ return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
+ return True
+
+ async def read_from_socket(
+ self,
+ timeout: Union[float, None, _Sentinel] = SENTINEL,
+ raise_on_timeout: bool = True,
+ ):
+ if self._stream is None or self._reader is None:
+ raise RedisError("Parser already closed.")
+
+ timeout = self._socket_timeout if timeout is SENTINEL else timeout
+ try:
+ async with async_timeout.timeout(timeout):
+ buffer = await self._stream.read(self._read_size)
+ if not isinstance(buffer, bytes) or len(buffer) == 0:
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
+ self._reader.feed(buffer)
+ # data was read from the socket and added to the buffer.
+ # return True to indicate that data was read.
+ return True
+ except asyncio.CancelledError:
+ raise
+ except (socket.timeout, asyncio.TimeoutError):
+ if raise_on_timeout:
+ raise TimeoutError("Timeout reading from socket") from None
+ return False
+ except NONBLOCKING_EXCEPTIONS as ex:
+ # if we're in nonblocking mode and the recv raises a
+ # blocking error, simply return False indicating that
+ # there's no data to be read. otherwise raise the
+ # original exception.
+ allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1)
+ if not raise_on_timeout and ex.errno == allowed:
+ return False
+ raise ConnectionError(f"Error while reading from socket: {ex.args}")
+
+ async def read_response(
+ self, disable_decoding: bool = False
+ ) -> Union[EncodableT, List[EncodableT]]:
+ if not self._stream or not self._reader:
+ self.on_disconnect()
+ raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
+
+ response: Union[
+ EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]
+ ]
+ # _next_response might be cached from a can_read() call
+ if self._next_response is not False:
+ response = self._next_response
+ self._next_response = False
+ return response
+
+ response = self._reader.gets()
+ while response is False:
+ await self.read_from_socket()
+ response = self._reader.gets()
+
+ # if the response is a ConnectionError or the response is a list and
+ # the first item is a ConnectionError, raise it as something bad
+ # happened
+ if isinstance(response, ConnectionError):
+ raise response
+ elif (
+ isinstance(response, list)
+ and response
+ and isinstance(response[0], ConnectionError)
+ ):
+ raise response[0]
+ # cast as there won't be a ConnectionError here.
+ return cast(Union[EncodableT, List[EncodableT]], response)
+
+
+DefaultParser: Type[Union[PythonParser, HiredisParser]]
+if HIREDIS_AVAILABLE:
+ DefaultParser = HiredisParser
+else:
+ DefaultParser = PythonParser
+
+
+class ConnectCallbackProtocol(Protocol):
+ def __call__(self, connection: "Connection"):
+ ...
+
+
+class AsyncConnectCallbackProtocol(Protocol):
+ async def __call__(self, connection: "Connection"):
+ ...
+
+
+ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol]
+
+
+class Connection:
+ """Manages TCP communication to and from a Redis server"""
+
+ __slots__ = (
+ "pid",
+ "host",
+ "port",
+ "db",
+ "username",
+ "client_name",
+ "password",
+ "socket_timeout",
+ "socket_connect_timeout",
+ "socket_keepalive",
+ "socket_keepalive_options",
+ "socket_type",
+ "redis_connect_func",
+ "retry_on_timeout",
+ "health_check_interval",
+ "next_health_check",
+ "last_active_at",
+ "encoder",
+ "ssl_context",
+ "_reader",
+ "_writer",
+ "_parser",
+ "_connect_callbacks",
+ "_buffer_cutoff",
+ "_lock",
+ "_socket_read_size",
+ "__dict__",
+ )
+
+ def __init__(
+ self,
+ *,
+ host: str = "localhost",
+ port: Union[str, int] = 6379,
+ db: Union[str, int] = 0,
+ password: Optional[str] = None,
+ socket_timeout: Optional[float] = None,
+ socket_connect_timeout: Optional[float] = None,
+ socket_keepalive: bool = False,
+ socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None,
+ socket_type: int = 0,
+ retry_on_timeout: bool = False,
+ encoding: str = "utf-8",
+ encoding_errors: str = "strict",
+ decode_responses: bool = False,
+ parser_class: Type[BaseParser] = DefaultParser,
+ socket_read_size: int = 65536,
+ health_check_interval: float = 0,
+ client_name: Optional[str] = None,
+ username: Optional[str] = None,
+ retry: Optional[Retry] = None,
+ redis_connect_func: Optional[ConnectCallbackT] = None,
+ encoder_class: Type[Encoder] = Encoder,
+ ):
+ self.pid = os.getpid()
+ self.host = host
+ self.port = int(port)
+ self.db = db
+ self.username = username
+ self.client_name = client_name
+ self.password = password
+ self.socket_timeout = socket_timeout
+ self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
+ self.socket_keepalive = socket_keepalive
+ self.socket_keepalive_options = socket_keepalive_options or {}
+ self.socket_type = socket_type
+ self.retry_on_timeout = retry_on_timeout
+ if retry_on_timeout:
+ if retry is None:
+ self.retry = Retry(NoBackoff(), 1)
+ else:
+ # deep-copy the Retry object as it is mutable
+ self.retry = copy.deepcopy(retry)
+ else:
+ self.retry = Retry(NoBackoff(), 0)
+ self.health_check_interval = health_check_interval
+ self.next_health_check: float = -1
+ self.ssl_context: Optional[RedisSSLContext] = None
+ self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
+ self.redis_connect_func = redis_connect_func
+ self._reader: Optional[asyncio.StreamReader] = None
+ self._writer: Optional[asyncio.StreamWriter] = None
+ self._socket_read_size = socket_read_size
+ self.set_parser(parser_class)
+ self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
+ self._buffer_cutoff = 6000
+ self._lock = asyncio.Lock()
+
+ def __repr__(self):
+ repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
+ return f"{self.__class__.__name__}<{repr_args}>"
+
+ def repr_pieces(self):
+ pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
+ if self.client_name:
+ pieces.append(("client_name", self.client_name))
+ return pieces
+
+ def __del__(self):
+ try:
+ if self.is_connected:
+ loop = asyncio.get_event_loop()
+ coro = self.disconnect()
+ if loop.is_running():
+ loop.create_task(coro)
+ else:
+ loop.run_until_complete(coro)
+ except Exception:
+ pass
+
+ @property
+ def is_connected(self):
+ return bool(self._reader and self._writer)
+
+ def register_connect_callback(self, callback):
+ self._connect_callbacks.append(weakref.WeakMethod(callback))
+
+ def clear_connect_callbacks(self):
+ self._connect_callbacks = []
+
+ def set_parser(self, parser_class):
+ """
+ Creates a new instance of parser_class with socket size:
+ _socket_read_size and assigns it to the parser for the connection
+ :param parser_class: The required parser class
+ """
+ self._parser = parser_class(socket_read_size=self._socket_read_size)
+
+ async def connect(self):
+ """Connects to the Redis server if not already connected"""
+ if self.is_connected:
+ return
+ try:
+ await self._connect()
+ except asyncio.CancelledError:
+ raise
+ except (socket.timeout, asyncio.TimeoutError):
+ raise TimeoutError("Timeout connecting to server")
+ except OSError as e:
+ raise ConnectionError(self._error_message(e))
+ except Exception as exc:
+ raise ConnectionError(exc) from exc
+
+ try:
+ if self.redis_connect_func is None:
+ # Use the default on_connect function
+ await self.on_connect()
+ else:
+ # Use the passed function redis_connect_func
+ await self.redis_connect_func(self) if asyncio.iscoroutinefunction(
+ self.redis_connect_func
+ ) else self.redis_connect_func(self)
+ except RedisError:
+ # clean up after any error in on_connect
+ await self.disconnect()
+ raise
+
+ # run any user callbacks. right now the only internal callback
+ # is for pubsub channel/pattern resubscription
+ for ref in self._connect_callbacks:
+ callback = ref()
+ task = callback(self)
+ if task and inspect.isawaitable(task):
+ await task
+
+ async def _connect(self):
+ """Create a TCP socket connection"""
+ async with async_timeout.timeout(self.socket_connect_timeout):
+ reader, writer = await asyncio.open_connection(
+ host=self.host,
+ port=self.port,
+ ssl=self.ssl_context.get() if self.ssl_context else None,
+ )
+ self._reader = reader
+ self._writer = writer
+ sock = writer.transport.get_extra_info("socket")
+ if sock is not None:
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ try:
+ # TCP_KEEPALIVE
+ if self.socket_keepalive:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ for k, v in self.socket_keepalive_options.items():
+ sock.setsockopt(socket.SOL_TCP, k, v)
+
+ except (OSError, TypeError):
+ # `socket_keepalive_options` might contain invalid options
+ # causing an error. Do not leave the connection open.
+ writer.close()
+ raise
+
+ def _error_message(self, exception):
+ # args for socket.error can either be (errno, "message")
+ # or just "message"
+ if len(exception.args) == 1:
+ return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}."
+ else:
+ return (
+ f"Error {exception.args[0]} connecting to {self.host}:{self.port}. "
+ f"{exception.args[0]}."
+ )
+
+ async def on_connect(self):
+ """Initialize the connection, authenticate and select a database"""
+ self._parser.on_connect(self)
+
+ # if username and/or password are set, authenticate
+ if self.username or self.password:
+ auth_args: Union[Tuple[str], Tuple[str, str]]
+ if self.username:
+ auth_args = (self.username, self.password or "")
+ else:
+ # Mypy bug: https://github.com/python/mypy/issues/10944
+ auth_args = (self.password or "",)
+ # avoid checking health here -- PING will fail if we try
+ # to check the health prior to the AUTH
+ await self.send_command("AUTH", *auth_args, check_health=False)
+
+ try:
+ auth_response = await self.read_response()
+ except AuthenticationWrongNumberOfArgsError:
+ # a username and password were specified but the Redis
+ # server seems to be < 6.0.0 which expects a single password
+ # arg. retry auth with just the password.
+ # https://github.com/andymccurdy/redis-py/issues/1274
+ await self.send_command("AUTH", self.password, check_health=False)
+ auth_response = await self.read_response()
+
+ if str_if_bytes(auth_response) != "OK":
+ raise AuthenticationError("Invalid Username or Password")
+
+ # if a client_name is given, set it
+ if self.client_name:
+ await self.send_command("CLIENT", "SETNAME", self.client_name)
+ if str_if_bytes(await self.read_response()) != "OK":
+ raise ConnectionError("Error setting client name")
+
+ # if a database is specified, switch to it
+ if self.db:
+ await self.send_command("SELECT", self.db)
+ if str_if_bytes(await self.read_response()) != "OK":
+ raise ConnectionError("Invalid Database")
+
+ async def disconnect(self):
+ """Disconnects from the Redis server"""
+ try:
+ async with async_timeout.timeout(self.socket_connect_timeout):
+ self._parser.on_disconnect()
+ if not self.is_connected:
+ return
+ try:
+ if os.getpid() == self.pid:
+ self._writer.close() # type: ignore[union-attr]
+ # py3.6 doesn't have this method
+ if hasattr(self._writer, "wait_closed"):
+ await self._writer.wait_closed() # type: ignore[union-attr]
+ except OSError:
+ pass
+ self._reader = None
+ self._writer = None
+ except asyncio.TimeoutError:
+ raise TimeoutError(
+ f"Timed out closing connection after {self.socket_connect_timeout}"
+ ) from None
+
+ async def _send_ping(self):
+ """Send PING, expect PONG in return"""
+ await self.send_command("PING", check_health=False)
+ if str_if_bytes(await self.read_response()) != "PONG":
+ raise ConnectionError("Bad response from PING health check")
+
+ async def _ping_failed(self, error):
+ """Function to call when PING fails"""
+ await self.disconnect()
+
+ async def check_health(self):
+ """Check the health of the connection with a PING/PONG"""
+ if sys.version_info[0:2] == (3, 6):
+ func = asyncio.get_event_loop
+ else:
+ func = asyncio.get_running_loop
+
+ if self.health_check_interval and func().time() > self.next_health_check:
+ await self.retry.call_with_retry(self._send_ping, self._ping_failed)
+
+ async def _send_packed_command(self, command: Iterable[bytes]) -> None:
+ if self._writer is None:
+ raise RedisError("Connection already closed.")
+
+ self._writer.writelines(command)
+ await self._writer.drain()
+
+ async def send_packed_command(
+ self,
+ command: Union[bytes, str, Iterable[bytes]],
+ check_health: bool = True,
+ ):
+ """Send an already packed command to the Redis server"""
+ if not self._writer:
+ await self.connect()
+ # guard against health check recursion
+ if check_health:
+ await self.check_health()
+ try:
+ if isinstance(command, str):
+ command = command.encode()
+ if isinstance(command, bytes):
+ command = [command]
+ await asyncio.wait_for(
+ self._send_packed_command(command),
+ self.socket_timeout,
+ )
+ except asyncio.TimeoutError:
+ await self.disconnect()
+ raise TimeoutError("Timeout writing to socket") from None
+ except OSError as e:
+ await self.disconnect()
+ if len(e.args) == 1:
+ err_no, errmsg = "UNKNOWN", e.args[0]
+ else:
+ err_no = e.args[0]
+ errmsg = e.args[1]
+ raise ConnectionError(
+ f"Error {err_no} while writing to socket. {errmsg}."
+ ) from e
+ except BaseException:
+ await self.disconnect()
+ raise
+
+ async def send_command(self, *args, **kwargs):
+ """Pack and send a command to the Redis server"""
+ if not self.is_connected:
+ await self.connect()
+ await self.send_packed_command(
+ self.pack_command(*args), check_health=kwargs.get("check_health", True)
+ )
+
+ async def can_read(self, timeout: float = 0):
+ """Poll the socket to see if there's data that can be read."""
+ if not self.is_connected:
+ await self.connect()
+ return await self._parser.can_read(timeout)
+
+ async def read_response(self, disable_decoding: bool = False):
+ """Read the response from a previously sent command"""
+ try:
+ async with self._lock:
+ async with async_timeout.timeout(self.socket_timeout):
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding
+ )
+ except asyncio.TimeoutError:
+ await self.disconnect()
+ raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
+ except OSError as e:
+ await self.disconnect()
+ raise ConnectionError(
+ f"Error while reading from {self.host}:{self.port} : {e.args}"
+ )
+ except BaseException:
+ await self.disconnect()
+ raise
+
+ if self.health_check_interval:
+ if sys.version_info[0:2] == (3, 6):
+ func = asyncio.get_event_loop
+ else:
+ func = asyncio.get_running_loop
+ self.next_health_check = func().time() + self.health_check_interval
+
+ if isinstance(response, ResponseError):
+ raise response from None
+ return response
+
+ def pack_command(self, *args: EncodableT) -> List[bytes]:
+ """Pack a series of arguments into the Redis protocol"""
+ output = []
+ # the client might have included 1 or more literal arguments in
+ # the command name, e.g., 'CONFIG GET'. The Redis server expects these
+ # arguments to be sent separately, so split the first argument
+ # manually. These arguments should be bytestrings so that they are
+ # not encoded.
+ assert not isinstance(args[0], float)
+ if isinstance(args[0], str):
+ args = tuple(args[0].encode().split()) + args[1:]
+ elif b" " in args[0]:
+ args = tuple(args[0].split()) + args[1:]
+
+ buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF))
+
+ buffer_cutoff = self._buffer_cutoff
+ for arg in map(self.encoder.encode, args):
+ # to avoid large string mallocs, chunk the command into the
+ # output list if we're sending large values or memoryviews
+ arg_length = len(arg)
+ if (
+ len(buff) > buffer_cutoff
+ or arg_length > buffer_cutoff
+ or isinstance(arg, memoryview)
+ ):
+ buff = SYM_EMPTY.join(
+ (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF)
+ )
+ output.append(buff)
+ output.append(arg)
+ buff = SYM_CRLF
+ else:
+ buff = SYM_EMPTY.join(
+ (
+ buff,
+ SYM_DOLLAR,
+ str(arg_length).encode(),
+ SYM_CRLF,
+ arg,
+ SYM_CRLF,
+ )
+ )
+ output.append(buff)
+ return output
+
+ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]:
+ """Pack multiple commands into the Redis protocol"""
+ output: List[bytes] = []
+ pieces: List[bytes] = []
+ buffer_length = 0
+ buffer_cutoff = self._buffer_cutoff
+
+ for cmd in commands:
+ for chunk in self.pack_command(*cmd):
+ chunklen = len(chunk)
+ if (
+ buffer_length > buffer_cutoff
+ or chunklen > buffer_cutoff
+ or isinstance(chunk, memoryview)
+ ):
+ output.append(SYM_EMPTY.join(pieces))
+ buffer_length = 0
+ pieces = []
+
+ if chunklen > buffer_cutoff or isinstance(chunk, memoryview):
+ output.append(chunk)
+ else:
+ pieces.append(chunk)
+ buffer_length += chunklen
+
+ if pieces:
+ output.append(SYM_EMPTY.join(pieces))
+ return output
+
+
+class SSLConnection(Connection):
+ def __init__(
+ self,
+ ssl_keyfile: Optional[str] = None,
+ ssl_certfile: Optional[str] = None,
+ ssl_cert_reqs: str = "required",
+ ssl_ca_certs: Optional[str] = None,
+ ssl_check_hostname: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.ssl_context: RedisSSLContext = RedisSSLContext(
+ keyfile=ssl_keyfile,
+ certfile=ssl_certfile,
+ cert_reqs=ssl_cert_reqs,
+ ca_certs=ssl_ca_certs,
+ check_hostname=ssl_check_hostname,
+ )
+
+ @property
+ def keyfile(self):
+ return self.ssl_context.keyfile
+
+ @property
+ def certfile(self):
+ return self.ssl_context.certfile
+
+ @property
+ def cert_reqs(self):
+ return self.ssl_context.cert_reqs
+
+ @property
+ def ca_certs(self):
+ return self.ssl_context.ca_certs
+
+ @property
+ def check_hostname(self):
+ return self.ssl_context.check_hostname
+
+
+class RedisSSLContext:
+ __slots__ = (
+ "keyfile",
+ "certfile",
+ "cert_reqs",
+ "ca_certs",
+ "context",
+ "check_hostname",
+ )
+
+ def __init__(
+ self,
+ keyfile: Optional[str] = None,
+ certfile: Optional[str] = None,
+ cert_reqs: Optional[str] = None,
+ ca_certs: Optional[str] = None,
+ check_hostname: bool = False,
+ ):
+ self.keyfile = keyfile
+ self.certfile = certfile
+ if cert_reqs is None:
+ self.cert_reqs = ssl.CERT_NONE
+ elif isinstance(cert_reqs, str):
+ CERT_REQS = {
+ "none": ssl.CERT_NONE,
+ "optional": ssl.CERT_OPTIONAL,
+ "required": ssl.CERT_REQUIRED,
+ }
+ if cert_reqs not in CERT_REQS:
+ raise RedisError(
+ f"Invalid SSL Certificate Requirements Flag: {cert_reqs}"
+ )
+ self.cert_reqs = CERT_REQS[cert_reqs]
+ self.ca_certs = ca_certs
+ self.check_hostname = check_hostname
+ self.context: Optional[ssl.SSLContext] = None
+
+ def get(self) -> ssl.SSLContext:
+ if not self.context:
+ context = ssl.create_default_context()
+ context.check_hostname = self.check_hostname
+ context.verify_mode = self.cert_reqs
+ if self.certfile and self.keyfile:
+ context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
+ if self.ca_certs:
+ context.load_verify_locations(self.ca_certs)
+ self.context = context
+ return self.context
+
+
+class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init]
+ def __init__(
+ self,
+ *,
+ path: str = "",
+ db: Union[str, int] = 0,
+ username: Optional[str] = None,
+ password: Optional[str] = None,
+ socket_timeout: Optional[float] = None,
+ socket_connect_timeout: Optional[float] = None,
+ encoding: str = "utf-8",
+ encoding_errors: str = "strict",
+ decode_responses: bool = False,
+ retry_on_timeout: bool = False,
+ parser_class: Type[BaseParser] = DefaultParser,
+ socket_read_size: int = 65536,
+ health_check_interval: float = 0.0,
+ client_name: str = None,
+ retry: Optional[Retry] = None,
+ ):
+ """
+ Initialize a new UnixDomainSocketConnection.
+ To specify a retry policy, first set `retry_on_timeout` to `True`
+ then set `retry` to a valid `Retry` object
+ """
+ self.pid = os.getpid()
+ self.path = path
+ self.db = db
+ self.username = username
+ self.client_name = client_name
+ self.password = password
+ self.socket_timeout = socket_timeout
+ self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None
+ self.retry_on_timeout = retry_on_timeout
+ if retry_on_timeout:
+ if retry is None:
+ self.retry = Retry(NoBackoff(), 1)
+ else:
+ # deep-copy the Retry object as it is mutable
+ self.retry = copy.deepcopy(retry)
+ else:
+ self.retry = Retry(NoBackoff(), 0)
+ self.health_check_interval = health_check_interval
+ self.next_health_check = -1
+ self.encoder = Encoder(encoding, encoding_errors, decode_responses)
+ self._sock = None
+ self._reader = None
+ self._writer = None
+ self._socket_read_size = socket_read_size
+ self.set_parser(parser_class)
+ self._connect_callbacks = []
+ self._buffer_cutoff = 6000
+ self._lock = asyncio.Lock()
+
+ def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
+ pieces = [
+ ("path", self.path),
+ ("db", self.db),
+ ]
+ if self.client_name:
+ pieces.append(("client_name", self.client_name))
+ return pieces
+
+ async def _connect(self):
+ async with async_timeout.timeout(self.socket_connect_timeout):
+ reader, writer = await asyncio.open_unix_connection(path=self.path)
+ self._reader = reader
+ self._writer = writer
+ await self.on_connect()
+
+ def _error_message(self, exception):
+ # args for socket.error can either be (errno, "message")
+ # or just "message"
+ if len(exception.args) == 1:
+ return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
+ else:
+ return (
+ f"Error {exception.args[0]} connecting to unix socket: "
+ f"{self.path}. {exception.args[1]}."
+ )
+
+
+FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO")
+
+
+def to_bool(value) -> Optional[bool]:
+ if value is None or value == "":
+ return None
+ if isinstance(value, str) and value.upper() in FALSE_STRINGS:
+ return False
+ return bool(value)
+
+
+URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType(
+ {
+ "db": int,
+ "socket_timeout": float,
+ "socket_connect_timeout": float,
+ "socket_keepalive": to_bool,
+ "retry_on_timeout": to_bool,
+ "max_connections": int,
+ "health_check_interval": int,
+ "ssl_check_hostname": to_bool,
+ }
+)
+
+
+class ConnectKwargs(TypedDict, total=False):
+ username: str
+ password: str
+ connection_class: Type[Connection]
+ host: str
+ port: int
+ db: int
+ path: str
+
+
+def parse_url(url: str) -> ConnectKwargs:
+ parsed: ParseResult = urlparse(url)
+ kwargs: ConnectKwargs = {}
+
+ for name, value_list in parse_qs(parsed.query).items():
+ if value_list and len(value_list) > 0:
+ value = unquote(value_list[0])
+ parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
+ if parser:
+ try:
+ # We can't type this.
+ kwargs[name] = parser(value) # type: ignore[misc]
+ except (TypeError, ValueError):
+ raise ValueError(f"Invalid value for `{name}` in connection URL.")
+ else:
+ kwargs[name] = value # type: ignore[misc]
+
+ if parsed.username:
+ kwargs["username"] = unquote(parsed.username)
+ if parsed.password:
+ kwargs["password"] = unquote(parsed.password)
+
+ # We only support redis://, rediss:// and unix:// schemes.
+ if parsed.scheme == "unix":
+ if parsed.path:
+ kwargs["path"] = unquote(parsed.path)
+ kwargs["connection_class"] = UnixDomainSocketConnection
+
+ elif parsed.scheme in ("redis", "rediss"):
+ if parsed.hostname:
+ kwargs["host"] = unquote(parsed.hostname)
+ if parsed.port:
+ kwargs["port"] = int(parsed.port)
+
+ # If there's a path argument, use it as the db argument if a
+ # querystring value wasn't specified
+ if parsed.path and "db" not in kwargs:
+ try:
+ kwargs["db"] = int(unquote(parsed.path).replace("/", ""))
+ except (AttributeError, ValueError):
+ pass
+
+ if parsed.scheme == "rediss":
+ kwargs["connection_class"] = SSLConnection
+ else:
+ valid_schemes = "redis://, rediss://, unix://"
+ raise ValueError(
+ f"Redis URL must specify one of the following schemes ({valid_schemes})"
+ )
+
+ return kwargs
+
+
+_CP = TypeVar("_CP", bound="ConnectionPool")
+
+
+class ConnectionPool:
+ """
+ Create a connection pool. ``If max_connections`` is set, then this
+ object raises :py:class:`~redis.ConnectionError` when the pool's
+ limit is reached.
+
+ By default, TCP connections are created unless ``connection_class``
+ is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
+ unix sockets.
+
+ Any additional keyword arguments are passed to the constructor of
+ ``connection_class``.
+ """
+
+ @classmethod
+ def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP:
+ """
+ Return a connection pool configured from the given URL.
+
+ For example::
+
+ redis://[[username]:[password]]@localhost:6379/0
+ rediss://[[username]:[password]]@localhost:6379/0
+ unix://[[username]:[password]]@/path/to/socket.sock?db=0
+
+ Three URL schemes are supported:
+
+ - `redis://` creates a TCP socket connection. See more at:
+ <https://www.iana.org/assignments/uri-schemes/prov/redis>
+ - `rediss://` creates a SSL wrapped TCP socket connection. See more at:
+ <https://www.iana.org/assignments/uri-schemes/prov/rediss>
+ - ``unix://``: creates a Unix Domain Socket connection.
+
+ The username, password, hostname, path and all querystring values
+ are passed through urllib.parse.unquote in order to replace any
+ percent-encoded values with their corresponding characters.
+
+ There are several ways to specify a database number. The first value
+ found will be used:
+ 1. A ``db`` querystring option, e.g. redis://localhost?db=0
+ 2. If using the redis:// or rediss:// schemes, the path argument
+ of the url, e.g. redis://localhost/0
+ 3. A ``db`` keyword argument to this function.
+
+ If none of these options are specified, the default db=0 is used.
+
+ All querystring options are cast to their appropriate Python types.
+ Boolean arguments can be specified with string values "True"/"False"
+ or "Yes"/"No". Values that cannot be properly cast cause a
+ ``ValueError`` to be raised. Once parsed, the querystring arguments
+ and keyword arguments are passed to the ``ConnectionPool``'s
+ class initializer. In the case of conflicting arguments, querystring
+ arguments always win.
+ """
+ url_options = parse_url(url)
+ kwargs.update(url_options)
+ return cls(**kwargs)
+
+ def __init__(
+ self,
+ connection_class: Type[Connection] = Connection,
+ max_connections: Optional[int] = None,
+ **connection_kwargs,
+ ):
+ max_connections = max_connections or 2 ** 31
+ if not isinstance(max_connections, int) or max_connections < 0:
+ raise ValueError('"max_connections" must be a positive integer')
+
+ self.connection_class = connection_class
+ self.connection_kwargs = connection_kwargs
+ self.max_connections = max_connections
+
+ # a lock to protect the critical section in _checkpid().
+ # this lock is acquired when the process id changes, such as
+ # after a fork. during this time, multiple threads in the child
+ # process could attempt to acquire this lock. the first thread
+ # to acquire the lock will reset the data structures and lock
+ # object of this pool. subsequent threads acquiring this lock
+ # will notice the first thread already did the work and simply
+ # release the lock.
+ self._fork_lock = threading.Lock()
+ self._lock = asyncio.Lock()
+ self._created_connections: int
+ self._available_connections: List[Connection]
+ self._in_use_connections: Set[Connection]
+ self.reset() # lgtm [py/init-calls-subclass]
+ self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}"
+ f"<{self.connection_class(**self.connection_kwargs)!r}>"
+ )
+
+ def reset(self):
+ self._lock = asyncio.Lock()
+ self._created_connections = 0
+ self._available_connections = []
+ self._in_use_connections = set()
+
+ # this must be the last operation in this method. while reset() is
+ # called when holding _fork_lock, other threads in this process
+ # can call _checkpid() which compares self.pid and os.getpid() without
+ # holding any lock (for performance reasons). keeping this assignment
+ # as the last operation ensures that those other threads will also
+ # notice a pid difference and block waiting for the first thread to
+ # release _fork_lock. when each of these threads eventually acquire
+ # _fork_lock, they will notice that another thread already called
+ # reset() and they will immediately release _fork_lock and continue on.
+ self.pid = os.getpid()
+
+ def _checkpid(self):
+ # _checkpid() attempts to keep ConnectionPool fork-safe on modern
+ # systems. this is called by all ConnectionPool methods that
+ # manipulate the pool's state such as get_connection() and release().
+ #
+ # _checkpid() determines whether the process has forked by comparing
+ # the current process id to the process id saved on the ConnectionPool
+ # instance. if these values are the same, _checkpid() simply returns.
+ #
+ # when the process ids differ, _checkpid() assumes that the process
+ # has forked and that we're now running in the child process. the child
+ # process cannot use the parent's file descriptors (e.g., sockets).
+ # therefore, when _checkpid() sees the process id change, it calls
+ # reset() in order to reinitialize the child's ConnectionPool. this
+ # will cause the child to make all new connection objects.
+ #
+ # _checkpid() is protected by self._fork_lock to ensure that multiple
+ # threads in the child process do not call reset() multiple times.
+ #
+ # there is an extremely small chance this could fail in the following
+ # scenario:
+ # 1. process A calls _checkpid() for the first time and acquires
+ # self._fork_lock.
+ # 2. while holding self._fork_lock, process A forks (the fork()
+ # could happen in a different thread owned by process A)
+ # 3. process B (the forked child process) inherits the
+ # ConnectionPool's state from the parent. that state includes
+ # a locked _fork_lock. process B will not be notified when
+ # process A releases the _fork_lock and will thus never be
+ # able to acquire the _fork_lock.
+ #
+ # to mitigate this possible deadlock, _checkpid() will only wait 5
+ # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
+ # that time it is assumed that the child is deadlocked and a
+ # redis.ChildDeadlockedError error is raised.
+ if self.pid != os.getpid():
+ acquired = self._fork_lock.acquire(timeout=5)
+ if not acquired:
+ raise ChildDeadlockedError
+ # reset() the instance for the new process if another thread
+ # hasn't already done so
+ try:
+ if self.pid != os.getpid():
+ self.reset()
+ finally:
+ self._fork_lock.release()
+
+ async def get_connection(self, command_name, *keys, **options):
+ """Get a connection from the pool"""
+ self._checkpid()
+ async with self._lock:
+ try:
+ connection = self._available_connections.pop()
+ except IndexError:
+ connection = self.make_connection()
+ self._in_use_connections.add(connection)
+
+ try:
+ # ensure this connection is connected to Redis
+ await connection.connect()
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
+ # pool before all data has been read or the socket has been
+ # closed. either way, reconnect and verify everything is good.
+ try:
+ if await connection.can_read():
+ raise ConnectionError("Connection has data") from None
+ except ConnectionError:
+ await connection.disconnect()
+ await connection.connect()
+ if await connection.can_read():
+ raise ConnectionError("Connection not ready") from None
+ except BaseException:
+ # release the connection back to the pool so that we don't
+ # leak it
+ await self.release(connection)
+ raise
+
+ return connection
+
+ def get_encoder(self):
+ """Return an encoder based on encoding settings"""
+ kwargs = self.connection_kwargs
+ return self.encoder_class(
+ encoding=kwargs.get("encoding", "utf-8"),
+ encoding_errors=kwargs.get("encoding_errors", "strict"),
+ decode_responses=kwargs.get("decode_responses", False),
+ )
+
+ def make_connection(self):
+ """Create a new connection"""
+ if self._created_connections >= self.max_connections:
+ raise ConnectionError("Too many connections")
+ self._created_connections += 1
+ return self.connection_class(**self.connection_kwargs)
+
+ async def release(self, connection: Connection):
+ """Releases the connection back to the pool"""
+ self._checkpid()
+ async with self._lock:
+ try:
+ self._in_use_connections.remove(connection)
+ except KeyError:
+ # Gracefully fail when a connection is returned to this pool
+ # that the pool doesn't actually own
+ pass
+
+ if self.owns_connection(connection):
+ self._available_connections.append(connection)
+ else:
+ # pool doesn't own this connection. do not add it back
+ # to the pool and decrement the count so that another
+ # connection can take its place if needed
+ self._created_connections -= 1
+ await connection.disconnect()
+ return
+
+ def owns_connection(self, connection: Connection):
+ return connection.pid == self.pid
+
+ async def disconnect(self, inuse_connections: bool = True):
+ """
+ Disconnects connections in the pool
+
+ If ``inuse_connections`` is True, disconnect connections that are
+ current in use, potentially by other tasks. Otherwise only disconnect
+ connections that are idle in the pool.
+ """
+ self._checkpid()
+ async with self._lock:
+ if inuse_connections:
+ connections: Iterable[Connection] = chain(
+ self._available_connections, self._in_use_connections
+ )
+ else:
+ connections = self._available_connections
+ resp = await asyncio.gather(
+ *(connection.disconnect() for connection in connections),
+ return_exceptions=True,
+ )
+ exc = next((r for r in resp if isinstance(r, BaseException)), None)
+ if exc:
+ raise exc
+
+
+class BlockingConnectionPool(ConnectionPool):
+ """
+ Thread-safe blocking connection pool::
+
+ >>> from redis.client import Redis
+ >>> client = Redis(connection_pool=BlockingConnectionPool())
+
+ It performs the same function as the default
+ :py:class:`~redis.ConnectionPool` implementation, in that,
+ it maintains a pool of reusable connections that can be shared by
+ multiple redis clients (safely across threads if required).
+
+ The difference is that, in the event that a client tries to get a
+ connection from the pool when all of connections are in use, rather than
+ raising a :py:class:`~redis.ConnectionError` (as the default
+ :py:class:`~redis.ConnectionPool` implementation does), it
+ makes the client wait ("blocks") for a specified number of seconds until
+ a connection becomes available.
+
+ Use ``max_connections`` to increase / decrease the pool size::
+
+ >>> pool = BlockingConnectionPool(max_connections=10)
+
+ Use ``timeout`` to tell it either how many seconds to wait for a connection
+ to become available, or to block forever:
+
+ >>> # Block forever.
+ >>> pool = BlockingConnectionPool(timeout=None)
+
+ >>> # Raise a ``ConnectionError`` after five seconds if a connection is
+ >>> # not available.
+ >>> pool = BlockingConnectionPool(timeout=5)
+ """
+
+ def __init__(
+ self,
+ max_connections: int = 50,
+ timeout: Optional[int] = 20,
+ connection_class: Type[Connection] = Connection,
+ queue_class: Type[asyncio.Queue] = asyncio.LifoQueue,
+ **connection_kwargs,
+ ):
+
+ self.queue_class = queue_class
+ self.timeout = timeout
+ self._connections: List[Connection]
+ super().__init__(
+ connection_class=connection_class,
+ max_connections=max_connections,
+ **connection_kwargs,
+ )
+
+ def reset(self):
+ # Create and fill up a thread safe queue with ``None`` values.
+ self.pool = self.queue_class(self.max_connections)
+ while True:
+ try:
+ self.pool.put_nowait(None)
+ except asyncio.QueueFull:
+ break
+
+ # Keep a list of actual connection instances so that we can
+ # disconnect them later.
+ self._connections = []
+
+ # this must be the last operation in this method. while reset() is
+ # called when holding _fork_lock, other threads in this process
+ # can call _checkpid() which compares self.pid and os.getpid() without
+ # holding any lock (for performance reasons). keeping this assignment
+ # as the last operation ensures that those other threads will also
+ # notice a pid difference and block waiting for the first thread to
+ # release _fork_lock. when each of these threads eventually acquire
+ # _fork_lock, they will notice that another thread already called
+ # reset() and they will immediately release _fork_lock and continue on.
+ self.pid = os.getpid()
+
+ def make_connection(self):
+ """Make a fresh connection."""
+ connection = self.connection_class(**self.connection_kwargs)
+ self._connections.append(connection)
+ return connection
+
+ async def get_connection(self, command_name, *keys, **options):
+ """
+ Get a connection, blocking for ``self.timeout`` until a connection
+ is available from the pool.
+
+ If the connection returned is ``None`` then creates a new connection.
+ Because we use a last-in first-out queue, the existing connections
+ (having been returned to the pool after the initial ``None`` values
+ were added) will be returned before ``None`` values. This means we only
+ create new connections when we need to, i.e.: the actual number of
+ connections will only increase in response to demand.
+ """
+ # Make sure we haven't changed process.
+ self._checkpid()
+
+ # Try and get a connection from the pool. If one isn't available within
+ # self.timeout then raise a ``ConnectionError``.
+ connection = None
+ try:
+ async with async_timeout.timeout(self.timeout):
+ connection = await self.pool.get()
+ except (asyncio.QueueEmpty, asyncio.TimeoutError):
+ # Note that this is not caught by the redis client and will be
+ # raised unless handled by application code. If you want never to
+ raise ConnectionError("No connection available.")
+
+ # If the ``connection`` is actually ``None`` then that's a cue to make
+ # a new connection to add to the pool.
+ if connection is None:
+ connection = self.make_connection()
+
+ try:
+ # ensure this connection is connected to Redis
+ await connection.connect()
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
+ # pool before all data has been read or the socket has been
+ # closed. either way, reconnect and verify everything is good.
+ try:
+ if await connection.can_read():
+ raise ConnectionError("Connection has data") from None
+ except ConnectionError:
+ await connection.disconnect()
+ await connection.connect()
+ if await connection.can_read():
+ raise ConnectionError("Connection not ready") from None
+ except BaseException:
+ # release the connection back to the pool so that we don't leak it
+ await self.release(connection)
+ raise
+
+ return connection
+
+ async def release(self, connection: Connection):
+ """Releases the connection back to the pool."""
+ # Make sure we haven't changed process.
+ self._checkpid()
+ if not self.owns_connection(connection):
+ # pool doesn't own this connection. do not add it back
+ # to the pool. instead add a None value which is a placeholder
+ # that will cause the pool to recreate the connection if
+ # its needed.
+ await connection.disconnect()
+ self.pool.put_nowait(None)
+ return
+
+ # Put the connection back into the pool.
+ try:
+ self.pool.put_nowait(connection)
+ except asyncio.QueueFull:
+ # perhaps the pool has been reset() after a fork? regardless,
+ # we don't want this connection
+ pass
+
+ async def disconnect(self, inuse_connections: bool = True):
+ """Disconnects all connections in the pool."""
+ self._checkpid()
+ async with self._lock:
+ resp = await asyncio.gather(
+ *(connection.disconnect() for connection in self._connections),
+ return_exceptions=True,
+ )
+ exc = next((r for r in resp if isinstance(r, BaseException)), None)
+ if exc:
+ raise exc