diff options
author | Jordan Cook <JWCook@users.noreply.github.com> | 2021-04-22 13:27:24 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-22 13:27:24 -0500 |
commit | f2bf9bbfd8d711e9dfc8ccce83f439d765d484cb (patch) | |
tree | fb617be7c7760390b7239899e6522d865956342f | |
parent | 38ddcf5425eadc6b174a8b053b2175c4dda00a1f (diff) | |
parent | 3ec85f0107caedbe3b3bd8080bd3dac81b7baa17 (diff) | |
download | requests-cache-f2bf9bbfd8d711e9dfc8ccce83f439d765d484cb.tar.gz |
Merge pull request #242 from JWCook/integration-tests
Refactor backend integration tests
-rw-r--r-- | requests_cache/backends/base.py | 3 | ||||
-rw-r--r-- | requests_cache/backends/dynamodb.py | 3 | ||||
-rw-r--r-- | requests_cache/backends/filesystem.py | 21 | ||||
-rw-r--r-- | requests_cache/backends/gridfs.py | 3 | ||||
-rw-r--r-- | requests_cache/backends/mongo.py | 3 | ||||
-rw-r--r-- | requests_cache/backends/redis.py | 3 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 6 | ||||
-rwxr-xr-x | runtests.sh | 4 | ||||
-rw-r--r-- | tests/conftest.py | 12 | ||||
-rw-r--r-- | tests/integration/test_backends.py | 210 | ||||
-rw-r--r-- | tests/integration/test_dynamodb.py | 34 | ||||
-rw-r--r-- | tests/integration/test_filesystem.py | 39 | ||||
-rw-r--r-- | tests/integration/test_gridfs.py | 43 | ||||
-rw-r--r-- | tests/integration/test_mongodb.py | 37 | ||||
-rw-r--r-- | tests/integration/test_redis.py | 23 | ||||
-rw-r--r-- | tests/integration/test_sqlite.py | 136 | ||||
-rw-r--r-- | tests/integration/test_thread_safety.py | 41 |
17 files changed, 308 insertions, 313 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index e912879..def4dc1 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -201,3 +201,6 @@ class BaseStorage(MutableMapping, ABC): return Serializer(secret_key, salt=salt, serializer=pickle) else: return pickle + + def __str__(self): + return str(list(self.keys())) diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py index ffb610c..bcd1e26 100644 --- a/requests_cache/backends/dynamodb.py +++ b/requests_cache/backends/dynamodb.py @@ -110,9 +110,6 @@ class DynamoDbDict(BaseStorage): composite_key = {'namespace': v['namespace'], 'key': v['key']} self._table.delete_item(Key=composite_key) - def __str__(self): - return str(dict(self.items())) - def __scan_table(self): expression_attribute_values = {':Namespace': self._self_key} expression_attribute_names = {'#N': 'namespace'} diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py index c85751f..c9fc515 100644 --- a/requests_cache/backends/filesystem.py +++ b/requests_cache/backends/filesystem.py @@ -1,7 +1,7 @@ # TODO: Add option for compression? from contextlib import contextmanager from os import listdir, makedirs, unlink -from os.path import abspath, expanduser, isabs, join +from os.path import abspath, dirname, expanduser, isabs, join from pathlib import Path from pickle import PickleError from shutil import rmtree @@ -13,8 +13,9 @@ from .sqlite import DbDict class FileCache(BaseCache): - """Backend that stores cached responses as files on the local filesystem. Response paths will be - in the format ``<cache_name>/<cache_key>``. Redirects are stored in a SQLite database. + """Backend that stores cached responses as files on the local filesystem. + Response paths will be in the format ``<cache_name>/responses/<cache_key>``. + Redirects are stored in a SQLite database, located at ``<cache_name>/redirects.sqlite``. Args: cache_name: Base directory for cache files @@ -24,18 +25,18 @@ class FileCache(BaseCache): def __init__(self, cache_name: Union[Path, str] = 'http_cache', use_temp: bool = False, **kwargs): super().__init__(**kwargs) - cache_dir = _get_cache_dir(cache_name, use_temp) - self.responses = FileDict(cache_dir, **kwargs) - self.redirects = DbDict(join(cache_dir, 'redirects.sqlite'), 'redirects', **kwargs) + self.responses = FileDict(cache_name, use_temp=use_temp, **kwargs) + db_path = join(dirname(self.responses.cache_dir), 'redirects.sqlite') + self.redirects = DbDict(db_path, 'redirects', **kwargs) class FileDict(BaseStorage): """A dictionary-like interface to files on the local filesystem""" - def __init__(self, cache_dir, **kwargs): + def __init__(self, cache_name, use_temp: bool = False, **kwargs): kwargs.setdefault('suppress_warnings', True) super().__init__(**kwargs) - self.cache_dir = cache_dir + self.cache_dir = _get_cache_dir(cache_name, use_temp) makedirs(self.cache_dir, exist_ok=True) @contextmanager @@ -70,7 +71,7 @@ class FileDict(BaseStorage): def clear(self): with self._try_io(ignore_errors=True): - rmtree(self.cache_dir) + rmtree(self.cache_dir, ignore_errors=True) makedirs(self.cache_dir) def paths(self): @@ -81,7 +82,7 @@ class FileDict(BaseStorage): def _get_cache_dir(cache_dir: Union[Path, str], use_temp: bool) -> str: if use_temp and not isabs(cache_dir): - cache_dir = join(gettempdir(), cache_dir) + cache_dir = join(gettempdir(), cache_dir, 'responses') cache_dir = abspath(expanduser(str(cache_dir))) makedirs(cache_dir, exist_ok=True) return cache_dir diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index 73f258e..82ca05c 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -77,6 +77,3 @@ class GridFSPickleDict(BaseStorage): def clear(self): self.db['fs.files'].drop() self.db['fs.chunks'].drop() - - def __str__(self): - return str(dict(self.items())) diff --git a/requests_cache/backends/mongo.py b/requests_cache/backends/mongo.py index b479a47..9ebadf4 100644 --- a/requests_cache/backends/mongo.py +++ b/requests_cache/backends/mongo.py @@ -70,9 +70,6 @@ class MongoDict(BaseStorage): def clear(self): self.collection.drop() - def __str__(self): - return str(dict(self.items())) - class MongoPickleDict(MongoDict): """Same as :class:`MongoDict`, but pickles values before saving""" diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py index 33236d5..5209aa6 100644 --- a/requests_cache/backends/redis.py +++ b/requests_cache/backends/redis.py @@ -60,6 +60,3 @@ class RedisDict(BaseStorage): def clear(self): self.connection.delete(self._self_key) - - def __str__(self): - return str(dict(self.items())) diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 470d00c..9c98fcd 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -25,7 +25,6 @@ class DbCache(BaseCache): def __init__(self, db_path: Union[Path, str] = 'http_cache', fast_save: bool = False, **kwargs): super().__init__(**kwargs) - db_path = _get_db_path(db_path) self.responses = DbPickleDict(db_path, table_name='responses', fast_save=fast_save, **kwargs) self.redirects = DbDict(db_path, table_name='redirects', **kwargs) @@ -60,7 +59,7 @@ class DbDict(BaseStorage): kwargs.setdefault('suppress_warnings', True) super().__init__(**kwargs) self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) - self.db_path = db_path + self.db_path = _get_db_path(db_path) self.fast_save = fast_save self.table_name = table_name @@ -144,9 +143,6 @@ class DbDict(BaseStorage): with self.connection(True) as con: con.execute("vacuum") - def __str__(self): - return str(dict(self.items())) - class DbPickleDict(DbDict): """Same as :class:`DbDict`, but pickles values before saving""" diff --git a/runtests.sh b/runtests.sh index f3be637..e4cd976 100755 --- a/runtests.sh +++ b/runtests.sh @@ -4,5 +4,5 @@ COVERAGE_ARGS='--cov --cov-report=term --cov-report=html' export STRESS_TEST_MULTIPLIER=2 # Run unit tests first (and with multiprocessing) to fail quickly if there are issues -pytest tests/unit --numprocesses=auto $COVERAGE_ARGS -pytest tests/integration --cov-append $COVERAGE_ARGS +pytest tests/unit -x --numprocesses=auto $COVERAGE_ARGS +pytest tests/integration -x --cov-append $COVERAGE_ARGS diff --git a/tests/conftest.py b/tests/conftest.py index 577cb78..a918bc7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,13 @@ from timeout_decorator import timeout import requests_cache from requests_cache.session import ALL_METHODS, CachedSession +CACHE_NAME = 'pytest_cache' + +# Allow running longer stress tests with an environment variable +STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1')) +N_THREADS = 2 * STRESS_TEST_MULTIPLIER +N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER + MOCKED_URL = 'http+mock://requests-cache.com/text' MOCKED_URL_HTTPS = 'https+mock://requests-cache.com/text' MOCKED_URL_JSON = 'http+mock://requests-cache.com/json' @@ -36,9 +43,10 @@ AWS_OPTIONS = { 'aws_secret_access_key': 'placeholder', } -# Configure logging to show debug output when tests fail (or with pytest -s) + +# Configure logging to show log output when tests fail (or with pytest -s) basicConfig(level='INFO') -getLogger('requests_cache').setLevel('DEBUG') +# getLogger('requests_cache').setLevel('DEBUG') logger = getLogger(__name__) diff --git a/tests/integration/test_backends.py b/tests/integration/test_backends.py index ca9c6e2..9849a6b 100644 --- a/tests/integration/test_backends.py +++ b/tests/integration/test_backends.py @@ -1,100 +1,166 @@ -# TODO: Refactor with pytest fixtures import pytest -from typing import Type +from threading import Thread +from time import time +from typing import Dict, Type -from requests_cache.backends.base import BaseStorage +from requests_cache.backends.base import BaseCache, BaseStorage +from requests_cache.session import CachedSession +from tests.conftest import CACHE_NAME, N_ITERATIONS, N_THREADS, httpbin -class BaseStorageTestCase: - """Base class for testing backends""" +class BaseStorageTest: + """Base class for testing cache storage dict-like interfaces""" - def __init__( - self, - *args, - storage_class: Type[BaseStorage], - picklable: bool = False, - **kwargs, - ): - self.storage_class = storage_class - self.picklable = picklable - super().__init__(*args, **kwargs) + storage_class: Type[BaseStorage] = None + init_kwargs: Dict = {} + picklable: bool = False + num_instances: int = 10 # Max number of cache instances to test - NAMESPACE = 'pytest-temp' - TABLES = ['table%s' % i for i in range(5)] + def init_cache(self, index=0, clear=True, **kwargs): + kwargs['suppress_warnings'] = True + cache = self.storage_class(CACHE_NAME, f'table_{index}', **self.init_kwargs, **kwargs) + if clear: + cache.clear() + return cache def tearDown(self): - for table in self.TABLES: - self.storage_class(self.NAMESPACE, table).clear() + for i in range(self.num_instances): + self.init_cache(i, clear=True) super().tearDown() - def test_set_get(self): - d1 = self.storage_class(self.NAMESPACE, self.TABLES[0]) - d2 = self.storage_class(self.NAMESPACE, self.TABLES[1]) - d3 = self.storage_class(self.NAMESPACE, self.TABLES[2]) - d1['key_1'] = 1 - d2['key_2'] = 2 - d3['key_3'] = 3 - assert list(d1.keys()) == ['key_1'] - assert list(d2.keys()) == ['key_2'] - assert list(d3.keys()) == ['key_3'] - - with pytest.raises(KeyError): - d1[4] - - def test_str(self): - d = self.storage_class(self.NAMESPACE) - d.clear() - d['key_1'] = 'value_1' - d['key_2'] = 'value_2' - assert dict(d) == {'key_1': 'value_1', 'key_2': 'value_2'} + def test_basic_methods(self): + """Test basic dict methods with multiple cache instances: + ``getitem, setitem, delitem, len, contains`` + """ + caches = [self.init_cache(i) for i in range(10)] + for i in range(self.num_instances): + caches[i][f'key_{i}'] = f'value_{i}' + caches[i][f'key_{i+1}'] = f'value_{i+1}' + + for i in range(self.num_instances): + cache = caches[i] + cache[f'key_{i}'] == f'value_{i}' + assert len(cache) == 2 + assert f'key_{i}' in cache and f'key_{i+1}' in cache + + del cache[f'key_{i}'] + assert f'key_{i}' not in cache + + def test_iterable_methods(self): + """Test iterable dict methods with multiple cache instances: + ``iter, keys, values, items`` + """ + caches = [self.init_cache(i) for i in range(self.num_instances)] + for i in range(self.num_instances): + caches[i][f'key_{i}'] = f'value_{i}' + + for i in range(self.num_instances): + cache = caches[i] + assert list(cache) == [f'key_{i}'] + assert list(cache.keys()) == [f'key_{i}'] + assert list(cache.values()) == [f'value_{i}'] + assert list(cache.items()) == [(f'key_{i}', f'value_{i}')] + assert dict(cache) == {f'key_{i}': f'value_{i}'} def test_del(self): - d = self.storage_class(self.NAMESPACE) - d.clear() + """Some more tests to ensure ``delitem`` deletes only the expected items""" + cache = self.init_cache() + for i in range(20): + cache[f'key_{i}'] = f'value_{i}' for i in range(5): - d[f'key_{i}'] = i - del d['key_0'] - del d['key_1'] - del d['key_2'] - assert set(d.keys()) == {f'key_{i}' for i in range(3, 5)} - assert set(d.values()) == set(range(3, 5)) + del cache[f'key_{i}'] + + assert len(cache) == 15 + assert set(cache.keys()) == {f'key_{i}' for i in range(5, 20)} + assert set(cache.values()) == {f'value_{i}' for i in range(5, 20)} + def test_keyerrors(self): + """Accessing or deleting a deleted item should raise a KeyError""" + cache = self.init_cache() + cache['key'] = 'value' + del cache['key'] + + with pytest.raises(KeyError): + del cache['key'] with pytest.raises(KeyError): - del d['key_0'] + cache['key'] def test_picklable_dict(self): if self.picklable: - d = self.storage_class(self.NAMESPACE) - d['key_1'] = Picklable() - d = self.storage_class(self.NAMESPACE) - assert d['key_1'].a == 1 - assert d['key_1'].b == 2 + cache = self.init_cache() + cache['key_1'] = Picklable() + assert cache['key_1'].attr_1 == 'value_1' + assert cache['key_1'].attr_2 == 'value_2' def test_clear_and_work_again(self): - d1 = self.storage_class(self.NAMESPACE) - d2 = self.storage_class(self.NAMESPACE, connection=getattr(d1, 'connection', None)) - d1.clear() - d2.clear() + cache_1 = self.init_cache() + cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None)) for i in range(5): - d1[i] = i - d2[i] = i + cache_1[i] = i + cache_2[i] = i - assert len(d1) == len(d2) == 5 - d1.clear() - d2.clear() - assert len(d1) == len(d2) == 0 + assert len(cache_1) == len(cache_2) == 5 + cache_1.clear() + cache_2.clear() + assert len(cache_1) == len(cache_2) == 0 def test_same_settings(self): - d1 = self.storage_class(self.NAMESPACE) - d2 = self.storage_class(self.NAMESPACE, connection=getattr(d1, 'connection', None)) - d1.clear() - d2.clear() - d1['key_1'] = 1 - d2['key_2'] = 2 - assert d1 == d2 + cache_1 = self.init_cache() + cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None)) + cache_1['key_1'] = 1 + cache_2['key_2'] = 2 + assert cache_1 == cache_2 + + def test_str(self): + """Not much to test for __str__ methods, just make sure they return keys in some format""" + cache = self.init_cache() + for i in range(10): + cache[f'key_{i}'] = f'value_{i}' + for i in range(10): + assert f'key_{i}' in str(cache) class Picklable: - a = 1 - b = 2 + attr_1 = 'value_1' + attr_2 = 'value_2' + + +class BaseCacheTest: + """Base class for testing cache backend classes""" + + backend_class: Type[BaseCache] = None + init_kwargs: Dict = {} + + def init_backend(self, clear=True, **kwargs): + kwargs['suppress_warnings'] = True + backend = self.backend_class(CACHE_NAME, **self.init_kwargs, **kwargs) + if clear: + backend.redirects.clear() + backend.responses.clear() + return backend + + @pytest.mark.parametrize('iteration', range(N_ITERATIONS)) + def test_caching_with_threads(self, iteration): + """Run a multi-threaded stress test for each backend""" + start = time() + session = CachedSession(backend=self.init_backend()) + url = httpbin('anything') + + def send_requests(): + for i in range(N_ITERATIONS): + session.get(url, params={f'key_{i}': f'value_{i}'}) + + threads = [Thread(target=send_requests) for i in range(N_THREADS)] + for t in threads: + t.start() + for t in threads: + t.join() + + elapsed = time() - start + average = (elapsed * 1000) / (N_ITERATIONS * N_THREADS) + print(f'{self.backend_class}: Ran {N_ITERATIONS} iterations with {N_THREADS} threads each in {elapsed} s') + print(f'Average time per request: {average} ms') + + for i in range(N_ITERATIONS): + assert session.cache.has_url(f'{url}?key_{i}=value_{i}') diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py index 3a4bd23..951d9c8 100644 --- a/tests/integration/test_dynamodb.py +++ b/tests/integration/test_dynamodb.py @@ -1,10 +1,9 @@ import pytest -import unittest from unittest.mock import patch -from requests_cache.backends import DynamoDbDict +from requests_cache.backends import DynamoDbCache, DynamoDbDict from tests.conftest import AWS_OPTIONS, fail_if_no_connection -from tests.integration.test_backends import BaseStorageTestCase +from tests.integration.test_backends import BaseCacheTest, BaseStorageTest # Run this test module last, since the DynamoDB container takes the longest to initialize pytestmark = pytest.mark.order(-1) @@ -20,23 +19,18 @@ def ensure_connection(): client.describe_limits() -class DynamoDbDictWrapper(DynamoDbDict): - def __init__(self, namespace, collection_name='dynamodb_dict_data', **options): - super().__init__(namespace, collection_name, **options, **AWS_OPTIONS) +class TestDynamoDbDict(BaseStorageTest): + storage_class = DynamoDbDict + init_kwargs = AWS_OPTIONS + picklable = True + @patch('requests_cache.backends.dynamodb.boto3.resource') + def test_connection_kwargs(self, mock_resource): + """A spot check to make sure optional connection kwargs gets passed to connection""" + DynamoDbDict('test', region_name='us-east-2', invalid_kwarg='???') + mock_resource.assert_called_with('dynamodb', region_name='us-east-2') -class DynamoDbTestCase(BaseStorageTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__( - *args, - storage_class=DynamoDbDictWrapper, - picklable=True, - **kwargs, - ) - -@patch('requests_cache.backends.dynamodb.boto3.resource') -def test_connection_kwargs(mock_resource): - """A spot check to make sure optional connection kwargs gets passed to connection""" - DynamoDbDict('test', region_name='us-east-2', invalid_kwarg='???') - mock_resource.assert_called_with('dynamodb', region_name='us-east-2') +class TestDynamoDbCache(BaseCacheTest): + backend_class = DynamoDbCache + init_kwargs = AWS_OPTIONS diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py index 72301c9..0a3caec 100644 --- a/tests/integration/test_filesystem.py +++ b/tests/integration/test_filesystem.py @@ -1,32 +1,33 @@ -import pytest -import unittest from os.path import isfile from shutil import rmtree -from requests_cache.backends import FileDict -from tests.integration.test_backends import BaseStorageTestCase +from requests_cache.backends import FileCache, FileDict +from tests.integration.test_backends import CACHE_NAME, BaseCacheTest, BaseStorageTest -class FilesystemTestCase(BaseStorageTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=FileDict, picklable=True, **kwargs) +class TestFileDict(BaseStorageTest): + storage_class = FileDict + picklable = True - def tearDown(self): - rmtree(self.NAMESPACE) + @classmethod + def teardown_class(cls): + rmtree(CACHE_NAME, ignore_errors=True) - def test_set_get(self): - cache = self.storage_class(self.NAMESPACE) - cache['key'] = 'value' - assert list(cache.keys()) == ['key'] - assert list(cache.values()) == ['value'] - - with pytest.raises(KeyError): - cache[4] + def init_cache(self, index=0, **kwargs): + cache = self.storage_class(f'{CACHE_NAME}_{index}', use_temp=True, **kwargs) + cache.clear() + return cache def test_paths(self): - cache = self.storage_class(self.NAMESPACE) - for i in range(10): + cache = self.storage_class(CACHE_NAME) + for i in range(self.num_instances): cache[f'key_{i}'] = f'value_{i}' + assert len(list(cache.paths())) == self.num_instances for path in cache.paths(): assert isfile(path) + + +class TestFileCache(BaseCacheTest): + backend_class = FileCache + init_kwargs = {'use_temp': True} diff --git a/tests/integration/test_gridfs.py b/tests/integration/test_gridfs.py index aae9948..d463824 100644 --- a/tests/integration/test_gridfs.py +++ b/tests/integration/test_gridfs.py @@ -1,12 +1,11 @@ import pytest -import unittest from unittest.mock import patch from pymongo import MongoClient -from requests_cache.backends import GridFSPickleDict, get_valid_kwargs +from requests_cache.backends import GridFSCache, GridFSPickleDict, get_valid_kwargs from tests.conftest import fail_if_no_connection -from tests.integration.test_backends import BaseStorageTestCase +from tests.integration.test_backends import BaseCacheTest, BaseStorageTest @pytest.fixture(scope='module', autouse=True) @@ -19,28 +18,22 @@ def ensure_connection(): client.server_info() -class GridFSPickleDictTestCase(BaseStorageTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=GridFSPickleDict, picklable=True, **kwargs) +class TestGridFSPickleDict(BaseStorageTest): + storage_class = GridFSPickleDict + picklable = True + num_instances = 1 # Only test a single collecton instead of multiple - def test_set_get(self): - """Override base test to test a single collecton instead of multiple""" - d1 = self.storage_class(self.NAMESPACE, self.TABLES[0]) - d1[1] = 1 - d1[2] = 2 - assert list(d1.keys()) == [1, 2] + @patch('requests_cache.backends.gridfs.GridFS') + @patch('requests_cache.backends.gridfs.MongoClient') + @patch( + 'requests_cache.backends.gridfs.get_valid_kwargs', + side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs), + ) + def test_connection_kwargs(self, mock_get_valid_kwargs, mock_client, mock_gridfs): + """A spot check to make sure optional connection kwargs gets passed to connection""" + GridFSPickleDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???') + mock_client.assert_called_with(host='http://0.0.0.0', port=1234) - with pytest.raises(KeyError): - d1[4] - -@patch('requests_cache.backends.gridfs.GridFS') -@patch('requests_cache.backends.gridfs.MongoClient') -@patch( - 'requests_cache.backends.gridfs.get_valid_kwargs', - side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs), -) -def test_connection_kwargs(mock_get_valid_kwargs, mock_client, mock_gridfs): - """A spot check to make sure optional connection kwargs gets passed to connection""" - GridFSPickleDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???') - mock_client.assert_called_with(host='http://0.0.0.0', port=1234) +class TestGridFSCache(BaseCacheTest): + backend_class = GridFSCache diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py index 05e61a2..a186ee7 100644 --- a/tests/integration/test_mongodb.py +++ b/tests/integration/test_mongodb.py @@ -1,12 +1,11 @@ import pytest -import unittest from unittest.mock import patch from pymongo import MongoClient -from requests_cache.backends import MongoDict, MongoPickleDict, get_valid_kwargs +from requests_cache.backends import MongoCache, MongoDict, MongoPickleDict, get_valid_kwargs from tests.conftest import fail_if_no_connection -from tests.integration.test_backends import BaseStorageTestCase +from tests.integration.test_backends import BaseCacheTest, BaseStorageTest @pytest.fixture(scope='module', autouse=True) @@ -19,22 +18,24 @@ def ensure_connection(): client.server_info() -class MongoDictTestCase(BaseStorageTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=MongoDict, **kwargs) +class TestMongoDict(BaseStorageTest): + storage_class = MongoDict -class MongoPickleDictTestCase(BaseStorageTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=MongoPickleDict, picklable=True, **kwargs) +class TestMongoPickleDict(BaseStorageTest): + storage_class = MongoPickleDict + picklable = True + @patch('requests_cache.backends.mongo.MongoClient') + @patch( + 'requests_cache.backends.mongo.get_valid_kwargs', + side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs), + ) + def test_connection_kwargs(self, mock_get_valid_kwargs, mock_client): + """A spot check to make sure optional connection kwargs gets passed to connection""" + MongoDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???') + mock_client.assert_called_with(host='http://0.0.0.0', port=1234) -@patch('requests_cache.backends.mongo.MongoClient') -@patch( - 'requests_cache.backends.mongo.get_valid_kwargs', - side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs), -) -def test_connection_kwargs(mock_get_valid_kwargs, mock_client): - """A spot check to make sure optional connection kwargs gets passed to connection""" - MongoDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???') - mock_client.assert_called_with(host='http://0.0.0.0', port=1234) + +class TestMongoCache(BaseCacheTest): + backend_class = MongoCache diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py index 5edf559..899d2b6 100644 --- a/tests/integration/test_redis.py +++ b/tests/integration/test_redis.py @@ -1,10 +1,9 @@ import pytest -import unittest from unittest.mock import patch from requests_cache.backends.redis import RedisCache, RedisDict from tests.conftest import fail_if_no_connection -from tests.integration.test_backends import BaseStorageTestCase +from tests.integration.test_backends import BaseCacheTest, BaseStorageTest @pytest.fixture(scope='module', autouse=True) @@ -16,14 +15,16 @@ def ensure_connection(): Redis().info() -class RedisTestCase(BaseStorageTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=RedisDict, picklable=True, **kwargs) +class TestRedisDict(BaseStorageTest): + storage_class = RedisDict + picklable = True + @patch('requests_cache.backends.redis.StrictRedis') + def test_connection_kwargs(self, mock_redis): + """A spot check to make sure optional connection kwargs gets passed to connection""" + RedisCache('test', username='user', password='pass', invalid_kwarg='???') + mock_redis.assert_called_with(username='user', password='pass') -# @patch.object(Redis, '__init__', Redis.__init__) -@patch('requests_cache.backends.redis.StrictRedis') -def test_connection_kwargs(mock_redis): - """A spot check to make sure optional connection kwargs gets passed to connection""" - RedisCache('test', username='user', password='pass', invalid_kwarg='???') - mock_redis.assert_called_with(username='user', password='pass') + +class TestRedisCache(BaseCacheTest): + backend_class = RedisCache diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index a6fd50d..63c67ec 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -1,110 +1,94 @@ import os -import unittest from threading import Thread from unittest.mock import patch -from requests_cache.backends.sqlite import DbDict, DbPickleDict -from tests.integration.test_backends import BaseStorageTestCase +from requests_cache.backends.sqlite import DbCache, DbDict, DbPickleDict +from tests.integration.test_backends import CACHE_NAME, BaseCacheTest, BaseStorageTest -class SQLiteTestCase(BaseStorageTestCase): - def tearDown(self): +class SQLiteTestCase(BaseStorageTest): + @classmethod + def teardown_class(cls): try: - os.unlink(self.NAMESPACE) + os.unlink(CACHE_NAME) except Exception: pass def test_bulk_commit(self): - d = self.storage_class(self.NAMESPACE, self.TABLES[0]) - with d.bulk_commit(): + cache = self.init_cache() + with cache.bulk_commit(): pass - d.clear() + n = 1000 - with d.bulk_commit(): + with cache.bulk_commit(): for i in range(n): - d[i] = i - assert list(d.keys()) == list(range(n)) + cache[i] = i + assert list(cache.keys()) == list(range(n)) def test_switch_commit(self): - d = self.storage_class(self.NAMESPACE) - d.clear() - d[1] = 1 - d = self.storage_class(self.NAMESPACE) - assert 1 in d + cache = self.init_cache() + cache.clear() + cache['key_1'] = 'value_1' + cache = self.init_cache(clear=False) + assert 'key_1' in cache - d._can_commit = False - d[2] = 2 + cache._can_commit = False + cache['key_2'] = 'value_2' - d = self.storage_class(self.NAMESPACE) - assert 2 not in d - assert d._can_commit is True + cache = self.init_cache(clear=False) + assert 2 not in cache + assert cache._can_commit is True def test_fast_save(self): - d1 = self.storage_class(self.NAMESPACE, fast_save=True) - d2 = self.storage_class(self.NAMESPACE, self.TABLES[1], fast_save=True) - d1.clear() + cache_1 = self.init_cache(1, fast_save=True) + cache_2 = self.init_cache(2, fast_save=True) + n = 1000 for i in range(n): - d1[i] = i - d2[i * 2] = i - # HACK if we will not sort, fast save can produce different order of records - assert sorted(d1.keys()) == list(range(n)) - assert sorted(d2.values()) == list(range(n)) - - def test_usage_with_threads(self): - def do_test_for(d, n_threads=5): - d.clear() - - def do_inserts(values): - for v in values: - d[v] = v - - def values(x, n): - return [i * x for i in range(n)] - - threads = [Thread(target=do_inserts, args=(values(i, n_threads),)) for i in range(n_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - for i in range(n_threads): - for x in values(i, n_threads): - assert d[x] == x - - do_test_for(self.storage_class(self.NAMESPACE)) - do_test_for(self.storage_class(self.NAMESPACE, fast_save=True), 20) - do_test_for(self.storage_class(self.NAMESPACE, fast_save=True)) - do_test_for(self.storage_class(self.NAMESPACE, self.TABLES[1], fast_save=True)) + cache_1[i] = i + cache_2[i * 2] = i + + assert set(cache_1.keys()) == set(range(n)) + assert set(cache_2.values()) == set(range(n)) def test_noop(self): - def do_noop_bulk(d): - with d.bulk_commit(): + def do_noop_bulk(cache): + with cache.bulk_commit(): pass - del d + del cache - d = self.storage_class(self.NAMESPACE) - t = Thread(target=do_noop_bulk, args=(d,)) - t.start() - t.join() + cache = self.init_cache() + thread = Thread(target=do_noop_bulk, args=(cache,)) + thread.start() + thread.join() # make sure connection is not closed by the thread - d[0] = 0 - assert str(d) == "{0: 0}" + cache['key_1'] = 'value_1' + assert list(cache.keys()) == ['key_1'] + + @patch('requests_cache.backends.sqlite.sqlite3') + def test_connection_kwargs(self, mock_sqlite): + """A spot check to make sure optional connection kwargs gets passed to connection""" + cache = self.storage_class('test', timeout=0.5, invalid_kwarg='???') + mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5) -class DbDictTestCase(SQLiteTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=DbDict, **kwargs) +class TestDbDict(SQLiteTestCase): + storage_class = DbDict -class DbPickleDictTestCase(SQLiteTestCase, unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, storage_class=DbPickleDict, picklable=True, **kwargs) +class TestDbPickleDict(SQLiteTestCase): + storage_class = DbPickleDict + picklable = True -@patch('requests_cache.backends.sqlite.sqlite3') -def test_connection_kwargs(mock_sqlite): - """A spot check to make sure optional connection kwargs gets passed to connection""" - DbDict('test', timeout=0.5, invalid_kwarg='???') - mock_sqlite.connect.assert_called_with('test', timeout=0.5) +class TestDbCache(BaseCacheTest): + backend_class = DbCache + init_kwargs = {'use_temp': True} + + @classmethod + def teardown_class(cls): + try: + os.unlink(CACHE_NAME) + except Exception: + pass diff --git a/tests/integration/test_thread_safety.py b/tests/integration/test_thread_safety.py deleted file mode 100644 index a62133a..0000000 --- a/tests/integration/test_thread_safety.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest -from os import getenv -from threading import Thread -from time import time - -from requests_cache.backends import BACKEND_CLASSES -from requests_cache.session import CachedSession -from tests.conftest import AWS_OPTIONS, httpbin - -# Allow running longer stress tests with an environment variable -MULTIPLIER = int(getenv('STRESS_TEST_MULTIPLIER', '1')) -N_THREADS = 2 * MULTIPLIER -N_ITERATIONS = 4 * MULTIPLIER - - -@pytest.mark.parametrize('iteration', range(N_ITERATIONS)) -@pytest.mark.parametrize('backend', BACKEND_CLASSES.keys()) -def test_caching_with_threads(backend, iteration): - """Run a multi-threaded stress test for each backend""" - start = time() - session = CachedSession(backend=backend, use_temp=True, **AWS_OPTIONS) - session.cache.clear() - url = httpbin('anything') - - def send_requests(): - for i in range(N_ITERATIONS): - session.get(url, params={f'key_{i}': f'value_{i}'}) - - threads = [Thread(target=send_requests) for i in range(N_THREADS)] - for t in threads: - t.start() - for t in threads: - t.join() - - elapsed = time() - start - average = (elapsed * 1000) / (N_ITERATIONS * N_THREADS) - print(f'{backend}: Ran {N_ITERATIONS} iterations with {N_THREADS} threads each in {elapsed} s') - print(f'Average time per request: {average} ms') - - for i in range(N_ITERATIONS): - assert session.cache.has_url(f'{url}?key_{i}=value_{i}') |