summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristján Valur Jónsson <sweskman@gmail.com>2022-09-29 11:56:50 +0000
committerGitHub <noreply@github.com>2022-09-29 14:56:50 +0300
commitf014dc3a5d76935914c6e2975e66da44f2e6263b (patch)
tree465591eb372ccea9d210710d2f0b1760f9ef288d
parent652ca790b0bbaefa78278ddf57074316a3fb21bd (diff)
downloadredis-py-f014dc3a5d76935914c6e2975e66da44f2e6263b.tar.gz
Dev/no can read (#2360)
* make can_read() destructive for simplicity, and rename the method. Remove timeout argument, always timeout immediately. * don't use can_read in pubsub * connection.connect() now has its own retry, don't need it inside a retry loop
-rw-r--r--redis/asyncio/client.py22
-rw-r--r--redis/asyncio/connection.py51
-rw-r--r--tests/test_asyncio/test_cluster.py8
-rw-r--r--tests/test_asyncio/test_connection_pool.py2
-rw-r--r--tests/test_asyncio/test_pubsub.py2
5 files changed, 41 insertions, 44 deletions
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 9c8caae..c13054b 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -24,6 +24,8 @@ from typing import (
cast,
)
+import async_timeout
+
from redis.asyncio.connection import (
Connection,
ConnectionPool,
@@ -754,15 +756,21 @@ class PubSub:
await self.check_health()
- async def try_read():
- if not block:
- if not await conn.can_read(timeout=timeout):
+ if not conn.is_connected:
+ await conn.connect()
+
+ if not block:
+
+ async def read_with_timeout():
+ try:
+ async with async_timeout.timeout(timeout):
+ return await conn.read_response()
+ except asyncio.TimeoutError:
return None
- else:
- await conn.connect()
- return await conn.read_response()
- response = await self._execute(conn, try_read)
+ response = await self._execute(conn, read_with_timeout)
+ else:
+ response = await self._execute(conn, conn.read_response)
if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 64848f4..53b41af 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -208,7 +208,7 @@ class BaseParser:
def on_connect(self, connection: "Connection"):
raise NotImplementedError()
- async def can_read(self, timeout: float) -> bool:
+ async def can_read_destructive(self) -> bool:
raise NotImplementedError()
async def read_response(
@@ -286,9 +286,9 @@ class SocketBuffer:
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
- async def can_read(self, timeout: float) -> bool:
+ async def can_read_destructive(self) -> bool:
return bool(self.length) or await self._read_from_socket(
- timeout=timeout, raise_on_timeout=False
+ timeout=0, raise_on_timeout=False
)
async def read(self, length: int) -> bytes:
@@ -386,8 +386,8 @@ class PythonParser(BaseParser):
self._buffer = None
self.encoder = None
- async def can_read(self, timeout: float):
- return self._buffer and bool(await self._buffer.can_read(timeout))
+ async def can_read_destructive(self):
+ return self._buffer and bool(await self._buffer.can_read_destructive())
async def read_response(
self, disable_decoding: bool = False
@@ -444,9 +444,7 @@ class PythonParser(BaseParser):
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""
- __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")
-
- _next_response: bool
+ __slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")
def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
@@ -466,23 +464,18 @@ class HiredisParser(BaseParser):
kwargs["errors"] = connection.encoder.encoding_errors
self._reader = hiredis.Reader(**kwargs)
- self._next_response = False
self._socket_timeout = connection.socket_timeout
def on_disconnect(self):
self._stream = None
self._reader = None
- self._next_response = False
- async def can_read(self, timeout: float):
+ async def can_read_destructive(self):
if not self._stream or not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
-
- if self._next_response is False:
- self._next_response = self._reader.gets()
- if self._next_response is False:
- return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
- return True
+ if self._reader.gets():
+ return True
+ return await self.read_from_socket(timeout=0, raise_on_timeout=False)
async def read_from_socket(
self,
@@ -523,12 +516,6 @@ class HiredisParser(BaseParser):
self.on_disconnect()
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
- # _next_response might be cached from a can_read() call
- if self._next_response is not False:
- response = self._next_response
- self._next_response = False
- return response
-
response = self._reader.gets()
while response is False:
await self.read_from_socket()
@@ -925,12 +912,10 @@ class Connection:
self.pack_command(*args), check_health=kwargs.get("check_health", True)
)
- async def can_read(self, timeout: float = 0):
+ async def can_read_destructive(self):
"""Poll the socket to see if there's data that can be read."""
- if not self.is_connected:
- await self.connect()
try:
- return await self._parser.can_read(timeout)
+ return await self._parser.can_read_destructive()
except OSError as e:
await self.disconnect(nowait=True)
raise ConnectionError(
@@ -957,6 +942,10 @@ class Connection:
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
+ except asyncio.CancelledError:
+ # need this check for 3.7, where CancelledError
+ # is subclass of Exception, not BaseException
+ raise
except Exception:
await self.disconnect(nowait=True)
raise
@@ -1498,12 +1487,12 @@ class ConnectionPool:
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
- if await connection.can_read():
+ if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except ConnectionError:
await connection.disconnect()
await connection.connect()
- if await connection.can_read():
+ if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# release the connection back to the pool so that we don't
@@ -1699,12 +1688,12 @@ class BlockingConnectionPool(ConnectionPool):
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
- if await connection.can_read():
+ if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except ConnectionError:
await connection.disconnect()
await connection.connect()
- if await connection.can_read():
+ if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# release the connection back to the pool so that we don't leak it
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index 88cfb1f..f1bbe42 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -433,7 +433,7 @@ class TestRedisClusterObj:
Connection,
send_packed_command=mock.DEFAULT,
connect=mock.DEFAULT,
- can_read=mock.DEFAULT,
+ can_read_destructive=mock.DEFAULT,
) as mocks:
# simulate 7006 as a failed node
def execute_command_mock(self, *args, **options):
@@ -473,7 +473,7 @@ class TestRedisClusterObj:
execute_command.successful_calls = 0
execute_command.failed_calls = 0
initialize.side_effect = initialize_mock
- mocks["can_read"].return_value = False
+ mocks["can_read_destructive"].return_value = False
mocks["send_packed_command"].return_value = "MOCK_OK"
mocks["connect"].return_value = None
with mock.patch.object(
@@ -514,7 +514,7 @@ class TestRedisClusterObj:
send_command=mock.DEFAULT,
read_response=mock.DEFAULT,
_connect=mock.DEFAULT,
- can_read=mock.DEFAULT,
+ can_read_destructive=mock.DEFAULT,
on_connect=mock.DEFAULT,
) as mocks:
with mock.patch.object(
@@ -546,7 +546,7 @@ class TestRedisClusterObj:
mocks["send_command"].return_value = True
mocks["read_response"].return_value = "OK"
mocks["_connect"].return_value = True
- mocks["can_read"].return_value = False
+ mocks["can_read_destructive"].return_value = False
mocks["on_connect"].return_value = True
# Create a cluster with reading from replications
diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py
index d281608..35f23f4 100644
--- a/tests/test_asyncio/test_connection_pool.py
+++ b/tests/test_asyncio/test_connection_pool.py
@@ -103,7 +103,7 @@ class DummyConnection(Connection):
async def disconnect(self):
pass
- async def can_read(self, timeout: float = 0):
+ async def can_read_destructive(self, timeout: float = 0):
return False
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 32268fe..86584e4 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -847,7 +847,7 @@ class TestPubSubAutoReconnect:
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as m:
m.read_response.side_effect = socket.error
- m.can_read.side_effect = socket.error
+ m.can_read_destructive.side_effect = socket.error
# wait until task noticies the disconnect until we
# undo the patch
await self.cond.wait_for(lambda: self.state >= 2)