summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <JWCook@users.noreply.github.com>2021-04-02 10:31:45 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-04-02 10:37:55 -0500
commitddba563e19f45015b251eafe776741952451433c (patch)
treeedc2b9896d08129a5fac089654c5f9473d4056ca
parentad0aaaa66569799e127a3cd143f3a9434ad6449e (diff)
parentaa1447ee8cf8db112fce0f7f4644f4bdc84197e8 (diff)
downloadrequests-cache-ddba563e19f45015b251eafe776741952451433c.tar.gz
Merge pull request #215 from JWCook/backend-init
Improvements to backend initialization
-rw-r--r--requests_cache/backends/__init__.py124
-rw-r--r--requests_cache/backends/sqlite.py25
-rw-r--r--requests_cache/core.py20
-rwxr-xr-x[-rw-r--r--]runtests.sh0
-rw-r--r--tests/conftest.py2
-rw-r--r--tests/unit/test_cache.py26
-rw-r--r--tests/unit/test_monkey_patch.py13
7 files changed, 111 insertions, 99 deletions
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py
index f7bde00..9f9af90 100644
--- a/requests_cache/backends/__init__.py
+++ b/requests_cache/backends/__init__.py
@@ -1,20 +1,21 @@
"""Classes and functions for cache persistence"""
# flake8: noqa: F401
-from ..cache_keys import normalize_dict
-from .base import BaseCache
+from logging import getLogger
+from typing import Type, Union
-# All backend-specific keyword arguments combined
-BACKEND_KWARGS = [
+from .base import BaseCache, BaseStorage
+
+# Backend-specific keyword arguments equivalent to 'cache_name'
+CACHE_NAME_KWARGS = ['db_path', 'db_name', 'namespace', 'table_name']
+
+# All backend-specific keyword arguments
+BACKEND_KWARGS = CACHE_NAME_KWARGS + [
'connection',
- 'db_name',
- 'endpont_url',
- 'extension',
+ 'endpoint_url',
'fast_save',
'ignored_parameters',
'include_get_headers',
- 'location',
'name',
- 'namespace',
'read_capacity_units',
'region_name',
'salt',
@@ -23,71 +24,76 @@ BACKEND_KWARGS = [
'write_capacity_units',
]
-registry = {
- 'memory': BaseCache,
-}
+BackendSpecifier = Union[str, BaseCache, Type[BaseCache], None]
+logger = getLogger(__name__)
-_backend_dependencies = {
- 'sqlite': 'sqlite3',
- 'mongo': 'pymongo',
- 'redis': 'redis',
- 'dynamodb': 'dynamodb',
-}
-try:
- # Heroku doesn't allow the SQLite3 module to be installed
- from .sqlite import DbCache
+def get_placeholder_backend(original_exception: Exception = None) -> Type[BaseCache]:
+ """Create a placeholder type for a backend class that does not have dependencies installed.
+ This allows delaying ImportErrors until init time, rather than at import time.
+ """
- registry['sqlite'] = DbCache
-except ImportError:
- DbCache = None
+ class PlaceholderBackend(BaseCache):
+ def __init__(*args, **kwargs):
+ msg = 'Dependencies are not installed for this backend'
+ logger.error(msg)
+ raise original_exception or ImportError(msg)
-try:
- from .mongo import MongoCache
-
- registry['mongo'] = registry['mongodb'] = MongoCache
-except ImportError:
- MongoCache = None
+ return PlaceholderBackend
+# Import all backends for which dependencies are installed
+try:
+ from .dynamodb import DynamoDbCache
+except ImportError as e:
+ DynamoDbCache = get_placeholder_backend(e) # type: ignore
try:
from .gridfs import GridFSCache
-
- registry['gridfs'] = GridFSCache
-except ImportError:
- GridFSCache = None
-
+except ImportError as e:
+ GridFSCache = get_placeholder_backend(e) # type: ignore
+try:
+ from .mongo import MongoCache
+except ImportError as e:
+ MongoCache = get_placeholder_backend(e) # type: ignore
try:
from .redis import RedisCache
+except ImportError as e:
+ RedisCache = get_placeholder_backend(e) # type: ignore
+try:
+ # Note: Heroku doesn't support SQLite due to ephemeral storage
+ from .sqlite import DbCache
+except ImportError as e:
+ DbCache = get_placeholder_backend(e) # type: ignore
- registry['redis'] = RedisCache
-except ImportError:
- RedisCache = None
-try:
- from .dynamodb import DynamoDbCache
+BACKEND_CLASSES = {
+ 'dynamodb': DynamoDbCache,
+ 'gridfs': GridFSCache,
+ 'memory': BaseCache,
+ 'mongo': MongoCache,
+ 'redis': RedisCache,
+ 'sqlite': DbCache,
+}
- registry['dynamodb'] = DynamoDbCache
-except ImportError:
- DynamoDbCache = None
+def init_backend(backend: BackendSpecifier, *args, **kwargs) -> BaseCache:
+ """Initialize a backend given a name, class, or instance"""
+ logger.info(f'Initializing backend: {backend}')
-def create_backend(backend_name, cache_name, kwargs):
- if isinstance(backend_name, BaseCache):
- return backend_name
+ # Omit 'cache_name' positional arg if an equivalent backend-specific kwarg is specified
+ if any([k in kwargs for k in CACHE_NAME_KWARGS]):
+ args = tuple()
- if backend_name is None:
- backend_name = _get_default_backend_name()
- try:
- return registry[backend_name](cache_name, **kwargs)
- except KeyError:
- if backend_name in _backend_dependencies:
- raise ImportError('You must install the python package: %s' % _backend_dependencies[backend_name])
- else:
- raise ValueError('Unsupported backend "%s" try one of: %s' % (backend_name, ', '.join(registry.keys())))
+ # Determine backend class
+ if isinstance(backend, BaseCache):
+ return backend
+ elif isinstance(backend, type):
+ return backend(*args, **kwargs)
+ elif not backend:
+ backend = 'sqlite' if BACKEND_CLASSES['sqlite'] else 'memory'
+ backend = str(backend).lower().replace('mongodb', 'mongo')
+ if backend not in BACKEND_CLASSES:
+ raise ValueError(f'Invalid backend: {backend}. Choose from: {BACKEND_CLASSES.keys()}')
-def _get_default_backend_name():
- if 'sqlite' in registry:
- return 'sqlite'
- return 'memory'
+ return BACKEND_CLASSES[backend](*args, **kwargs)
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 39b80ce..c2e21ac 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -2,7 +2,7 @@ import sqlite3
import threading
from contextlib import contextmanager
from logging import getLogger
-from os.path import expanduser
+from os.path import basename, expanduser
from .base import BaseCache, BaseStorage
@@ -15,17 +15,20 @@ class DbCache(BaseCache):
Reading is fast, saving is a bit slower. It can store big amount of data with low memory usage.
Args:
- location: database file path (without extension)
- extension: Database file extension
+ db_path: Database file path
fast_save: Speedup cache saving up to 50 times but with possibility of data loss.
See :ref:`backends.DbDict <backends_dbdict>` for more info
timeout: Timeout for acquiring a database lock
"""
- def __init__(self, location='http_cache', extension='.sqlite', fast_save=False, **kwargs):
+ def __init__(self, db_path: str = 'http_cache', fast_save: bool = False, **kwargs):
super().__init__(**kwargs)
kwargs.setdefault('suppress_warnings', True)
- db_path = expanduser(str(location) + extension)
+ # Allow paths with user directories (~/*), and add file extension if not specified
+ db_path = expanduser(str(db_path))
+ if '.' not in basename(db_path):
+ db_path += '.sqlite'
+
self.responses = DbPickleDict(db_path, table_name='responses', fast_save=fast_save, **kwargs)
self.redirects = DbDict(db_path, table_name='redirects', **kwargs)
@@ -49,17 +52,17 @@ class DbDict(BaseStorage):
All data will be stored in separate tables in the file ``test.sqlite``.
Args:
- filename: Filename for database (without extension)
+ db_path: Database file path
table_name: Table name
fast_save: Use `"PRAGMA synchronous = 0;" <http://www.sqlite.org/pragma.html#pragma_synchronous>`_
to speed up cache saving, but with the potential for data loss
timeout: Timeout for acquiring a database lock
"""
- def __init__(self, filename, table_name='http_cache', fast_save=False, timeout=5.0, **kwargs):
+ def __init__(self, db_path, table_name='http_cache', fast_save=False, timeout=5.0, **kwargs):
super().__init__(**kwargs)
+ self.db_path = db_path
self.fast_save = fast_save
- self.filename = filename
self.table_name = table_name
self.timeout = timeout
@@ -72,14 +75,14 @@ class DbDict(BaseStorage):
@contextmanager
def connection(self, commit_on_success=False):
- logger.debug(f'Opening connection to {self.filename}:{self.table_name}')
+ logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
with self._lock:
if self._bulk_commit:
if self._pending_connection is None:
- self._pending_connection = sqlite3.connect(self.filename, timeout=self.timeout)
+ self._pending_connection = sqlite3.connect(self.db_path, timeout=self.timeout)
con = self._pending_connection
else:
- con = sqlite3.connect(self.filename, timeout=self.timeout)
+ con = sqlite3.connect(self.db_path, timeout=self.timeout)
try:
if self.fast_save:
con.execute("PRAGMA synchronous = 0;")
diff --git a/requests_cache/core.py b/requests_cache/core.py
index 3a13b21..393c4f4 100644
--- a/requests_cache/core.py
+++ b/requests_cache/core.py
@@ -2,14 +2,14 @@
from contextlib import contextmanager
from fnmatch import fnmatch
from logging import getLogger
-from typing import Any, Callable, Dict, Iterable, Type
+from typing import Any, Callable, Dict, Iterable, Optional, Type
import requests
from requests import PreparedRequest
from requests import Session as OriginalSession
from requests.hooks import dispatch_hook
-from . import backends
+from .backends import BACKEND_KWARGS, BackendSpecifier, BaseCache, init_backend
from .cache_keys import normalize_dict
from .response import AnyResponse, ExpirationTime, set_response_defaults
@@ -24,8 +24,8 @@ class CacheMixin:
def __init__(
self,
- cache_name: str = 'cache',
- backend: str = None,
+ cache_name: str = 'http-cache',
+ backend: BackendSpecifier = None,
expire_after: ExpirationTime = -1,
urls_expire_after: Dict[str, ExpirationTime] = None,
allowable_codes: Iterable[int] = (200,),
@@ -34,7 +34,7 @@ class CacheMixin:
old_data_on_error: bool = False,
**kwargs,
):
- self.cache = backends.create_backend(backend, cache_name, kwargs)
+ self.cache = init_backend(backend, cache_name, **kwargs)
self.allowable_codes = allowable_codes
self.allowable_methods = allowable_methods
self.expire_after = expire_after
@@ -47,7 +47,7 @@ class CacheMixin:
self._disabled = False
# Remove any requests-cache-specific kwargs before passing along to superclass
- session_kwargs = {k: v for k, v in kwargs.items() if k not in backends.BACKEND_KWARGS}
+ session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS}
super().__init__(**session_kwargs)
def request(
@@ -82,7 +82,7 @@ class CacheMixin:
2. :py:meth:`.CachedSession.request`
3. :py:meth:`requests.Session.request`
4. :py:meth:`.CachedSession.send`
- 5. :py:meth:`.backends.BaseCache.get_response`
+ 5. :py:meth:`.BaseCache.get_response`
6. :py:meth:`requests.Session.send` (if not cached)
"""
with self.request_expire_after(expire_after):
@@ -225,8 +225,8 @@ class CachedSession(CacheMixin, OriginalSession):
Args:
cache_name: Cache prefix or namespace, depending on backend
- backend: Cache backend name; one of ``['sqlite', 'mongodb', 'gridfs', 'redis', 'dynamodb', 'memory']``.
- Default behavior is to use ``'sqlite'`` if available, otherwise fallback to ``'memory'``.
+ backend: Cache backend name, class, or instance; name may be one of
+ ``['sqlite', 'mongodb', 'gridfs', 'redis', 'dynamodb', 'memory']``.
expire_after: Time after which cached items will expire
urls_expire_after: Expiration times to apply for different URL patterns
allowable_codes: Only cache responses with one of these codes
@@ -352,7 +352,7 @@ def enabled(*args, **kwargs):
uninstall_cache()
-def get_cache() -> backends.BaseCache:
+def get_cache() -> Optional[BaseCache]:
"""Returns internal cache object from globally installed ``CachedSession``"""
return requests.Session().cache if is_installed() else None
diff --git a/runtests.sh b/runtests.sh
index 7614288..7614288 100644..100755
--- a/runtests.sh
+++ b/runtests.sh
diff --git a/tests/conftest.py b/tests/conftest.py
index 2bc21d8..466dd41 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -44,7 +44,7 @@ def mock_session() -> CachedSession:
"""
with NamedTemporaryFile(suffix='.db') as temp:
session = CachedSession(
- cache_name=temp.name,
+ db_path=temp.name,
backend='sqlite',
allowable_methods=ALL_METHODS,
suppress_warnings=True,
diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py
index e38f7e5..28f038e 100644
--- a/tests/unit/test_cache.py
+++ b/tests/unit/test_cache.py
@@ -1,5 +1,6 @@
"""CachedSession + BaseCache tests that use mocked responses only"""
# flake8: noqa: F841
+# TODO: This could be split up into some smaller test modules
import json
import pickle
import pytest
@@ -16,6 +17,7 @@ from itsdangerous.serializer import Serializer
from requests.structures import CaseInsensitiveDict
from requests_cache import ALL_METHODS, CachedSession
+from requests_cache.backends import BACKEND_CLASSES, BaseCache, get_placeholder_backend
from requests_cache.backends.sqlite import DbDict, DbPickleDict
from tests.conftest import (
MOCKED_URL,
@@ -26,17 +28,31 @@ from tests.conftest import (
)
-def test_unregistered_backend():
+def test_init_unregistered_backend():
with pytest.raises(ValueError):
CachedSession(backend='nonexistent')
-@patch('requests_cache.backends.registry')
-def test_missing_backend_dependency(mocked_registry):
+@patch.dict(BACKEND_CLASSES, {'mongo': get_placeholder_backend()})
+def test_init_missing_backend_dependency():
"""Test that the correct error is thrown when a user does not have a dependency installed"""
- mocked_registry.__getitem__.side_effect = KeyError
with pytest.raises(ImportError):
- CachedSession(backend='redis')
+ CachedSession(backend='mongo')
+
+
+class MyCache(BaseCache):
+ pass
+
+
+def test_init_backend_instance():
+ backend = MyCache()
+ session = CachedSession(backend=backend)
+ assert session.cache is backend
+
+
+def test_init_backend_class():
+ session = CachedSession(backend=MyCache)
+ assert isinstance(session.cache, MyCache)
@pytest.mark.parametrize('method', ALL_METHODS)
diff --git a/tests/unit/test_monkey_patch.py b/tests/unit/test_monkey_patch.py
index 9d43aa5..54c5c33 100644
--- a/tests/unit/test_monkey_patch.py
+++ b/tests/unit/test_monkey_patch.py
@@ -44,19 +44,6 @@ def test_inheritance_after_monkey_patch(installed_session):
assert "new_one" in s.__attrs__
-def test_passing_backend_instance_support():
- class MyCache(BaseCache):
- pass
-
- backend = MyCache()
- requests_cache.install_cache(name=CACHE_NAME, backend=backend)
- assert requests.Session().cache is backend
- requests_cache.uninstall_cache()
-
- session = CachedSession(backend=backend)
- assert session.cache is backend
-
-
@patch.object(BaseCache, 'clear')
def test_clear(mock_clear, installed_session):
requests_cache.clear()