summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-11-11 12:22:43 -0600
committerJordan Cook <jordan.cook@pioneer.com>2021-11-13 12:46:02 -0600
commit8a6f8691bf3cadd98bf60365880ba04c7a708765 (patch)
tree2b9195f35f48617c6268643d2c530ce975c3ac77
parentda847a2fc0dbf40fbb8596a74450f84b3c2db114 (diff)
downloadrequests-cache-8a6f8691bf3cadd98bf60365880ba04c7a708765.tar.gz
Add support for BaseCache keyword arguments passed along with a backend instance
-rw-r--r--HISTORY.md1
-rw-r--r--requests_cache/backends/__init__.py18
-rw-r--r--requests_cache/backends/base.py20
-rw-r--r--requests_cache/session.py1
-rw-r--r--tests/unit/test_session.py19
5 files changed, 49 insertions, 10 deletions
diff --git a/HISTORY.md b/HISTORY.md
index b00c478..afed5f5 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -22,6 +22,7 @@
**Bugfixes:**
* Handle some additional corner cases when normalizing request data
+* Add support for `BaseCache` keyword arguments passed along with a backend instance
* Fix issue with cache headers not being used correctly if `cache_control=True` is used with an `expire_after` value
* Fix license metadata as shown on PyPI
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py
index bc921f7..0c525fe 100644
--- a/requests_cache/backends/__init__.py
+++ b/requests_cache/backends/__init__.py
@@ -3,7 +3,7 @@
from logging import getLogger
from typing import Callable, Dict, Iterable, Optional, Type, Union
-from .._utils import get_placeholder_class
+from .._utils import get_placeholder_class, get_valid_kwargs
from .base import BaseCache, BaseStorage, DictStorage
# Backend-specific keyword arguments equivalent to 'cache_name'
@@ -83,14 +83,26 @@ def init_backend(cache_name: str, backend: Optional[BackendSpecifier], **kwargs)
# Determine backend class
if isinstance(backend, BaseCache):
- return backend
+ return _set_backend_kwargs(cache_name, backend, **kwargs)
elif isinstance(backend, type):
return backend(cache_name, **kwargs)
elif not backend:
- backend = 'sqlite' if BACKEND_CLASSES['sqlite'] else 'memory'
+ sqlite_supported = issubclass(BACKEND_CLASSES['sqlite'], BaseCache)
+ backend = 'sqlite' if sqlite_supported else 'memory'
backend = str(backend).lower()
if backend not in BACKEND_CLASSES:
raise ValueError(f'Invalid backend: {backend}. Choose from: {BACKEND_CLASSES.keys()}')
return BACKEND_CLASSES[backend](cache_name, **kwargs)
+
+
+def _set_backend_kwargs(cache_name, backend, **kwargs):
+ """Set any backend arguments if they are passed along with a backend instance"""
+ backend_kwargs = get_valid_kwargs(BaseCache.__init__, kwargs)
+ backend_kwargs.setdefault('match_headers', kwargs.pop('include_get_headers', False))
+ for k, v in backend_kwargs.items():
+ setattr(backend, k, v)
+ if cache_name:
+ backend.cache_name = cache_name
+ return backend
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 7d46a85..ddd47e7 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -44,7 +44,7 @@ class BaseCache:
def __init__(
self,
- *args,
+ cache_name: str = 'http_cache',
match_headers: Union[Iterable[str], bool] = False,
ignored_parameters: Iterable[str] = None,
key_fn: KEY_FN = None,
@@ -52,9 +52,9 @@ class BaseCache:
):
self.responses: BaseStorage = DictStorage()
self.redirects: BaseStorage = DictStorage()
+ self.cache_name = cache_name
self.ignored_parameters = ignored_parameters
self.key_fn = key_fn or create_key
- self.name: str = kwargs.get('cache_name', '')
self.match_headers = match_headers or kwargs.pop('include_get_headers', False)
@property
@@ -235,7 +235,7 @@ class BaseCache:
return f'Total rows: {len(self.responses)} responses, {len(self.redirects)} redirects'
def __repr__(self):
- return f'<{self.__class__.__name__}(name={self.name})>'
+ return f'<{self.__class__.__name__}(name={self.cache_name})>'
class BaseStorage(MutableMapping, ABC):
@@ -261,9 +261,17 @@ class BaseStorage(MutableMapping, ABC):
serializer=None,
**kwargs,
):
- self.serializer = init_serializer(serializer, **kwargs)
+ self._serializer = init_serializer(serializer, **kwargs)
logger.debug(f'Initializing {type(self).__name__} with serializer: {self.serializer}')
+ @property
+ def serializer(self):
+ return self._serializer
+
+ @serializer.setter
+ def serializer(self, value):
+ self._serializer = init_serializer(value)
+
def bulk_delete(self, keys: Iterable[str]):
"""Delete multiple keys from the cache, without raising errors for missing keys. This is a
naive implementation that subclasses should override with a more efficient backend-specific
@@ -289,6 +297,10 @@ class DictStorage(UserDict, BaseStorage):
"""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._serializer = None
+
def __getitem__(self, key):
"""An additional step is needed here for response data. Since the original response object
is still in memory, its content has already been read and needs to be reset.
diff --git a/requests_cache/session.py b/requests_cache/session.py
index 3b26b4a..9b2531d 100644
--- a/requests_cache/session.py
+++ b/requests_cache/session.py
@@ -66,7 +66,6 @@ class CacheMixin(MIXIN_BASE):
self.filter_fn = filter_fn or (lambda r: True)
self.stale_if_error = stale_if_error or kwargs.pop('old_data_on_error', False)
- self.cache.name = cache_name # Set to handle backend=<instance>
self._disabled = False
self._lock = RLock()
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index d3509de..d225e57 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -52,9 +52,24 @@ def test_init_backend_instance():
assert session.cache is backend
+def test_init_backend_instance__kwargs():
+ backend = MyCache()
+ session = CachedSession(
+ 'test_cache',
+ backend=backend,
+ ignored_parameters=['foo'],
+ include_get_headers=True,
+ )
+
+ assert session.cache.cache_name == 'test_cache'
+ assert session.cache.ignored_parameters == ['foo']
+ assert session.cache.match_headers is True
+
+
def test_init_backend_class():
- session = CachedSession(backend=MyCache)
+ session = CachedSession('test_cache', backend=MyCache)
assert isinstance(session.cache, MyCache)
+ assert session.cache.cache_name == 'test_cache'
@pytest.mark.parametrize('method', ALL_METHODS)
@@ -138,7 +153,7 @@ def test_repr(mock_session):
mock_session.cache.redirects['key'] = 'value'
mock_session.cache.redirects['key_2'] = 'value'
- assert mock_session.cache.name in repr(mock_session) and '10.5' in repr(mock_session)
+ assert mock_session.cache.cache_name in repr(mock_session) and '10.5' in repr(mock_session)
assert '2 redirects' in str(mock_session.cache) and '1 responses' in str(mock_session.cache)