diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2021-07-06 14:49:07 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-07-06 16:24:26 -0500 |
commit | 6f54262b06a96b1019339e655b22c63f5882a632 (patch) | |
tree | a6c895c17a67f183c6fd93962f5db3f11aebb0f6 /requests_cache/cache_keys.py | |
parent | df4d83c606dbd3c993ae0d8705dd0d983c5b9f1d (diff) | |
download | requests-cache-6f54262b06a96b1019339e655b22c63f5882a632.tar.gz |
Improve type annotations and fix type checking errors
Diffstat (limited to 'requests_cache/cache_keys.py')
-rw-r--r-- | requests_cache/cache_keys.py | 76 |
1 files changed, 42 insertions, 34 deletions
diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py index ec0d39f..d1ad8dc 100644 --- a/requests_cache/cache_keys.py +++ b/requests_cache/cache_keys.py @@ -1,25 +1,27 @@ import hashlib import json from operator import itemgetter -from typing import Iterable, List, Mapping, Tuple, Union +from typing import Iterable, List, Mapping, Optional, Tuple, Union from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse -import requests +from requests import PreparedRequest, Request, Session +from requests.models import CaseInsensitiveDict +from requests.utils import default_headers from url_normalize import url_normalize -DEFAULT_HEADERS = requests.utils.default_headers() +DEFAULT_HEADERS = default_headers() RequestContent = Union[Mapping, str, bytes] def create_key( - request: requests.PreparedRequest, + request: PreparedRequest, ignored_params: Iterable[str] = None, include_get_headers: bool = False, **kwargs, ) -> str: """Create a normalized cache key from a request object""" key = hashlib.sha256() - key.update(encode(request.method.upper())) + key.update(encode((request.method or '').upper())) url = remove_ignored_url_params(request, ignored_params) url = url_normalize(url) key.update(encode(url)) @@ -29,15 +31,15 @@ def create_key( if body: key.update(body) if include_get_headers and request.headers != DEFAULT_HEADERS: - for name, value in normalize_dict(request.headers).items(): + for name, value in normalize_dict(request.headers).items(): # type: ignore key.update(encode(f'{name}={value}')) return key.hexdigest() def remove_ignored_params( - request: requests.PreparedRequest, ignored_params: Iterable[str] -) -> requests.PreparedRequest: + request: PreparedRequest, ignored_params: Optional[Iterable[str]] +) -> PreparedRequest: if not ignored_params: return request request.headers = remove_ignored_headers(request, ignored_params) @@ -46,51 +48,55 @@ def remove_ignored_params( return request -def remove_ignored_headers(request: requests.PreparedRequest, ignored_params: Iterable[str]) -> dict: +def remove_ignored_headers( + request: PreparedRequest, ignored_params: Optional[Iterable[str]] +) -> CaseInsensitiveDict: if not ignored_params: return request.headers - headers = request.headers.copy() - headers.update({k: None for k in ignored_params}) + headers = CaseInsensitiveDict(request.headers.copy()) + for k in ignored_params: + headers.pop(k, None) return headers -def remove_ignored_url_params(request: requests.PreparedRequest, ignored_params: Iterable[str]) -> str: - url = str(request.url) +def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str: + url_str = str(request.url) if not ignored_params: - return url + return url_str - url = urlparse(url) - query = parse_qsl(url.query) - query = filter_params(query, ignored_params) - query = urlencode(query) - url = urlunparse((url.scheme, url.netloc, url.path, url.params, query, url.fragment)) - return url + url = urlparse(url_str) + query = filter_params(parse_qsl(url.query), ignored_params) + return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment)) def remove_ignored_body_params( - request: requests.PreparedRequest, ignored_params: Iterable[str] + request: PreparedRequest, ignored_params: Optional[Iterable[str]] ) -> bytes: - body = request.body + original_body = request.body content_type = request.headers.get('content-type') - if not ignored_params or not body or not content_type: - return encode(request.body) + if not ignored_params or not original_body or not content_type: + return encode(original_body) if content_type == 'application/x-www-form-urlencoded': - body = parse_qsl(body) - body = filter_params(body, ignored_params) - body = urlencode(body) + body = filter_params(parse_qsl(decode(original_body)), ignored_params) + filtered_body = urlencode(body) elif content_type == 'application/json': - body = json.loads(decode(body)) - body = filter_params(sorted(body.items()), ignored_params) - body = json.dumps(body) - return encode(body) + body = json.loads(decode(original_body)).items() + body = filter_params(sorted(body), ignored_params) + filtered_body = json.dumps(body) + else: + filtered_body = original_body # type: ignore + return encode(filtered_body) -def filter_params(data: List[Tuple], ignored_params: Iterable[str]) -> List[Tuple]: + +def filter_params(data: List[Tuple[str, str]], ignored_params: Iterable[str]) -> List[Tuple[str, str]]: return [(k, v) for k, v in data if k not in set(ignored_params)] -def normalize_dict(items: RequestContent = None, normalize_data: bool = True) -> RequestContent: +def normalize_dict( + items: Optional[RequestContent], normalize_data: bool = True +) -> Optional[RequestContent]: """Sort items in a dict Args: @@ -101,6 +107,8 @@ def normalize_dict(items: RequestContent = None, normalize_data: bool = True) -> def sort_dict(d): return dict(sorted(d.items(), key=itemgetter(0))) + if not items: + return None if isinstance(items, Mapping): return sort_dict(items) if normalize_data and isinstance(items, (bytes, str)): @@ -116,7 +124,7 @@ def normalize_dict(items: RequestContent = None, normalize_data: bool = True) -> def url_to_key(url: str, *args, **kwargs) -> str: - request = requests.Session().prepare_request(requests.Request('GET', url)) + request = Session().prepare_request(Request('GET', url)) return create_key(request, *args, **kwargs) |