diff options
author | Jordan Cook <jordan.cook.git@proton.me> | 2022-10-28 17:12:44 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook.git@proton.me> | 2022-10-28 17:38:28 -0500 |
commit | b9dbc008be7274b6cc0d45078b3334249df1a530 (patch) | |
tree | 5024cc33e50abd4786a3bd73eb16bb32f9cb5de6 | |
parent | b4e270f6b129b75003d754987acea9e2d8670a5e (diff) | |
parent | cfe381fdae57d1e2f93d5f2732d3ed57fce05b66 (diff) | |
download | requests-cache-b9dbc008be7274b6cc0d45078b3334249df1a530.tar.gz |
Merge pull request #721 from requests-cache/sqlite-improvements
Sqlite improvements
-rw-r--r-- | HISTORY.md | 1 | ||||
-rw-r--r-- | requests_cache/backends/base.py | 31 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 79 | ||||
-rw-r--r-- | tests/integration/base_cache_test.py | 4 | ||||
-rw-r--r-- | tests/integration/test_sqlite.py | 56 | ||||
-rw-r--r-- | tests/unit/test_base_cache.py | 51 | ||||
-rw-r--r-- | tests/unit/test_session.py | 7 |
7 files changed, 165 insertions, 64 deletions
@@ -51,6 +51,7 @@ * Add `ttl_offset` argument to add a delay between cache expiration and deletion * **SQLite**: * Improve performance for removing expired responses with `delete()` + * Add `count()` method to count responses, with option to exclude expired responses (performs a fast indexed count instead of slower in-memory filtering) * Add `size()` method to get estimated size of the database (including in-memory databases) * Add `sorted()` method with sorting and other query options * Add `wal` parameter to enable write-ahead logging diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index bc57ca1..efbd214 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -70,11 +70,7 @@ class BaseCache: response = self.responses[self.redirects[key]] response.cache_key = key return response - except KeyError: - return default - except DESERIALIZE_ERRORS as e: - logger.error(f'Unable to deserialize response {key}: {str(e)}') - logger.debug(e, exc_info=True) + except (AttributeError, KeyError): return default def save_response(self, response: Response, cache_key: str = None, expires: datetime = None): @@ -296,8 +292,11 @@ class BaseCache: DeprecationWarning, ) yield from self.redirects.keys() - for response in self.filter(expired=not check_expiry): - yield response.cache_key + if not check_expiry: + yield from self.responses.keys() + else: + for response in self.filter(expired=False): + yield response.cache_key def response_count(self, check_expiry: bool = False) -> int: warn( @@ -394,16 +393,28 @@ class BaseStorage(MutableMapping[KT, VT], ABC): """Close any open backend connections""" def serialize(self, value: VT): - """Serialize value, if a serializer is available""" + """Serialize a value, if a serializer is available""" if TYPE_CHECKING: assert hasattr(self.serializer, 'dumps') return self.serializer.dumps(value) if self.serializer else value def deserialize(self, value: VT): - """Deserialize value, if a serializer is available""" + """Deserialize a value, if a serializer is available. + + If deserialization fails (usually due to a value saved in an older requests-cache version), + ``None`` will be returned. + """ + if not self.serializer: + return value if TYPE_CHECKING: assert hasattr(self.serializer, 'loads') - return self.serializer.loads(value) if self.serializer else value + + try: + return self.serializer.loads(value) + except DESERIALIZE_ERRORS as e: + logger.error(f'Unable to deserialize response: {str(e)}') + logger.debug(e, exc_info=True) + return None def __str__(self): return str(list(self.keys())) diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 502c0dd..f936a77 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,7 +17,9 @@ from typing import Collection, Iterator, List, Tuple, Type, Union from platformdirs import user_cache_dir +from requests_cache.backends.base import VT from requests_cache.models.response import CachedResponse +from requests_cache.policy import ExpirationTime from .._utils import chunkify, get_valid_kwargs from . import BaseCache, BaseStorage @@ -55,8 +57,8 @@ class SQLiteCache(BaseCache): return self.responses.db_path def clear(self): - """Clear the cache. If this fails due to a corrupted cache or other I/O error, this will - attempt to delete the cache file and re-initialize. + """Delete all items from the cache. If this fails due to a corrupted cache or other I/O + error, this will attempt to delete the cache file and re-initialize. """ try: super().clear() @@ -67,13 +69,13 @@ class SQLiteCache(BaseCache): self.responses.init_db() self.redirects.init_db() + # A more efficient SQLite implementation of :py:meth:`BaseCache.delete` def delete( self, *keys: str, expired: bool = False, **kwargs, ): - """A more efficient SQLite implementation of :py:meth:`BaseCache.delete`""" if keys: self.responses.bulk_delete(keys) if expired: @@ -109,19 +111,31 @@ class SQLiteCache(BaseCache): ')' ) - def filter( # type: ignore - self, valid: bool = True, expired: bool = True, **kwargs - ) -> Iterator[CachedResponse]: - """A more efficient implementation of :py:meth:`BaseCache.filter`, in the case where we want - to get **only** expired responses + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired + + Args: + expired: Set to ``False`` to count only unexpired responses """ - if expired and not valid and not kwargs: - return self.responses.sorted(expired=True) + return self.responses.count(expired=expired) + + # A more efficient implementation of :py:meth:`BaseCache.filter` to make use of indexes + def filter( + self, + valid: bool = True, + expired: bool = True, + invalid: bool = False, + older_than: ExpirationTime = None, + ) -> Iterator[CachedResponse]: + if valid and not invalid: + return self.responses.sorted(expired=expired) else: - return super().filter(valid, expired, **kwargs) + return super().filter( + valid=valid, expired=expired, invalid=invalid, older_than=older_than + ) + # A more efficient implementation of :py:meth:`BaseCache.recreate_keys def recreate_keys(self): - """A more efficient implementation of :py:meth:`BaseCache.recreate_keys`""" with self.responses.bulk_commit(): super().recreate_keys() @@ -162,6 +176,8 @@ class SQLiteDict(BaseStorage): self._local_context = threading.local() self._lock = threading.RLock() self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) + if use_memory: + self.connection_kwargs['uri'] = True self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory) self.fast_save = fast_save self.table_name = table_name @@ -172,9 +188,9 @@ class SQLiteDict(BaseStorage): """Initialize the database, if it hasn't already been""" self.close() with self._lock, self.connection() as con: - # Add new column to tables created before 0.10 + # Add new column to tables created before 1.0 try: - con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT') + con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER') except sqlite3.OperationalError: pass @@ -247,10 +263,8 @@ class SQLiteDict(BaseStorage): def __setitem__(self, key, value): # If available, set expiration as a timestamp in unix format - expires = value.expires_unix if getattr(value, 'expires_unix', None) else None + expires = getattr(value, 'expires_unix', None) value = self.serialize(value) - if isinstance(value, bytes): - value = sqlite3.Binary(value) with self.connection(commit=True) as con: con.execute( f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)', @@ -263,8 +277,7 @@ class SQLiteDict(BaseStorage): yield row[0] def __len__(self): - with self.connection() as con: - return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0] + return self.count() def bulk_delete(self, keys=None, values=None): """Delete multiple items from the cache, without raising errors for any missing items. @@ -288,6 +301,22 @@ class SQLiteDict(BaseStorage): self.init_db() self.vacuum() + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired""" + filter_expr = '' + params: Tuple = () + if not expired: + filter_expr = 'WHERE expires is null or expires > ?' + params = (time(),) + query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}' + + with self.connection() as con: + return con.execute(query, params).fetchone()[0] + + def serialize(self, value: VT): + value = super().serialize(value) + return sqlite3.Binary(value) if isinstance(value, bytes) else value + def size(self) -> int: """Return the size of the database, in bytes. For an in-memory database, this will be an estimate based on page size. @@ -325,11 +354,19 @@ class SQLiteDict(BaseStorage): with self.connection(commit=True) as con: for row in con.execute( - f'SELECT value FROM {self.table_name} {filter_expr}' + f'SELECT key, value FROM {self.table_name} {filter_expr}' f' ORDER BY {key} {direction} {limit_expr}', params, ): - yield self.deserialize(row[0]) + result = self.deserialize(row[1]) + # Set cache key, if it's a response object + try: + result.cache_key = row[0] + except AttributeError: + pass + # Omit any results that can't be deserialized + if result: + yield result def vacuum(self): with self.connection(commit=True) as con: diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 4a55504..a7d2412 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -16,7 +16,7 @@ import requests from requests import PreparedRequest, Session from requests_cache import ALL_METHODS, CachedResponse, CachedSession -from requests_cache.backends.base import BaseCache +from requests_cache.backends import BaseCache from tests.conftest import ( CACHE_NAME, ETAG, @@ -99,7 +99,7 @@ class BaseCacheTest: # Patch storage class to track number of times getitem is called, without changing behavior with patch.object( - storage_class, '__getitem__', side_effect=storage_class.__getitem__ + storage_class, '__getitem__', side_effect=lambda k: CachedResponse() ) as getitem: session.get(httpbin('get')) assert getitem.call_count == 1 diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index 580948b..06b17cb 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -1,4 +1,5 @@ import os +import pickle from datetime import datetime, timedelta from os.path import join from tempfile import NamedTemporaryFile, gettempdir @@ -8,9 +9,9 @@ from unittest.mock import patch import pytest from platformdirs import user_cache_dir -from requests_cache.backends.base import BaseCache -from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict -from requests_cache.models.response import CachedResponse +from requests_cache.backends import BaseCache, SQLiteCache, SQLiteDict +from requests_cache.backends.sqlite import MEMORY_URI +from requests_cache.models import CachedResponse from tests.integration.base_cache_test import BaseCacheTest from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest @@ -60,7 +61,7 @@ class TestSQLiteDict(BaseStorageTest): assert len(cache) == 0 def test_use_memory__uri(self): - self.init_cache(':memory:').db_path == ':memory:' + assert self.init_cache(':memory:').db_path == ':memory:' def test_non_dir_parent_exists(self): """Expect a custom error message if a parent path already exists but isn't a directory""" @@ -219,6 +220,30 @@ class TestSQLiteDict(BaseStorageTest): assert prev_item is None or prev_item.expires < item.expires assert item.status_code % 2 == 0 + def test_sorted__error(self): + """sorted() should handle deserialization errors and not return invalid responses""" + + class BadSerializer: + def loads(self, value): + response = pickle.loads(value) + if response.cache_key == 'key_42': + raise pickle.PickleError() + return response + + def dumps(self, value): + return pickle.dumps(value) + + cache = self.init_cache(serializer=BadSerializer()) + + for i in range(100): + response = CachedResponse(status_code=i) + response.cache_key = f'key_{i}' + cache[f'key_{i}'] = response + + # Items should only include unexpired (even numbered) items, and still be in sorted order + items = list(cache.sorted()) + assert len(items) == 99 + @pytest.mark.parametrize( 'db_path, use_temp', [ @@ -272,12 +297,26 @@ class TestSQLiteCache(BaseCacheTest): session = self.init_session() assert session.cache.db_path == session.cache.responses.db_path + def test_count(self): + """count() should work the same as len(), but with the option to exclude expired responses""" + session = self.init_session() + now = datetime.utcnow() + session.cache.responses['key_1'] = CachedResponse(expires=now + timedelta(1)) + session.cache.responses['key_2'] = CachedResponse(expires=now - timedelta(1)) + + assert session.cache.count() == 2 + assert session.cache.count(expired=False) == 1 + @patch.object(SQLiteDict, 'sorted') - def test_filter__expired_only(self, mock_sorted): - """Filtering by expired only should use a more efficient SQL query""" + def test_filter__expired(self, mock_sorted): + """Filtering by expired should use a more efficient SQL query""" session = self.init_session() - session.cache.filter(valid=False, expired=True) - mock_sorted.assert_called_once_with(expired=True) + + session.cache.filter() + mock_sorted.assert_called_with(expired=True) + + session.cache.filter(expired=False) + mock_sorted.assert_called_with(expired=False) def test_sorted(self): """Test wrapper method for SQLiteDict.sorted(), with all arguments combined""" @@ -300,4 +339,5 @@ class TestSQLiteCache(BaseCacheTest): prev_item = None for i, item in enumerate(items): assert prev_item is None or prev_item.expires < item.expires + assert item.cache_key assert not item.is_expired diff --git a/tests/unit/test_base_cache.py b/tests/unit/test_base_cache.py index e0b6b0e..b57ed09 100644 --- a/tests/unit/test_base_cache.py +++ b/tests/unit/test_base_cache.py @@ -1,4 +1,5 @@ """BaseCache tests that use mocked responses only""" +import pickle from datetime import datetime, timedelta from logging import getLogger from pickle import PickleError @@ -8,9 +9,10 @@ from unittest.mock import patch import pytest from requests import Request -from requests_cache.backends import BaseCache, SQLiteDict +from requests_cache.backends import BaseCache, SQLiteCache, SQLiteDict from requests_cache.cache_keys import create_key from requests_cache.models import CachedRequest, CachedResponse +from requests_cache.session import CachedSession from tests.conftest import ( MOCKED_URL, MOCKED_URL_ETAG, @@ -18,6 +20,7 @@ from tests.conftest import ( MOCKED_URL_JSON, MOCKED_URL_REDIRECT, ignore_deprecation, + mount_mock_adapter, patch_normalize_url, ) @@ -109,19 +112,28 @@ def test_delete__expired__per_request(mock_session): assert len(mock_session.cache.responses) == 1 -def test_delete__invalid(mock_session): +def test_delete__invalid(tempfile_path): + class BadSerialzier: + def dumps(self, value): + return pickle.dumps(value) + + def loads(self, value): + response = pickle.loads(value) + if response.url.endswith('/json'): + raise PickleError + return response + + mock_session = CachedSession( + cache_name=tempfile_path, backend='sqlite', serializer=BadSerialzier() + ) + mock_session = mount_mock_adapter(mock_session) + # Start with two cached responses, one of which will raise an error response_1 = mock_session.get(MOCKED_URL) response_2 = mock_session.get(MOCKED_URL_JSON) - def error_on_key(key): - if key == response_2.cache_key: - raise PickleError - return CachedResponse.from_response(response_1) - # Use the generic BaseCache implementation, not the SQLite-specific one - with patch.object(SQLiteDict, '__getitem__', side_effect=error_on_key): - BaseCache.delete(mock_session.cache, expired=True, invalid=True) + BaseCache.delete(mock_session.cache, expired=True, invalid=True) assert len(mock_session.cache.responses) == 1 assert mock_session.get(MOCKED_URL).from_cache is True @@ -220,7 +232,7 @@ def test_recreate_keys__same_key_fn(mock_session): def test_reset_expiration__extend_expiration(mock_session): # Start with an expired response - mock_session.settings.expire_after = datetime.utcnow() - timedelta(seconds=0.01) + mock_session.settings.expire_after = datetime.utcnow() - timedelta(seconds=1) mock_session.get(MOCKED_URL) # Set expiration in the future @@ -236,7 +248,7 @@ def test_reset_expiration__shorten_expiration(mock_session): mock_session.get(MOCKED_URL) # Set expiration in the past - mock_session.cache.reset_expiration(datetime.utcnow() - timedelta(seconds=0.01)) + mock_session.cache.reset_expiration(datetime.utcnow() - timedelta(seconds=1)) response = mock_session.get(MOCKED_URL) assert response.is_expired is False and response.from_cache is False @@ -277,8 +289,8 @@ def test_urls(mock_normalize_url, mock_session): def test_urls__error(mock_session): responses = [mock_session.get(url) for url in [MOCKED_URL, MOCKED_URL_JSON, MOCKED_URL_HTTPS]] - responses[2] = AttributeError - with patch.object(SQLiteDict, '__getitem__', side_effect=responses): + responses[2] = None + with patch.object(SQLiteDict, 'deserialize', side_effect=responses): expected_urls = [MOCKED_URL_JSON, MOCKED_URL] assert mock_session.cache.urls() == expected_urls @@ -346,6 +358,7 @@ def test_keys(mock_session): response_keys = set(mock_session.cache.responses.keys()) redirect_keys = set(mock_session.cache.redirects.keys()) assert set(mock_session.cache.keys()) == response_keys | redirect_keys + assert len(list(mock_session.cache.keys(check_expiry=True))) == 5 def test_remove_expired_responses(mock_session): @@ -385,16 +398,14 @@ def test_values(mock_session): assert all([isinstance(response, CachedResponse) for response in responses]) -@pytest.mark.parametrize('check_expiry, expected_count', [(True, 1), (False, 2)]) -def test_values__with_invalid_responses(check_expiry, expected_count, mock_session): - """values() should always exclude invalid responses, and optionally exclude expired responses""" +def test_values__with_invalid_responses(mock_session): responses = [mock_session.get(url) for url in [MOCKED_URL, MOCKED_URL_JSON, MOCKED_URL_HTTPS]] - responses[1] = AttributeError + responses[1] = None responses[2] = CachedResponse(expires=YESTERDAY, url='test') - with ignore_deprecation(), patch.object(SQLiteDict, '__getitem__', side_effect=responses): - values = mock_session.cache.values(check_expiry=check_expiry) - assert len(list(values)) == expected_count + with ignore_deprecation(), patch.object(SQLiteCache, 'filter', side_effect=responses): + values = mock_session.cache.values(check_expiry=True) + assert len(list(values)) == 1 # The invalid response should be skipped, but remain in the cache for now assert len(mock_session.cache.responses.keys()) == 3 diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 60e8bb4..e7495cc 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -17,7 +17,7 @@ from requests.structures import CaseInsensitiveDict from requests_cache import ALL_METHODS, CachedSession from requests_cache._utils import get_placeholder_class -from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict +from requests_cache.backends import BACKEND_CLASSES, BaseCache from requests_cache.backends.base import DESERIALIZE_ERRORS from requests_cache.policy.expiration import DO_NOT_CACHE, EXPIRE_IMMEDIATELY, NEVER_EXPIRE from tests.conftest import ( @@ -351,7 +351,8 @@ def test_include_get_headers(): def test_cache_error(exception_cls, mock_session): """If there is an error while fetching a cached response, a new one should be fetched""" mock_session.get(MOCKED_URL) - with patch.object(SQLiteDict, '__getitem__', side_effect=exception_cls): + + with patch.object(mock_session.cache.responses.serializer, 'loads', side_effect=exception_cls): assert mock_session.get(MOCKED_URL).from_cache is False @@ -440,7 +441,7 @@ def test_unpickle_errors(mock_session): """If there is an error during deserialization, the request should be made again""" assert mock_session.get(MOCKED_URL_JSON).from_cache is False - with patch.object(SQLiteDict, '__getitem__', side_effect=PickleError): + with patch.object(mock_session.cache.responses.serializer, 'loads', side_effect=PickleError): resp = mock_session.get(MOCKED_URL_JSON) assert resp.from_cache is False assert resp.json()['message'] == 'mock json response' |