diff options
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index a470b6f..64848f4 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -836,7 +836,7 @@ class Connection: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") - async def disconnect(self) -> None: + async def disconnect(self, nowait: bool = False) -> None: """Disconnects from the Redis server""" try: async with async_timeout.timeout(self.socket_connect_timeout): @@ -846,8 +846,9 @@ class Connection: try: if os.getpid() == self.pid: self._writer.close() # type: ignore[union-attr] - # py3.6 doesn't have this method - if hasattr(self._writer, "wait_closed"): + # wait for close to finish, except when handling errors and + # forcefully disconnecting. + if not nowait: await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass @@ -902,10 +903,10 @@ class Connection: self._writer.writelines(command) await self._writer.drain() except asyncio.TimeoutError: - await self.disconnect() + await self.disconnect(nowait=True) raise TimeoutError("Timeout writing to socket") from None except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) if len(e.args) == 1: err_no, errmsg = "UNKNOWN", e.args[0] else: @@ -915,7 +916,7 @@ class Connection: f"Error {err_no} while writing to socket. {errmsg}." ) from e except Exception: - await self.disconnect() + await self.disconnect(nowait=True) raise async def send_command(self, *args: Any, **kwargs: Any) -> None: @@ -931,7 +932,7 @@ class Connection: try: return await self._parser.can_read(timeout) except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port}: {e.args}" ) @@ -949,15 +950,15 @@ class Connection: disable_decoding=disable_decoding ) except asyncio.TimeoutError: - await self.disconnect() + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) except Exception: - await self.disconnect() + await self.disconnect(nowait=True) raise if self.health_check_interval: |