diff options
author | Kristján Valur Jónsson <sweskman@gmail.com> | 2022-09-29 11:23:07 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-29 14:23:07 +0300 |
commit | 652ca790b0bbaefa78278ddf57074316a3fb21bd (patch) | |
tree | 6c8b6a942851b79ab1f184b27a5e9a537d81807a | |
parent | 9fe836698ca5930e39633b86839b6c1bae07237e (diff) | |
download | redis-py-652ca790b0bbaefa78278ddf57074316a3fb21bd.tar.gz |
Add `nowait` flag to `asyncio.Connection.disconnect()` (#2356)
* Don't wait for disconnect() when handling errors.
This can result in other errors such as timeouts.
* add CHANGES
* Update redis/asyncio/connection.py
Co-authored-by: Aarni Koskela <akx@iki.fi>
* await a task to try to diagnose unittest failures in CI
Co-authored-by: Aarni Koskela <akx@iki.fi>
-rw-r--r-- | CHANGES | 1 | ||||
-rw-r--r-- | redis/asyncio/connection.py | 21 | ||||
-rw-r--r-- | tests/test_asyncio/test_pubsub.py | 73 |
3 files changed, 55 insertions, 40 deletions
@@ -1,3 +1,4 @@ + * add `nowait` flag to `asyncio.Connection.disconnect()` * Update README.md links * Fix timezone handling for datetime to unixtime conversions * Fix start_id type for XAUTOCLAIM diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index a470b6f..64848f4 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -836,7 +836,7 @@ class Connection: if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Invalid Database") - async def disconnect(self) -> None: + async def disconnect(self, nowait: bool = False) -> None: """Disconnects from the Redis server""" try: async with async_timeout.timeout(self.socket_connect_timeout): @@ -846,8 +846,9 @@ class Connection: try: if os.getpid() == self.pid: self._writer.close() # type: ignore[union-attr] - # py3.6 doesn't have this method - if hasattr(self._writer, "wait_closed"): + # wait for close to finish, except when handling errors and + # forcefully disconnecting. + if not nowait: await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass @@ -902,10 +903,10 @@ class Connection: self._writer.writelines(command) await self._writer.drain() except asyncio.TimeoutError: - await self.disconnect() + await self.disconnect(nowait=True) raise TimeoutError("Timeout writing to socket") from None except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) if len(e.args) == 1: err_no, errmsg = "UNKNOWN", e.args[0] else: @@ -915,7 +916,7 @@ class Connection: f"Error {err_no} while writing to socket. {errmsg}." ) from e except Exception: - await self.disconnect() + await self.disconnect(nowait=True) raise async def send_command(self, *args: Any, **kwargs: Any) -> None: @@ -931,7 +932,7 @@ class Connection: try: return await self._parser.can_read(timeout) except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port}: {e.args}" ) @@ -949,15 +950,15 @@ class Connection: disable_decoding=disable_decoding ) except asyncio.TimeoutError: - await self.disconnect() + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect() + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) except Exception: - await self.disconnect() + await self.disconnect(nowait=True) raise if self.health_check_interval: diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 555cfdb..32268fe 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -819,7 +819,7 @@ class TestPubSubAutoReconnect: "type": "subscribe", } - async def mycleanup(self): + async def myfinish(self): message = await self.messages.get() assert message == { "channel": b"foo", @@ -827,6 +827,8 @@ class TestPubSubAutoReconnect: "pattern": None, "type": "subscribe", } + + async def mykill(self): # kill thread async with self.cond: self.state = 4 # quit @@ -836,41 +838,52 @@ class TestPubSubAutoReconnect: """ Test that a socket error will cause reconnect """ - async with async_timeout.timeout(self.timeout): - await self.mysetup(r, method) - # now, disconnect the connection, and wait for it to be re-established - async with self.cond: - assert self.state == 0 - self.state = 1 - with mock.patch.object(self.pubsub.connection, "_parser") as mockobj: - mockobj.read_response.side_effect = socket.error - mockobj.can_read.side_effect = socket.error - # wait until task noticies the disconnect until we undo the patch - await self.cond.wait_for(lambda: self.state >= 2) - assert not self.pubsub.connection.is_connected - # it is in a disconnecte state - # wait for reconnect - await self.cond.wait_for(lambda: self.pubsub.connection.is_connected) - assert self.state == 3 + try: + async with async_timeout.timeout(self.timeout): + await self.mysetup(r, method) + # now, disconnect the connection, and wait for it to be re-established + async with self.cond: + assert self.state == 0 + self.state = 1 + with mock.patch.object(self.pubsub.connection, "_parser") as m: + m.read_response.side_effect = socket.error + m.can_read.side_effect = socket.error + # wait until task noticies the disconnect until we + # undo the patch + await self.cond.wait_for(lambda: self.state >= 2) + assert not self.pubsub.connection.is_connected + # it is in a disconnecte state + # wait for reconnect + await self.cond.wait_for( + lambda: self.pubsub.connection.is_connected + ) + assert self.state == 3 - await self.mycleanup() + await self.myfinish() + finally: + await self.mykill() async def test_reconnect_disconnect(self, r: redis.Redis, method): """ Test that a manual disconnect() will cause reconnect """ - async with async_timeout.timeout(self.timeout): - await self.mysetup(r, method) - # now, disconnect the connection, and wait for it to be re-established - async with self.cond: - self.state = 1 - await self.pubsub.connection.disconnect() - assert not self.pubsub.connection.is_connected - # wait for reconnect - await self.cond.wait_for(lambda: self.pubsub.connection.is_connected) - assert self.state == 3 - - await self.mycleanup() + try: + async with async_timeout.timeout(self.timeout): + await self.mysetup(r, method) + # now, disconnect the connection, and wait for it to be re-established + async with self.cond: + self.state = 1 + await self.pubsub.connection.disconnect() + assert not self.pubsub.connection.is_connected + # wait for reconnect + await self.cond.wait_for( + lambda: self.pubsub.connection.is_connected + ) + assert self.state == 3 + + await self.myfinish() + finally: + await self.mykill() async def loop(self): # reader loop, performing state transitions as it |