diff options
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2c75d4f..862f6f0 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -350,13 +350,14 @@ class PythonParser(BaseParser): class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_reader",) + __slots__ = BaseParser.__slots__ + ("_reader", "_connected") def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None + self._connected: bool = False def on_connect(self, connection: "Connection"): self._stream = connection._reader @@ -369,13 +370,13 @@ class HiredisParser(BaseParser): kwargs["errors"] = connection.encoder.encoding_errors self._reader = hiredis.Reader(**kwargs) + self._connected = True def on_disconnect(self): - self._stream = None - self._reader = None + self._connected = False async def can_read_destructive(self): - if not self._stream or not self._reader: + if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._reader.gets(): return True @@ -397,8 +398,10 @@ class HiredisParser(BaseParser): async def read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, List[EncodableT]]: - if not self._stream or not self._reader: - self.on_disconnect() + # If `on_disconnect()` has been called, prohibit any more reads + # even if they could happen because data might be present. + # We still allow reads in progress to finish + if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None response = self._reader.gets() |