summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <JWCook@users.noreply.github.com>2021-04-22 13:27:24 -0500
committerGitHub <noreply@github.com>2021-04-22 13:27:24 -0500
commitf2bf9bbfd8d711e9dfc8ccce83f439d765d484cb (patch)
treefb617be7c7760390b7239899e6522d865956342f
parent38ddcf5425eadc6b174a8b053b2175c4dda00a1f (diff)
parent3ec85f0107caedbe3b3bd8080bd3dac81b7baa17 (diff)
downloadrequests-cache-f2bf9bbfd8d711e9dfc8ccce83f439d765d484cb.tar.gz
Merge pull request #242 from JWCook/integration-tests
Refactor backend integration tests
-rw-r--r--requests_cache/backends/base.py3
-rw-r--r--requests_cache/backends/dynamodb.py3
-rw-r--r--requests_cache/backends/filesystem.py21
-rw-r--r--requests_cache/backends/gridfs.py3
-rw-r--r--requests_cache/backends/mongo.py3
-rw-r--r--requests_cache/backends/redis.py3
-rw-r--r--requests_cache/backends/sqlite.py6
-rwxr-xr-xruntests.sh4
-rw-r--r--tests/conftest.py12
-rw-r--r--tests/integration/test_backends.py210
-rw-r--r--tests/integration/test_dynamodb.py34
-rw-r--r--tests/integration/test_filesystem.py39
-rw-r--r--tests/integration/test_gridfs.py43
-rw-r--r--tests/integration/test_mongodb.py37
-rw-r--r--tests/integration/test_redis.py23
-rw-r--r--tests/integration/test_sqlite.py136
-rw-r--r--tests/integration/test_thread_safety.py41
17 files changed, 308 insertions, 313 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index e912879..def4dc1 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -201,3 +201,6 @@ class BaseStorage(MutableMapping, ABC):
return Serializer(secret_key, salt=salt, serializer=pickle)
else:
return pickle
+
+ def __str__(self):
+ return str(list(self.keys()))
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index ffb610c..bcd1e26 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -110,9 +110,6 @@ class DynamoDbDict(BaseStorage):
composite_key = {'namespace': v['namespace'], 'key': v['key']}
self._table.delete_item(Key=composite_key)
- def __str__(self):
- return str(dict(self.items()))
-
def __scan_table(self):
expression_attribute_values = {':Namespace': self._self_key}
expression_attribute_names = {'#N': 'namespace'}
diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py
index c85751f..c9fc515 100644
--- a/requests_cache/backends/filesystem.py
+++ b/requests_cache/backends/filesystem.py
@@ -1,7 +1,7 @@
# TODO: Add option for compression?
from contextlib import contextmanager
from os import listdir, makedirs, unlink
-from os.path import abspath, expanduser, isabs, join
+from os.path import abspath, dirname, expanduser, isabs, join
from pathlib import Path
from pickle import PickleError
from shutil import rmtree
@@ -13,8 +13,9 @@ from .sqlite import DbDict
class FileCache(BaseCache):
- """Backend that stores cached responses as files on the local filesystem. Response paths will be
- in the format ``<cache_name>/<cache_key>``. Redirects are stored in a SQLite database.
+ """Backend that stores cached responses as files on the local filesystem.
+ Response paths will be in the format ``<cache_name>/responses/<cache_key>``.
+ Redirects are stored in a SQLite database, located at ``<cache_name>/redirects.sqlite``.
Args:
cache_name: Base directory for cache files
@@ -24,18 +25,18 @@ class FileCache(BaseCache):
def __init__(self, cache_name: Union[Path, str] = 'http_cache', use_temp: bool = False, **kwargs):
super().__init__(**kwargs)
- cache_dir = _get_cache_dir(cache_name, use_temp)
- self.responses = FileDict(cache_dir, **kwargs)
- self.redirects = DbDict(join(cache_dir, 'redirects.sqlite'), 'redirects', **kwargs)
+ self.responses = FileDict(cache_name, use_temp=use_temp, **kwargs)
+ db_path = join(dirname(self.responses.cache_dir), 'redirects.sqlite')
+ self.redirects = DbDict(db_path, 'redirects', **kwargs)
class FileDict(BaseStorage):
"""A dictionary-like interface to files on the local filesystem"""
- def __init__(self, cache_dir, **kwargs):
+ def __init__(self, cache_name, use_temp: bool = False, **kwargs):
kwargs.setdefault('suppress_warnings', True)
super().__init__(**kwargs)
- self.cache_dir = cache_dir
+ self.cache_dir = _get_cache_dir(cache_name, use_temp)
makedirs(self.cache_dir, exist_ok=True)
@contextmanager
@@ -70,7 +71,7 @@ class FileDict(BaseStorage):
def clear(self):
with self._try_io(ignore_errors=True):
- rmtree(self.cache_dir)
+ rmtree(self.cache_dir, ignore_errors=True)
makedirs(self.cache_dir)
def paths(self):
@@ -81,7 +82,7 @@ class FileDict(BaseStorage):
def _get_cache_dir(cache_dir: Union[Path, str], use_temp: bool) -> str:
if use_temp and not isabs(cache_dir):
- cache_dir = join(gettempdir(), cache_dir)
+ cache_dir = join(gettempdir(), cache_dir, 'responses')
cache_dir = abspath(expanduser(str(cache_dir)))
makedirs(cache_dir, exist_ok=True)
return cache_dir
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index 73f258e..82ca05c 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -77,6 +77,3 @@ class GridFSPickleDict(BaseStorage):
def clear(self):
self.db['fs.files'].drop()
self.db['fs.chunks'].drop()
-
- def __str__(self):
- return str(dict(self.items()))
diff --git a/requests_cache/backends/mongo.py b/requests_cache/backends/mongo.py
index b479a47..9ebadf4 100644
--- a/requests_cache/backends/mongo.py
+++ b/requests_cache/backends/mongo.py
@@ -70,9 +70,6 @@ class MongoDict(BaseStorage):
def clear(self):
self.collection.drop()
- def __str__(self):
- return str(dict(self.items()))
-
class MongoPickleDict(MongoDict):
"""Same as :class:`MongoDict`, but pickles values before saving"""
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index 33236d5..5209aa6 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -60,6 +60,3 @@ class RedisDict(BaseStorage):
def clear(self):
self.connection.delete(self._self_key)
-
- def __str__(self):
- return str(dict(self.items()))
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 470d00c..9c98fcd 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -25,7 +25,6 @@ class DbCache(BaseCache):
def __init__(self, db_path: Union[Path, str] = 'http_cache', fast_save: bool = False, **kwargs):
super().__init__(**kwargs)
- db_path = _get_db_path(db_path)
self.responses = DbPickleDict(db_path, table_name='responses', fast_save=fast_save, **kwargs)
self.redirects = DbDict(db_path, table_name='redirects', **kwargs)
@@ -60,7 +59,7 @@ class DbDict(BaseStorage):
kwargs.setdefault('suppress_warnings', True)
super().__init__(**kwargs)
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
- self.db_path = db_path
+ self.db_path = _get_db_path(db_path)
self.fast_save = fast_save
self.table_name = table_name
@@ -144,9 +143,6 @@ class DbDict(BaseStorage):
with self.connection(True) as con:
con.execute("vacuum")
- def __str__(self):
- return str(dict(self.items()))
-
class DbPickleDict(DbDict):
"""Same as :class:`DbDict`, but pickles values before saving"""
diff --git a/runtests.sh b/runtests.sh
index f3be637..e4cd976 100755
--- a/runtests.sh
+++ b/runtests.sh
@@ -4,5 +4,5 @@ COVERAGE_ARGS='--cov --cov-report=term --cov-report=html'
export STRESS_TEST_MULTIPLIER=2
# Run unit tests first (and with multiprocessing) to fail quickly if there are issues
-pytest tests/unit --numprocesses=auto $COVERAGE_ARGS
-pytest tests/integration --cov-append $COVERAGE_ARGS
+pytest tests/unit -x --numprocesses=auto $COVERAGE_ARGS
+pytest tests/integration -x --cov-append $COVERAGE_ARGS
diff --git a/tests/conftest.py b/tests/conftest.py
index 577cb78..a918bc7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,13 @@ from timeout_decorator import timeout
import requests_cache
from requests_cache.session import ALL_METHODS, CachedSession
+CACHE_NAME = 'pytest_cache'
+
+# Allow running longer stress tests with an environment variable
+STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1'))
+N_THREADS = 2 * STRESS_TEST_MULTIPLIER
+N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER
+
MOCKED_URL = 'http+mock://requests-cache.com/text'
MOCKED_URL_HTTPS = 'https+mock://requests-cache.com/text'
MOCKED_URL_JSON = 'http+mock://requests-cache.com/json'
@@ -36,9 +43,10 @@ AWS_OPTIONS = {
'aws_secret_access_key': 'placeholder',
}
-# Configure logging to show debug output when tests fail (or with pytest -s)
+
+# Configure logging to show log output when tests fail (or with pytest -s)
basicConfig(level='INFO')
-getLogger('requests_cache').setLevel('DEBUG')
+# getLogger('requests_cache').setLevel('DEBUG')
logger = getLogger(__name__)
diff --git a/tests/integration/test_backends.py b/tests/integration/test_backends.py
index ca9c6e2..9849a6b 100644
--- a/tests/integration/test_backends.py
+++ b/tests/integration/test_backends.py
@@ -1,100 +1,166 @@
-# TODO: Refactor with pytest fixtures
import pytest
-from typing import Type
+from threading import Thread
+from time import time
+from typing import Dict, Type
-from requests_cache.backends.base import BaseStorage
+from requests_cache.backends.base import BaseCache, BaseStorage
+from requests_cache.session import CachedSession
+from tests.conftest import CACHE_NAME, N_ITERATIONS, N_THREADS, httpbin
-class BaseStorageTestCase:
- """Base class for testing backends"""
+class BaseStorageTest:
+ """Base class for testing cache storage dict-like interfaces"""
- def __init__(
- self,
- *args,
- storage_class: Type[BaseStorage],
- picklable: bool = False,
- **kwargs,
- ):
- self.storage_class = storage_class
- self.picklable = picklable
- super().__init__(*args, **kwargs)
+ storage_class: Type[BaseStorage] = None
+ init_kwargs: Dict = {}
+ picklable: bool = False
+ num_instances: int = 10 # Max number of cache instances to test
- NAMESPACE = 'pytest-temp'
- TABLES = ['table%s' % i for i in range(5)]
+ def init_cache(self, index=0, clear=True, **kwargs):
+ kwargs['suppress_warnings'] = True
+ cache = self.storage_class(CACHE_NAME, f'table_{index}', **self.init_kwargs, **kwargs)
+ if clear:
+ cache.clear()
+ return cache
def tearDown(self):
- for table in self.TABLES:
- self.storage_class(self.NAMESPACE, table).clear()
+ for i in range(self.num_instances):
+ self.init_cache(i, clear=True)
super().tearDown()
- def test_set_get(self):
- d1 = self.storage_class(self.NAMESPACE, self.TABLES[0])
- d2 = self.storage_class(self.NAMESPACE, self.TABLES[1])
- d3 = self.storage_class(self.NAMESPACE, self.TABLES[2])
- d1['key_1'] = 1
- d2['key_2'] = 2
- d3['key_3'] = 3
- assert list(d1.keys()) == ['key_1']
- assert list(d2.keys()) == ['key_2']
- assert list(d3.keys()) == ['key_3']
-
- with pytest.raises(KeyError):
- d1[4]
-
- def test_str(self):
- d = self.storage_class(self.NAMESPACE)
- d.clear()
- d['key_1'] = 'value_1'
- d['key_2'] = 'value_2'
- assert dict(d) == {'key_1': 'value_1', 'key_2': 'value_2'}
+ def test_basic_methods(self):
+ """Test basic dict methods with multiple cache instances:
+ ``getitem, setitem, delitem, len, contains``
+ """
+ caches = [self.init_cache(i) for i in range(10)]
+ for i in range(self.num_instances):
+ caches[i][f'key_{i}'] = f'value_{i}'
+ caches[i][f'key_{i+1}'] = f'value_{i+1}'
+
+ for i in range(self.num_instances):
+ cache = caches[i]
+ cache[f'key_{i}'] == f'value_{i}'
+ assert len(cache) == 2
+ assert f'key_{i}' in cache and f'key_{i+1}' in cache
+
+ del cache[f'key_{i}']
+ assert f'key_{i}' not in cache
+
+ def test_iterable_methods(self):
+ """Test iterable dict methods with multiple cache instances:
+ ``iter, keys, values, items``
+ """
+ caches = [self.init_cache(i) for i in range(self.num_instances)]
+ for i in range(self.num_instances):
+ caches[i][f'key_{i}'] = f'value_{i}'
+
+ for i in range(self.num_instances):
+ cache = caches[i]
+ assert list(cache) == [f'key_{i}']
+ assert list(cache.keys()) == [f'key_{i}']
+ assert list(cache.values()) == [f'value_{i}']
+ assert list(cache.items()) == [(f'key_{i}', f'value_{i}')]
+ assert dict(cache) == {f'key_{i}': f'value_{i}'}
def test_del(self):
- d = self.storage_class(self.NAMESPACE)
- d.clear()
+ """Some more tests to ensure ``delitem`` deletes only the expected items"""
+ cache = self.init_cache()
+ for i in range(20):
+ cache[f'key_{i}'] = f'value_{i}'
for i in range(5):
- d[f'key_{i}'] = i
- del d['key_0']
- del d['key_1']
- del d['key_2']
- assert set(d.keys()) == {f'key_{i}' for i in range(3, 5)}
- assert set(d.values()) == set(range(3, 5))
+ del cache[f'key_{i}']
+
+ assert len(cache) == 15
+ assert set(cache.keys()) == {f'key_{i}' for i in range(5, 20)}
+ assert set(cache.values()) == {f'value_{i}' for i in range(5, 20)}
+ def test_keyerrors(self):
+ """Accessing or deleting a deleted item should raise a KeyError"""
+ cache = self.init_cache()
+ cache['key'] = 'value'
+ del cache['key']
+
+ with pytest.raises(KeyError):
+ del cache['key']
with pytest.raises(KeyError):
- del d['key_0']
+ cache['key']
def test_picklable_dict(self):
if self.picklable:
- d = self.storage_class(self.NAMESPACE)
- d['key_1'] = Picklable()
- d = self.storage_class(self.NAMESPACE)
- assert d['key_1'].a == 1
- assert d['key_1'].b == 2
+ cache = self.init_cache()
+ cache['key_1'] = Picklable()
+ assert cache['key_1'].attr_1 == 'value_1'
+ assert cache['key_1'].attr_2 == 'value_2'
def test_clear_and_work_again(self):
- d1 = self.storage_class(self.NAMESPACE)
- d2 = self.storage_class(self.NAMESPACE, connection=getattr(d1, 'connection', None))
- d1.clear()
- d2.clear()
+ cache_1 = self.init_cache()
+ cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None))
for i in range(5):
- d1[i] = i
- d2[i] = i
+ cache_1[i] = i
+ cache_2[i] = i
- assert len(d1) == len(d2) == 5
- d1.clear()
- d2.clear()
- assert len(d1) == len(d2) == 0
+ assert len(cache_1) == len(cache_2) == 5
+ cache_1.clear()
+ cache_2.clear()
+ assert len(cache_1) == len(cache_2) == 0
def test_same_settings(self):
- d1 = self.storage_class(self.NAMESPACE)
- d2 = self.storage_class(self.NAMESPACE, connection=getattr(d1, 'connection', None))
- d1.clear()
- d2.clear()
- d1['key_1'] = 1
- d2['key_2'] = 2
- assert d1 == d2
+ cache_1 = self.init_cache()
+ cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None))
+ cache_1['key_1'] = 1
+ cache_2['key_2'] = 2
+ assert cache_1 == cache_2
+
+ def test_str(self):
+ """Not much to test for __str__ methods, just make sure they return keys in some format"""
+ cache = self.init_cache()
+ for i in range(10):
+ cache[f'key_{i}'] = f'value_{i}'
+ for i in range(10):
+ assert f'key_{i}' in str(cache)
class Picklable:
- a = 1
- b = 2
+ attr_1 = 'value_1'
+ attr_2 = 'value_2'
+
+
+class BaseCacheTest:
+ """Base class for testing cache backend classes"""
+
+ backend_class: Type[BaseCache] = None
+ init_kwargs: Dict = {}
+
+ def init_backend(self, clear=True, **kwargs):
+ kwargs['suppress_warnings'] = True
+ backend = self.backend_class(CACHE_NAME, **self.init_kwargs, **kwargs)
+ if clear:
+ backend.redirects.clear()
+ backend.responses.clear()
+ return backend
+
+ @pytest.mark.parametrize('iteration', range(N_ITERATIONS))
+ def test_caching_with_threads(self, iteration):
+ """Run a multi-threaded stress test for each backend"""
+ start = time()
+ session = CachedSession(backend=self.init_backend())
+ url = httpbin('anything')
+
+ def send_requests():
+ for i in range(N_ITERATIONS):
+ session.get(url, params={f'key_{i}': f'value_{i}'})
+
+ threads = [Thread(target=send_requests) for i in range(N_THREADS)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ elapsed = time() - start
+ average = (elapsed * 1000) / (N_ITERATIONS * N_THREADS)
+ print(f'{self.backend_class}: Ran {N_ITERATIONS} iterations with {N_THREADS} threads each in {elapsed} s')
+ print(f'Average time per request: {average} ms')
+
+ for i in range(N_ITERATIONS):
+ assert session.cache.has_url(f'{url}?key_{i}=value_{i}')
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index 3a4bd23..951d9c8 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -1,10 +1,9 @@
import pytest
-import unittest
from unittest.mock import patch
-from requests_cache.backends import DynamoDbDict
+from requests_cache.backends import DynamoDbCache, DynamoDbDict
from tests.conftest import AWS_OPTIONS, fail_if_no_connection
-from tests.integration.test_backends import BaseStorageTestCase
+from tests.integration.test_backends import BaseCacheTest, BaseStorageTest
# Run this test module last, since the DynamoDB container takes the longest to initialize
pytestmark = pytest.mark.order(-1)
@@ -20,23 +19,18 @@ def ensure_connection():
client.describe_limits()
-class DynamoDbDictWrapper(DynamoDbDict):
- def __init__(self, namespace, collection_name='dynamodb_dict_data', **options):
- super().__init__(namespace, collection_name, **options, **AWS_OPTIONS)
+class TestDynamoDbDict(BaseStorageTest):
+ storage_class = DynamoDbDict
+ init_kwargs = AWS_OPTIONS
+ picklable = True
+ @patch('requests_cache.backends.dynamodb.boto3.resource')
+ def test_connection_kwargs(self, mock_resource):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ DynamoDbDict('test', region_name='us-east-2', invalid_kwarg='???')
+ mock_resource.assert_called_with('dynamodb', region_name='us-east-2')
-class DynamoDbTestCase(BaseStorageTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(
- *args,
- storage_class=DynamoDbDictWrapper,
- picklable=True,
- **kwargs,
- )
-
-@patch('requests_cache.backends.dynamodb.boto3.resource')
-def test_connection_kwargs(mock_resource):
- """A spot check to make sure optional connection kwargs gets passed to connection"""
- DynamoDbDict('test', region_name='us-east-2', invalid_kwarg='???')
- mock_resource.assert_called_with('dynamodb', region_name='us-east-2')
+class TestDynamoDbCache(BaseCacheTest):
+ backend_class = DynamoDbCache
+ init_kwargs = AWS_OPTIONS
diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py
index 72301c9..0a3caec 100644
--- a/tests/integration/test_filesystem.py
+++ b/tests/integration/test_filesystem.py
@@ -1,32 +1,33 @@
-import pytest
-import unittest
from os.path import isfile
from shutil import rmtree
-from requests_cache.backends import FileDict
-from tests.integration.test_backends import BaseStorageTestCase
+from requests_cache.backends import FileCache, FileDict
+from tests.integration.test_backends import CACHE_NAME, BaseCacheTest, BaseStorageTest
-class FilesystemTestCase(BaseStorageTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=FileDict, picklable=True, **kwargs)
+class TestFileDict(BaseStorageTest):
+ storage_class = FileDict
+ picklable = True
- def tearDown(self):
- rmtree(self.NAMESPACE)
+ @classmethod
+ def teardown_class(cls):
+ rmtree(CACHE_NAME, ignore_errors=True)
- def test_set_get(self):
- cache = self.storage_class(self.NAMESPACE)
- cache['key'] = 'value'
- assert list(cache.keys()) == ['key']
- assert list(cache.values()) == ['value']
-
- with pytest.raises(KeyError):
- cache[4]
+ def init_cache(self, index=0, **kwargs):
+ cache = self.storage_class(f'{CACHE_NAME}_{index}', use_temp=True, **kwargs)
+ cache.clear()
+ return cache
def test_paths(self):
- cache = self.storage_class(self.NAMESPACE)
- for i in range(10):
+ cache = self.storage_class(CACHE_NAME)
+ for i in range(self.num_instances):
cache[f'key_{i}'] = f'value_{i}'
+ assert len(list(cache.paths())) == self.num_instances
for path in cache.paths():
assert isfile(path)
+
+
+class TestFileCache(BaseCacheTest):
+ backend_class = FileCache
+ init_kwargs = {'use_temp': True}
diff --git a/tests/integration/test_gridfs.py b/tests/integration/test_gridfs.py
index aae9948..d463824 100644
--- a/tests/integration/test_gridfs.py
+++ b/tests/integration/test_gridfs.py
@@ -1,12 +1,11 @@
import pytest
-import unittest
from unittest.mock import patch
from pymongo import MongoClient
-from requests_cache.backends import GridFSPickleDict, get_valid_kwargs
+from requests_cache.backends import GridFSCache, GridFSPickleDict, get_valid_kwargs
from tests.conftest import fail_if_no_connection
-from tests.integration.test_backends import BaseStorageTestCase
+from tests.integration.test_backends import BaseCacheTest, BaseStorageTest
@pytest.fixture(scope='module', autouse=True)
@@ -19,28 +18,22 @@ def ensure_connection():
client.server_info()
-class GridFSPickleDictTestCase(BaseStorageTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=GridFSPickleDict, picklable=True, **kwargs)
+class TestGridFSPickleDict(BaseStorageTest):
+ storage_class = GridFSPickleDict
+ picklable = True
+ num_instances = 1 # Only test a single collecton instead of multiple
- def test_set_get(self):
- """Override base test to test a single collecton instead of multiple"""
- d1 = self.storage_class(self.NAMESPACE, self.TABLES[0])
- d1[1] = 1
- d1[2] = 2
- assert list(d1.keys()) == [1, 2]
+ @patch('requests_cache.backends.gridfs.GridFS')
+ @patch('requests_cache.backends.gridfs.MongoClient')
+ @patch(
+ 'requests_cache.backends.gridfs.get_valid_kwargs',
+ side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs),
+ )
+ def test_connection_kwargs(self, mock_get_valid_kwargs, mock_client, mock_gridfs):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ GridFSPickleDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
+ mock_client.assert_called_with(host='http://0.0.0.0', port=1234)
- with pytest.raises(KeyError):
- d1[4]
-
-@patch('requests_cache.backends.gridfs.GridFS')
-@patch('requests_cache.backends.gridfs.MongoClient')
-@patch(
- 'requests_cache.backends.gridfs.get_valid_kwargs',
- side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs),
-)
-def test_connection_kwargs(mock_get_valid_kwargs, mock_client, mock_gridfs):
- """A spot check to make sure optional connection kwargs gets passed to connection"""
- GridFSPickleDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
- mock_client.assert_called_with(host='http://0.0.0.0', port=1234)
+class TestGridFSCache(BaseCacheTest):
+ backend_class = GridFSCache
diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py
index 05e61a2..a186ee7 100644
--- a/tests/integration/test_mongodb.py
+++ b/tests/integration/test_mongodb.py
@@ -1,12 +1,11 @@
import pytest
-import unittest
from unittest.mock import patch
from pymongo import MongoClient
-from requests_cache.backends import MongoDict, MongoPickleDict, get_valid_kwargs
+from requests_cache.backends import MongoCache, MongoDict, MongoPickleDict, get_valid_kwargs
from tests.conftest import fail_if_no_connection
-from tests.integration.test_backends import BaseStorageTestCase
+from tests.integration.test_backends import BaseCacheTest, BaseStorageTest
@pytest.fixture(scope='module', autouse=True)
@@ -19,22 +18,24 @@ def ensure_connection():
client.server_info()
-class MongoDictTestCase(BaseStorageTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=MongoDict, **kwargs)
+class TestMongoDict(BaseStorageTest):
+ storage_class = MongoDict
-class MongoPickleDictTestCase(BaseStorageTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=MongoPickleDict, picklable=True, **kwargs)
+class TestMongoPickleDict(BaseStorageTest):
+ storage_class = MongoPickleDict
+ picklable = True
+ @patch('requests_cache.backends.mongo.MongoClient')
+ @patch(
+ 'requests_cache.backends.mongo.get_valid_kwargs',
+ side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs),
+ )
+ def test_connection_kwargs(self, mock_get_valid_kwargs, mock_client):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ MongoDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
+ mock_client.assert_called_with(host='http://0.0.0.0', port=1234)
-@patch('requests_cache.backends.mongo.MongoClient')
-@patch(
- 'requests_cache.backends.mongo.get_valid_kwargs',
- side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs),
-)
-def test_connection_kwargs(mock_get_valid_kwargs, mock_client):
- """A spot check to make sure optional connection kwargs gets passed to connection"""
- MongoDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
- mock_client.assert_called_with(host='http://0.0.0.0', port=1234)
+
+class TestMongoCache(BaseCacheTest):
+ backend_class = MongoCache
diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py
index 5edf559..899d2b6 100644
--- a/tests/integration/test_redis.py
+++ b/tests/integration/test_redis.py
@@ -1,10 +1,9 @@
import pytest
-import unittest
from unittest.mock import patch
from requests_cache.backends.redis import RedisCache, RedisDict
from tests.conftest import fail_if_no_connection
-from tests.integration.test_backends import BaseStorageTestCase
+from tests.integration.test_backends import BaseCacheTest, BaseStorageTest
@pytest.fixture(scope='module', autouse=True)
@@ -16,14 +15,16 @@ def ensure_connection():
Redis().info()
-class RedisTestCase(BaseStorageTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=RedisDict, picklable=True, **kwargs)
+class TestRedisDict(BaseStorageTest):
+ storage_class = RedisDict
+ picklable = True
+ @patch('requests_cache.backends.redis.StrictRedis')
+ def test_connection_kwargs(self, mock_redis):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ RedisCache('test', username='user', password='pass', invalid_kwarg='???')
+ mock_redis.assert_called_with(username='user', password='pass')
-# @patch.object(Redis, '__init__', Redis.__init__)
-@patch('requests_cache.backends.redis.StrictRedis')
-def test_connection_kwargs(mock_redis):
- """A spot check to make sure optional connection kwargs gets passed to connection"""
- RedisCache('test', username='user', password='pass', invalid_kwarg='???')
- mock_redis.assert_called_with(username='user', password='pass')
+
+class TestRedisCache(BaseCacheTest):
+ backend_class = RedisCache
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index a6fd50d..63c67ec 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -1,110 +1,94 @@
import os
-import unittest
from threading import Thread
from unittest.mock import patch
-from requests_cache.backends.sqlite import DbDict, DbPickleDict
-from tests.integration.test_backends import BaseStorageTestCase
+from requests_cache.backends.sqlite import DbCache, DbDict, DbPickleDict
+from tests.integration.test_backends import CACHE_NAME, BaseCacheTest, BaseStorageTest
-class SQLiteTestCase(BaseStorageTestCase):
- def tearDown(self):
+class SQLiteTestCase(BaseStorageTest):
+ @classmethod
+ def teardown_class(cls):
try:
- os.unlink(self.NAMESPACE)
+ os.unlink(CACHE_NAME)
except Exception:
pass
def test_bulk_commit(self):
- d = self.storage_class(self.NAMESPACE, self.TABLES[0])
- with d.bulk_commit():
+ cache = self.init_cache()
+ with cache.bulk_commit():
pass
- d.clear()
+
n = 1000
- with d.bulk_commit():
+ with cache.bulk_commit():
for i in range(n):
- d[i] = i
- assert list(d.keys()) == list(range(n))
+ cache[i] = i
+ assert list(cache.keys()) == list(range(n))
def test_switch_commit(self):
- d = self.storage_class(self.NAMESPACE)
- d.clear()
- d[1] = 1
- d = self.storage_class(self.NAMESPACE)
- assert 1 in d
+ cache = self.init_cache()
+ cache.clear()
+ cache['key_1'] = 'value_1'
+ cache = self.init_cache(clear=False)
+ assert 'key_1' in cache
- d._can_commit = False
- d[2] = 2
+ cache._can_commit = False
+ cache['key_2'] = 'value_2'
- d = self.storage_class(self.NAMESPACE)
- assert 2 not in d
- assert d._can_commit is True
+ cache = self.init_cache(clear=False)
+ assert 2 not in cache
+ assert cache._can_commit is True
def test_fast_save(self):
- d1 = self.storage_class(self.NAMESPACE, fast_save=True)
- d2 = self.storage_class(self.NAMESPACE, self.TABLES[1], fast_save=True)
- d1.clear()
+ cache_1 = self.init_cache(1, fast_save=True)
+ cache_2 = self.init_cache(2, fast_save=True)
+
n = 1000
for i in range(n):
- d1[i] = i
- d2[i * 2] = i
- # HACK if we will not sort, fast save can produce different order of records
- assert sorted(d1.keys()) == list(range(n))
- assert sorted(d2.values()) == list(range(n))
-
- def test_usage_with_threads(self):
- def do_test_for(d, n_threads=5):
- d.clear()
-
- def do_inserts(values):
- for v in values:
- d[v] = v
-
- def values(x, n):
- return [i * x for i in range(n)]
-
- threads = [Thread(target=do_inserts, args=(values(i, n_threads),)) for i in range(n_threads)]
- for t in threads:
- t.start()
- for t in threads:
- t.join()
-
- for i in range(n_threads):
- for x in values(i, n_threads):
- assert d[x] == x
-
- do_test_for(self.storage_class(self.NAMESPACE))
- do_test_for(self.storage_class(self.NAMESPACE, fast_save=True), 20)
- do_test_for(self.storage_class(self.NAMESPACE, fast_save=True))
- do_test_for(self.storage_class(self.NAMESPACE, self.TABLES[1], fast_save=True))
+ cache_1[i] = i
+ cache_2[i * 2] = i
+
+ assert set(cache_1.keys()) == set(range(n))
+ assert set(cache_2.values()) == set(range(n))
def test_noop(self):
- def do_noop_bulk(d):
- with d.bulk_commit():
+ def do_noop_bulk(cache):
+ with cache.bulk_commit():
pass
- del d
+ del cache
- d = self.storage_class(self.NAMESPACE)
- t = Thread(target=do_noop_bulk, args=(d,))
- t.start()
- t.join()
+ cache = self.init_cache()
+ thread = Thread(target=do_noop_bulk, args=(cache,))
+ thread.start()
+ thread.join()
# make sure connection is not closed by the thread
- d[0] = 0
- assert str(d) == "{0: 0}"
+ cache['key_1'] = 'value_1'
+ assert list(cache.keys()) == ['key_1']
+
+ @patch('requests_cache.backends.sqlite.sqlite3')
+ def test_connection_kwargs(self, mock_sqlite):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ cache = self.storage_class('test', timeout=0.5, invalid_kwarg='???')
+ mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5)
-class DbDictTestCase(SQLiteTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=DbDict, **kwargs)
+class TestDbDict(SQLiteTestCase):
+ storage_class = DbDict
-class DbPickleDictTestCase(SQLiteTestCase, unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, storage_class=DbPickleDict, picklable=True, **kwargs)
+class TestDbPickleDict(SQLiteTestCase):
+ storage_class = DbPickleDict
+ picklable = True
-@patch('requests_cache.backends.sqlite.sqlite3')
-def test_connection_kwargs(mock_sqlite):
- """A spot check to make sure optional connection kwargs gets passed to connection"""
- DbDict('test', timeout=0.5, invalid_kwarg='???')
- mock_sqlite.connect.assert_called_with('test', timeout=0.5)
+class TestDbCache(BaseCacheTest):
+ backend_class = DbCache
+ init_kwargs = {'use_temp': True}
+
+ @classmethod
+ def teardown_class(cls):
+ try:
+ os.unlink(CACHE_NAME)
+ except Exception:
+ pass
diff --git a/tests/integration/test_thread_safety.py b/tests/integration/test_thread_safety.py
deleted file mode 100644
index a62133a..0000000
--- a/tests/integration/test_thread_safety.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import pytest
-from os import getenv
-from threading import Thread
-from time import time
-
-from requests_cache.backends import BACKEND_CLASSES
-from requests_cache.session import CachedSession
-from tests.conftest import AWS_OPTIONS, httpbin
-
-# Allow running longer stress tests with an environment variable
-MULTIPLIER = int(getenv('STRESS_TEST_MULTIPLIER', '1'))
-N_THREADS = 2 * MULTIPLIER
-N_ITERATIONS = 4 * MULTIPLIER
-
-
-@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
-@pytest.mark.parametrize('backend', BACKEND_CLASSES.keys())
-def test_caching_with_threads(backend, iteration):
- """Run a multi-threaded stress test for each backend"""
- start = time()
- session = CachedSession(backend=backend, use_temp=True, **AWS_OPTIONS)
- session.cache.clear()
- url = httpbin('anything')
-
- def send_requests():
- for i in range(N_ITERATIONS):
- session.get(url, params={f'key_{i}': f'value_{i}'})
-
- threads = [Thread(target=send_requests) for i in range(N_THREADS)]
- for t in threads:
- t.start()
- for t in threads:
- t.join()
-
- elapsed = time() - start
- average = (elapsed * 1000) / (N_ITERATIONS * N_THREADS)
- print(f'{backend}: Ran {N_ITERATIONS} iterations with {N_THREADS} threads each in {elapsed} s')
- print(f'Average time per request: {average} ms')
-
- for i in range(N_ITERATIONS):
- assert session.cache.has_url(f'{url}?key_{i}=value_{i}')