summaryrefslogtreecommitdiff
path: root/requests_cache
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-08-10 16:56:13 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-08-14 21:58:01 -0500
commit85ccdfd12928964e2861768ce70301232f50c769 (patch)
tree6862f34ee3ec4e42b8f9e979d32bb99baa032615 /requests_cache
parentd17ad7f422ad01bc7ddcd707396f763ec0b236d3 (diff)
downloadrequests-cache-85ccdfd12928964e2861768ce70301232f50c769.tar.gz
Replace some 'type: ignore' statements with better type hinting
Diffstat (limited to 'requests_cache')
-rw-r--r--requests_cache/backends/base.py2
-rw-r--r--requests_cache/cache_keys.py31
-rw-r--r--requests_cache/patcher.py5
-rw-r--r--requests_cache/session.py2
4 files changed, 23 insertions, 17 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 3e4b045..2ed9503 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -159,7 +159,7 @@ class BaseCache:
def create_key(self, request: AnyRequest, **kwargs) -> str:
"""Create a normalized cache key from a request object"""
- return create_key(request, self.ignored_parameters, self.include_get_headers, **kwargs) # type: ignore
+ return create_key(request, self.ignored_parameters, self.include_get_headers, **kwargs)
def has_key(self, key: str) -> bool:
"""Returns `True` if cache has `key`, `False` otherwise"""
diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py
index d1ad8dc..3ded784 100644
--- a/requests_cache/cache_keys.py
+++ b/requests_cache/cache_keys.py
@@ -1,20 +1,25 @@
+from __future__ import annotations
+
import hashlib
import json
from operator import itemgetter
-from typing import Iterable, List, Mapping, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
-from requests import PreparedRequest, Request, Session
+from requests import Request, Session
from requests.models import CaseInsensitiveDict
from requests.utils import default_headers
from url_normalize import url_normalize
+if TYPE_CHECKING:
+ from .models import AnyRequest
+
DEFAULT_HEADERS = default_headers()
RequestContent = Union[Mapping, str, bytes]
def create_key(
- request: PreparedRequest,
+ request: AnyRequest,
ignored_params: Iterable[str] = None,
include_get_headers: bool = False,
**kwargs,
@@ -31,15 +36,16 @@ def create_key(
if body:
key.update(body)
if include_get_headers and request.headers != DEFAULT_HEADERS:
- for name, value in normalize_dict(request.headers).items(): # type: ignore
+ headers = normalize_dict(request.headers)
+ if TYPE_CHECKING:
+ assert isinstance(headers, dict)
+ for name, value in headers.items():
key.update(encode(f'{name}={value}'))
return key.hexdigest()
-def remove_ignored_params(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
-) -> PreparedRequest:
+def remove_ignored_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> AnyRequest:
if not ignored_params:
return request
request.headers = remove_ignored_headers(request, ignored_params)
@@ -49,7 +55,7 @@ def remove_ignored_params(
def remove_ignored_headers(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
+ request: AnyRequest, ignored_params: Optional[Iterable[str]]
) -> CaseInsensitiveDict:
if not ignored_params:
return request.headers
@@ -59,7 +65,7 @@ def remove_ignored_headers(
return headers
-def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str:
+def remove_ignored_url_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> str:
url_str = str(request.url)
if not ignored_params:
return url_str
@@ -69,10 +75,9 @@ def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional
return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment))
-def remove_ignored_body_params(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
-) -> bytes:
+def remove_ignored_body_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> bytes:
original_body = request.body
+ filtered_body: Union[str, bytes] = b''
content_type = request.headers.get('content-type')
if not ignored_params or not original_body or not content_type:
return encode(original_body)
@@ -85,7 +90,7 @@ def remove_ignored_body_params(
body = filter_params(sorted(body), ignored_params)
filtered_body = json.dumps(body)
else:
- filtered_body = original_body # type: ignore
+ filtered_body = original_body
return encode(filtered_body)
diff --git a/requests_cache/patcher.py b/requests_cache/patcher.py
index a0a929e..9b906a7 100644
--- a/requests_cache/patcher.py
+++ b/requests_cache/patcher.py
@@ -125,8 +125,9 @@ def remove_expired_responses(expire_after: ExpirationTime = None):
Args:
expire_after: A new expiration time used to revalidate the cache
"""
- if is_installed():
- return requests.Session().remove_expired_responses(expire_after) # type: ignore
+ session = requests.Session()
+ if isinstance(session, CachedSession):
+ session.remove_expired_responses(expire_after)
def _patch_session_factory(session_factory: Type[OriginalSession] = CachedSession):
diff --git a/requests_cache/session.py b/requests_cache/session.py
index 450865c..871140c 100644
--- a/requests_cache/session.py
+++ b/requests_cache/session.py
@@ -56,7 +56,7 @@ class CacheMixin(MIXIN_BASE):
self._disabled = False
self._lock = RLock()
- # If the superclass is custom Session, pass along valid kwargs (if any)
+ # If the superclass is custom Session, pass along any valid kwargs
session_kwargs = get_valid_kwargs(super().__init__, kwargs)
super().__init__(**session_kwargs) # type: ignore