diff options
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 392 |
1 files changed, 26 insertions, 366 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 057067a..d9c9583 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -38,26 +38,23 @@ from redis.credentials import CredentialProvider, UsernamePasswordCredentialProv from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.typing import EncodableT, EncodedT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -hiredis = None -if HIREDIS_AVAILABLE: - import hiredis +from ..parsers import ( + BaseParser, + Encoder, + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, +) SYM_STAR = b"*" SYM_DOLLAR = b"$" @@ -65,371 +62,19 @@ SYM_CRLF = b"\r\n" SYM_LF = b"\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - class _Sentinel(enum.Enum): sentinel = object() SENTINEL = _Sentinel.sentinel -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH <password> called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class _HiredisReaderArgs(TypedDict, total=False): - protocolError: Callable[[str], Exception] - replyError: Callable[[str], Exception] - encoding: Optional[str] - errors: Optional[str] - - -class Encoder: - """Encode strings to bytes-like and decode bytes-like to strings""" - - __slots__ = "encoding", "encoding_errors", "decode_responses" - - def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value: EncodableT) -> EncodedT: - """Return a bytestring or bytes-like representation of the value""" - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - if isinstance(value, (bytes, memoryview)): - return value - if isinstance(value, (int, float)): - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) - return repr(value).encode() - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - - def decode(self, value: EncodableT, force=False) -> EncodableT: - """Return a unicode string from the bytes-like representation""" - if self.decode_responses or force: - if isinstance(value, bytes): - return value.decode(self.encoding, self.encoding_errors) - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) - return value - - -ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] - - -class BaseParser: - """Plain Python parsing class""" - - __slots__ = "_stream", "_read_size", "_connected" - - EXCEPTION_CLASSES: ExceptionMappingT = { - "ERR": { - "max number of clients reached": ConnectionError, - "Client sent AUTH, but no password is set": AuthenticationError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def __init__(self, socket_read_size: int): - self._stream: Optional[asyncio.StreamReader] = None - self._read_size = socket_read_size - self._connected = False - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def parse_error(self, response: str) -> ResponseError: - """Parse an error response""" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - def on_disconnect(self): - raise NotImplementedError() - - def on_connect(self, connection: "Connection"): - raise NotImplementedError() - - async def can_read_destructive(self) -> bool: - raise NotImplementedError() - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: - raise NotImplementedError() - - -class PythonParser(BaseParser): - """Plain Python parsing class""" - - __slots__ = ("encoder", "_buffer", "_pos", "_chunks") - - def __init__(self, socket_read_size: int): - super().__init__(socket_read_size) - self.encoder: Optional[Encoder] = None - self._buffer = b"" - self._chunks = [] - self._pos = 0 - - def _clear(self): - self._buffer = b"" - self._chunks.clear() - - def on_connect(self, connection: "Connection"): - """Called when the stream connects""" - self._stream = connection._reader - if self._stream is None: - raise RedisError("Buffer is closed.") - self.encoder = connection.encoder - self._clear() - self._connected = True - - def on_disconnect(self): - """Called when the stream disconnects""" - self._connected = False - - async def can_read_destructive(self) -> bool: - if not self._connected: - raise RedisError("Buffer is closed.") - if self._buffer: - return True - try: - async with async_timeout(0): - return await self._stream.read(1) - except asyncio.TimeoutError: - return False - - async def read_response(self, disable_decoding: bool = False): - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._chunks: - # augment parsing buffer with previously read data - self._buffer += b"".join(self._chunks) - self._chunks.clear() - self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - # Successfully parsing a response allows us to clear our parsing buffer - self._clear() - return response - async def _read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None]: - raw = await self._readline() - response: Any - byte, response = raw[:1], raw[1:] - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - self._clear() # Successful parse - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - return int(response) - # bulk response - elif byte == b"$" and response == b"-1": - return None - elif byte == b"$": - response = await self._read(int(response)) - # multi-bulk response - elif byte == b"*" and response == b"-1": - return None - elif byte == b"*": - response = [ - (await self._read_response(disable_decoding)) - for _ in range(int(response)) # noqa - ] - else: - raise InvalidResponse(f"Protocol Error: {raw!r}") - - if disable_decoding is False: - response = self.encoder.decode(response) - return response - - async def _read(self, length: int) -> bytes: - """ - Read `length` bytes of data. These are assumed to be followed - by a '\r\n' terminator which is subsequently discarded. - """ - want = length + 2 - end = self._pos + want - if len(self._buffer) >= end: - result = self._buffer[self._pos : end - 2] - else: - tail = self._buffer[self._pos :] - try: - data = await self._stream.readexactly(want - len(tail)) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += want - return result - - async def _readline(self) -> bytes: - """ - read an unknown number of bytes up to the next '\r\n' - line separator, which is discarded. - """ - found = self._buffer.find(b"\r\n", self._pos) - if found >= 0: - result = self._buffer[self._pos : found] - else: - tail = self._buffer[self._pos :] - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += len(result) + 2 - return result - - -class HiredisParser(BaseParser): - """Parser class for connections using Hiredis""" - - __slots__ = ("_reader",) - - def __init__(self, socket_read_size: int): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not available.") - super().__init__(socket_read_size=socket_read_size) - self._reader: Optional[hiredis.Reader] = None - - def on_connect(self, connection: "Connection"): - self._stream = connection._reader - kwargs: _HiredisReaderArgs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - } - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - kwargs["errors"] = connection.encoder.encoding_errors - - self._reader = hiredis.Reader(**kwargs) - self._connected = True - - def on_disconnect(self): - self._connected = False - async def can_read_destructive(self): - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._reader.gets(): - return True - try: - async with async_timeout(0): - return await self.read_from_socket() - except asyncio.TimeoutError: - return False - - async def read_from_socket(self): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, List[EncodableT]]: - # If `on_disconnect()` has been called, prohibit any more reads - # even if they could happen because data might be present. - # We still allow reads in progress to finish - if not self._connected: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - - response = self._reader.gets() - while response is False: - await self.read_from_socket() - response = self._reader.gets() - - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - - -DefaultParser: Type[Union[PythonParser, HiredisParser]] +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _AsyncHiredisParser else: - DefaultParser = PythonParser + DefaultParser = _AsyncRESP2Parser class ConnectCallbackProtocol(Protocol): @@ -470,6 +115,7 @@ class Connection: "last_active_at", "encoder", "ssl_context", + "protocol", "_reader", "_writer", "_parser", @@ -506,6 +152,7 @@ class Connection: redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): if (username or password) and credential_provider is not None: raise DataError( @@ -556,6 +203,7 @@ class Connection: self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 + self.protocol = protocol def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) @@ -710,6 +358,18 @@ class Connection: if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + if self.protocol != 2: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol) + response = await self.read_response() + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): + raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: await self.send_command("CLIENT", "SETNAME", self.client_name) |