diff options
author | Jordan Cook <JWCook@users.noreply.github.com> | 2021-04-29 15:28:06 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-04-29 15:33:05 -0500 |
commit | c70dfde93ab43326fca9093f226301e173ebd3ab (patch) | |
tree | 8d5a4a6b8727899710348d75fa418e0a0e1c162c | |
parent | a2cd1afd4c4cfa9517a21f3ffc5ac04bbd3fe9af (diff) | |
parent | 4cd4c7fb9e0c837201afa16d4f0aefb81b824777 (diff) | |
download | requests-cache-c70dfde93ab43326fca9093f226301e173ebd3ab.tar.gz |
Merge pull request #254 from JWCook/refactor-delete-methods
Improve performance for remove_expired_responses(), mainly for SQLite backend
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | HISTORY.md | 1 | ||||
-rw-r--r-- | examples/generate_test_db.py | 110 | ||||
-rw-r--r-- | requests_cache/backends/base.py | 53 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 82 | ||||
-rwxr-xr-x | requests_cache/response.py | 1 | ||||
-rw-r--r-- | tests/conftest.py | 7 | ||||
-rw-r--r-- | tests/integration/base_cache_test.py | 23 | ||||
-rw-r--r-- | tests/integration/test_sqlite.py | 4 | ||||
-rw-r--r-- | tests/unit/test_cache.py | 83 | ||||
-rw-r--r-- | tests/unit/test_patcher.py | 15 |
11 files changed, 273 insertions, 108 deletions
@@ -1,5 +1,7 @@ +*.db *.py[cod] *.sqlite +*.sqlite-journal *.egg *.egg-info build/ @@ -9,6 +9,7 @@ * Update `BaseCache.urls` to only skip invalid responses, not delete them * Update `old_data_on_error` option to also handle error response codes * Use thread-local connections for SQLite backend +* Improve performance for `remove_expired_responses()`, mainly for SQLite backend * Fix `DynamoDbDict.__iter__` to return keys instead of values ### 0.6.3 (2021-04-21) diff --git a/examples/generate_test_db.py b/examples/generate_test_db.py new file mode 100644 index 0000000..e3ce60f --- /dev/null +++ b/examples/generate_test_db.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +"""An example of generating a test database with a large number of semi-randomized responses. +This is useful, for example, for reproducing issues that only occur with large caches. +""" +import logging +from os import urandom +from os.path import getsize +from random import random +from time import perf_counter as time + +from rich.progress import Progress + +from requests_cache import ALL_METHODS, CachedSession, CachedResponse +from requests_cache.response import format_file_size +from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, httpbin + + +CACHE_NAME = 'rubbish_bin.sqlite' +HTTPBIN_EXTRA_ENDPOINTS = [ + 'anything', + 'bytes/1024' 'cookies', + 'ip', + 'redirect/5', + 'stream-bytes/1024', +] +MAX_EXPIRE_AFTER = 30 +MAX_RESPONSE_SIZE = 10000 +N_RESPONSES = 100000 +N_INVALID_RESPONSES = 10 + +logging.basicConfig(level='INFO') +logger = logging.getLogger('requests_cache') + + +class InvalidResponse(CachedResponse): + """Response that will raise an exception when deserialized""" + + def __setstate__(self, d): + raise ValueError + + +def populate_cache(progress, task): + session = CachedSession(CACHE_NAME, backend='sqlite', allowable_methods=ALL_METHODS) + + # Cache a variety of different response formats, which may result in different behavior + urls = [('GET', httpbin(endpoint)) for endpoint in HTTPBIN_FORMATS + HTTPBIN_EXTRA_ENDPOINTS] + urls += [(method, httpbin(method.lower())) for method in HTTPBIN_METHODS] + for method, url in urls: + session.request(method, url) + progress.update(task, advance=1) + + # Cache a large number of responses with randomized response content, which will expire at random times + response = session.get(httpbin('get')) + with session.cache.responses.bulk_commit(): + for i in range(N_RESPONSES): + new_response = CachedResponse(response) + n_bytes = int(random() * MAX_RESPONSE_SIZE) + new_response._content = urandom(n_bytes) + + new_response.request.url += f'/response_{i}' + expire_after = random() * MAX_EXPIRE_AFTER + session.cache.save_response(new_response, expire_after=expire_after) + progress.update(task, advance=1) + + # Add some invalid responses + with session.cache.responses.bulk_commit(): + for i in range(N_INVALID_RESPONSES): + new_response = InvalidResponse(response) + new_response.request.url += f'/invalid_response_{i}' + key = session.cache.create_key(new_response.request) + session.cache.responses[key] = new_response + progress.update(task, advance=1) + + +def remove_expired_responses(expire_after=None): + logger.setLevel('DEBUG') + session = CachedSession(CACHE_NAME) + total_responses = len(session.cache.responses) + + start = time() + session.remove_expired_responses(expire_after=expire_after) + elapsed = time() - start + n_removed = total_responses - len(session.cache.responses) + logger.info( + f'Removed {n_removed} expired/invalid responses in {elapsed:.2f} seconds ' + f'(avg {(elapsed / n_removed) * 1000:.2f}ms per response)' + ) + + +def main(): + total_responses = len(HTTPBIN_FORMATS + HTTPBIN_EXTRA_ENDPOINTS + HTTPBIN_METHODS) + total_responses += N_RESPONSES + N_INVALID_RESPONSES + + with Progress() as progress: + task = progress.add_task('[cyan]Generating responses...', total=total_responses) + populate_cache(progress, task) + + actual_total_responses = len(CachedSession(CACHE_NAME).cache.responses) + cache_file_size = format_file_size(getsize(CACHE_NAME)) + logger.info(f'Generated cache with {actual_total_responses} responses ({cache_file_size})') + + +if __name__ == '__main__': + main() + + # Remove some responses (with randomized expiration) + # remove_expired_responses() + + # Expire and remove all responses + # remove_expired_responses(expire_after=1) diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 0ac4164..008209a 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -84,28 +84,28 @@ class BaseCache: return default def delete(self, key: str): - """Delete `key` from cache. Also deletes all responses from response history""" - self.delete_history(key) + """Delete a response or redirect from the cache, as well any associated redirect history""" + # If it's a response key, first delete any associated redirect history + try: + for r in self.responses[key].history: + del self.redirects[create_key(r.request, self.ignored_parameters)] + except (KeyError, *DESERIALIZE_ERRORS): + pass + # Then delete the response itself, or just the redirect if it's a redirect key for cache in [self.responses, self.redirects]: - # Skip `contains` checks to reduce # of service calls try: del cache[key] except (AttributeError, KeyError): pass - def delete_history(self, key: str): - """Delete redirect history associated with a response, if any""" - try: - response = self.responses[key] or self.responses[self.redirects[key]] - for r in response.history: - del self.redirects[create_key(r.request, self.ignored_parameters)] - except Exception: - pass + # TODO: Implement backend-specific bulk delete methods + def delete_all(self, keys): + """Remove multiple responses and their associated redirects from the cache""" + for key in keys: + self.delete(key) def delete_url(self, url: str): - """Delete response + redirects associated with `url` from cache. - Works only for GET requests. - """ + """Delete a cached response + redirects for ``GET <url>``""" self.delete(url_to_key(url, self.ignored_parameters)) def clear(self): @@ -115,19 +115,29 @@ class BaseCache: self.redirects.clear() def remove_expired_responses(self, expire_after: ExpirationTime = None): - """Remove expired responses from the cache, optionally with revalidation + """Remove expired and invalid responses from the cache, optionally with revalidation Args: expire_after: A new expiration time used to revalidate the cache """ logger.info('Removing expired responses.' + (f'Revalidating with: {expire_after}' if expire_after else '')) - # _get_valid_responses must be consumed before making any additional writes - for key, response in list(self._get_valid_responses(delete_invalid=True)): + keys_to_update = {} + keys_to_delete = [] + + for key, response in self._get_valid_responses(delete_invalid=True): # If we're revalidating and it's not yet expired, update the cached item's expiration if expire_after is not None and not response.revalidate(expire_after): - self.responses[key] = response + keys_to_update[key] = response if response.is_expired: - self.delete(key) + keys_to_delete.append(key) + + # Delay updates & deletes until the end, to avoid conflicts with _get_valid_responses() + logger.debug(f'Deleting {len(keys_to_delete)} expired responses') + self.delete_all(keys_to_delete) + if expire_after is not None: + logger.debug(f'Updating {len(keys_to_update)} revalidated responses') + for key, response in keys_to_update.items(): + self.responses[key] = response def remove_old_entries(self, *args, **kwargs): msg = 'BaseCache.remove_old_entries() is deprecated; please use CachedSession.remove_expired_responses()' @@ -165,11 +175,10 @@ class BaseCache: for key in self.responses.keys(): try: yield key, self.responses[key] - except DESERIALIZE_ERRORS as e: - logger.debug(f'Unable to deserialize response with key {key}: {e}') + except DESERIALIZE_ERRORS: keys_to_delete.append(key) - # Delay deletion until the end, to slightly improve responsiveness when used as a generator + # Delay deletion until the end, to improve responsiveness when used as a generator if delete_invalid: logger.debug(f'Deleting {len(keys_to_delete)} invalid responses') for key in keys_to_delete: diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 5ad100a..b311247 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -6,7 +6,7 @@ from os import makedirs from os.path import abspath, basename, dirname, expanduser, isabs, join from pathlib import Path from tempfile import gettempdir -from typing import Type, Union +from typing import Iterable, Tuple, Type, Union from . import BaseCache, BaseStorage, get_valid_kwargs @@ -37,10 +37,12 @@ class DbCache(BaseCache): self.responses = DbPickleDict(db_path, table_name='responses', use_temp=use_temp, fast_save=fast_save, **kwargs) self.redirects = DbDict(db_path, table_name='redirects', use_temp=use_temp, **kwargs) - def remove_expired_responses(self, *args, **kwargs): - """Remove expired responses from the cache, with additional cleanup""" - super().remove_expired_responses(*args, **kwargs) + def delete_all(self, keys): + """Remove multiple responses and their associated redirects from the cache, with additional cleanup""" + self.responses.delete_all(keys=keys) self.responses.vacuum() + self.redirects.delete_all(keys=keys) + self.redirects.delete_all(values=keys) self.redirects.vacuum() @@ -59,7 +61,7 @@ class DbDict(BaseStorage): Args: db_path: Database file path table_name: Table name - fast_save: Use `"PRAGMA synchronous = 0;" <http://www.sqlite.org/pragma.html#pragma_synchronous>`_ + fast_save: Use `'PRAGMA synchronous = 0;' <http://www.sqlite.org/pragma.html#pragma_synchronous>`_ to speed up cache saving, but with the potential for data loss timeout: Timeout for acquiring a database lock """ @@ -75,28 +77,33 @@ class DbDict(BaseStorage): self._can_commit = True self._local_context = threading.local() with sqlite3.connect(self.db_path, **self.connection_kwargs) as con: - con.execute("create table if not exists `%s` (key PRIMARY KEY, value)" % self.table_name) + self._create_table(con) + + # Initial CREATE TABLE must happen in shared connection; subsequent queries will use thread-local connections + def _create_table(self, connection): + connection.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)') @contextmanager - def connection(self, commit_on_success=False): - if not hasattr(self._local_context, "con"): + def connection(self, commit=False) -> sqlite3.Connection: + """Get a thread-local database connection""" + if not hasattr(self._local_context, 'con'): logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs) if self.fast_save: - self._local_context.con.execute("PRAGMA synchronous = 0;") + self._local_context.con.execute('PRAGMA synchronous = 0;') yield self._local_context.con - if commit_on_success and self._can_commit: + if commit and self._can_commit: self._local_context.con.commit() def __del__(self): - if hasattr(self._local_context, "con"): + if hasattr(self._local_context, 'con'): self._local_context.con.close() @contextmanager def bulk_commit(self): - """ - Context manager used to speedup insertion of big number of records - :: + """Context manager used to speed up insertion of a large number of records + + Example: >>> d1 = DbDict('test') >>> with d1.bulk_commit(): @@ -107,50 +114,62 @@ class DbDict(BaseStorage): self._can_commit = False try: yield - if hasattr(self._local_context, "con"): + if hasattr(self._local_context, 'con'): self._local_context.con.commit() finally: self._can_commit = True def __getitem__(self, key): with self.connection() as con: - row = con.execute("select value from `%s` where key=?" % self.table_name, (key,)).fetchone() + row = con.execute(f'SELECT value FROM {self.table_name} WHERE key=?', (key,)).fetchone() # raise error after the with block, otherwise the connection will be locked if not row: raise KeyError return row[0] def __setitem__(self, key, item): - with self.connection(True) as con: + with self.connection(commit=True) as con: con.execute( - "insert or replace into `%s` (key,value) values (?,?)" % self.table_name, + f'INSERT OR REPLACE INTO {self.table_name} (key,value) VALUES (?,?)', (key, item), ) def __delitem__(self, key): - with self.connection(True) as con: - cur = con.execute("delete from `%s` where key=?" % self.table_name, (key,)) + with self.connection(commit=True) as con: + cur = con.execute(f'DELETE FROM {self.table_name} WHERE key=?', (key,)) if not cur.rowcount: raise KeyError def __iter__(self): with self.connection() as con: - for row in con.execute("select key from `%s`" % self.table_name): + for row in con.execute(f'SELECT key FROM {self.table_name}'): yield row[0] def __len__(self): with self.connection() as con: - return con.execute("select count(key) from `%s`" % self.table_name).fetchone()[0] + return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0] def clear(self): - with self.connection(True) as con: - con.execute("drop table if exists `%s`" % self.table_name) - con.execute("create table `%s` (key PRIMARY KEY, value)" % self.table_name) - con.execute("vacuum") + with self.connection(commit=True) as con: + con.execute(f'DROP TABLE IF EXISTS {self.table_name}') + self._create_table(con) + con.execute('VACUUM') + + def delete_all(self, keys=None, values=None): + """Delete multiple records, either by keys or values""" + if not keys and not values: + return + + column = 'key' if keys else 'value' + marks, args = _format_sequence(keys or values) + statement = f'DELETE FROM {self.table_name} WHERE {column} IN ({marks})' + + with self.connection(commit=True) as con: + con.execute(statement, args) def vacuum(self): - with self.connection(True) as con: - con.execute("vacuum") + with self.connection(commit=True) as con: + con.execute('VACUUM') class DbPickleDict(DbDict): @@ -163,6 +182,13 @@ class DbPickleDict(DbDict): return self.deserialize(super().__getitem__(key)) +def _format_sequence(values) -> Tuple[str, Iterable]: + """Get SQL parameter marks for a sequence-based query, and ensure value is a sequence""" + if not isinstance(values, Iterable): + values = [values] + return ','.join(['?'] * len(values)), values + + def _get_db_path(db_path: Union[Path, str], use_temp: bool) -> str: """Get resolved path for database file""" # Save to a temp directory, if specified diff --git a/requests_cache/response.py b/requests_cache/response.py index 9b09474..0c051bd 100755 --- a/requests_cache/response.py +++ b/requests_cache/response.py @@ -85,7 +85,6 @@ class CachedResponse(Response): def _get_expiration_datetime(self, expire_after: ExpirationTime) -> Optional[datetime]: """Convert a time value or delta to an absolute datetime, if it's not already""" - logger.debug(f'Determining expiration time based on: {expire_after}') if expire_after is None or expire_after == -1: return None elif isinstance(expire_after, datetime): diff --git a/tests/conftest.py b/tests/conftest.py index 7f155d4..12cca1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,6 +131,13 @@ def tempfile_session() -> CachedSession: @pytest.fixture(scope='function') +def tempfile_path() -> CachedSession: + """Get a tempfile path that will be deleted after the test finished""" + with NamedTemporaryFile(suffix='.db') as temp: + yield temp.name + + +@pytest.fixture(scope='function') def installed_session() -> CachedSession: """Get a CachedSession using a temporary SQLite db, with global patching. Installs cache before test and uninstalls after. diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 7745cc4..6097488 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -2,7 +2,7 @@ import json import pytest from threading import Thread -from time import time +from time import sleep, time from typing import Dict, Type from requests_cache import ALL_METHODS, CachedResponse, CachedSession @@ -64,7 +64,7 @@ class BaseCacheTest: def test_redirects(self, endpoint, n_redirects): """Test all types of redirect endpoints with different numbers of consecutive redirects""" session = self.init_session() - session.get(httpbin(f'redirect/{n_redirects}')) + session.get(httpbin(f'{endpoint}/{n_redirects}')) r2 = session.get(httpbin('get')) assert r2.from_cache is True @@ -107,6 +107,25 @@ class BaseCacheTest: assert b'gzipped' in res.content assert b'gzipped' in res.raw.read(None, decode_content=True) + def test_remove_expired_responses(self): + session = self.init_session(expire_after=0.01) + + # Populate the cache with several responses that should expire immediately + for response_format in HTTPBIN_FORMATS: + session.get(httpbin(response_format)) + session.get(httpbin('redirect/1')) + sleep(0.01) + + # Cache a response + redirects, which should be the only non-expired cache items + session.get(httpbin('get'), expire_after=-1) + session.get(httpbin('redirect/3'), expire_after=-1) + session.cache.remove_expired_responses() + + assert len(session.cache.responses.keys()) == 2 + assert len(session.cache.redirects.keys()) == 3 + assert not session.cache.has_url(httpbin('redirect/1')) + assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS]) + @pytest.mark.parametrize('iteration', range(N_ITERATIONS)) def test_multithreaded(self, iteration): """Run a multi-threaded stress test for each backend""" diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index 6efba0b..703a9d1 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -14,7 +14,7 @@ class SQLiteTestCase(BaseStorageTest): @classmethod def teardown_class(cls): try: - os.unlink(CACHE_NAME) + os.unlink(f'{CACHE_NAME}.sqlite') except Exception: pass @@ -79,7 +79,7 @@ class SQLiteTestCase(BaseStorageTest): @patch('requests_cache.backends.sqlite.sqlite3') def test_connection_kwargs(self, mock_sqlite): """A spot check to make sure optional connection kwargs gets passed to connection""" - cache = self.storage_class('test', timeout=0.5, invalid_kwarg='???') + cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???') mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5) diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index f429b12..2b94033 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -123,16 +123,15 @@ def test_response_history(mock_session): assert len(mock_session.cache.redirects) == 1 -def test_repr(): +def test_repr(mock_session): """Test session and cache string representations""" - cache_name = 'requests_cache_test' - session = CachedSession(cache_name=cache_name, backend='memory', expire_after=10) - session.cache.responses['key'] = 'value' - session.cache.redirects['key'] = 'value' - session.cache.redirects['key_2'] = 'value' + mock_session.expire_after = 10.5 + mock_session.cache.responses['key'] = 'value' + mock_session.cache.redirects['key'] = 'value' + mock_session.cache.redirects['key_2'] = 'value' - assert cache_name in repr(session) and '10' in repr(session) - assert 'redirects: 2' in str(session.cache) and 'responses: 1' in str(session.cache) + assert mock_session._cache_name in repr(mock_session) and '10.5' in repr(mock_session) + assert 'redirects: 2' in str(mock_session.cache) and 'responses: 1' in str(mock_session.cache) def test_urls(mock_session): @@ -330,7 +329,7 @@ def test_response_defaults(mock_session): assert response_2.is_expired is response_3.is_expired is False -def testinclude_get_headers(mock_session): +def test_include_get_headers(mock_session): """With include_get_headers, requests with different headers should have different cache keys""" mock_session.cache.include_get_headers = True headers_list = [{'Accept': 'text/json'}, {'Accept': 'text/xml'}, {'Accept': 'custom'}, None] @@ -339,7 +338,7 @@ def testinclude_get_headers(mock_session): assert mock_session.get(MOCKED_URL, headers=headers).from_cache is True -def testinclude_get_headers_normalize(mock_session): +def test_include_get_headers_normalize(mock_session): """With include_get_headers, the same headers (in any order) should have the same cache key""" mock_session.cache.include_get_headers = True headers = {'Accept': 'text/json', 'Custom': 'abc'} @@ -432,16 +431,14 @@ def test_cache_disabled__nested(mock_session): ('some_other_site.com', None), ], ) -def test_urls_expire_after(url, expected_expire_after): - session = CachedSession( - urls_expire_after={ - '*.site_1.com': 60 * 60, - 'site_2.com/resource_1': 60 * 60 * 2, - 'site_2.com/resource_2': 60 * 60 * 24, - 'site_2.com/static': -1, - }, - ) - assert session._url_expire_after(url) == expected_expire_after +def test_urls_expire_after(url, expected_expire_after, mock_session): + mock_session.urls_expire_after = { + '*.site_1.com': 60 * 60, + 'site_2.com/resource_1': 60 * 60 * 2, + 'site_2.com/resource_2': 60 * 60 * 24, + 'site_2.com/static': -1, + } + assert mock_session._url_expire_after(url) == expected_expire_after @pytest.mark.parametrize( @@ -453,26 +450,25 @@ def test_urls_expire_after(url, expected_expire_after): ('https://any_other_site.com', 1), ], ) -def test_urls_expire_after__evaluation_order(url, expected_expire_after): +def test_urls_expire_after__evaluation_order(url, expected_expire_after, mock_session): """If there are multiple matches, the first match should be used in the order defined""" - session = CachedSession( - urls_expire_after={ - '*.site_1.com/resource': 60 * 60 * 2, - '*.site_1.com': 60 * 60, - '*': 1, - }, - ) - assert session._url_expire_after(url) == expected_expire_after - - -def test_get_expiration_precedence(): - session = CachedSession(expire_after=1, urls_expire_after={'*.site_1.com': 60 * 60}) - assert session._get_expiration() == 1 - assert session._get_expiration('site_2.com') == 1 - assert session._get_expiration('img.site_1.com/image.jpg') == 60 * 60 - with session.request_expire_after(30): - assert session._get_expiration() == 30 - assert session._get_expiration('img.site_1.com/image.jpg') == 30 + mock_session.urls_expire_after = { + '*.site_1.com/resource': 60 * 60 * 2, + '*.site_1.com': 60 * 60, + '*': 1, + } + assert mock_session._url_expire_after(url) == expected_expire_after + + +def test_get_expiration_precedence(mock_session): + mock_session.expire_after = 1 + mock_session.urls_expire_after = {'*.site_1.com': 60 * 60} + assert mock_session._get_expiration() == 1 + assert mock_session._get_expiration('site_2.com') == 1 + assert mock_session._get_expiration('img.site_1.com/image.jpg') == 60 * 60 + with mock_session.request_expire_after(30): + assert mock_session._get_expiration() == 30 + assert mock_session._get_expiration('img.site_1.com/image.jpg') == 30 def test_remove_expired_responses(mock_session): @@ -591,14 +587,13 @@ def test_unpickle_errors(mock_session): assert resp.json()['message'] == 'mock json response' -def test_cache_signing(): - # Without a secret key, plain pickle should be used - session = CachedSession() +def test_cache_signing(tempfile_path): + session = CachedSession(tempfile_path) assert session.cache.responses._serializer == pickle # With a secret key, itsdangerous should be used secret_key = str(uuid4()) - session = CachedSession(secret_key=secret_key) + session = CachedSession(tempfile_path, secret_key=secret_key) assert isinstance(session.cache.responses._serializer, Serializer) # Simple serialize/deserialize round trip @@ -606,6 +601,6 @@ def test_cache_signing(): assert session.cache.responses['key'] == 'value' # Without the same signing key, the item shouldn't be considered safe to deserialize - session = CachedSession(secret_key='a different key') + session = CachedSession(tempfile_path, secret_key='a different key') with pytest.raises(BadSignature): session.cache.responses['key'] diff --git a/tests/unit/test_patcher.py b/tests/unit/test_patcher.py index 4a3e417..fb1c25a 100644 --- a/tests/unit/test_patcher.py +++ b/tests/unit/test_patcher.py @@ -6,15 +6,12 @@ from requests.sessions import Session as OriginalSession import requests_cache from requests_cache import CachedSession from requests_cache.backends import BaseCache - -CACHE_NAME = 'requests_cache_test' -CACHE_BACKEND = 'sqlite' -FAST_SAVE = False +from tests.conftest import CACHE_NAME def test_install_uninstall(): for _ in range(2): - requests_cache.install_cache(name=CACHE_NAME, backend=CACHE_BACKEND) + requests_cache.install_cache(name=CACHE_NAME, use_temp=True) assert isinstance(requests.Session(), CachedSession) assert isinstance(requests.sessions.Session(), CachedSession) requests_cache.uninstall_cache() @@ -61,8 +58,8 @@ def test_disabled(cached_request, original_request, installed_session): @patch.object(OriginalSession, 'request') @patch.object(CachedSession, 'request') -def test_enabled(cached_request, original_request): - with requests_cache.enabled(): +def test_enabled(cached_request, original_request, tempfile_path): + with requests_cache.enabled(tempfile_path): for i in range(3): requests.get('some_url') assert cached_request.call_count == 3 @@ -70,8 +67,8 @@ def test_enabled(cached_request, original_request): @patch.object(BaseCache, 'remove_expired_responses') -def test_remove_expired_responses(remove_expired_responses): - requests_cache.install_cache(expire_after=360) +def test_remove_expired_responses(remove_expired_responses, tempfile_path): + requests_cache.install_cache(tempfile_path, expire_after=360) requests_cache.remove_expired_responses() assert remove_expired_responses.called is True requests_cache.uninstall_cache() |