diff options
author | Jordan Cook <JWCook@users.noreply.github.com> | 2021-04-02 10:31:45 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-04-02 10:37:55 -0500 |
commit | ddba563e19f45015b251eafe776741952451433c (patch) | |
tree | edc2b9896d08129a5fac089654c5f9473d4056ca | |
parent | ad0aaaa66569799e127a3cd143f3a9434ad6449e (diff) | |
parent | aa1447ee8cf8db112fce0f7f4644f4bdc84197e8 (diff) | |
download | requests-cache-ddba563e19f45015b251eafe776741952451433c.tar.gz |
Merge pull request #215 from JWCook/backend-init
Improvements to backend initialization
-rw-r--r-- | requests_cache/backends/__init__.py | 124 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 25 | ||||
-rw-r--r-- | requests_cache/core.py | 20 | ||||
-rwxr-xr-x[-rw-r--r--] | runtests.sh | 0 | ||||
-rw-r--r-- | tests/conftest.py | 2 | ||||
-rw-r--r-- | tests/unit/test_cache.py | 26 | ||||
-rw-r--r-- | tests/unit/test_monkey_patch.py | 13 |
7 files changed, 111 insertions, 99 deletions
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py index f7bde00..9f9af90 100644 --- a/requests_cache/backends/__init__.py +++ b/requests_cache/backends/__init__.py @@ -1,20 +1,21 @@ """Classes and functions for cache persistence""" # flake8: noqa: F401 -from ..cache_keys import normalize_dict -from .base import BaseCache +from logging import getLogger +from typing import Type, Union -# All backend-specific keyword arguments combined -BACKEND_KWARGS = [ +from .base import BaseCache, BaseStorage + +# Backend-specific keyword arguments equivalent to 'cache_name' +CACHE_NAME_KWARGS = ['db_path', 'db_name', 'namespace', 'table_name'] + +# All backend-specific keyword arguments +BACKEND_KWARGS = CACHE_NAME_KWARGS + [ 'connection', - 'db_name', - 'endpont_url', - 'extension', + 'endpoint_url', 'fast_save', 'ignored_parameters', 'include_get_headers', - 'location', 'name', - 'namespace', 'read_capacity_units', 'region_name', 'salt', @@ -23,71 +24,76 @@ BACKEND_KWARGS = [ 'write_capacity_units', ] -registry = { - 'memory': BaseCache, -} +BackendSpecifier = Union[str, BaseCache, Type[BaseCache], None] +logger = getLogger(__name__) -_backend_dependencies = { - 'sqlite': 'sqlite3', - 'mongo': 'pymongo', - 'redis': 'redis', - 'dynamodb': 'dynamodb', -} -try: - # Heroku doesn't allow the SQLite3 module to be installed - from .sqlite import DbCache +def get_placeholder_backend(original_exception: Exception = None) -> Type[BaseCache]: + """Create a placeholder type for a backend class that does not have dependencies installed. + This allows delaying ImportErrors until init time, rather than at import time. + """ - registry['sqlite'] = DbCache -except ImportError: - DbCache = None + class PlaceholderBackend(BaseCache): + def __init__(*args, **kwargs): + msg = 'Dependencies are not installed for this backend' + logger.error(msg) + raise original_exception or ImportError(msg) -try: - from .mongo import MongoCache - - registry['mongo'] = registry['mongodb'] = MongoCache -except ImportError: - MongoCache = None + return PlaceholderBackend +# Import all backends for which dependencies are installed +try: + from .dynamodb import DynamoDbCache +except ImportError as e: + DynamoDbCache = get_placeholder_backend(e) # type: ignore try: from .gridfs import GridFSCache - - registry['gridfs'] = GridFSCache -except ImportError: - GridFSCache = None - +except ImportError as e: + GridFSCache = get_placeholder_backend(e) # type: ignore +try: + from .mongo import MongoCache +except ImportError as e: + MongoCache = get_placeholder_backend(e) # type: ignore try: from .redis import RedisCache +except ImportError as e: + RedisCache = get_placeholder_backend(e) # type: ignore +try: + # Note: Heroku doesn't support SQLite due to ephemeral storage + from .sqlite import DbCache +except ImportError as e: + DbCache = get_placeholder_backend(e) # type: ignore - registry['redis'] = RedisCache -except ImportError: - RedisCache = None -try: - from .dynamodb import DynamoDbCache +BACKEND_CLASSES = { + 'dynamodb': DynamoDbCache, + 'gridfs': GridFSCache, + 'memory': BaseCache, + 'mongo': MongoCache, + 'redis': RedisCache, + 'sqlite': DbCache, +} - registry['dynamodb'] = DynamoDbCache -except ImportError: - DynamoDbCache = None +def init_backend(backend: BackendSpecifier, *args, **kwargs) -> BaseCache: + """Initialize a backend given a name, class, or instance""" + logger.info(f'Initializing backend: {backend}') -def create_backend(backend_name, cache_name, kwargs): - if isinstance(backend_name, BaseCache): - return backend_name + # Omit 'cache_name' positional arg if an equivalent backend-specific kwarg is specified + if any([k in kwargs for k in CACHE_NAME_KWARGS]): + args = tuple() - if backend_name is None: - backend_name = _get_default_backend_name() - try: - return registry[backend_name](cache_name, **kwargs) - except KeyError: - if backend_name in _backend_dependencies: - raise ImportError('You must install the python package: %s' % _backend_dependencies[backend_name]) - else: - raise ValueError('Unsupported backend "%s" try one of: %s' % (backend_name, ', '.join(registry.keys()))) + # Determine backend class + if isinstance(backend, BaseCache): + return backend + elif isinstance(backend, type): + return backend(*args, **kwargs) + elif not backend: + backend = 'sqlite' if BACKEND_CLASSES['sqlite'] else 'memory' + backend = str(backend).lower().replace('mongodb', 'mongo') + if backend not in BACKEND_CLASSES: + raise ValueError(f'Invalid backend: {backend}. Choose from: {BACKEND_CLASSES.keys()}') -def _get_default_backend_name(): - if 'sqlite' in registry: - return 'sqlite' - return 'memory' + return BACKEND_CLASSES[backend](*args, **kwargs) diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 39b80ce..c2e21ac 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -2,7 +2,7 @@ import sqlite3 import threading from contextlib import contextmanager from logging import getLogger -from os.path import expanduser +from os.path import basename, expanduser from .base import BaseCache, BaseStorage @@ -15,17 +15,20 @@ class DbCache(BaseCache): Reading is fast, saving is a bit slower. It can store big amount of data with low memory usage. Args: - location: database file path (without extension) - extension: Database file extension + db_path: Database file path fast_save: Speedup cache saving up to 50 times but with possibility of data loss. See :ref:`backends.DbDict <backends_dbdict>` for more info timeout: Timeout for acquiring a database lock """ - def __init__(self, location='http_cache', extension='.sqlite', fast_save=False, **kwargs): + def __init__(self, db_path: str = 'http_cache', fast_save: bool = False, **kwargs): super().__init__(**kwargs) kwargs.setdefault('suppress_warnings', True) - db_path = expanduser(str(location) + extension) + # Allow paths with user directories (~/*), and add file extension if not specified + db_path = expanduser(str(db_path)) + if '.' not in basename(db_path): + db_path += '.sqlite' + self.responses = DbPickleDict(db_path, table_name='responses', fast_save=fast_save, **kwargs) self.redirects = DbDict(db_path, table_name='redirects', **kwargs) @@ -49,17 +52,17 @@ class DbDict(BaseStorage): All data will be stored in separate tables in the file ``test.sqlite``. Args: - filename: Filename for database (without extension) + db_path: Database file path table_name: Table name fast_save: Use `"PRAGMA synchronous = 0;" <http://www.sqlite.org/pragma.html#pragma_synchronous>`_ to speed up cache saving, but with the potential for data loss timeout: Timeout for acquiring a database lock """ - def __init__(self, filename, table_name='http_cache', fast_save=False, timeout=5.0, **kwargs): + def __init__(self, db_path, table_name='http_cache', fast_save=False, timeout=5.0, **kwargs): super().__init__(**kwargs) + self.db_path = db_path self.fast_save = fast_save - self.filename = filename self.table_name = table_name self.timeout = timeout @@ -72,14 +75,14 @@ class DbDict(BaseStorage): @contextmanager def connection(self, commit_on_success=False): - logger.debug(f'Opening connection to {self.filename}:{self.table_name}') + logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') with self._lock: if self._bulk_commit: if self._pending_connection is None: - self._pending_connection = sqlite3.connect(self.filename, timeout=self.timeout) + self._pending_connection = sqlite3.connect(self.db_path, timeout=self.timeout) con = self._pending_connection else: - con = sqlite3.connect(self.filename, timeout=self.timeout) + con = sqlite3.connect(self.db_path, timeout=self.timeout) try: if self.fast_save: con.execute("PRAGMA synchronous = 0;") diff --git a/requests_cache/core.py b/requests_cache/core.py index 3a13b21..393c4f4 100644 --- a/requests_cache/core.py +++ b/requests_cache/core.py @@ -2,14 +2,14 @@ from contextlib import contextmanager from fnmatch import fnmatch from logging import getLogger -from typing import Any, Callable, Dict, Iterable, Type +from typing import Any, Callable, Dict, Iterable, Optional, Type import requests from requests import PreparedRequest from requests import Session as OriginalSession from requests.hooks import dispatch_hook -from . import backends +from .backends import BACKEND_KWARGS, BackendSpecifier, BaseCache, init_backend from .cache_keys import normalize_dict from .response import AnyResponse, ExpirationTime, set_response_defaults @@ -24,8 +24,8 @@ class CacheMixin: def __init__( self, - cache_name: str = 'cache', - backend: str = None, + cache_name: str = 'http-cache', + backend: BackendSpecifier = None, expire_after: ExpirationTime = -1, urls_expire_after: Dict[str, ExpirationTime] = None, allowable_codes: Iterable[int] = (200,), @@ -34,7 +34,7 @@ class CacheMixin: old_data_on_error: bool = False, **kwargs, ): - self.cache = backends.create_backend(backend, cache_name, kwargs) + self.cache = init_backend(backend, cache_name, **kwargs) self.allowable_codes = allowable_codes self.allowable_methods = allowable_methods self.expire_after = expire_after @@ -47,7 +47,7 @@ class CacheMixin: self._disabled = False # Remove any requests-cache-specific kwargs before passing along to superclass - session_kwargs = {k: v for k, v in kwargs.items() if k not in backends.BACKEND_KWARGS} + session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS} super().__init__(**session_kwargs) def request( @@ -82,7 +82,7 @@ class CacheMixin: 2. :py:meth:`.CachedSession.request` 3. :py:meth:`requests.Session.request` 4. :py:meth:`.CachedSession.send` - 5. :py:meth:`.backends.BaseCache.get_response` + 5. :py:meth:`.BaseCache.get_response` 6. :py:meth:`requests.Session.send` (if not cached) """ with self.request_expire_after(expire_after): @@ -225,8 +225,8 @@ class CachedSession(CacheMixin, OriginalSession): Args: cache_name: Cache prefix or namespace, depending on backend - backend: Cache backend name; one of ``['sqlite', 'mongodb', 'gridfs', 'redis', 'dynamodb', 'memory']``. - Default behavior is to use ``'sqlite'`` if available, otherwise fallback to ``'memory'``. + backend: Cache backend name, class, or instance; name may be one of + ``['sqlite', 'mongodb', 'gridfs', 'redis', 'dynamodb', 'memory']``. expire_after: Time after which cached items will expire urls_expire_after: Expiration times to apply for different URL patterns allowable_codes: Only cache responses with one of these codes @@ -352,7 +352,7 @@ def enabled(*args, **kwargs): uninstall_cache() -def get_cache() -> backends.BaseCache: +def get_cache() -> Optional[BaseCache]: """Returns internal cache object from globally installed ``CachedSession``""" return requests.Session().cache if is_installed() else None diff --git a/runtests.sh b/runtests.sh index 7614288..7614288 100644..100755 --- a/runtests.sh +++ b/runtests.sh diff --git a/tests/conftest.py b/tests/conftest.py index 2bc21d8..466dd41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,7 @@ def mock_session() -> CachedSession: """ with NamedTemporaryFile(suffix='.db') as temp: session = CachedSession( - cache_name=temp.name, + db_path=temp.name, backend='sqlite', allowable_methods=ALL_METHODS, suppress_warnings=True, diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index e38f7e5..28f038e 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -1,5 +1,6 @@ """CachedSession + BaseCache tests that use mocked responses only""" # flake8: noqa: F841 +# TODO: This could be split up into some smaller test modules import json import pickle import pytest @@ -16,6 +17,7 @@ from itsdangerous.serializer import Serializer from requests.structures import CaseInsensitiveDict from requests_cache import ALL_METHODS, CachedSession +from requests_cache.backends import BACKEND_CLASSES, BaseCache, get_placeholder_backend from requests_cache.backends.sqlite import DbDict, DbPickleDict from tests.conftest import ( MOCKED_URL, @@ -26,17 +28,31 @@ from tests.conftest import ( ) -def test_unregistered_backend(): +def test_init_unregistered_backend(): with pytest.raises(ValueError): CachedSession(backend='nonexistent') -@patch('requests_cache.backends.registry') -def test_missing_backend_dependency(mocked_registry): +@patch.dict(BACKEND_CLASSES, {'mongo': get_placeholder_backend()}) +def test_init_missing_backend_dependency(): """Test that the correct error is thrown when a user does not have a dependency installed""" - mocked_registry.__getitem__.side_effect = KeyError with pytest.raises(ImportError): - CachedSession(backend='redis') + CachedSession(backend='mongo') + + +class MyCache(BaseCache): + pass + + +def test_init_backend_instance(): + backend = MyCache() + session = CachedSession(backend=backend) + assert session.cache is backend + + +def test_init_backend_class(): + session = CachedSession(backend=MyCache) + assert isinstance(session.cache, MyCache) @pytest.mark.parametrize('method', ALL_METHODS) diff --git a/tests/unit/test_monkey_patch.py b/tests/unit/test_monkey_patch.py index 9d43aa5..54c5c33 100644 --- a/tests/unit/test_monkey_patch.py +++ b/tests/unit/test_monkey_patch.py @@ -44,19 +44,6 @@ def test_inheritance_after_monkey_patch(installed_session): assert "new_one" in s.__attrs__ -def test_passing_backend_instance_support(): - class MyCache(BaseCache): - pass - - backend = MyCache() - requests_cache.install_cache(name=CACHE_NAME, backend=backend) - assert requests.Session().cache is backend - requests_cache.uninstall_cache() - - session = CachedSession(backend=backend) - assert session.cache is backend - - @patch.object(BaseCache, 'clear') def test_clear(mock_clear, installed_session): requests_cache.clear() |