From aba7f3ffd0fe76f0a6068c05133b5bb2fef94909 Mon Sep 17 00:00:00 2001 From: Joe Gordon Date: Wed, 17 Aug 2022 10:01:18 -0700 Subject: Start to add type hints First pass at adding some type hints to pymemcache to make it easier to develop against etc. --- pymemcache/client/base.py | 178 +++++++++++++++++++++++++++++++--------------- pymemcache/client/hash.py | 2 +- pymemcache/serde.py | 10 +-- pyproject.toml | 1 + 4 files changed, 126 insertions(+), 65 deletions(-) diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 3d4234c..ef6bcac 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -16,20 +16,18 @@ import errno from functools import partial import platform import socket -from typing import Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable, Iterable from pymemcache import pool - -from pymemcache.serde import LegacyWrappingSerde from pymemcache.exceptions import ( MemcacheClientError, - MemcacheUnknownCommandError, MemcacheIllegalInputError, MemcacheServerError, - MemcacheUnknownError, MemcacheUnexpectedCloseError, + MemcacheUnknownCommandError, + MemcacheUnknownError, ) - +from pymemcache.serde import LegacyWrappingSerde RECV_SIZE = 4096 VALID_STORE_RESULTS = { @@ -53,23 +51,24 @@ STORE_RESULTS_VALUE = { } ServerSpec = Union[Tuple[str, int], str] +Key = Union[bytes, str] # Some of the values returned by the "stats" command # need mapping into native Python types -def _parse_bool_int(value): +def _parse_bool_int(value: bytes) -> bool: return int(value) != 0 -def _parse_bool_string_is_yes(value): +def _parse_bool_string_is_yes(value: bytes) -> bool: return value == b"yes" -def _parse_float(value): +def _parse_float(value: bytes) -> float: return float(value.replace(b":", b".")) -def _parse_hex(value): +def _parse_hex(value: bytes) -> int: return int(value, 8) @@ -96,7 +95,9 @@ STAT_TYPES = { # Common helper functions. -def check_key_helper(key, allow_unicode_keys, key_prefix=b""): +def check_key_helper( + key: Key, allow_unicode_keys: bool, key_prefix: bytes = b"" +) -> bytes: """Checks key and add key_prefix.""" if allow_unicode_keys: if isinstance(key, str): @@ -160,7 +161,7 @@ class KeepaliveOpts: __slots__ = ("idle", "intvl", "cnt") - def __init__(self, idle=1, intvl=1, cnt=5): + def __init__(self, idle: int = 1, intvl: int = 1, cnt: int = 5) -> None: if idle < 1: raise ValueError("The idle parameter must be greater or equal to 1.") self.idle = idle @@ -275,15 +276,15 @@ class Client: serializer=None, deserializer=None, connect_timeout=None, - timeout=None, - no_delay=False, - ignore_exc=False, + timeout: Optional[float] = None, + no_delay: bool = False, + ignore_exc: bool = False, socket_module=socket, - socket_keepalive=None, - key_prefix=b"", + socket_keepalive: Optional[KeepaliveOpts] = None, + key_prefix: bytes = b"", default_noreply=True, - allow_unicode_keys=False, - encoding="ascii", + allow_unicode_keys: bool = False, + encoding: str = "ascii", tls_context=None, ): """ @@ -354,7 +355,7 @@ class Client: "KeepaliveOpts object. That's the only supported type " "of structure." ) - self.sock = None + self.sock: Optional[socket.socket] = None if isinstance(key_prefix, str): key_prefix = key_prefix.encode("ascii") if not isinstance(key_prefix, bytes): @@ -365,13 +366,13 @@ class Client: self.encoding = encoding self.tls_context = tls_context - def check_key(self, key): + def check_key(self, key: Key) -> bytes: """Checks key and add key_prefix.""" return check_key_helper( key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix ) - def _connect(self): + def _connect(self) -> None: self.close() s = self.socket_module @@ -426,7 +427,7 @@ class Client: self.sock = sock - def close(self): + def close(self) -> None: """Close the connection to memcached, if it is open. The next call to a method that requires a connection will re-open it.""" if self.sock is not None: @@ -439,7 +440,14 @@ class Client: disconnect_all = close - def set(self, key, value, expire=0, noreply=None, flags=None): + def set( + self, + key: Key, + value: Any, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> Optional[bool]: """ The memcached "set" command. @@ -460,9 +468,17 @@ class Client: """ if noreply is None: noreply = self.default_noreply + # Optional because _store_cmd lookup in STORE_RESULTS_VALUE can return None in some cases. + # TODO: refactor to fix return self._store_cmd(b"set", {key: value}, expire, noreply, flags=flags)[key] - def set_many(self, values, expire=0, noreply=None, flags=None): + def set_many( + self, + values: Dict[Key, Any], + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ) -> List[Key]: """ A convenience function for setting multiple values. @@ -487,7 +503,14 @@ class Client: set_multi = set_many - def add(self, key, value, expire=0, noreply=None, flags=None): + def add( + self, + key: Key, + value: Any, + expire: int = 0, + noreply: Optional[bool] = None, + flags: Optional[int] = None, + ): """ The memcached "add" command. @@ -608,7 +631,7 @@ class Client: b"cas", {key: value}, expire, noreply, flags=flags, cas=cas )[key] - def get(self, key, default=None): + def get(self, key: Key, default: Optional[Any] = None): """ The memcached "get" command, but only for one key, as a convenience. @@ -673,7 +696,7 @@ class Client: return self._fetch_cmd(b"gets", keys, True) - def delete(self, key, noreply=None): + def delete(self, key: Key, noreply=None): """ The memcached "delete" command. @@ -698,7 +721,7 @@ class Client: return True return results[0] == b"DELETED" - def delete_many(self, keys, noreply=None): + def delete_many(self, keys: Iterable[Key], noreply: Optional[bool] = None) -> bool: """ A convenience function to delete multiple keys. @@ -732,7 +755,9 @@ class Client: delete_multi = delete_many - def incr(self, key, value, noreply=False): + def incr( + self, key: Key, value: int, noreply: Optional[bool] = False + ) -> Optional[int]: """ The memcached "incr" command. @@ -746,8 +771,8 @@ class Client: value of the key, or None if the key wasn't found. """ key = self.check_key(key) - value = self._check_integer(value, "value") - cmd = b"incr " + key + b" " + value + val = self._check_integer(value, "value") + cmd = b"incr " + key + b" " + val if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -758,7 +783,9 @@ class Client: return None return int(results[0]) - def decr(self, key, value, noreply=False): + def decr( + self, key: Key, value: int, noreply: Optional[bool] = False + ) -> Optional[int]: """ The memcached "decr" command. @@ -772,8 +799,8 @@ class Client: value of the key, or None if the key wasn't found. """ key = self.check_key(key) - value = self._check_integer(value, "value") - cmd = b"decr " + key + b" " + value + val = self._check_integer(value, "value") + cmd = b"decr " + key + b" " + val if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -784,7 +811,7 @@ class Client: return None return int(results[0]) - def touch(self, key, expire=0, noreply=None): + def touch(self, key: Key, expire: int = 0, noreply: Optional[bool] = None) -> bool: """ The memcached "touch" command. @@ -802,8 +829,8 @@ class Client: if noreply is None: noreply = self.default_noreply key = self.check_key(key) - expire = self._check_integer(expire, "expire") - cmd = b"touch " + key + b" " + expire + expire_bytes = self._check_integer(expire, "expire") + cmd = b"touch " + key + b" " + expire_bytes if noreply: cmd += b" noreply" cmd += b"\r\n" @@ -914,7 +941,7 @@ class Client: return True return results[0] == b"OK" - def quit(self): + def quit(self) -> None: """ The memcached "quit" command. @@ -964,7 +991,7 @@ class Client: error = line[line.find(b" ") + 1 :] raise MemcacheServerError(error) - def _check_integer(self, value, name): + def _check_integer(self, value: int, name: str) -> bytes: """Check that a value is an integer and encode it as a binary string""" if not isinstance(value, int): raise MemcacheIllegalInputError( @@ -973,7 +1000,7 @@ class Client: return str(value).encode(self.encoding) - def _check_cas(self, cas): + def _check_cas(self, cas: Union[int, str, bytes]) -> bytes: """Check that a value is a valid input for 'cas' -- either an int or a string containing only 0-9 @@ -997,7 +1024,14 @@ class Client: return cas - def _extract_value(self, expect_cas, line, buf, remapped_keys, prefixed_keys): + def _extract_value( + self, + expect_cas: bool, + line: bytes, + buf: bytes, + remapped_keys, + prefixed_keys: List[bytes], + ): """ This function is abstracted from _fetch_cmd to support different ways of value extraction. In order to use this feature, _extract_value needs @@ -1009,7 +1043,7 @@ class Client: try: _, key, flags, size = line.split() except Exception as e: - raise ValueError(f"Unable to parse line {line}: {e}") + raise ValueError(f"Unable to parse line {line!r}: {e}") value = None try: @@ -1025,7 +1059,7 @@ class Client: else: return key, value, buf - def _fetch_cmd(self, name, keys, expect_cas): + def _fetch_cmd(self, name: bytes, keys: Iterable[Key], expect_cas: bool): prefixed_keys = [self.check_key(k) for k in keys] remapped_keys = dict(zip(prefixed_keys, keys)) @@ -1039,11 +1073,14 @@ class Client: if self.sock is None: self._connect() + # For typing + assert self.sock is not None + self.sock.sendall(cmd) buf = b"" line = None - result = {} + result: Dict[bytes, bytes] = {} while True: try: buf, line = _readline(self.sock, buf) @@ -1073,7 +1110,15 @@ class Client: return {} raise - def _store_cmd(self, name, values, expire, noreply, flags=None, cas=None): + def _store_cmd( + self, + name: bytes, + values: Dict[Key, Any], + expire: int, + noreply: bool, + flags: Optional[int] = None, + cas: Optional[bytes] = None, + ) -> Dict[Key, Optional[bool]]: cmds = [] keys = [] @@ -1082,7 +1127,7 @@ class Client: extra += b" " + cas if noreply: extra += b" noreply" - expire = self._check_integer(expire, "expire") + expire_bytes = self._check_integer(expire, "expire") for key, data in values.items(): # must be able to reliably map responses back to the original order @@ -1111,7 +1156,7 @@ class Client: + b" " + str(data_flags).encode(self.encoding) + b" " - + expire + + expire_bytes + b" " + str(len(data)).encode(self.encoding) + extra @@ -1123,6 +1168,9 @@ class Client: if self.sock is None: self._connect() + # For typing + assert self.sock is not None + try: self.sock.sendall(b"".join(cmds)) if noreply: @@ -1148,10 +1196,17 @@ class Client: self.close() raise - def _misc_cmd(self, cmds, cmd_name, noreply, end_tokens=None): + def _misc_cmd( + self, + cmds: Iterable[bytes], + cmd_name: bytes, + noreply: Optional[bool], + end_tokens=None, + ) -> List[bytes]: # If no end_tokens have been given, just assume standard memcached # operations, which end in "\r\n", use regular code for that. + _reader: Callable[[socket.socket, bytes], Tuple[bytes, bytes]] if end_tokens: _reader = partial(_readsegment, end_tokens=end_tokens) else: @@ -1160,6 +1215,9 @@ class Client: if self.sock is None: self._connect() + # For typing + assert self.sock is not None + try: self.sock.sendall(b"".join(cmds)) @@ -1236,7 +1294,7 @@ class PooledClient: max_pool_size=None, pool_idle_timeout=0, lock_generator=None, - default_noreply=True, + default_noreply: bool = True, allow_unicode_keys=False, encoding="ascii", tls_context=None, @@ -1266,7 +1324,7 @@ class PooledClient: self.encoding = encoding self.tls_context = tls_context - def check_key(self, key): + def check_key(self, key: Key) -> bytes: """Checks key and add key_prefix.""" return check_key_helper( key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix @@ -1443,7 +1501,7 @@ class PooledClient: self.delete(key, noreply=True) -def _readline(sock, buf): +def _readline(sock: socket.socket, buf: bytes) -> Tuple[bytes, bytes]: """Read line of text from the socket. Read a line of text (delimited by "\r\n") from the socket, and @@ -1452,18 +1510,18 @@ def _readline(sock, buf): Args: sock: Socket object, should be connected. - buf: String, zero or more characters, returned from an earlier - call to _readline or _readvalue (pass an empty string on the + buf: Bytes, zero or more characters, returned from an earlier + call to _readline or _readvalue (pass an empty byte string on the first call). Returns: A tuple of (buf, line) where line is the full line read from the socket (minus the "\r\n" characters) and buf is any trailing characters read after the "\r\n" was found (which may be an empty - string). + byte string). """ - chunks = [] + chunks: List[bytes] = [] last_char = b"" while True: @@ -1494,7 +1552,7 @@ def _readline(sock, buf): raise MemcacheUnexpectedCloseError() -def _readvalue(sock, buf, size): +def _readvalue(sock, buf, size: int): """Read specified amount of bytes from the socket. Read size bytes, followed by the "\r\n" characters, from the socket, @@ -1539,7 +1597,9 @@ def _readvalue(sock, buf, size): return buf[rlen:], b"".join(chunks) -def _readsegment(sock, buf, end_tokens): +def _readsegment( + sock: socket.socket, buf: bytes, end_tokens: bytes +) -> Tuple[bytes, bytes]: """Read a segment from the socket. Read a segment from the socket, up to the first end_token sub-string/bytes, @@ -1575,7 +1635,7 @@ def _readsegment(sock, buf, end_tokens): raise MemcacheUnexpectedCloseError() -def _recv(sock, size): +def _recv(sock: socket.socket, size: int) -> bytes: """sock.recv() with retry on EINTR""" while True: try: diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index db44108..57b07db 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -154,7 +154,7 @@ class HashClient: self._dead_clients[server] = dead_time self.hasher.remove_node(key) - def _retry_dead(self): + def _retry_dead(self) -> None: current_time = time.time() ldc = self._last_dead_check_time # We have reached the retry timeout diff --git a/pymemcache/serde.py b/pymemcache/serde.py index 6e77766..42ec922 100644 --- a/pymemcache/serde.py +++ b/pymemcache/serde.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import logging -from io import BytesIO import pickle import zlib +from functools import partial +from io import BytesIO FLAG_BYTES = 0 FLAG_PICKLE = 1 << 0 @@ -61,7 +61,7 @@ def _python_memcache_serializer(key, value, pickle_version=None): return value, flags -def get_python_memcache_serializer(pickle_version=DEFAULT_PICKLE_VERSION): +def get_python_memcache_serializer(pickle_version: int = DEFAULT_PICKLE_VERSION): """Return a serializer using a specific pickle version""" return partial(_python_memcache_serializer, pickle_version=pickle_version) @@ -112,7 +112,7 @@ class PickleSerde: for :py:class:`pymemcache.client.base.Client` """ - def __init__(self, pickle_version=DEFAULT_PICKLE_VERSION): + def __init__(self, pickle_version: int = DEFAULT_PICKLE_VERSION) -> None: self._serialize_func = get_python_memcache_serializer(pickle_version) def serialize(self, key, value): @@ -182,7 +182,7 @@ class LegacyWrappingSerde: case that they are missing. """ - def __init__(self, serializer_func, deserializer_func): + def __init__(self, serializer_func, deserializer_func) -> None: self.serialize = serializer_func or self._default_serialize self.deserialize = deserializer_func or self._default_deserialize diff --git a/pyproject.toml b/pyproject.toml index 8d26def..d08806c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,3 +3,4 @@ target-version = ['py37', 'py38', 'py39', 'py310'] [tool.mypy] python_version = 3.7 +ignore_missing_imports = true -- cgit v1.2.1