summaryrefslogtreecommitdiff
path: root/requests_cache/cache_keys.py
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-07-06 14:49:07 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-07-06 16:24:26 -0500
commit6f54262b06a96b1019339e655b22c63f5882a632 (patch)
treea6c895c17a67f183c6fd93962f5db3f11aebb0f6 /requests_cache/cache_keys.py
parentdf4d83c606dbd3c993ae0d8705dd0d983c5b9f1d (diff)
downloadrequests-cache-6f54262b06a96b1019339e655b22c63f5882a632.tar.gz
Improve type annotations and fix type checking errors
Diffstat (limited to 'requests_cache/cache_keys.py')
-rw-r--r--requests_cache/cache_keys.py76
1 files changed, 42 insertions, 34 deletions
diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py
index ec0d39f..d1ad8dc 100644
--- a/requests_cache/cache_keys.py
+++ b/requests_cache/cache_keys.py
@@ -1,25 +1,27 @@
import hashlib
import json
from operator import itemgetter
-from typing import Iterable, List, Mapping, Tuple, Union
+from typing import Iterable, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
-import requests
+from requests import PreparedRequest, Request, Session
+from requests.models import CaseInsensitiveDict
+from requests.utils import default_headers
from url_normalize import url_normalize
-DEFAULT_HEADERS = requests.utils.default_headers()
+DEFAULT_HEADERS = default_headers()
RequestContent = Union[Mapping, str, bytes]
def create_key(
- request: requests.PreparedRequest,
+ request: PreparedRequest,
ignored_params: Iterable[str] = None,
include_get_headers: bool = False,
**kwargs,
) -> str:
"""Create a normalized cache key from a request object"""
key = hashlib.sha256()
- key.update(encode(request.method.upper()))
+ key.update(encode((request.method or '').upper()))
url = remove_ignored_url_params(request, ignored_params)
url = url_normalize(url)
key.update(encode(url))
@@ -29,15 +31,15 @@ def create_key(
if body:
key.update(body)
if include_get_headers and request.headers != DEFAULT_HEADERS:
- for name, value in normalize_dict(request.headers).items():
+ for name, value in normalize_dict(request.headers).items(): # type: ignore
key.update(encode(f'{name}={value}'))
return key.hexdigest()
def remove_ignored_params(
- request: requests.PreparedRequest, ignored_params: Iterable[str]
-) -> requests.PreparedRequest:
+ request: PreparedRequest, ignored_params: Optional[Iterable[str]]
+) -> PreparedRequest:
if not ignored_params:
return request
request.headers = remove_ignored_headers(request, ignored_params)
@@ -46,51 +48,55 @@ def remove_ignored_params(
return request
-def remove_ignored_headers(request: requests.PreparedRequest, ignored_params: Iterable[str]) -> dict:
+def remove_ignored_headers(
+ request: PreparedRequest, ignored_params: Optional[Iterable[str]]
+) -> CaseInsensitiveDict:
if not ignored_params:
return request.headers
- headers = request.headers.copy()
- headers.update({k: None for k in ignored_params})
+ headers = CaseInsensitiveDict(request.headers.copy())
+ for k in ignored_params:
+ headers.pop(k, None)
return headers
-def remove_ignored_url_params(request: requests.PreparedRequest, ignored_params: Iterable[str]) -> str:
- url = str(request.url)
+def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str:
+ url_str = str(request.url)
if not ignored_params:
- return url
+ return url_str
- url = urlparse(url)
- query = parse_qsl(url.query)
- query = filter_params(query, ignored_params)
- query = urlencode(query)
- url = urlunparse((url.scheme, url.netloc, url.path, url.params, query, url.fragment))
- return url
+ url = urlparse(url_str)
+ query = filter_params(parse_qsl(url.query), ignored_params)
+ return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment))
def remove_ignored_body_params(
- request: requests.PreparedRequest, ignored_params: Iterable[str]
+ request: PreparedRequest, ignored_params: Optional[Iterable[str]]
) -> bytes:
- body = request.body
+ original_body = request.body
content_type = request.headers.get('content-type')
- if not ignored_params or not body or not content_type:
- return encode(request.body)
+ if not ignored_params or not original_body or not content_type:
+ return encode(original_body)
if content_type == 'application/x-www-form-urlencoded':
- body = parse_qsl(body)
- body = filter_params(body, ignored_params)
- body = urlencode(body)
+ body = filter_params(parse_qsl(decode(original_body)), ignored_params)
+ filtered_body = urlencode(body)
elif content_type == 'application/json':
- body = json.loads(decode(body))
- body = filter_params(sorted(body.items()), ignored_params)
- body = json.dumps(body)
- return encode(body)
+ body = json.loads(decode(original_body)).items()
+ body = filter_params(sorted(body), ignored_params)
+ filtered_body = json.dumps(body)
+ else:
+ filtered_body = original_body # type: ignore
+ return encode(filtered_body)
-def filter_params(data: List[Tuple], ignored_params: Iterable[str]) -> List[Tuple]:
+
+def filter_params(data: List[Tuple[str, str]], ignored_params: Iterable[str]) -> List[Tuple[str, str]]:
return [(k, v) for k, v in data if k not in set(ignored_params)]
-def normalize_dict(items: RequestContent = None, normalize_data: bool = True) -> RequestContent:
+def normalize_dict(
+ items: Optional[RequestContent], normalize_data: bool = True
+) -> Optional[RequestContent]:
"""Sort items in a dict
Args:
@@ -101,6 +107,8 @@ def normalize_dict(items: RequestContent = None, normalize_data: bool = True) ->
def sort_dict(d):
return dict(sorted(d.items(), key=itemgetter(0)))
+ if not items:
+ return None
if isinstance(items, Mapping):
return sort_dict(items)
if normalize_data and isinstance(items, (bytes, str)):
@@ -116,7 +124,7 @@ def normalize_dict(items: RequestContent = None, normalize_data: bool = True) ->
def url_to_key(url: str, *args, **kwargs) -> str:
- request = requests.Session().prepare_request(requests.Request('GET', url))
+ request = Session().prepare_request(Request('GET', url))
return create_key(request, *args, **kwargs)