diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2022-06-11 11:12:43 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2022-06-11 19:33:54 -0500 |
commit | d04cc094efb44586dd996e6a0554dec99f7d40c6 (patch) | |
tree | ee526f73a5096a4ec1c648c3207e4155efe113db /requests_cache | |
parent | 18d9c24564b40744847fbe49d58ae980d612e352 (diff) | |
download | requests-cache-d04cc094efb44586dd996e6a0554dec99f7d40c6.tar.gz |
Consolidate BaseCache convenience methods into contains(), filter(), and delete()
Diffstat (limited to 'requests_cache')
-rw-r--r-- | requests_cache/backends/base.py | 320 | ||||
-rw-r--r-- | requests_cache/backends/filesystem.py | 4 | ||||
-rw-r--r-- | requests_cache/backends/gridfs.py | 4 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 38 | ||||
-rw-r--r-- | requests_cache/cache_keys.py | 13 | ||||
-rwxr-xr-x | requests_cache/models/response.py | 2 | ||||
-rw-r--r-- | requests_cache/patcher.py | 2 | ||||
-rw-r--r-- | requests_cache/serializers/pipeline.py | 13 | ||||
-rw-r--r-- | requests_cache/session.py | 7 |
9 files changed, 214 insertions, 189 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index e586138..e307cb4 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -8,21 +8,18 @@ from __future__ import annotations from abc import ABC from collections import UserDict -from collections.abc import MutableMapping from datetime import datetime from logging import getLogger from pickle import PickleError -from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, Iterator, List, MutableMapping, Optional, TypeVar +from warnings import warn -from requests import PreparedRequest, Response - -from requests_cache.serializers.cattrs import CattrStage -from requests_cache.serializers.pipeline import SerializerPipeline +from requests import Request, Response from ..cache_keys import create_key, redact_response -from ..models import CachedResponse +from ..models import AnyRequest, CachedResponse from ..policy import DEFAULT_CACHE_NAME, CacheSettings, ExpirationTime -from ..serializers import SERIALIZERS, SerializerType, pickle_serializer +from ..serializers import SERIALIZERS, SerializerPipeline, SerializerType, pickle_serializer # Specific exceptions that may be raised during deserialization DESERIALIZE_ERRORS = (AttributeError, ImportError, PickleError, TypeError, ValueError) @@ -38,10 +35,12 @@ class BaseCache: * Saving and retrieving responses * Managing redirect history * Convenience methods for general cache info + * Dict-like wrapper methods around the underlying storage - Lower-level storage operations are handled by :py:class:`.BaseStorage`. + Notes: - To extend this with your own custom backend, see :ref:`custom-backends`. + * Lower-level storage operations are handled by :py:class:`.BaseStorage`. + * To extend this with your own custom backend, see :ref:`custom-backends`. Args: cache_name: Cache prefix or namespace, depending on backend @@ -51,15 +50,12 @@ class BaseCache: def __init__(self, cache_name: str = DEFAULT_CACHE_NAME, **kwargs): self.cache_name = cache_name - self.responses: BaseStorage = DictStorage() - self.redirects: BaseStorage = DictStorage() + self.responses: BaseStorage[str, CachedResponse] = DictStorage() + self.redirects: BaseStorage[str, str] = DictStorage() self._settings = CacheSettings() # Init and public access is done in CachedSession - @property - def urls(self) -> Iterator[str]: - """Get all URLs currently in the cache (excluding redirects)""" - for response in self.values(): - yield response.url + # Main cache operations + # --------------------- def get_response(self, key: str, default=None) -> Optional[CachedResponse]: """Retrieve a response from the cache, if it exists @@ -77,7 +73,7 @@ class BaseCache: except KeyError: return default except DESERIALIZE_ERRORS as e: - logger.error(f'Unable to deserialize response with key {key}: {str(e)}') + logger.error(f'Unable to deserialize response {key}: {str(e)}') logger.debug(e, exc_info=True) return default @@ -108,7 +104,7 @@ class BaseCache: self.responses.close() self.redirects.close() - def create_key(self, request: PreparedRequest = None, **kwargs) -> str: + def create_key(self, request: AnyRequest = None, **kwargs) -> str: """Create a normalized cache key from a request object""" key_fn = self._settings.key_fn or create_key return key_fn( @@ -119,122 +115,92 @@ class BaseCache: **kwargs, ) - def delete(self, key: str): - """Delete a response or redirect from the cache, as well any associated redirect history""" - # If it's a response key, first delete any associated redirect history - try: - for r in self.responses[key].history: - del self.redirects[create_key(r.request, self._settings.ignored_parameters)] - except (KeyError, *DESERIALIZE_ERRORS): - pass - # Then delete the response itself, or just the redirect if it's a redirect key - for cache in [self.responses, self.redirects]: - try: - del cache[key] - except KeyError: - pass - - def delete_url(self, url: str, method: str = 'GET', **kwargs): - """Delete a cached response for the specified request""" - key = self.create_key(method=method, url=url, **kwargs) - self.delete(key) - - def delete_urls(self, urls: Iterable[str], method: str = 'GET', **kwargs): - """Delete all cached responses for the specified requests""" - keys = [self.create_key(method=method, url=url, **kwargs) for url in urls] - self.bulk_delete(keys) - - def bulk_delete(self, keys: Iterable[str]): - """Remove multiple responses and their associated redirects from the cache""" - self.responses.bulk_delete(keys) - self.remove_invalid_redirects() - - def has_key(self, key: str) -> bool: - """Returns ``True`` if ``key`` is in the cache""" - return key in self.responses or key in self.redirects - - def has_url(self, url: str, method: str = 'GET', **kwargs) -> bool: - """Returns ``True`` if the specified request is cached""" - key = self.create_key(method=method, url=url, **kwargs) - return self.has_key(key) # noqa: W601 + # Convenience methods + # -------------------- - def keys(self, include_expired: bool = True) -> Iterator[str]: - """Get all cache keys for redirects and responses combined""" - yield from self.redirects.keys() - for key, _ in self.items(include_expired=include_expired): - yield key - - def values(self, include_expired: bool = True) -> Iterator[CachedResponse]: - """Get all response objects from the cache""" - for _, response in self.items(include_expired=include_expired): - if TYPE_CHECKING: - assert response is not None - yield response - - def items( - self, include_expired: bool = True, include_invalid: bool = False - ) -> Iterator[Tuple[str, Optional[CachedResponse]]]: - """Get all keys and responses from the cache, and optionally skip any expired or invalid - ones + def contains( + self, + key: str = None, + request: AnyRequest = None, + ): + """Check if the specified request is cached Args: - include_expired: Include expired responses in the results - include_invalid: Include invalid responses in the results + key: Check for a specific cache key + request: Check for a matching request, according to current request matching settings """ - for key in self.responses.keys(): - try: - response = self.responses[key] - response.cache_key = key - if include_expired or not response.is_expired: - yield key, response - except DESERIALIZE_ERRORS: - if include_invalid: - yield key, None - - def remove( - self, expired: bool = False, invalid: bool = True, older_than: ExpirationTime = None + if request and not key: + key = self.create_key(request) + return key in self.responses or key in self.redirects + + def delete( + self, + *keys: str, + expired: bool = False, + invalid: bool = False, + older_than: ExpirationTime = None, + requests: Iterable[AnyRequest] = None, ): - """Remove responses from the cache according to the specified condition(s). + """Remove responses from the cache according one or more conditions. Args: + keys: Remove responses with these cache keys expired: Remove all expired responses - invalid: Remove all invalid responses (ones that can't be deserialized with current - settings) - older_than: Remove all cache items older than this value, **relative to the cache - creation time** + invalid: Remove all invalid responses (that can't be deserialized with current settings) + older_than: Remove responses older than this value, relative to ``response.created_at`` + requests: Remove matching responses, according to current request matching settings """ - if expired: - logger.info('Removing expired responses') - if older_than: - logger.info(f'Removing responses older than {older_than}') - keys_to_delete = [] - - for key, response in self.items(include_invalid=invalid): - if ( - response is None # If the response was invalid - or (expired and response.is_expired) - or (older_than is not None and response.is_older_than(older_than)) - ): - keys_to_delete.append(key) + delete_keys: List[str] = list(keys) if keys else [] + if requests: + delete_keys += [self.create_key(request) for request in requests] - # Delay deletes until the end, to use more efficient bulk_delete - logger.debug(f'Deleting {len(keys_to_delete)} expired responses') - self.bulk_delete(keys_to_delete) + for response in self.filter( + valid=False, expired=expired, invalid=invalid, older_than=older_than + ): + delete_keys.append(response.cache_key) - def remove_expired_responses(self, expire_after: ExpirationTime = None): - """Remove expired and invalid responses from the cache + logger.debug(f'Deleting {len(delete_keys)} responses') + self.responses.bulk_delete(delete_keys) + self._prune_redirects() - **Deprecated:** Please use :py:meth:`.remove` with ``expire=True`` instead. - """ - self.remove(expired=True, invalid=True) - if expire_after: - self.reset_expiration(expire_after) - - def remove_invalid_redirects(self): + def _prune_redirects(self): """Remove any redirects that no longer point to an existing response""" invalid_redirects = [k for k, v in self.redirects.items() if v not in self.responses] self.redirects.bulk_delete(invalid_redirects) + def filter( + self, + valid: bool = True, + expired: bool = True, + invalid: bool = False, + older_than: ExpirationTime = None, + ) -> Iterator[CachedResponse]: + """Get responses from the cache, with optional filters + + Args: + valid: Include valid and unexpired responses; set to ``False`` to get **only** + expired/invalid/old responses + expired: Include expired responses + invalid: Include invalid responses (as an empty ``CachedResponse``) + older_than: Get responses older than this value, relative to ``response.created_at`` + """ + if not any([valid, expired, invalid, older_than]): + return + for key in self.responses.keys(): + response = self.get_response(key) + + # Use an empty response as a placeholder for an invalid response, if specified + if invalid and response is None: + response = CachedResponse(status_code=504) + response.cache_key = key + yield response + elif response is not None and ( + (valid and not response.is_expired) + or (expired and response.is_expired) + or (older_than and response.is_older_than(older_than)) + ): + yield response + def reset_expiration(self, expire_after: ExpirationTime = None): """Set a new expiration value to set on existing cache items @@ -242,35 +208,85 @@ class BaseCache: expire_after: New expiration value, **relative to the current time** """ logger.info(f'Resetting expiration with: {expire_after}') - for key, response in self.items(): - if TYPE_CHECKING: - assert response is not None + for response in self.filter(): response.reset_expiration(expire_after) - self.responses[key] = response + self.responses[response.cache_key] = response - def response_count(self, include_expired: bool = True) -> int: - """Get the number of responses in the cache, excluding invalid responses. - Can also optionally exclude expired responses. - """ - return len(list(self.values(include_expired=include_expired))) - - def update(self, other: 'BaseCache'): + def update(self, other: 'BaseCache'): # type: ignore """Update this cache with the contents of another cache""" logger.debug(f'Copying {len(other.responses)} responses from {repr(other)} to {repr(self)}') self.responses.update(other.responses) self.redirects.update(other.redirects) + def urls(self, **kwargs) -> List[str]: + """Get all unique cached URLs. Optionally takes keyword arguments for :py:meth:`.filter`.""" + return sorted({response.url for response in self.filter(**kwargs)}) + def __str__(self): - """Show a count of total **rows** currently stored in the backend. For performance reasons, - this does not check for invalid or expired responses. - """ return f'<{self.__class__.__name__}(name={self.cache_name})>' def __repr__(self): return str(self) + # Deprecated methods + # -------------------- + + def delete_url(self, url: str, method: str = 'GET', **kwargs): + warn( + 'BaseCache.delete_url() is deprecated; please use .delete() instead', DeprecationWarning + ) + self.delete(requests=[Request(method, url, **kwargs)]) + + def delete_urls(self, urls: Iterable[str], method: str = 'GET', **kwargs): + warn( + 'BaseCache.delete_urls() is deprecated; please use .delete() instead', + DeprecationWarning, + ) + self.delete(requests=[Request(method, url, **kwargs) for url in urls]) + + def has_url(self, url: str, method: str = 'GET', **kwargs) -> bool: + warn( + 'BaseCache.has_url() is deprecated; please use .contains() instead', DeprecationWarning + ) + return self.contains(request=Request(method, url, **kwargs)) + + def keys(self, check_expiry: bool = False) -> Iterator[str]: + warn( + 'BaseCache.keys() is deprecated; please use .filter() or ' + 'BaseCache.responses.keys() instead', + DeprecationWarning, + ) + yield from self.redirects.keys() + for response in self.filter(expired=not check_expiry): + yield response.cache_key + + def response_count(self, check_expiry: bool = False) -> int: + warn( + 'BaseCache.response_count() is deprecated; please use .filter() or ' + 'len(BaseCache.responses) instead', + DeprecationWarning, + ) + return len(list(self.filter(expired=not check_expiry))) -class BaseStorage(MutableMapping, ABC): + def remove_expired_responses(self, expire_after: ExpirationTime = None): + warn( + 'BaseCache.remove_expired_responses() is deprecated; please use .delete() instead', + DeprecationWarning, + ) + self.delete(expired=True, invalid=True) + if expire_after: + self.reset_expiration(expire_after) + + def values(self, check_expiry: bool = False) -> Iterator[CachedResponse]: + warn('BaseCache.values() is deprecated; please use .filter() instead', DeprecationWarning) + yield from self.filter(expired=not check_expiry) + + +KT = TypeVar('KT') +VT = TypeVar('VT') + + +class BaseStorage(MutableMapping[KT, VT], ABC): """Base class for client-agnostic storage implementations. Notes: * This provides a common dictionary-like interface for the underlying storage operations @@ -301,20 +317,28 @@ class BaseStorage(MutableMapping, ABC): decode_content: bool = False, **kwargs, ): - # Set a default serializer, unless explicitly disabled - self.serializer = None if no_serializer else (serializer or self.default_serializer) + self.serializer: Optional[SerializerPipeline] = None + if not no_serializer: + self._init_serializer(serializer or self.default_serializer, decode_content) + def _init_serializer(self, serializer: SerializerType, decode_content: bool): # Look up a serializer by name, if needed - if isinstance(self.serializer, str): - self.serializer = SERIALIZERS[self.serializer] - if isinstance(self.serializer, (SerializerPipeline, CattrStage)): - self.serializer.decode_content = decode_content + if isinstance(serializer, str): + serializer = SERIALIZERS[serializer] + + # Wrap in a SerializerPipeline, if needed + if not isinstance(serializer, SerializerPipeline): + serializer = SerializerPipeline([serializer], name=str(serializer)) + serializer.decode_content = decode_content + + self.serializer = serializer logger.debug(f'Initialized {type(self).__name__} with serializer: {self.serializer}') - def bulk_delete(self, keys: Iterable[str]): - """Delete multiple keys from the cache, without raising errors for missing keys. This is a - naive, generic implementation that subclasses should override with a more efficient - backend-specific implementation, if possible. + def bulk_delete(self, keys: Iterable[KT]): + """Delete multiple keys from the cache, without raising errors for missing keys. + + This is a naive, generic implementation that subclasses should override with a more + efficient backend-specific implementation, if possible. """ for k in keys: try: @@ -325,12 +349,16 @@ class BaseStorage(MutableMapping, ABC): def close(self): """Close any open backend connections""" - def serialize(self, value): + def serialize(self, value: VT): """Serialize value, if a serializer is available""" + if TYPE_CHECKING: + assert hasattr(self.serializer, 'dumps') return self.serializer.dumps(value) if self.serializer else value - def deserialize(self, value): + def deserialize(self, value: VT): """Deserialize value, if a serializer is available""" + if TYPE_CHECKING: + assert hasattr(self.serializer, 'loads') return self.serializer.loads(value) if self.serializer else value def __str__(self): diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py index bbb8f2d..a349dae 100644 --- a/requests_cache/backends/filesystem.py +++ b/requests_cache/backends/filesystem.py @@ -60,9 +60,9 @@ class FileCache(BaseCache): self.responses.clear() self.redirects.init_db() - def remove(self, *args, **kwargs): + def delete(self, *args, **kwargs): with self.responses._lock: - return super().remove(*args, **kwargs) + return super().delete(*args, **kwargs) class FileDict(BaseStorage): diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index be6814d..6380578 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -38,9 +38,9 @@ class GridFSCache(BaseCache): **kwargs ) - def remove(self, *args, **kwargs): + def delete(self, *args, **kwargs): with self.responses._lock: - return super().remove(*args, **kwargs) + return super().delete(*args, **kwargs) class GridFSDict(BaseStorage): diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index ff55a4b..0fa5d99 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -18,7 +18,6 @@ from typing import Collection, Iterator, List, Tuple, Type, Union from platformdirs import user_cache_dir from .._utils import chunkify, get_valid_kwargs -from ..policy.expiration import ExpirationTime from . import BaseCache, BaseStorage MEMORY_URI = 'file::memory:?cache=shared' @@ -53,15 +52,6 @@ class SQLiteCache(BaseCache): def db_path(self) -> AnyPath: return self.responses.db_path - def bulk_delete(self, keys): - """Remove multiple responses and their associated redirects from the cache, with additional cleanup""" - self.responses.bulk_delete(keys=keys) - self.responses.vacuum() - - self.redirects.bulk_delete(keys=keys) - self.redirects.bulk_delete(values=keys) - self.redirects.vacuum() - def clear(self): """Clear the cache. If this fails due to a corrupted cache or other I/O error, this will attempt to delete the cache file and re-initialize. @@ -75,17 +65,27 @@ class SQLiteCache(BaseCache): self.responses.init_db() self.redirects.init_db() - def remove( - self, expired: bool = False, invalid: bool = False, older_than: ExpirationTime = None + def delete( + self, + *keys: str, + expired: bool = False, + **kwargs, ): - if invalid or older_than is not None: + """More efficient implementation of :py:meth:`BaseCache.delete`""" + if keys: + self.responses.bulk_delete(keys) + if expired: + self.responses.delete_expired() + if kwargs: with self.responses._lock, self.redirects._lock: - return super().remove(expired=expired, invalid=invalid, older_than=older_than) + return super().delete(**kwargs) else: - self.responses.clear_expired() - self.remove_invalid_redirects() + self._prune_redirects() + self.responses.vacuum() + self.redirects.vacuum() - def remove_invalid_redirects(self): + def _prune_redirects(self): + """More efficient implementation of :py:meth:`BaseCache.remove_invalid_redirects`""" with self.redirects.connection(commit=True) as conn: t1 = self.redirects.table_name t2 = self.responses.table_name @@ -260,8 +260,8 @@ class SQLiteDict(BaseStorage): self.init_db() self.vacuum() - def clear_expired(self): - """Remove expired items from the cache""" + def delete_expired(self): + """Delete expired items from the cache""" with self._lock, self.connection(commit=True) as con: con.execute(f"DELETE FROM {self.table_name} WHERE expires <= ?", (round(time()),)) self.vacuum() diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py index d551979..6038716 100644 --- a/requests_cache/cache_keys.py +++ b/requests_cache/cache_keys.py @@ -16,7 +16,7 @@ from requests import Request, Session from requests.models import CaseInsensitiveDict from url_normalize import url_normalize -from ._utils import decode, encode, get_valid_kwargs +from ._utils import decode, encode __all__ = ['create_key', 'normalize_request'] if TYPE_CHECKING: @@ -32,25 +32,20 @@ logger = getLogger(__name__) def create_key( - request: AnyRequest = None, + request: AnyRequest, ignored_parameters: ParamList = None, match_headers: Union[ParamList, bool] = False, serializer: Any = None, **request_kwargs, ) -> str: - """Create a normalized cache key from either a request object or :py:class:`~requests.Request` - arguments + """Create a normalized cache key based on a request object Args: request: Request object to generate a cache key from ignored_parameters: Request paramters, headers, and/or JSON body params to exclude match_headers: Match only the specified headers, or ``True`` to match all headers - request_kwargs: Request arguments to generate a cache key from + request_kwargs: Additional keyword arguments for :py:func:`~requests.request` """ - # Convert raw request arguments into a request object, if needed - if not request: - request = Request(**get_valid_kwargs(Request.__init__, request_kwargs)) - # Normalize and gather all relevant request info to match against request = normalize_request(request, ignored_parameters) key_parts = [ diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py index 4606f23..68df763 100755 --- a/requests_cache/models/response.py +++ b/requests_cache/models/response.py @@ -31,7 +31,7 @@ class BaseResponse(Response): created_at: datetime = field(factory=datetime.utcnow) expires: Optional[datetime] = field(default=None) - cache_key: Optional[str] = None # Not serialized; set by BaseCache.get_response() + cache_key: str = '' # Not serialized; set by BaseCache.get_response() revalidated: bool = False # Not serialized; set by CacheActions.update_revalidated_response() @property diff --git a/requests_cache/patcher.py b/requests_cache/patcher.py index b11cbec..30ba908 100644 --- a/requests_cache/patcher.py +++ b/requests_cache/patcher.py @@ -110,7 +110,7 @@ def remove_expired_responses(): """Remove expired and invalid responses from the cache""" session = requests.Session() if isinstance(session, CachedSession): - session.cache.remove(expired=True) + session.cache.delete(expired=True) def _patch_session_factory(session_factory: Type[OriginalSession] = CachedSession): diff --git a/requests_cache/serializers/pipeline.py b/requests_cache/serializers/pipeline.py index 25da5b0..08e1cac 100644 --- a/requests_cache/serializers/pipeline.py +++ b/requests_cache/serializers/pipeline.py @@ -60,15 +60,20 @@ class SerializerPipeline: value = step(value) return value + # TODO: I don't love this. Could BaseStorage init be refactored to not need this getter/setter? @property def decode_content(self) -> bool: - return getattr(self.stages[0], 'decode_content', False) if self.stages else False + for stage in self.stages: + if hasattr(stage, 'decode_content'): + return stage.decode_content + return False @decode_content.setter def decode_content(self, value: bool): - """Set decode_content, if the pipeline is based on CattrStage""" - if self.stages and hasattr(self.stages[0], 'decode_content'): - self.stages[0].decode_content = value + """Set decode_content, if the pipeline contains a CattrStage or compatible object""" + for stage in self.stages: + if hasattr(stage, 'decode_content'): + stage.decode_content = value def __str__(self) -> str: return f'SerializerPipeline(name={self.name}, n_stages={len(self.dump_stages)})' diff --git a/requests_cache/session.py b/requests_cache/session.py index b7a8c78..b65da1e 100644 --- a/requests_cache/session.py +++ b/requests_cache/session.py @@ -308,11 +308,8 @@ class CacheMixin(MIXIN_BASE): self.cache.close() def remove_expired_responses(self): - """Remove expired and invalid responses from the cache - - **Deprecated:** Use ``session.cache.remove(expired=True)`` instead. - """ - self.cache.remove(expired=True, invalid=True) + """**Deprecated:** Use ``session.cache.remove(expired=True)`` instead""" + self.cache.delete(expired=True, invalid=True) def __repr__(self): return f'<CachedSession(cache={repr(self.cache)}, settings={self.settings})>' |