diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2022-04-17 18:46:29 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2022-04-18 13:40:23 -0500 |
commit | 642f872204809f5159f4e6078cd98e981beee221 (patch) | |
tree | e939bcc28c5c9c5c145b7de0067fca29785cbdb5 | |
parent | 7ebf9df7ae2534bad66dc4f102993f5fb6d789b2 (diff) | |
download | requests-cache-642f872204809f5159f4e6078cd98e981beee221.tar.gz |
Refactor utilities for parsing cache headers into CacheDirectives class
-rw-r--r-- | docs/reference.md | 4 | ||||
-rw-r--r-- | requests_cache/backends/base.py | 5 | ||||
-rw-r--r-- | requests_cache/cache_keys.py | 2 | ||||
-rwxr-xr-x | requests_cache/models/response.py | 3 | ||||
-rw-r--r-- | requests_cache/policy/__init__.py | 18 | ||||
-rw-r--r-- | requests_cache/policy/actions.py | 150 | ||||
-rw-r--r-- | requests_cache/policy/directives.py | 79 | ||||
-rw-r--r-- | requests_cache/policy/expiration.py | 24 | ||||
-rw-r--r-- | requests_cache/policy/settings.py | 9 | ||||
-rw-r--r-- | requests_cache/serializers/__init__.py | 2 | ||||
-rw-r--r-- | requests_cache/session.py | 9 |
11 files changed, 157 insertions, 148 deletions
diff --git a/docs/reference.md b/docs/reference.md index bc74488..97ef991 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -25,7 +25,6 @@ modules/requests_cache.session modules/requests_cache.patcher modules/requests_cache.backends modules/requests_cache.models -modules/requests_cache.settings ``` ## Secondary Modules @@ -33,7 +32,6 @@ The following modules are mainly for internal use, and are relevant for contribu ```{toctree} :maxdepth: 2 modules/requests_cache.cache_keys -modules/requests_cache.cache_control -modules/requests_cache.expiration +modules/requests_cache.policy modules/requests_cache.serializers ``` diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 46193b6..250b5e1 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -12,7 +12,7 @@ from collections.abc import MutableMapping from datetime import datetime from logging import getLogger from pickle import PickleError -from typing import Iterable, Iterator, Optional, Tuple, Union +from typing import Iterable, Iterator, Optional, Tuple from requests import PreparedRequest, Response @@ -25,7 +25,6 @@ from ..serializers import init_serializer # Specific exceptions that may be raised during deserialization DESERIALIZE_ERRORS = (AttributeError, ImportError, PickleError, TypeError, ValueError) -ResponseOrKey = Union[CachedResponse, str] logger = getLogger(__name__) @@ -259,7 +258,7 @@ class BaseStorage(MutableMapping, ABC): """ def __init__(self, serializer=None, **kwargs): - self.serializer = init_serializer(serializer, **kwargs) + self.serializer = init_serializer(serializer) logger.debug(f'Initializing {type(self).__name__} with serializer: {self.serializer}') def bulk_delete(self, keys: Iterable[str]): diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py index 3d6f63a..71606fe 100644 --- a/requests_cache/cache_keys.py +++ b/requests_cache/cache_keys.py @@ -1,4 +1,4 @@ -"""Internal utilities for generating the cache keys that are used to match requests +"""Internal utilities for generating cache keys that are used for request matching .. automodsumm:: requests_cache.cache_keys :functions-only: diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py index bf89c8c..4e2643b 100755 --- a/requests_cache/models/response.py +++ b/requests_cache/models/response.py @@ -2,7 +2,7 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone from logging import getLogger -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional import attr from attr import define, field @@ -18,7 +18,6 @@ if TYPE_CHECKING: from ..policy.actions import CacheActions DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S %Z' # Format used for __str__ only -HeaderList = List[Tuple[str, str]] logger = getLogger(__name__) diff --git a/requests_cache/policy/__init__.py b/requests_cache/policy/__init__.py index 9d4f7d6..dbd5bab 100644 --- a/requests_cache/policy/__init__.py +++ b/requests_cache/policy/__init__.py @@ -1,5 +1,21 @@ +"""Modules that implement cache policy, based on a combination of standard HTTP headers and +additional settings and features specific to requests-cache. +""" # flake8: noqa: E402,F401 # isort: skip_file +from datetime import datetime, timedelta +from typing import Callable, Dict, Union, MutableMapping + +from requests import Response + +ExpirationTime = Union[None, int, float, str, datetime, timedelta] +ExpirationPatterns = Dict[str, ExpirationTime] +FilterCallback = Callable[[Response], bool] +KeyCallback = Callable[..., str] +HeaderDict = MutableMapping[str, str] + + from .expiration import * from .settings import * -from .actions import * +from .directives import CacheDirectives, set_request_headers +from .actions import CacheActions diff --git a/requests_cache/policy/actions.py b/requests_cache/policy/actions.py index ba333a6..3411d2f 100644 --- a/requests_cache/policy/actions.py +++ b/requests_cache/policy/actions.py @@ -1,38 +1,25 @@ -"""Internal utilities for determining cache expiration and other cache actions. - -.. automodsumm:: requests_cache.cache_control - :classes-only: - :nosignatures: - -.. automodsumm:: requests_cache.cache_control - :functions-only: - :nosignatures: -""" from datetime import datetime from logging import getLogger -from typing import Dict, MutableMapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional from attr import define, field from requests import PreparedRequest, Response -from requests.models import CaseInsensitiveDict -from .._utils import coalesce, try_int -from ..models import CachedResponse -from .expiration import ( +from .._utils import coalesce +from . import ( DO_NOT_CACHE, EXPIRE_IMMEDIATELY, NEVER_EXPIRE, + CacheDirectives, ExpirationTime, get_expiration_datetime, - get_expiration_seconds, get_url_expiration, ) from .settings import CacheSettings -__all__ = ['CacheActions'] +if TYPE_CHECKING: + from ..models import CachedResponse -CacheDirective = Union[None, bool, int, str] -HeaderDict = MutableMapping[str, str] logger = getLogger(__name__) @@ -40,8 +27,7 @@ logger = getLogger(__name__) @define class CacheActions: """Translates cache settings and headers into specific actions to take for a given cache item. - This class defines the caching policy, and resulting actions are handled in - :py:meth:`CachedSession.send`. + The resulting actions are then handled in :py:meth:`CachedSession.send`. .. rubric:: Notes @@ -71,7 +57,6 @@ class CacheActions: # Inputs/internal attributes _settings: CacheSettings = field(default=None, repr=False, init=False) _validation_headers: Dict[str, str] = field(factory=dict, repr=False, init=False) - # TODO: It would be nice to not need these temp variables _only_if_cached: bool = field(default=False) _refresh: bool = field(default=False) @@ -84,19 +69,19 @@ class CacheActions: `max-age=0` would be used by a client to request a refresh. However, this would conflict with the `expire_after` option provided in :py:meth:`.CachedSession.request`. """ - directives = get_cache_directives(request.headers) + settings = settings or CacheSettings() + directives = CacheDirectives.from_headers(request.headers) logger.debug(f'Cache directives from request headers: {directives}') # Merge relevant headers with session + request settings - settings = settings or CacheSettings() - only_if_cached = settings.only_if_cached or 'only-if-cached' in directives - expire_immediately = directives.get('max-age') == EXPIRE_IMMEDIATELY - refresh = expire_immediately or 'must-revalidate' in directives - force_refresh = 'no-cache' in directives + expire_immediately = directives.max_age == EXPIRE_IMMEDIATELY + only_if_cached = settings.only_if_cached or directives.only_if_cached + refresh = expire_immediately or directives.must_revalidate + force_refresh = directives.no_cache # Check expiration values in order of precedence expire_after = coalesce( - directives.get('max-age'), + directives.max_age, get_url_expiration(request.url, settings.urls_expire_after), settings.expire_after, ) @@ -105,7 +90,7 @@ class CacheActions: read_criteria = { 'disabled cache': settings.disabled, 'disabled method': str(request.method) not in settings.allowable_methods, - 'disabled by headers': 'no-store' in directives, + 'disabled by headers': directives.no_store, 'disabled by refresh': force_refresh, 'disabled by expiration': expire_after == DO_NOT_CACHE, } @@ -117,7 +102,7 @@ class CacheActions: only_if_cached=only_if_cached, refresh=refresh, skip_read=any(read_criteria.values()), - skip_write='no-store' in directives, + skip_write=directives.no_store, ) actions._settings = settings return actions @@ -127,7 +112,7 @@ class CacheActions: """Convert the user/header-provided expiration value to a datetime""" return get_expiration_datetime(self.expire_after) - def update_from_cached_response(self, cached_response: CachedResponse): + def update_from_cached_response(self, cached_response: 'CachedResponse'): """Check for relevant cache headers from a cached response, and set headers for a conditional request, if possible. @@ -147,25 +132,25 @@ class CacheActions: self._update_validation_headers(cached_response) logger.debug(f'Post-read cache actions: {self}') - def _update_validation_headers(self, response: CachedResponse): + def _update_validation_headers(self, response: 'CachedResponse'): """If needed, get validation headers based on a cached response. Revalidation may be triggered by a stale response, request headers, or cached response headers. """ - directives = get_cache_directives(response.headers) - revalidate = _has_validator(response.headers) and ( + directives = CacheDirectives.from_headers(response.headers) + revalidate = directives.has_validator and ( response.is_expired or self._refresh - or 'no-cache' in directives - or 'must-revalidate' in directives - and directives.get('max-age') == 0 + or directives.no_cache + or directives.must_revalidate + and directives.max_age == 0 ) # Add the appropriate validation headers, if needed if revalidate: - if response.headers.get('ETag'): - self._validation_headers['If-None-Match'] = response.headers['ETag'] - if response.headers.get('Last-Modified'): - self._validation_headers['If-Modified-Since'] = response.headers['Last-Modified'] + if directives.etag: + self._validation_headers['If-None-Match'] = directives.etag + if directives.last_modified: + self._validation_headers['If-Modified-Since'] = directives.last_modified self.send_request = True self.resend_request = False @@ -174,13 +159,13 @@ class CacheActions: Used after receiving a new response, but before saving it to the cache. """ + directives = CacheDirectives.from_headers(response.headers) if self._settings.cache_control: - self._update_from_response_headers(response) + self._update_from_response_headers(directives) # If "expired" but there's a validator, save it to the cache and revalidate on use do_not_cache = self.expire_after == DO_NOT_CACHE - expire_immediately = self.expire_after == EXPIRE_IMMEDIATELY - has_validator = _has_validator(response.headers) + skip_stale = self.expire_after == EXPIRE_IMMEDIATELY and not directives.has_validator # Apply filter callback, if any callback = self._settings.filter_fn @@ -193,25 +178,24 @@ class CacheActions: 'disabled status': response.status_code not in self._settings.allowable_codes, 'disabled by filter': filtered_out, 'disabled by headers': self.skip_write, - 'disabled by expiration': do_not_cache or (expire_immediately and not has_validator), + 'disabled by expiration': do_not_cache or skip_stale, } self.skip_write = any(write_criteria.values()) _log_cache_criteria('write', write_criteria) - def _update_from_response_headers(self, response: Response): + def _update_from_response_headers(self, directives: CacheDirectives): """Check response headers for expiration and other cache directives""" - directives = get_cache_directives(response.headers) logger.debug(f'Cache directives from response headers: {directives}') - if directives.get('immutable'): + if directives.immutable: self.expire_after = NEVER_EXPIRE else: self.expire_after = coalesce( - directives.get('max-age'), - directives.get('expires'), + directives.max_age, + directives.expires, self.expire_after, ) - self.skip_write = self.skip_write or 'no-store' in directives + self.skip_write = self.skip_write or directives.no_store def update_request(self, request: PreparedRequest) -> PreparedRequest: """Apply validation headers (if any) before sending a request""" @@ -219,8 +203,8 @@ class CacheActions: return request def update_revalidated_response( - self, response: Response, cached_response: CachedResponse - ) -> CachedResponse: + self, response: Response, cached_response: 'CachedResponse' + ) -> 'CachedResponse': """After revalidation, update the cached response's headers and reset its expiration""" logger.debug( f'Response for URL {response.request.url} has not been modified; ' @@ -232,64 +216,6 @@ class CacheActions: return cached_response -def append_directive(headers: HeaderDict, directive: str) -> HeaderDict: - """Append a Cache-Control directive to existing headers (if any)""" - directives = headers['Cache-Control'].split(',') if headers.get('Cache-Control') else [] - directives.append(directive) - headers['Cache-Control'] = ','.join(directives) - return headers - - -def get_cache_directives(headers: HeaderDict) -> Dict[str, CacheDirective]: - """Get all Cache-Control directives as a dict. Handles duplicate headers (with - CaseInsensitiveDict) and comma-separated lists. - Key-only directives are returned as ``{key: True}``. - """ - if not headers: - return {} - - kv_directives: Dict[str, CacheDirective] = {} - if headers.get('Cache-Control'): - cache_directives = headers['Cache-Control'].split(',') - kv_directives = dict([_split_kv_directive(value) for value in cache_directives]) - - if 'Expires' in headers: - kv_directives['expires'] = headers['Expires'] - return kv_directives - - -def _split_kv_directive(header_value: str) -> Tuple[str, CacheDirective]: - """Split a cache directive into a ``(key, int)`` pair, if possible; otherwise just - ``(key, True)``. - """ - header_value = header_value.strip() - if '=' in header_value: - k, v = header_value.split('=', 1) - return k, try_int(v) - else: - return header_value, True - - -def set_request_headers( - headers: Optional[HeaderDict], expire_after, only_if_cached, refresh, force_refresh -): - """Translate keyword arguments into equivalent request headers, to be handled in CacheActions""" - headers = CaseInsensitiveDict(headers) - if expire_after is not None: - headers = append_directive(headers, f'max-age={get_expiration_seconds(expire_after)}') - if only_if_cached: - headers = append_directive(headers, 'only-if-cached') - if refresh: - headers = append_directive(headers, 'must-revalidate') - if force_refresh: - headers = append_directive(headers, 'no-cache') - return headers - - -def _has_validator(headers: HeaderDict) -> bool: - return bool(headers.get('ETag') or headers.get('Last-Modified')) - - def _log_cache_criteria(operation: str, criteria: Dict): """Log details on any failed checks for cache read or write""" if any(criteria.values()): diff --git a/requests_cache/policy/directives.py b/requests_cache/policy/directives.py new file mode 100644 index 0000000..b3f09c0 --- /dev/null +++ b/requests_cache/policy/directives.py @@ -0,0 +1,79 @@ +from typing import Optional + +from attr import define, field +from requests.models import CaseInsensitiveDict + +from .._utils import get_valid_kwargs, try_int +from . import HeaderDict, get_expiration_seconds + + +@define +class CacheDirectives: + """Parses Cache-Control directives and other relevant cache settings from either request or + response headers + """ + + expires: str = field(default=None) + immutable: bool = field(default=False) + max_age: int = field(default=None, converter=try_int) + must_revalidate: bool = field(default=False) + no_cache: bool = field(default=False) + no_store: bool = field(default=False) + only_if_cached: bool = field(default=False) + etag: str = field(default=None) + last_modified: str = field(default=None) + + # Not yet implemented: + # max_stale: int = field(default=None, converter=try_int) + # min_fresh: int = field(default=None, converter=try_int) + # stale_if_error: int = field(default=None, converter=try_int) + # stale_while_revalidate: bool = field(default=False) + + @classmethod + def from_headers(cls, headers: HeaderDict): + """Parse cache directives and other settings from request or response headers""" + headers = CaseInsensitiveDict(headers) + directives = headers.get('Cache-Control', '').split(',') + kv_directives = dict(_split_kv_directive(value) for value in directives) + kwargs = get_valid_kwargs( + cls.__init__, {k.replace('-', '_'): v for k, v in kv_directives.items()} + ) + + kwargs['expires'] = headers.get('Expires') + kwargs['etag'] = headers.get('ETag') + kwargs['last_modified'] = headers.get('Last-Modified') + return cls(**kwargs) + + # def to_dict(self) -> CaseInsensitiveDict: + # return {k.title().replace('_', '-'): v for k, v in asdict(self).items() if v is not None} + + @property + def has_validator(self) -> bool: + return bool(self.etag or self.last_modified) + + +def _split_kv_directive(directive: str): + """Split a cache directive into a `(key, value)` pair, or `(key, True)` if value-only""" + directive = directive.strip().lower() + return directive.split('=', 1) if '=' in directive else (directive, True) + + +def set_request_headers( + headers: Optional[HeaderDict], expire_after, only_if_cached, refresh, force_refresh +): + """Translate keyword arguments into equivalent request headers""" + headers = CaseInsensitiveDict(headers) + directives = headers['Cache-Control'].split(',') if headers.get('Cache-Control') else [] + + if expire_after is not None: + directives.append(f'max-age={get_expiration_seconds(expire_after)}') + if only_if_cached: + directives.append('only-if-cached') + if refresh: + directives.append('must-revalidate') + if force_refresh: + directives.append('no-cache') + + if directives: + headers['Cache-Control'] = ','.join(directives) + return headers diff --git a/requests_cache/policy/expiration.py b/requests_cache/policy/expiration.py index 7219718..ee030b0 100644 --- a/requests_cache/policy/expiration.py +++ b/requests_cache/policy/expiration.py @@ -1,23 +1,19 @@ -"""Utility functions used for converting expiration values""" +"""Utility functions for parsing and converting expiration values""" from datetime import datetime, timedelta, timezone from email.utils import parsedate_to_datetime from fnmatch import fnmatch from logging import getLogger from math import ceil -from typing import Dict, Optional, Union +from typing import Optional from .._utils import try_int - -__all__ = ['DO_NOT_CACHE', 'EXPIRE_IMMEDIATELY', 'NEVER_EXPIRE', 'get_expiration_datetime'] +from . import ExpirationPatterns, ExpirationTime # Special expiration values that may be set by either headers or keyword args DO_NOT_CACHE = 0x0D0E0200020704 # Per RFC 4824 EXPIRE_IMMEDIATELY = 0 NEVER_EXPIRE = -1 -ExpirationTime = Union[None, int, float, str, datetime, timedelta] -ExpirationPatterns = Dict[str, ExpirationTime] - logger = getLogger(__name__) @@ -31,9 +27,9 @@ def get_expiration_datetime(expire_after: ExpirationTime) -> Optional[datetime]: return datetime.utcnow() # Already a datetime or datetime str if isinstance(expire_after, str): - return parse_http_date(expire_after) + return _parse_http_date(expire_after) elif isinstance(expire_after, datetime): - return to_utc(expire_after) + return _to_utc(expire_after) # Otherwise, it must be a timedelta or time in seconds if not isinstance(expire_after, timedelta): @@ -57,23 +53,23 @@ def get_url_expiration( return None for pattern, expire_after in (urls_expire_after or {}).items(): - if url_match(url, pattern): + if _url_match(url, pattern): logger.debug(f'URL {url} matched pattern "{pattern}": {expire_after}') return expire_after return None -def parse_http_date(value: str) -> Optional[datetime]: +def _parse_http_date(value: str) -> Optional[datetime]: """Attempt to parse an HTTP (RFC 5322-compatible) timestamp""" try: expire_after = parsedate_to_datetime(value) - return to_utc(expire_after) + return _to_utc(expire_after) except (TypeError, ValueError): logger.debug(f'Failed to parse timestamp: {value}') return None -def to_utc(dt: datetime): +def _to_utc(dt: datetime): """All internal datetimes are UTC and timezone-naive. Convert any user/header-provided datetimes to the same format. """ @@ -83,7 +79,7 @@ def to_utc(dt: datetime): return dt -def url_match(url: str, pattern: str) -> bool: +def _url_match(url: str, pattern: str) -> bool: """Determine if a URL matches a pattern Args: diff --git a/requests_cache/policy/settings.py b/requests_cache/policy/settings.py index e23fd50..d0eadab 100644 --- a/requests_cache/policy/settings.py +++ b/requests_cache/policy/settings.py @@ -1,10 +1,9 @@ -from typing import Callable, Dict, Iterable, Union +from typing import Dict, Iterable, Union from attr import define, field -from requests import Response from .._utils import get_valid_kwargs -from .expiration import ExpirationTime +from . import ExpirationTime, FilterCallback, KeyCallback ALL_METHODS = ('GET', 'HEAD', 'OPTIONS', 'POST', 'PUT', 'PATCH', 'DELETE') DEFAULT_CACHE_NAME = 'http_cache' @@ -14,10 +13,6 @@ DEFAULT_STATUS_CODES = (200,) # Default params and/or headers that are excluded from cache keys and redacted from cached responses DEFAULT_IGNORED_PARAMS = ('Authorization', 'X-API-KEY', 'access_token', 'api_key') -# Signatures for user-provided callbacks -FilterCallback = Callable[[Response], bool] -KeyCallback = Callable[..., str] - @define class CacheSettings: diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py index dec86ef..72420b7 100644 --- a/requests_cache/serializers/__init__.py +++ b/requests_cache/serializers/__init__.py @@ -38,7 +38,7 @@ SERIALIZERS = { } -def init_serializer(serializer=None, **kwargs): +def init_serializer(serializer=None): """Initialize a serializer from a name or instance""" serializer = serializer or 'pickle' if isinstance(serializer, str): diff --git a/requests_cache/session.py b/requests_cache/session.py index 8260c64..68d0c68 100644 --- a/requests_cache/session.py +++ b/requests_cache/session.py @@ -2,7 +2,7 @@ from contextlib import contextmanager, nullcontext from logging import getLogger from threading import RLock -from typing import TYPE_CHECKING, Dict, Iterable, MutableMapping, Optional, Union +from typing import TYPE_CHECKING, Iterable, MutableMapping, Optional, Union from requests import PreparedRequest from requests import Session as OriginalSession @@ -19,6 +19,7 @@ from .policy import ( DEFAULT_STATUS_CODES, CacheActions, CacheSettings, + ExpirationPatterns, ExpirationTime, FilterCallback, KeyCallback, @@ -27,13 +28,13 @@ from .policy import ( from .serializers import SerializerPipeline __all__ = ['CachedSession', 'CacheMixin'] - -logger = getLogger(__name__) if TYPE_CHECKING: MIXIN_BASE = OriginalSession else: MIXIN_BASE = object +logger = getLogger(__name__) + class CacheMixin(MIXIN_BASE): """Mixin class that extends :py:class:`requests.Session` with caching features. @@ -46,7 +47,7 @@ class CacheMixin(MIXIN_BASE): backend: BackendSpecifier = None, serializer: Union[str, SerializerPipeline] = None, expire_after: ExpirationTime = -1, - urls_expire_after: Dict[str, ExpirationTime] = None, + urls_expire_after: ExpirationPatterns = None, cache_control: bool = False, allowable_codes: Iterable[int] = DEFAULT_STATUS_CODES, allowable_methods: Iterable[str] = DEFAULT_METHODS, |