diff options
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 140 |
1 files changed, 83 insertions, 57 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 91961ba..9de2d46 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -24,7 +24,6 @@ from typing import ( Type, TypeVar, Union, - cast, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse @@ -110,32 +109,32 @@ class Encoder: def encode(self, value: EncodableT) -> EncodedT: """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, str): + return value.encode(self.encoding, self.encoding_errors) if isinstance(value, (bytes, memoryview)): return value - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) if isinstance(value, (int, float)): + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." + ) return repr(value).encode() - if not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ # type: ignore[unreachable] - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - return value.encode(self.encoding, self.encoding_errors) + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) def decode(self, value: EncodableT, force=False) -> EncodableT: """Return a unicode string from the bytes-like representation""" if self.decode_responses or force: - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) if isinstance(value, bytes): return value.decode(self.encoding, self.encoding_errors) + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) return value @@ -336,7 +335,7 @@ class SocketBuffer: def close(self): try: self.purge() - self._buffer.close() # type: ignore[union-attr] + self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -466,7 +465,7 @@ class HiredisParser(BaseParser): self._next_response = False async def can_read(self, timeout: float): - if not self._reader: + if not self._stream or not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._next_response is False: @@ -480,14 +479,14 @@ class HiredisParser(BaseParser): timeout: Union[float, None, _Sentinel] = SENTINEL, raise_on_timeout: bool = True, ): - if self._stream is None or self._reader is None: - raise RedisError("Parser already closed.") - timeout = self._socket_timeout if timeout is SENTINEL else timeout try: - async with async_timeout.timeout(timeout): + if timeout is None: buffer = await self._stream.read(self._read_size) - if not isinstance(buffer, bytes) or len(buffer) == 0: + else: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None self._reader.feed(buffer) # data was read from the socket and added to the buffer. @@ -516,9 +515,6 @@ class HiredisParser(BaseParser): self.on_disconnect() raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - response: Union[ - EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]] - ] # _next_response might be cached from a can_read() call if self._next_response is not False: response = self._next_response @@ -541,8 +537,7 @@ class HiredisParser(BaseParser): and isinstance(response[0], ConnectionError) ): raise response[0] - # cast as there won't be a ConnectionError here. - return cast(Union[EncodableT, List[EncodableT]], response) + return response DefaultParser: Type[Union[PythonParser, HiredisParser]] @@ -637,7 +632,7 @@ class Connection: self.socket_type = socket_type self.retry_on_timeout = retry_on_timeout if retry_on_timeout: - if retry is None: + if not retry: self.retry = Retry(NoBackoff(), 1) else: # deep-copy the Retry object as it is mutable @@ -681,7 +676,7 @@ class Connection: @property def is_connected(self): - return bool(self._reader and self._writer) + return self._reader and self._writer def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -713,7 +708,7 @@ class Connection: raise ConnectionError(exc) from exc try: - if self.redis_connect_func is None: + if not self.redis_connect_func: # Use the default on_connect function await self.on_connect() else: @@ -745,7 +740,7 @@ class Connection: self._reader = reader self._writer = writer sock = writer.transport.get_extra_info("socket") - if sock is not None: + if sock: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: # TCP_KEEPALIVE @@ -856,32 +851,29 @@ class Connection: await self.retry.call_with_retry(self._send_ping, self._ping_failed) async def _send_packed_command(self, command: Iterable[bytes]) -> None: - if self._writer is None: - raise RedisError("Connection already closed.") - self._writer.writelines(command) await self._writer.drain() async def send_packed_command( - self, - command: Union[bytes, str, Iterable[bytes]], - check_health: bool = True, - ): - """Send an already packed command to the Redis server""" - if not self._writer: + self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True + ) -> None: + if not self.is_connected: await self.connect() - # guard against health check recursion - if check_health: + elif check_health: await self.check_health() + try: if isinstance(command, str): command = command.encode() if isinstance(command, bytes): command = [command] - await asyncio.wait_for( - self._send_packed_command(command), - self.socket_timeout, - ) + if self.socket_timeout: + await asyncio.wait_for( + self._send_packed_command(command), self.socket_timeout + ) + else: + self._writer.writelines(command) + await self._writer.drain() except asyncio.TimeoutError: await self.disconnect() raise TimeoutError("Timeout writing to socket") from None @@ -901,8 +893,6 @@ class Connection: async def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" - if not self.is_connected: - await self.connect() await self.send_packed_command( self.pack_command(*args), check_health=kwargs.get("check_health", True) ) @@ -923,10 +913,50 @@ class Connection: """Read the response from a previously sent command""" try: async with self._lock: + if self.socket_timeout: + async with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except OSError as e: + await self.disconnect() + raise ConnectionError( + f"Error while reading from {self.host}:{self.port} : {e.args}" + ) + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + if sys.version_info[0:2] == (3, 6): + func = asyncio.get_event_loop + else: + func = asyncio.get_running_loop + self.next_health_check = func().time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + + async def read_response_without_lock(self, disable_decoding: bool = False): + """Read the response from a previously sent command""" + try: + if self.socket_timeout: async with async_timeout.timeout(self.socket_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding ) + else: + response = await self._parser.read_response( + disable_decoding=disable_decoding + ) except asyncio.TimeoutError: await self.disconnect() raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") @@ -1182,10 +1212,7 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] self._lock = asyncio.Lock() def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: - pieces = [ - ("path", self.path), - ("db", self.db), - ] + pieces = [("path", self.path), ("db", self.db)] if self.client_name: pieces.append(("client_name", self.client_name)) return pieces @@ -1254,12 +1281,11 @@ def parse_url(url: str) -> ConnectKwargs: parser = URL_QUERY_ARGUMENT_PARSERS.get(name) if parser: try: - # We can't type this. - kwargs[name] = parser(value) # type: ignore[misc] + kwargs[name] = parser(value) except (TypeError, ValueError): raise ValueError(f"Invalid value for `{name}` in connection URL.") else: - kwargs[name] = value # type: ignore[misc] + kwargs[name] = value if parsed.username: kwargs["username"] = unquote(parsed.username) |