summaryrefslogtreecommitdiff
path: root/redis/asyncio/connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/asyncio/connection.py')
-rw-r--r--redis/asyncio/connection.py140
1 files changed, 83 insertions, 57 deletions
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 91961ba..9de2d46 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -24,7 +24,6 @@ from typing import (
Type,
TypeVar,
Union,
- cast,
)
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
@@ -110,32 +109,32 @@ class Encoder:
def encode(self, value: EncodableT) -> EncodedT:
"""Return a bytestring or bytes-like representation of the value"""
+ if isinstance(value, str):
+ return value.encode(self.encoding, self.encoding_errors)
if isinstance(value, (bytes, memoryview)):
return value
- if isinstance(value, bool):
- # special case bool since it is a subclass of int
- raise DataError(
- "Invalid input of type: 'bool'. "
- "Convert to a bytes, string, int or float first."
- )
if isinstance(value, (int, float)):
+ if isinstance(value, bool):
+ # special case bool since it is a subclass of int
+ raise DataError(
+ "Invalid input of type: 'bool'. "
+ "Convert to a bytes, string, int or float first."
+ )
return repr(value).encode()
- if not isinstance(value, str):
- # a value we don't know how to deal with. throw an error
- typename = value.__class__.__name__ # type: ignore[unreachable]
- raise DataError(
- f"Invalid input of type: {typename!r}. "
- "Convert to a bytes, string, int or float first."
- )
- return value.encode(self.encoding, self.encoding_errors)
+ # a value we don't know how to deal with. throw an error
+ typename = value.__class__.__name__
+ raise DataError(
+ f"Invalid input of type: {typename!r}. "
+ "Convert to a bytes, string, int or float first."
+ )
def decode(self, value: EncodableT, force=False) -> EncodableT:
"""Return a unicode string from the bytes-like representation"""
if self.decode_responses or force:
- if isinstance(value, memoryview):
- return value.tobytes().decode(self.encoding, self.encoding_errors)
if isinstance(value, bytes):
return value.decode(self.encoding, self.encoding_errors)
+ if isinstance(value, memoryview):
+ return value.tobytes().decode(self.encoding, self.encoding_errors)
return value
@@ -336,7 +335,7 @@ class SocketBuffer:
def close(self):
try:
self.purge()
- self._buffer.close() # type: ignore[union-attr]
+ self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
# BadFileDescriptor error. Perhaps the client ran out of
@@ -466,7 +465,7 @@ class HiredisParser(BaseParser):
self._next_response = False
async def can_read(self, timeout: float):
- if not self._reader:
+ if not self._stream or not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
if self._next_response is False:
@@ -480,14 +479,14 @@ class HiredisParser(BaseParser):
timeout: Union[float, None, _Sentinel] = SENTINEL,
raise_on_timeout: bool = True,
):
- if self._stream is None or self._reader is None:
- raise RedisError("Parser already closed.")
-
timeout = self._socket_timeout if timeout is SENTINEL else timeout
try:
- async with async_timeout.timeout(timeout):
+ if timeout is None:
buffer = await self._stream.read(self._read_size)
- if not isinstance(buffer, bytes) or len(buffer) == 0:
+ else:
+ async with async_timeout.timeout(timeout):
+ buffer = await self._stream.read(self._read_size)
+ if not buffer or not isinstance(buffer, bytes):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
self._reader.feed(buffer)
# data was read from the socket and added to the buffer.
@@ -516,9 +515,6 @@ class HiredisParser(BaseParser):
self.on_disconnect()
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None
- response: Union[
- EncodableT, ConnectionError, List[Union[EncodableT, ConnectionError]]
- ]
# _next_response might be cached from a can_read() call
if self._next_response is not False:
response = self._next_response
@@ -541,8 +537,7 @@ class HiredisParser(BaseParser):
and isinstance(response[0], ConnectionError)
):
raise response[0]
- # cast as there won't be a ConnectionError here.
- return cast(Union[EncodableT, List[EncodableT]], response)
+ return response
DefaultParser: Type[Union[PythonParser, HiredisParser]]
@@ -637,7 +632,7 @@ class Connection:
self.socket_type = socket_type
self.retry_on_timeout = retry_on_timeout
if retry_on_timeout:
- if retry is None:
+ if not retry:
self.retry = Retry(NoBackoff(), 1)
else:
# deep-copy the Retry object as it is mutable
@@ -681,7 +676,7 @@ class Connection:
@property
def is_connected(self):
- return bool(self._reader and self._writer)
+ return self._reader and self._writer
def register_connect_callback(self, callback):
self._connect_callbacks.append(weakref.WeakMethod(callback))
@@ -713,7 +708,7 @@ class Connection:
raise ConnectionError(exc) from exc
try:
- if self.redis_connect_func is None:
+ if not self.redis_connect_func:
# Use the default on_connect function
await self.on_connect()
else:
@@ -745,7 +740,7 @@ class Connection:
self._reader = reader
self._writer = writer
sock = writer.transport.get_extra_info("socket")
- if sock is not None:
+ if sock:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
try:
# TCP_KEEPALIVE
@@ -856,32 +851,29 @@ class Connection:
await self.retry.call_with_retry(self._send_ping, self._ping_failed)
async def _send_packed_command(self, command: Iterable[bytes]) -> None:
- if self._writer is None:
- raise RedisError("Connection already closed.")
-
self._writer.writelines(command)
await self._writer.drain()
async def send_packed_command(
- self,
- command: Union[bytes, str, Iterable[bytes]],
- check_health: bool = True,
- ):
- """Send an already packed command to the Redis server"""
- if not self._writer:
+ self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
+ ) -> None:
+ if not self.is_connected:
await self.connect()
- # guard against health check recursion
- if check_health:
+ elif check_health:
await self.check_health()
+
try:
if isinstance(command, str):
command = command.encode()
if isinstance(command, bytes):
command = [command]
- await asyncio.wait_for(
- self._send_packed_command(command),
- self.socket_timeout,
- )
+ if self.socket_timeout:
+ await asyncio.wait_for(
+ self._send_packed_command(command), self.socket_timeout
+ )
+ else:
+ self._writer.writelines(command)
+ await self._writer.drain()
except asyncio.TimeoutError:
await self.disconnect()
raise TimeoutError("Timeout writing to socket") from None
@@ -901,8 +893,6 @@ class Connection:
async def send_command(self, *args, **kwargs):
"""Pack and send a command to the Redis server"""
- if not self.is_connected:
- await self.connect()
await self.send_packed_command(
self.pack_command(*args), check_health=kwargs.get("check_health", True)
)
@@ -923,10 +913,50 @@ class Connection:
"""Read the response from a previously sent command"""
try:
async with self._lock:
+ if self.socket_timeout:
+ async with async_timeout.timeout(self.socket_timeout):
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding
+ )
+ else:
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding
+ )
+ except asyncio.TimeoutError:
+ await self.disconnect()
+ raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
+ except OSError as e:
+ await self.disconnect()
+ raise ConnectionError(
+ f"Error while reading from {self.host}:{self.port} : {e.args}"
+ )
+ except BaseException:
+ await self.disconnect()
+ raise
+
+ if self.health_check_interval:
+ if sys.version_info[0:2] == (3, 6):
+ func = asyncio.get_event_loop
+ else:
+ func = asyncio.get_running_loop
+ self.next_health_check = func().time() + self.health_check_interval
+
+ if isinstance(response, ResponseError):
+ raise response from None
+ return response
+
+ async def read_response_without_lock(self, disable_decoding: bool = False):
+ """Read the response from a previously sent command"""
+ try:
+ if self.socket_timeout:
async with async_timeout.timeout(self.socket_timeout):
response = await self._parser.read_response(
disable_decoding=disable_decoding
)
+ else:
+ response = await self._parser.read_response(
+ disable_decoding=disable_decoding
+ )
except asyncio.TimeoutError:
await self.disconnect()
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
@@ -1182,10 +1212,7 @@ class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init]
self._lock = asyncio.Lock()
def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]:
- pieces = [
- ("path", self.path),
- ("db", self.db),
- ]
+ pieces = [("path", self.path), ("db", self.db)]
if self.client_name:
pieces.append(("client_name", self.client_name))
return pieces
@@ -1254,12 +1281,11 @@ def parse_url(url: str) -> ConnectKwargs:
parser = URL_QUERY_ARGUMENT_PARSERS.get(name)
if parser:
try:
- # We can't type this.
- kwargs[name] = parser(value) # type: ignore[misc]
+ kwargs[name] = parser(value)
except (TypeError, ValueError):
raise ValueError(f"Invalid value for `{name}` in connection URL.")
else:
- kwargs[name] = value # type: ignore[misc]
+ kwargs[name] = value
if parsed.username:
kwargs["username"] = unquote(parsed.username)