summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <JWCook@users.noreply.github.com>2021-04-29 15:28:06 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-04-29 15:33:05 -0500
commitc70dfde93ab43326fca9093f226301e173ebd3ab (patch)
tree8d5a4a6b8727899710348d75fa418e0a0e1c162c
parenta2cd1afd4c4cfa9517a21f3ffc5ac04bbd3fe9af (diff)
parent4cd4c7fb9e0c837201afa16d4f0aefb81b824777 (diff)
downloadrequests-cache-c70dfde93ab43326fca9093f226301e173ebd3ab.tar.gz
Merge pull request #254 from JWCook/refactor-delete-methods
Improve performance for remove_expired_responses(), mainly for SQLite backend
-rw-r--r--.gitignore2
-rw-r--r--HISTORY.md1
-rw-r--r--examples/generate_test_db.py110
-rw-r--r--requests_cache/backends/base.py53
-rw-r--r--requests_cache/backends/sqlite.py82
-rwxr-xr-xrequests_cache/response.py1
-rw-r--r--tests/conftest.py7
-rw-r--r--tests/integration/base_cache_test.py23
-rw-r--r--tests/integration/test_sqlite.py4
-rw-r--r--tests/unit/test_cache.py83
-rw-r--r--tests/unit/test_patcher.py15
11 files changed, 273 insertions, 108 deletions
diff --git a/.gitignore b/.gitignore
index ff5da6f..04cbefb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,7 @@
+*.db
*.py[cod]
*.sqlite
+*.sqlite-journal
*.egg
*.egg-info
build/
diff --git a/HISTORY.md b/HISTORY.md
index ab1e4f6..719b25f 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -9,6 +9,7 @@
* Update `BaseCache.urls` to only skip invalid responses, not delete them
* Update `old_data_on_error` option to also handle error response codes
* Use thread-local connections for SQLite backend
+* Improve performance for `remove_expired_responses()`, mainly for SQLite backend
* Fix `DynamoDbDict.__iter__` to return keys instead of values
### 0.6.3 (2021-04-21)
diff --git a/examples/generate_test_db.py b/examples/generate_test_db.py
new file mode 100644
index 0000000..e3ce60f
--- /dev/null
+++ b/examples/generate_test_db.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+"""An example of generating a test database with a large number of semi-randomized responses.
+This is useful, for example, for reproducing issues that only occur with large caches.
+"""
+import logging
+from os import urandom
+from os.path import getsize
+from random import random
+from time import perf_counter as time
+
+from rich.progress import Progress
+
+from requests_cache import ALL_METHODS, CachedSession, CachedResponse
+from requests_cache.response import format_file_size
+from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, httpbin
+
+
+CACHE_NAME = 'rubbish_bin.sqlite'
+HTTPBIN_EXTRA_ENDPOINTS = [
+ 'anything',
+ 'bytes/1024' 'cookies',
+ 'ip',
+ 'redirect/5',
+ 'stream-bytes/1024',
+]
+MAX_EXPIRE_AFTER = 30
+MAX_RESPONSE_SIZE = 10000
+N_RESPONSES = 100000
+N_INVALID_RESPONSES = 10
+
+logging.basicConfig(level='INFO')
+logger = logging.getLogger('requests_cache')
+
+
+class InvalidResponse(CachedResponse):
+ """Response that will raise an exception when deserialized"""
+
+ def __setstate__(self, d):
+ raise ValueError
+
+
+def populate_cache(progress, task):
+ session = CachedSession(CACHE_NAME, backend='sqlite', allowable_methods=ALL_METHODS)
+
+ # Cache a variety of different response formats, which may result in different behavior
+ urls = [('GET', httpbin(endpoint)) for endpoint in HTTPBIN_FORMATS + HTTPBIN_EXTRA_ENDPOINTS]
+ urls += [(method, httpbin(method.lower())) for method in HTTPBIN_METHODS]
+ for method, url in urls:
+ session.request(method, url)
+ progress.update(task, advance=1)
+
+ # Cache a large number of responses with randomized response content, which will expire at random times
+ response = session.get(httpbin('get'))
+ with session.cache.responses.bulk_commit():
+ for i in range(N_RESPONSES):
+ new_response = CachedResponse(response)
+ n_bytes = int(random() * MAX_RESPONSE_SIZE)
+ new_response._content = urandom(n_bytes)
+
+ new_response.request.url += f'/response_{i}'
+ expire_after = random() * MAX_EXPIRE_AFTER
+ session.cache.save_response(new_response, expire_after=expire_after)
+ progress.update(task, advance=1)
+
+ # Add some invalid responses
+ with session.cache.responses.bulk_commit():
+ for i in range(N_INVALID_RESPONSES):
+ new_response = InvalidResponse(response)
+ new_response.request.url += f'/invalid_response_{i}'
+ key = session.cache.create_key(new_response.request)
+ session.cache.responses[key] = new_response
+ progress.update(task, advance=1)
+
+
+def remove_expired_responses(expire_after=None):
+ logger.setLevel('DEBUG')
+ session = CachedSession(CACHE_NAME)
+ total_responses = len(session.cache.responses)
+
+ start = time()
+ session.remove_expired_responses(expire_after=expire_after)
+ elapsed = time() - start
+ n_removed = total_responses - len(session.cache.responses)
+ logger.info(
+ f'Removed {n_removed} expired/invalid responses in {elapsed:.2f} seconds '
+ f'(avg {(elapsed / n_removed) * 1000:.2f}ms per response)'
+ )
+
+
+def main():
+ total_responses = len(HTTPBIN_FORMATS + HTTPBIN_EXTRA_ENDPOINTS + HTTPBIN_METHODS)
+ total_responses += N_RESPONSES + N_INVALID_RESPONSES
+
+ with Progress() as progress:
+ task = progress.add_task('[cyan]Generating responses...', total=total_responses)
+ populate_cache(progress, task)
+
+ actual_total_responses = len(CachedSession(CACHE_NAME).cache.responses)
+ cache_file_size = format_file_size(getsize(CACHE_NAME))
+ logger.info(f'Generated cache with {actual_total_responses} responses ({cache_file_size})')
+
+
+if __name__ == '__main__':
+ main()
+
+ # Remove some responses (with randomized expiration)
+ # remove_expired_responses()
+
+ # Expire and remove all responses
+ # remove_expired_responses(expire_after=1)
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 0ac4164..008209a 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -84,28 +84,28 @@ class BaseCache:
return default
def delete(self, key: str):
- """Delete `key` from cache. Also deletes all responses from response history"""
- self.delete_history(key)
+ """Delete a response or redirect from the cache, as well any associated redirect history"""
+ # If it's a response key, first delete any associated redirect history
+ try:
+ for r in self.responses[key].history:
+ del self.redirects[create_key(r.request, self.ignored_parameters)]
+ except (KeyError, *DESERIALIZE_ERRORS):
+ pass
+ # Then delete the response itself, or just the redirect if it's a redirect key
for cache in [self.responses, self.redirects]:
- # Skip `contains` checks to reduce # of service calls
try:
del cache[key]
except (AttributeError, KeyError):
pass
- def delete_history(self, key: str):
- """Delete redirect history associated with a response, if any"""
- try:
- response = self.responses[key] or self.responses[self.redirects[key]]
- for r in response.history:
- del self.redirects[create_key(r.request, self.ignored_parameters)]
- except Exception:
- pass
+ # TODO: Implement backend-specific bulk delete methods
+ def delete_all(self, keys):
+ """Remove multiple responses and their associated redirects from the cache"""
+ for key in keys:
+ self.delete(key)
def delete_url(self, url: str):
- """Delete response + redirects associated with `url` from cache.
- Works only for GET requests.
- """
+ """Delete a cached response + redirects for ``GET <url>``"""
self.delete(url_to_key(url, self.ignored_parameters))
def clear(self):
@@ -115,19 +115,29 @@ class BaseCache:
self.redirects.clear()
def remove_expired_responses(self, expire_after: ExpirationTime = None):
- """Remove expired responses from the cache, optionally with revalidation
+ """Remove expired and invalid responses from the cache, optionally with revalidation
Args:
expire_after: A new expiration time used to revalidate the cache
"""
logger.info('Removing expired responses.' + (f'Revalidating with: {expire_after}' if expire_after else ''))
- # _get_valid_responses must be consumed before making any additional writes
- for key, response in list(self._get_valid_responses(delete_invalid=True)):
+ keys_to_update = {}
+ keys_to_delete = []
+
+ for key, response in self._get_valid_responses(delete_invalid=True):
# If we're revalidating and it's not yet expired, update the cached item's expiration
if expire_after is not None and not response.revalidate(expire_after):
- self.responses[key] = response
+ keys_to_update[key] = response
if response.is_expired:
- self.delete(key)
+ keys_to_delete.append(key)
+
+ # Delay updates & deletes until the end, to avoid conflicts with _get_valid_responses()
+ logger.debug(f'Deleting {len(keys_to_delete)} expired responses')
+ self.delete_all(keys_to_delete)
+ if expire_after is not None:
+ logger.debug(f'Updating {len(keys_to_update)} revalidated responses')
+ for key, response in keys_to_update.items():
+ self.responses[key] = response
def remove_old_entries(self, *args, **kwargs):
msg = 'BaseCache.remove_old_entries() is deprecated; please use CachedSession.remove_expired_responses()'
@@ -165,11 +175,10 @@ class BaseCache:
for key in self.responses.keys():
try:
yield key, self.responses[key]
- except DESERIALIZE_ERRORS as e:
- logger.debug(f'Unable to deserialize response with key {key}: {e}')
+ except DESERIALIZE_ERRORS:
keys_to_delete.append(key)
- # Delay deletion until the end, to slightly improve responsiveness when used as a generator
+ # Delay deletion until the end, to improve responsiveness when used as a generator
if delete_invalid:
logger.debug(f'Deleting {len(keys_to_delete)} invalid responses')
for key in keys_to_delete:
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 5ad100a..b311247 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -6,7 +6,7 @@ from os import makedirs
from os.path import abspath, basename, dirname, expanduser, isabs, join
from pathlib import Path
from tempfile import gettempdir
-from typing import Type, Union
+from typing import Iterable, Tuple, Type, Union
from . import BaseCache, BaseStorage, get_valid_kwargs
@@ -37,10 +37,12 @@ class DbCache(BaseCache):
self.responses = DbPickleDict(db_path, table_name='responses', use_temp=use_temp, fast_save=fast_save, **kwargs)
self.redirects = DbDict(db_path, table_name='redirects', use_temp=use_temp, **kwargs)
- def remove_expired_responses(self, *args, **kwargs):
- """Remove expired responses from the cache, with additional cleanup"""
- super().remove_expired_responses(*args, **kwargs)
+ def delete_all(self, keys):
+ """Remove multiple responses and their associated redirects from the cache, with additional cleanup"""
+ self.responses.delete_all(keys=keys)
self.responses.vacuum()
+ self.redirects.delete_all(keys=keys)
+ self.redirects.delete_all(values=keys)
self.redirects.vacuum()
@@ -59,7 +61,7 @@ class DbDict(BaseStorage):
Args:
db_path: Database file path
table_name: Table name
- fast_save: Use `"PRAGMA synchronous = 0;" <http://www.sqlite.org/pragma.html#pragma_synchronous>`_
+ fast_save: Use `'PRAGMA synchronous = 0;' <http://www.sqlite.org/pragma.html#pragma_synchronous>`_
to speed up cache saving, but with the potential for data loss
timeout: Timeout for acquiring a database lock
"""
@@ -75,28 +77,33 @@ class DbDict(BaseStorage):
self._can_commit = True
self._local_context = threading.local()
with sqlite3.connect(self.db_path, **self.connection_kwargs) as con:
- con.execute("create table if not exists `%s` (key PRIMARY KEY, value)" % self.table_name)
+ self._create_table(con)
+
+ # Initial CREATE TABLE must happen in shared connection; subsequent queries will use thread-local connections
+ def _create_table(self, connection):
+ connection.execute(f'CREATE TABLE IF NOT EXISTS {self.table_name} (key PRIMARY KEY, value)')
@contextmanager
- def connection(self, commit_on_success=False):
- if not hasattr(self._local_context, "con"):
+ def connection(self, commit=False) -> sqlite3.Connection:
+ """Get a thread-local database connection"""
+ if not hasattr(self._local_context, 'con'):
logger.debug(f'Opening connection to {self.db_path}:{self.table_name}')
self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs)
if self.fast_save:
- self._local_context.con.execute("PRAGMA synchronous = 0;")
+ self._local_context.con.execute('PRAGMA synchronous = 0;')
yield self._local_context.con
- if commit_on_success and self._can_commit:
+ if commit and self._can_commit:
self._local_context.con.commit()
def __del__(self):
- if hasattr(self._local_context, "con"):
+ if hasattr(self._local_context, 'con'):
self._local_context.con.close()
@contextmanager
def bulk_commit(self):
- """
- Context manager used to speedup insertion of big number of records
- ::
+ """Context manager used to speed up insertion of a large number of records
+
+ Example:
>>> d1 = DbDict('test')
>>> with d1.bulk_commit():
@@ -107,50 +114,62 @@ class DbDict(BaseStorage):
self._can_commit = False
try:
yield
- if hasattr(self._local_context, "con"):
+ if hasattr(self._local_context, 'con'):
self._local_context.con.commit()
finally:
self._can_commit = True
def __getitem__(self, key):
with self.connection() as con:
- row = con.execute("select value from `%s` where key=?" % self.table_name, (key,)).fetchone()
+ row = con.execute(f'SELECT value FROM {self.table_name} WHERE key=?', (key,)).fetchone()
# raise error after the with block, otherwise the connection will be locked
if not row:
raise KeyError
return row[0]
def __setitem__(self, key, item):
- with self.connection(True) as con:
+ with self.connection(commit=True) as con:
con.execute(
- "insert or replace into `%s` (key,value) values (?,?)" % self.table_name,
+ f'INSERT OR REPLACE INTO {self.table_name} (key,value) VALUES (?,?)',
(key, item),
)
def __delitem__(self, key):
- with self.connection(True) as con:
- cur = con.execute("delete from `%s` where key=?" % self.table_name, (key,))
+ with self.connection(commit=True) as con:
+ cur = con.execute(f'DELETE FROM {self.table_name} WHERE key=?', (key,))
if not cur.rowcount:
raise KeyError
def __iter__(self):
with self.connection() as con:
- for row in con.execute("select key from `%s`" % self.table_name):
+ for row in con.execute(f'SELECT key FROM {self.table_name}'):
yield row[0]
def __len__(self):
with self.connection() as con:
- return con.execute("select count(key) from `%s`" % self.table_name).fetchone()[0]
+ return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0]
def clear(self):
- with self.connection(True) as con:
- con.execute("drop table if exists `%s`" % self.table_name)
- con.execute("create table `%s` (key PRIMARY KEY, value)" % self.table_name)
- con.execute("vacuum")
+ with self.connection(commit=True) as con:
+ con.execute(f'DROP TABLE IF EXISTS {self.table_name}')
+ self._create_table(con)
+ con.execute('VACUUM')
+
+ def delete_all(self, keys=None, values=None):
+ """Delete multiple records, either by keys or values"""
+ if not keys and not values:
+ return
+
+ column = 'key' if keys else 'value'
+ marks, args = _format_sequence(keys or values)
+ statement = f'DELETE FROM {self.table_name} WHERE {column} IN ({marks})'
+
+ with self.connection(commit=True) as con:
+ con.execute(statement, args)
def vacuum(self):
- with self.connection(True) as con:
- con.execute("vacuum")
+ with self.connection(commit=True) as con:
+ con.execute('VACUUM')
class DbPickleDict(DbDict):
@@ -163,6 +182,13 @@ class DbPickleDict(DbDict):
return self.deserialize(super().__getitem__(key))
+def _format_sequence(values) -> Tuple[str, Iterable]:
+ """Get SQL parameter marks for a sequence-based query, and ensure value is a sequence"""
+ if not isinstance(values, Iterable):
+ values = [values]
+ return ','.join(['?'] * len(values)), values
+
+
def _get_db_path(db_path: Union[Path, str], use_temp: bool) -> str:
"""Get resolved path for database file"""
# Save to a temp directory, if specified
diff --git a/requests_cache/response.py b/requests_cache/response.py
index 9b09474..0c051bd 100755
--- a/requests_cache/response.py
+++ b/requests_cache/response.py
@@ -85,7 +85,6 @@ class CachedResponse(Response):
def _get_expiration_datetime(self, expire_after: ExpirationTime) -> Optional[datetime]:
"""Convert a time value or delta to an absolute datetime, if it's not already"""
- logger.debug(f'Determining expiration time based on: {expire_after}')
if expire_after is None or expire_after == -1:
return None
elif isinstance(expire_after, datetime):
diff --git a/tests/conftest.py b/tests/conftest.py
index 7f155d4..12cca1d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -131,6 +131,13 @@ def tempfile_session() -> CachedSession:
@pytest.fixture(scope='function')
+def tempfile_path() -> CachedSession:
+ """Get a tempfile path that will be deleted after the test finished"""
+ with NamedTemporaryFile(suffix='.db') as temp:
+ yield temp.name
+
+
+@pytest.fixture(scope='function')
def installed_session() -> CachedSession:
"""Get a CachedSession using a temporary SQLite db, with global patching.
Installs cache before test and uninstalls after.
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 7745cc4..6097488 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -2,7 +2,7 @@
import json
import pytest
from threading import Thread
-from time import time
+from time import sleep, time
from typing import Dict, Type
from requests_cache import ALL_METHODS, CachedResponse, CachedSession
@@ -64,7 +64,7 @@ class BaseCacheTest:
def test_redirects(self, endpoint, n_redirects):
"""Test all types of redirect endpoints with different numbers of consecutive redirects"""
session = self.init_session()
- session.get(httpbin(f'redirect/{n_redirects}'))
+ session.get(httpbin(f'{endpoint}/{n_redirects}'))
r2 = session.get(httpbin('get'))
assert r2.from_cache is True
@@ -107,6 +107,25 @@ class BaseCacheTest:
assert b'gzipped' in res.content
assert b'gzipped' in res.raw.read(None, decode_content=True)
+ def test_remove_expired_responses(self):
+ session = self.init_session(expire_after=0.01)
+
+ # Populate the cache with several responses that should expire immediately
+ for response_format in HTTPBIN_FORMATS:
+ session.get(httpbin(response_format))
+ session.get(httpbin('redirect/1'))
+ sleep(0.01)
+
+ # Cache a response + redirects, which should be the only non-expired cache items
+ session.get(httpbin('get'), expire_after=-1)
+ session.get(httpbin('redirect/3'), expire_after=-1)
+ session.cache.remove_expired_responses()
+
+ assert len(session.cache.responses.keys()) == 2
+ assert len(session.cache.redirects.keys()) == 3
+ assert not session.cache.has_url(httpbin('redirect/1'))
+ assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS])
+
@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
def test_multithreaded(self, iteration):
"""Run a multi-threaded stress test for each backend"""
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 6efba0b..703a9d1 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -14,7 +14,7 @@ class SQLiteTestCase(BaseStorageTest):
@classmethod
def teardown_class(cls):
try:
- os.unlink(CACHE_NAME)
+ os.unlink(f'{CACHE_NAME}.sqlite')
except Exception:
pass
@@ -79,7 +79,7 @@ class SQLiteTestCase(BaseStorageTest):
@patch('requests_cache.backends.sqlite.sqlite3')
def test_connection_kwargs(self, mock_sqlite):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
- cache = self.storage_class('test', timeout=0.5, invalid_kwarg='???')
+ cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???')
mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5)
diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py
index f429b12..2b94033 100644
--- a/tests/unit/test_cache.py
+++ b/tests/unit/test_cache.py
@@ -123,16 +123,15 @@ def test_response_history(mock_session):
assert len(mock_session.cache.redirects) == 1
-def test_repr():
+def test_repr(mock_session):
"""Test session and cache string representations"""
- cache_name = 'requests_cache_test'
- session = CachedSession(cache_name=cache_name, backend='memory', expire_after=10)
- session.cache.responses['key'] = 'value'
- session.cache.redirects['key'] = 'value'
- session.cache.redirects['key_2'] = 'value'
+ mock_session.expire_after = 10.5
+ mock_session.cache.responses['key'] = 'value'
+ mock_session.cache.redirects['key'] = 'value'
+ mock_session.cache.redirects['key_2'] = 'value'
- assert cache_name in repr(session) and '10' in repr(session)
- assert 'redirects: 2' in str(session.cache) and 'responses: 1' in str(session.cache)
+ assert mock_session._cache_name in repr(mock_session) and '10.5' in repr(mock_session)
+ assert 'redirects: 2' in str(mock_session.cache) and 'responses: 1' in str(mock_session.cache)
def test_urls(mock_session):
@@ -330,7 +329,7 @@ def test_response_defaults(mock_session):
assert response_2.is_expired is response_3.is_expired is False
-def testinclude_get_headers(mock_session):
+def test_include_get_headers(mock_session):
"""With include_get_headers, requests with different headers should have different cache keys"""
mock_session.cache.include_get_headers = True
headers_list = [{'Accept': 'text/json'}, {'Accept': 'text/xml'}, {'Accept': 'custom'}, None]
@@ -339,7 +338,7 @@ def testinclude_get_headers(mock_session):
assert mock_session.get(MOCKED_URL, headers=headers).from_cache is True
-def testinclude_get_headers_normalize(mock_session):
+def test_include_get_headers_normalize(mock_session):
"""With include_get_headers, the same headers (in any order) should have the same cache key"""
mock_session.cache.include_get_headers = True
headers = {'Accept': 'text/json', 'Custom': 'abc'}
@@ -432,16 +431,14 @@ def test_cache_disabled__nested(mock_session):
('some_other_site.com', None),
],
)
-def test_urls_expire_after(url, expected_expire_after):
- session = CachedSession(
- urls_expire_after={
- '*.site_1.com': 60 * 60,
- 'site_2.com/resource_1': 60 * 60 * 2,
- 'site_2.com/resource_2': 60 * 60 * 24,
- 'site_2.com/static': -1,
- },
- )
- assert session._url_expire_after(url) == expected_expire_after
+def test_urls_expire_after(url, expected_expire_after, mock_session):
+ mock_session.urls_expire_after = {
+ '*.site_1.com': 60 * 60,
+ 'site_2.com/resource_1': 60 * 60 * 2,
+ 'site_2.com/resource_2': 60 * 60 * 24,
+ 'site_2.com/static': -1,
+ }
+ assert mock_session._url_expire_after(url) == expected_expire_after
@pytest.mark.parametrize(
@@ -453,26 +450,25 @@ def test_urls_expire_after(url, expected_expire_after):
('https://any_other_site.com', 1),
],
)
-def test_urls_expire_after__evaluation_order(url, expected_expire_after):
+def test_urls_expire_after__evaluation_order(url, expected_expire_after, mock_session):
"""If there are multiple matches, the first match should be used in the order defined"""
- session = CachedSession(
- urls_expire_after={
- '*.site_1.com/resource': 60 * 60 * 2,
- '*.site_1.com': 60 * 60,
- '*': 1,
- },
- )
- assert session._url_expire_after(url) == expected_expire_after
-
-
-def test_get_expiration_precedence():
- session = CachedSession(expire_after=1, urls_expire_after={'*.site_1.com': 60 * 60})
- assert session._get_expiration() == 1
- assert session._get_expiration('site_2.com') == 1
- assert session._get_expiration('img.site_1.com/image.jpg') == 60 * 60
- with session.request_expire_after(30):
- assert session._get_expiration() == 30
- assert session._get_expiration('img.site_1.com/image.jpg') == 30
+ mock_session.urls_expire_after = {
+ '*.site_1.com/resource': 60 * 60 * 2,
+ '*.site_1.com': 60 * 60,
+ '*': 1,
+ }
+ assert mock_session._url_expire_after(url) == expected_expire_after
+
+
+def test_get_expiration_precedence(mock_session):
+ mock_session.expire_after = 1
+ mock_session.urls_expire_after = {'*.site_1.com': 60 * 60}
+ assert mock_session._get_expiration() == 1
+ assert mock_session._get_expiration('site_2.com') == 1
+ assert mock_session._get_expiration('img.site_1.com/image.jpg') == 60 * 60
+ with mock_session.request_expire_after(30):
+ assert mock_session._get_expiration() == 30
+ assert mock_session._get_expiration('img.site_1.com/image.jpg') == 30
def test_remove_expired_responses(mock_session):
@@ -591,14 +587,13 @@ def test_unpickle_errors(mock_session):
assert resp.json()['message'] == 'mock json response'
-def test_cache_signing():
- # Without a secret key, plain pickle should be used
- session = CachedSession()
+def test_cache_signing(tempfile_path):
+ session = CachedSession(tempfile_path)
assert session.cache.responses._serializer == pickle
# With a secret key, itsdangerous should be used
secret_key = str(uuid4())
- session = CachedSession(secret_key=secret_key)
+ session = CachedSession(tempfile_path, secret_key=secret_key)
assert isinstance(session.cache.responses._serializer, Serializer)
# Simple serialize/deserialize round trip
@@ -606,6 +601,6 @@ def test_cache_signing():
assert session.cache.responses['key'] == 'value'
# Without the same signing key, the item shouldn't be considered safe to deserialize
- session = CachedSession(secret_key='a different key')
+ session = CachedSession(tempfile_path, secret_key='a different key')
with pytest.raises(BadSignature):
session.cache.responses['key']
diff --git a/tests/unit/test_patcher.py b/tests/unit/test_patcher.py
index 4a3e417..fb1c25a 100644
--- a/tests/unit/test_patcher.py
+++ b/tests/unit/test_patcher.py
@@ -6,15 +6,12 @@ from requests.sessions import Session as OriginalSession
import requests_cache
from requests_cache import CachedSession
from requests_cache.backends import BaseCache
-
-CACHE_NAME = 'requests_cache_test'
-CACHE_BACKEND = 'sqlite'
-FAST_SAVE = False
+from tests.conftest import CACHE_NAME
def test_install_uninstall():
for _ in range(2):
- requests_cache.install_cache(name=CACHE_NAME, backend=CACHE_BACKEND)
+ requests_cache.install_cache(name=CACHE_NAME, use_temp=True)
assert isinstance(requests.Session(), CachedSession)
assert isinstance(requests.sessions.Session(), CachedSession)
requests_cache.uninstall_cache()
@@ -61,8 +58,8 @@ def test_disabled(cached_request, original_request, installed_session):
@patch.object(OriginalSession, 'request')
@patch.object(CachedSession, 'request')
-def test_enabled(cached_request, original_request):
- with requests_cache.enabled():
+def test_enabled(cached_request, original_request, tempfile_path):
+ with requests_cache.enabled(tempfile_path):
for i in range(3):
requests.get('some_url')
assert cached_request.call_count == 3
@@ -70,8 +67,8 @@ def test_enabled(cached_request, original_request):
@patch.object(BaseCache, 'remove_expired_responses')
-def test_remove_expired_responses(remove_expired_responses):
- requests_cache.install_cache(expire_after=360)
+def test_remove_expired_responses(remove_expired_responses, tempfile_path):
+ requests_cache.install_cache(tempfile_path, expire_after=360)
requests_cache.remove_expired_responses()
assert remove_expired_responses.called is True
requests_cache.uninstall_cache()