diff options
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r-- | redis/asyncio/connection.py | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 93db37e..057067a 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -146,7 +146,7 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Except class BaseParser: """Plain Python parsing class""" - __slots__ = "_stream", "_read_size" + __slots__ = "_stream", "_read_size", "_connected" EXCEPTION_CLASSES: ExceptionMappingT = { "ERR": { @@ -177,6 +177,7 @@ class BaseParser: def __init__(self, socket_read_size: int): self._stream: Optional[asyncio.StreamReader] = None self._read_size = socket_read_size + self._connected = False def __del__(self): try: @@ -213,7 +214,7 @@ class BaseParser: class PythonParser(BaseParser): """Plain Python parsing class""" - __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + __slots__ = ("encoder", "_buffer", "_pos", "_chunks") def __init__(self, socket_read_size: int): super().__init__(socket_read_size) @@ -231,21 +232,19 @@ class PythonParser(BaseParser): self._stream = connection._reader if self._stream is None: raise RedisError("Buffer is closed.") - self.encoder = connection.encoder + self._clear() + self._connected = True def on_disconnect(self): """Called when the stream disconnects""" - if self._stream is not None: - self._stream = None - self.encoder = None - self._clear() + self._connected = False async def can_read_destructive(self) -> bool: + if not self._connected: + raise RedisError("Buffer is closed.") if self._buffer: return True - if self._stream is None: - raise RedisError("Buffer is closed.") try: async with async_timeout(0): return await self._stream.read(1) @@ -253,6 +252,8 @@ class PythonParser(BaseParser): return False async def read_response(self, disable_decoding: bool = False): + if not self._connected: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) if self._chunks: # augment parsing buffer with previously read data self._buffer += b"".join(self._chunks) @@ -266,8 +267,6 @@ class PythonParser(BaseParser): async def _read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: - if not self._stream or not self.encoder: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) raw = await self._readline() response: Any byte, response = raw[:1], raw[1:] @@ -354,14 +353,13 @@ class PythonParser(BaseParser): class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" - __slots__ = BaseParser.__slots__ + ("_reader", "_connected") + __slots__ = ("_reader",) def __init__(self, socket_read_size: int): if not HIREDIS_AVAILABLE: raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None - self._connected: bool = False def on_connect(self, connection: "Connection"): self._stream = connection._reader |