summaryrefslogtreecommitdiff
path: root/redis/asyncio/connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r--redis/asyncio/connection.py392
1 files changed, 26 insertions, 366 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 057067a..d9c9583 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -38,26 +38,23 @@ from redis.credentials import CredentialProvider, UsernamePasswordCredentialProv
from redis.exceptions import (
AuthenticationError,
AuthenticationWrongNumberOfArgsError,
- BusyLoadingError,
ChildDeadlockedError,
ConnectionError,
DataError,
- ExecAbortError,
- InvalidResponse,
- ModuleError,
- NoPermissionError,
- NoScriptError,
- ReadOnlyError,
RedisError,
ResponseError,
TimeoutError,
)
-from redis.typing import EncodableT, EncodedT
+from redis.typing import EncodableT
from redis.utils import HIREDIS_AVAILABLE, str_if_bytes
-hiredis = None
-if HIREDIS_AVAILABLE:
- import hiredis
+from ..parsers import (
+ BaseParser,
+ Encoder,
+ _AsyncHiredisParser,
+ _AsyncRESP2Parser,
+ _AsyncRESP3Parser,
+)
SYM_STAR = b"*"
SYM_DOLLAR = b"$"
@@ -65,371 +62,19 @@ SYM_CRLF = b"\r\n"
SYM_LF = b"\n"
SYM_EMPTY = b""
-SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
-
class _Sentinel(enum.Enum):
sentinel = object()
SENTINEL = _Sentinel.sentinel
-MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs."
-NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name"
-MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible."
-MODULE_EXPORTS_DATA_TYPES_ERROR = (
- "Error unloading module: the module "
- "exports one or more module-side data "
- "types, can't unload"
-)
-# user send an AUTH cmd to a server without authorization configured
-NO_AUTH_SET_ERROR = {
- # Redis >= 6.0
- "AUTH <password> called without any password "
- "configured for the default user. Are you sure "
- "your configuration is correct?": AuthenticationError,
- # Redis < 6.0
- "Client sent AUTH, but no password is set": AuthenticationError,
-}
-
-
-class _HiredisReaderArgs(TypedDict, total=False):
- protocolError: Callable[[str], Exception]
- replyError: Callable[[str], Exception]
- encoding: Optional[str]
- errors: Optional[str]
-
-
-class Encoder:
- """Encode strings to bytes-like and decode bytes-like to strings"""
-
- __slots__ = "encoding", "encoding_errors", "decode_responses"
-
- def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool):
- self.encoding = encoding
- self.encoding_errors = encoding_errors
- self.decode_responses = decode_responses
-
- def encode(self, value: EncodableT) -> EncodedT:
- """Return a bytestring or bytes-like representation of the value"""
- if isinstance(value, str):
- return value.encode(self.encoding, self.encoding_errors)
- if isinstance(value, (bytes, memoryview)):
- return value
- if isinstance(value, (int, float)):
- if isinstance(value, bool):
- # special case bool since it is a subclass of int
- raise DataError(
- "Invalid input of type: 'bool'. "
- "Convert to a bytes, string, int or float first."
- )
- return repr(value).encode()
- # a value we don't know how to deal with. throw an error
- typename = value.__class__.__name__
- raise DataError(
- f"Invalid input of type: {typename!r}. "
- "Convert to a bytes, string, int or float first."
- )
-
- def decode(self, value: EncodableT, force=False) -> EncodableT:
- """Return a unicode string from the bytes-like representation"""
- if self.decode_responses or force:
- if isinstance(value, bytes):
- return value.decode(self.encoding, self.encoding_errors)
- if isinstance(value, memoryview):
- return value.tobytes().decode(self.encoding, self.encoding_errors)
- return value
-
-
-ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]]
-
-
-class BaseParser:
- """Plain Python parsing class"""
-
- __slots__ = "_stream", "_read_size", "_connected"
-
- EXCEPTION_CLASSES: ExceptionMappingT = {
- "ERR": {
- "max number of clients reached": ConnectionError,
- "Client sent AUTH, but no password is set": AuthenticationError,
- "invalid password": AuthenticationError,
- # some Redis server versions report invalid command syntax
- # in lowercase
- "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501
- # some Redis server versions report invalid command syntax
- # in uppercase
- "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501
- MODULE_LOAD_ERROR: ModuleError,
- MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError,
- NO_SUCH_MODULE_ERROR: ModuleError,
- MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError,
- **NO_AUTH_SET_ERROR,
- },
- "WRONGPASS": AuthenticationError,
- "EXECABORT": ExecAbortError,
- "LOADING": BusyLoadingError,
- "NOSCRIPT": NoScriptError,
- "READONLY": ReadOnlyError,
- "NOAUTH": AuthenticationError,
- "NOPERM": NoPermissionError,
- }
-
- def __init__(self, socket_read_size: int):
- self._stream: Optional[asyncio.StreamReader] = None
- self._read_size = socket_read_size
- self._connected = False
-
- def __del__(self):
- try:
- self.on_disconnect()
- except Exception:
- pass
-
- def parse_error(self, response: str) -> ResponseError:
- """Parse an error response"""
- error_code = response.split(" ")[0]
- if error_code in self.EXCEPTION_CLASSES:
- response = response[len(error_code) + 1 :]
- exception_class = self.EXCEPTION_CLASSES[error_code]
- if isinstance(exception_class, dict):
- exception_class = exception_class.get(response, ResponseError)
- return exception_class(response)
- return ResponseError(response)
-
- def on_disconnect(self):
- raise NotImplementedError()
-
- def on_connect(self, connection: "Connection"):
- raise NotImplementedError()
-
- async def can_read_destructive(self) -> bool:
- raise NotImplementedError()
-
- async def read_response(
- self, disable_decoding: bool = False
- ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]:
- raise NotImplementedError()
-
-
-class PythonParser(BaseParser):
- """Plain Python parsing class"""
-
- __slots__ = ("encoder", "_buffer", "_pos", "_chunks")
-
- def __init__(self, socket_read_size: int):
- super().__init__(socket_read_size)
- self.encoder: Optional[Encoder] = None
- self._buffer = b""
- self._chunks = []
- self._pos = 0
-
- def _clear(self):
- self._buffer = b""
- self._chunks.clear()
-
- def on_connect(self, connection: "Connection"):
- """Called when the stream connects"""
- self._stream = connection._reader
- if self._stream is None:
- raise RedisError("Buffer is closed.")
- self.encoder = connection.encoder
- self._clear()
- self._connected = True
-
- def on_disconnect(self):
- """Called when the stream disconnects"""
- self._connected = False
-
- async def can_read_destructive(self) -> bool:
- if not self._connected:
- raise RedisError("Buffer is closed.")
- if self._buffer:
- return True
- try:
- async with async_timeout(0):
- return await self._stream.read(1)
- except asyncio.TimeoutError:
- return False
-
- async def read_response(self, disable_decoding: bool = False):
- if not self._connected:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- if self._chunks:
- # augment parsing buffer with previously read data
- self._buffer += b"".join(self._chunks)
- self._chunks.clear()
- self._pos = 0
- response = await self._read_response(disable_decoding=disable_decoding)
- # Successfully parsing a response allows us to clear our parsing buffer
- self._clear()
- return response
- async def _read_response(
- self, disable_decoding: bool = False
- ) -> Union[EncodableT, ResponseError, None]:
- raw = await self._readline()
- response: Any
- byte, response = raw[:1], raw[1:]
-
- # server returned an error
- if byte == b"-":
- response = response.decode("utf-8", errors="replace")
- error = self.parse_error(response)
- # if the error is a ConnectionError, raise immediately so the user
- # is notified
- if isinstance(error, ConnectionError):
- self._clear() # Successful parse
- raise error
- # otherwise, we're dealing with a ResponseError that might belong
- # inside a pipeline response. the connection's read_response()
- # and/or the pipeline's execute() will raise this error if
- # necessary, so just return the exception instance here.
- return error
- # single value
- elif byte == b"+":
- pass
- # int value
- elif byte == b":":
- return int(response)
- # bulk response
- elif byte == b"$" and response == b"-1":
- return None
- elif byte == b"$":
- response = await self._read(int(response))
- # multi-bulk response
- elif byte == b"*" and response == b"-1":
- return None
- elif byte == b"*":
- response = [
- (await self._read_response(disable_decoding))
- for _ in range(int(response)) # noqa
- ]
- else:
- raise InvalidResponse(f"Protocol Error: {raw!r}")
-
- if disable_decoding is False:
- response = self.encoder.decode(response)
- return response
-
- async def _read(self, length: int) -> bytes:
- """
- Read `length` bytes of data. These are assumed to be followed
- by a '\r\n' terminator which is subsequently discarded.
- """
- want = length + 2
- end = self._pos + want
- if len(self._buffer) >= end:
- result = self._buffer[self._pos : end - 2]
- else:
- tail = self._buffer[self._pos :]
- try:
- data = await self._stream.readexactly(want - len(tail))
- except asyncio.IncompleteReadError as error:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
- result = (tail + data)[:-2]
- self._chunks.append(data)
- self._pos += want
- return result
-
- async def _readline(self) -> bytes:
- """
- read an unknown number of bytes up to the next '\r\n'
- line separator, which is discarded.
- """
- found = self._buffer.find(b"\r\n", self._pos)
- if found >= 0:
- result = self._buffer[self._pos : found]
- else:
- tail = self._buffer[self._pos :]
- data = await self._stream.readline()
- if not data.endswith(b"\r\n"):
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- result = (tail + data)[:-2]
- self._chunks.append(data)
- self._pos += len(result) + 2
- return result
-
-
-class HiredisParser(BaseParser):
- """Parser class for connections using Hiredis"""
-
- __slots__ = ("_reader",)
-
- def __init__(self, socket_read_size: int):
- if not HIREDIS_AVAILABLE:
- raise RedisError("Hiredis is not available.")
- super().__init__(socket_read_size=socket_read_size)
- self._reader: Optional[hiredis.Reader] = None
-
- def on_connect(self, connection: "Connection"):
- self._stream = connection._reader
- kwargs: _HiredisReaderArgs = {
- "protocolError": InvalidResponse,
- "replyError": self.parse_error,
- }
- if connection.encoder.decode_responses:
- kwargs["encoding"] = connection.encoder.encoding
- kwargs["errors"] = connection.encoder.encoding_errors
-
- self._reader = hiredis.Reader(**kwargs)
- self._connected = True
-
- def on_disconnect(self):
- self._connected = False
- async def can_read_destructive(self):
- if not self._connected:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
- if self._reader.gets():
- return True
- try:
- async with async_timeout(0):
- return await self.read_from_socket()
- except asyncio.TimeoutError:
- return False
-
- async def read_from_socket(self):
- buffer = await self._stream.read(self._read_size)
- if not buffer or not isinstance(buffer, bytes):
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
- self._reader.feed(buffer)
- # data was read from the socket and added to the buffer.
- # return True to indicate that data was read.
- return True
-
- async def read_response(
- self, disable_decoding: bool = False
- ) -> Union[EncodableT, List[EncodableT]]:
- # If `on_disconnect()` has been called, prohibit any more reads
- # even if they could happen because data might be present.
- # We still allow reads in progress to finish
- if not self._connected:
- raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
-
- response = self._reader.gets()
- while response is False:
- await self.read_from_socket()
- response = self._reader.gets()
-
- # if the response is a ConnectionError or the response is a list and
- # the first item is a ConnectionError, raise it as something bad
- # happened
- if isinstance(response, ConnectionError):
- raise response
- elif (
- isinstance(response, list)
- and response
- and isinstance(response[0], ConnectionError)
- ):
- raise response[0]
- return response
-
-
-DefaultParser: Type[Union[PythonParser, HiredisParser]]
+DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]]
if HIREDIS_AVAILABLE:
- DefaultParser = HiredisParser
+ DefaultParser = _AsyncHiredisParser
else:
- DefaultParser = PythonParser
+ DefaultParser = _AsyncRESP2Parser
class ConnectCallbackProtocol(Protocol):
@@ -470,6 +115,7 @@ class Connection:
"last_active_at",
"encoder",
"ssl_context",
+ "protocol",
"_reader",
"_writer",
"_parser",
@@ -506,6 +152,7 @@ class Connection:
redis_connect_func: Optional[ConnectCallbackT] = None,
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
+ protocol: Optional[int] = 2,
):
if (username or password) and credential_provider is not None:
raise DataError(
@@ -556,6 +203,7 @@ class Connection:
self.set_parser(parser_class)
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
self._buffer_cutoff = 6000
+ self.protocol = protocol
def __repr__(self):
repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces()))
@@ -710,6 +358,18 @@ class Connection:
if str_if_bytes(auth_response) != "OK":
raise AuthenticationError("Invalid Username or Password")
+ # if resp version is specified, switch to it
+ if self.protocol != 2:
+ if isinstance(self._parser, _AsyncRESP2Parser):
+ self.set_parser(_AsyncRESP3Parser)
+ self._parser.on_connect(self)
+ await self.send_command("HELLO", self.protocol)
+ response = await self.read_response()
+ if response.get(b"proto") != int(self.protocol) and response.get(
+ "proto"
+ ) != int(self.protocol):
+ raise ConnectionError("Invalid RESP version")
+
# if a client_name is given, set it
if self.client_name:
await self.send_command("CLIENT", "SETNAME", self.client_name)