summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2022-09-29 11:23:07 +0000
committerGitHub <noreply@github.com>2022-09-29 14:23:07 +0300
commit652ca790b0bbaefa78278ddf57074316a3fb21bd (patch)
tree6c8b6a942851b79ab1f184b27a5e9a537d81807a
parent9fe836698ca5930e39633b86839b6c1bae07237e (diff)
downloadredis-py-652ca790b0bbaefa78278ddf57074316a3fb21bd.tar.gz
Add `nowait` flag to `asyncio.Connection.disconnect()` (#2356)
* Don't wait for disconnect() when handling errors. This can result in other errors such as timeouts. * add CHANGES * Update redis/asyncio/connection.py Co-authored-by: Aarni Koskela <akx@iki.fi> * await a task to try to diagnose unittest failures in CI Co-authored-by: Aarni Koskela <akx@iki.fi>
-rw-r--r--CHANGES1
-rw-r--r--redis/asyncio/connection.py21
-rw-r--r--tests/test_asyncio/test_pubsub.py73
3 files changed, 55 insertions, 40 deletions
diff --git a/CHANGES b/CHANGES
index 75fa83c..a5b5029 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,4 @@
+ * add `nowait` flag to `asyncio.Connection.disconnect()`
* Update README.md links
* Fix timezone handling for datetime to unixtime conversions
* Fix start_id type for XAUTOCLAIM
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index a470b6f..64848f4 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -836,7 +836,7 @@ class Connection:
if str_if_bytes(await self.read_response()) != "OK":
raise ConnectionError("Invalid Database")
- async def disconnect(self) -> None:
+ async def disconnect(self, nowait: bool = False) -> None:
"""Disconnects from the Redis server"""
try:
async with async_timeout.timeout(self.socket_connect_timeout):
@@ -846,8 +846,9 @@ class Connection:
try:
if os.getpid() == self.pid:
self._writer.close() # type: ignore[union-attr]
- # py3.6 doesn't have this method
- if hasattr(self._writer, "wait_closed"):
+ # wait for close to finish, except when handling errors and
+ # forcefully disconnecting.
+ if not nowait:
await self._writer.wait_closed() # type: ignore[union-attr]
except OSError:
pass
@@ -902,10 +903,10 @@ class Connection:
self._writer.writelines(command)
await self._writer.drain()
except asyncio.TimeoutError:
- await self.disconnect()
+ await self.disconnect(nowait=True)
raise TimeoutError("Timeout writing to socket") from None
except OSError as e:
- await self.disconnect()
+ await self.disconnect(nowait=True)
if len(e.args) == 1:
err_no, errmsg = "UNKNOWN", e.args[0]
else:
@@ -915,7 +916,7 @@ class Connection:
f"Error {err_no} while writing to socket. {errmsg}."
) from e
except Exception:
- await self.disconnect()
+ await self.disconnect(nowait=True)
raise
async def send_command(self, *args: Any, **kwargs: Any) -> None:
@@ -931,7 +932,7 @@ class Connection:
try:
return await self._parser.can_read(timeout)
except OSError as e:
- await self.disconnect()
+ await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port}: {e.args}"
)
@@ -949,15 +950,15 @@ class Connection:
disable_decoding=disable_decoding
)
except asyncio.TimeoutError:
- await self.disconnect()
+ await self.disconnect(nowait=True)
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
except OSError as e:
- await self.disconnect()
+ await self.disconnect(nowait=True)
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except Exception:
- await self.disconnect()
+ await self.disconnect(nowait=True)
raise
if self.health_check_interval:
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 555cfdb..32268fe 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -819,7 +819,7 @@ class TestPubSubAutoReconnect:
"type": "subscribe",
}
- async def mycleanup(self):
+ async def myfinish(self):
message = await self.messages.get()
assert message == {
"channel": b"foo",
@@ -827,6 +827,8 @@ class TestPubSubAutoReconnect:
"pattern": None,
"type": "subscribe",
}
+
+ async def mykill(self):
# kill thread
async with self.cond:
self.state = 4 # quit
@@ -836,41 +838,52 @@ class TestPubSubAutoReconnect:
"""
Test that a socket error will cause reconnect
"""
- async with async_timeout.timeout(self.timeout):
- await self.mysetup(r, method)
- # now, disconnect the connection, and wait for it to be re-established
- async with self.cond:
- assert self.state == 0
- self.state = 1
- with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
- mockobj.read_response.side_effect = socket.error
- mockobj.can_read.side_effect = socket.error
- # wait until task noticies the disconnect until we undo the patch
- await self.cond.wait_for(lambda: self.state >= 2)
- assert not self.pubsub.connection.is_connected
- # it is in a disconnecte state
- # wait for reconnect
- await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
- assert self.state == 3
+ try:
+ async with async_timeout.timeout(self.timeout):
+ await self.mysetup(r, method)
+ # now, disconnect the connection, and wait for it to be re-established
+ async with self.cond:
+ assert self.state == 0
+ self.state = 1
+ with mock.patch.object(self.pubsub.connection, "_parser") as m:
+ m.read_response.side_effect = socket.error
+ m.can_read.side_effect = socket.error
+ # wait until task noticies the disconnect until we
+ # undo the patch
+ await self.cond.wait_for(lambda: self.state >= 2)
+ assert not self.pubsub.connection.is_connected
+ # it is in a disconnecte state
+ # wait for reconnect
+ await self.cond.wait_for(
+ lambda: self.pubsub.connection.is_connected
+ )
+ assert self.state == 3
- await self.mycleanup()
+ await self.myfinish()
+ finally:
+ await self.mykill()
async def test_reconnect_disconnect(self, r: redis.Redis, method):
"""
Test that a manual disconnect() will cause reconnect
"""
- async with async_timeout.timeout(self.timeout):
- await self.mysetup(r, method)
- # now, disconnect the connection, and wait for it to be re-established
- async with self.cond:
- self.state = 1
- await self.pubsub.connection.disconnect()
- assert not self.pubsub.connection.is_connected
- # wait for reconnect
- await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
- assert self.state == 3
-
- await self.mycleanup()
+ try:
+ async with async_timeout.timeout(self.timeout):
+ await self.mysetup(r, method)
+ # now, disconnect the connection, and wait for it to be re-established
+ async with self.cond:
+ self.state = 1
+ await self.pubsub.connection.disconnect()
+ assert not self.pubsub.connection.is_connected
+ # wait for reconnect
+ await self.cond.wait_for(
+ lambda: self.pubsub.connection.is_connected
+ )
+ assert self.state == 3
+
+ await self.myfinish()
+ finally:
+ await self.mykill()
async def loop(self):
# reader loop, performing state transitions as it