diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2021-08-10 16:56:13 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-08-14 21:58:01 -0500 |
commit | 85ccdfd12928964e2861768ce70301232f50c769 (patch) | |
tree | 6862f34ee3ec4e42b8f9e979d32bb99baa032615 /requests_cache | |
parent | d17ad7f422ad01bc7ddcd707396f763ec0b236d3 (diff) | |
download | requests-cache-85ccdfd12928964e2861768ce70301232f50c769.tar.gz |
Replace some 'type: ignore' statements with better type hinting
Diffstat (limited to 'requests_cache')
-rw-r--r-- | requests_cache/backends/base.py | 2 | ||||
-rw-r--r-- | requests_cache/cache_keys.py | 31 | ||||
-rw-r--r-- | requests_cache/patcher.py | 5 | ||||
-rw-r--r-- | requests_cache/session.py | 2 |
4 files changed, 23 insertions, 17 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 3e4b045..2ed9503 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -159,7 +159,7 @@ class BaseCache: def create_key(self, request: AnyRequest, **kwargs) -> str: """Create a normalized cache key from a request object""" - return create_key(request, self.ignored_parameters, self.include_get_headers, **kwargs) # type: ignore + return create_key(request, self.ignored_parameters, self.include_get_headers, **kwargs) def has_key(self, key: str) -> bool: """Returns `True` if cache has `key`, `False` otherwise""" diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py index d1ad8dc..3ded784 100644 --- a/requests_cache/cache_keys.py +++ b/requests_cache/cache_keys.py @@ -1,20 +1,25 @@ +from __future__ import annotations + import hashlib import json from operator import itemgetter -from typing import Iterable, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, Union from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse -from requests import PreparedRequest, Request, Session +from requests import Request, Session from requests.models import CaseInsensitiveDict from requests.utils import default_headers from url_normalize import url_normalize +if TYPE_CHECKING: + from .models import AnyRequest + DEFAULT_HEADERS = default_headers() RequestContent = Union[Mapping, str, bytes] def create_key( - request: PreparedRequest, + request: AnyRequest, ignored_params: Iterable[str] = None, include_get_headers: bool = False, **kwargs, @@ -31,15 +36,16 @@ def create_key( if body: key.update(body) if include_get_headers and request.headers != DEFAULT_HEADERS: - for name, value in normalize_dict(request.headers).items(): # type: ignore + headers = normalize_dict(request.headers) + if TYPE_CHECKING: + assert isinstance(headers, dict) + for name, value in headers.items(): key.update(encode(f'{name}={value}')) return key.hexdigest() -def remove_ignored_params( - request: PreparedRequest, ignored_params: Optional[Iterable[str]] -) -> PreparedRequest: +def remove_ignored_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> AnyRequest: if not ignored_params: return request request.headers = remove_ignored_headers(request, ignored_params) @@ -49,7 +55,7 @@ def remove_ignored_params( def remove_ignored_headers( - request: PreparedRequest, ignored_params: Optional[Iterable[str]] + request: AnyRequest, ignored_params: Optional[Iterable[str]] ) -> CaseInsensitiveDict: if not ignored_params: return request.headers @@ -59,7 +65,7 @@ def remove_ignored_headers( return headers -def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str: +def remove_ignored_url_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> str: url_str = str(request.url) if not ignored_params: return url_str @@ -69,10 +75,9 @@ def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment)) -def remove_ignored_body_params( - request: PreparedRequest, ignored_params: Optional[Iterable[str]] -) -> bytes: +def remove_ignored_body_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> bytes: original_body = request.body + filtered_body: Union[str, bytes] = b'' content_type = request.headers.get('content-type') if not ignored_params or not original_body or not content_type: return encode(original_body) @@ -85,7 +90,7 @@ def remove_ignored_body_params( body = filter_params(sorted(body), ignored_params) filtered_body = json.dumps(body) else: - filtered_body = original_body # type: ignore + filtered_body = original_body return encode(filtered_body) diff --git a/requests_cache/patcher.py b/requests_cache/patcher.py index a0a929e..9b906a7 100644 --- a/requests_cache/patcher.py +++ b/requests_cache/patcher.py @@ -125,8 +125,9 @@ def remove_expired_responses(expire_after: ExpirationTime = None): Args: expire_after: A new expiration time used to revalidate the cache """ - if is_installed(): - return requests.Session().remove_expired_responses(expire_after) # type: ignore + session = requests.Session() + if isinstance(session, CachedSession): + session.remove_expired_responses(expire_after) def _patch_session_factory(session_factory: Type[OriginalSession] = CachedSession): diff --git a/requests_cache/session.py b/requests_cache/session.py index 450865c..871140c 100644 --- a/requests_cache/session.py +++ b/requests_cache/session.py @@ -56,7 +56,7 @@ class CacheMixin(MIXIN_BASE): self._disabled = False self._lock = RLock() - # If the superclass is custom Session, pass along valid kwargs (if any) + # If the superclass is custom Session, pass along any valid kwargs session_kwargs = get_valid_kwargs(super().__init__, kwargs) super().__init__(**session_kwargs) # type: ignore |