summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2022-10-28 17:12:44 -0500
committerJordan Cook <jordan.cook.git@proton.me>2022-10-28 17:38:28 -0500
commitb9dbc008be7274b6cc0d45078b3334249df1a530 (patch)
tree5024cc33e50abd4786a3bd73eb16bb32f9cb5de6
parentb4e270f6b129b75003d754987acea9e2d8670a5e (diff)
parentcfe381fdae57d1e2f93d5f2732d3ed57fce05b66 (diff)
downloadrequests-cache-b9dbc008be7274b6cc0d45078b3334249df1a530.tar.gz
Merge pull request #721 from requests-cache/sqlite-improvements
Sqlite improvements
-rw-r--r--HISTORY.md1
-rw-r--r--requests_cache/backends/base.py31
-rw-r--r--requests_cache/backends/sqlite.py79
-rw-r--r--tests/integration/base_cache_test.py4
-rw-r--r--tests/integration/test_sqlite.py56
-rw-r--r--tests/unit/test_base_cache.py51
-rw-r--r--tests/unit/test_session.py7
7 files changed, 165 insertions, 64 deletions
diff --git a/HISTORY.md b/HISTORY.md
index f602e7a..040c0f1 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -51,6 +51,7 @@
* Add `ttl_offset` argument to add a delay between cache expiration and deletion
* **SQLite**:
* Improve performance for removing expired responses with `delete()`
+ * Add `count()` method to count responses, with option to exclude expired responses (performs a fast indexed count instead of slower in-memory filtering)
* Add `size()` method to get estimated size of the database (including in-memory databases)
* Add `sorted()` method with sorting and other query options
* Add `wal` parameter to enable write-ahead logging
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index bc57ca1..efbd214 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -70,11 +70,7 @@ class BaseCache:
response = self.responses[self.redirects[key]]
response.cache_key = key
return response
- except KeyError:
- return default
- except DESERIALIZE_ERRORS as e:
- logger.error(f'Unable to deserialize response {key}: {str(e)}')
- logger.debug(e, exc_info=True)
+ except (AttributeError, KeyError):
return default
def save_response(self, response: Response, cache_key: str = None, expires: datetime = None):
@@ -296,8 +292,11 @@ class BaseCache:
DeprecationWarning,
)
yield from self.redirects.keys()
- for response in self.filter(expired=not check_expiry):
- yield response.cache_key
+ if not check_expiry:
+ yield from self.responses.keys()
+ else:
+ for response in self.filter(expired=False):
+ yield response.cache_key
def response_count(self, check_expiry: bool = False) -> int:
warn(
@@ -394,16 +393,28 @@ class BaseStorage(MutableMapping[KT, VT], ABC):
"""Close any open backend connections"""
def serialize(self, value: VT):
- """Serialize value, if a serializer is available"""
+ """Serialize a value, if a serializer is available"""
if TYPE_CHECKING:
assert hasattr(self.serializer, 'dumps')
return self.serializer.dumps(value) if self.serializer else value
def deserialize(self, value: VT):
- """Deserialize value, if a serializer is available"""
+ """Deserialize a value, if a serializer is available.
+
+ If deserialization fails (usually due to a value saved in an older requests-cache version),
+ ``None`` will be returned.
+ """
+ if not self.serializer:
+ return value
if TYPE_CHECKING:
assert hasattr(self.serializer, 'loads')
- return self.serializer.loads(value) if self.serializer else value
+
+ try:
+ return self.serializer.loads(value)
+ except DESERIALIZE_ERRORS as e:
+ logger.error(f'Unable to deserialize response: {str(e)}')
+ logger.debug(e, exc_info=True)
+ return None
def __str__(self):
return str(list(self.keys()))
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 502c0dd..f936a77 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -17,7 +17,9 @@ from typing import Collection, Iterator, List, Tuple, Type, Union
from platformdirs import user_cache_dir
+from requests_cache.backends.base import VT
from requests_cache.models.response import CachedResponse
+from requests_cache.policy import ExpirationTime
from .._utils import chunkify, get_valid_kwargs
from . import BaseCache, BaseStorage
@@ -55,8 +57,8 @@ class SQLiteCache(BaseCache):
return self.responses.db_path
def clear(self):
- """Clear the cache. If this fails due to a corrupted cache or other I/O error, this will
- attempt to delete the cache file and re-initialize.
+ """Delete all items from the cache. If this fails due to a corrupted cache or other I/O
+ error, this will attempt to delete the cache file and re-initialize.
"""
try:
super().clear()
@@ -67,13 +69,13 @@ class SQLiteCache(BaseCache):
self.responses.init_db()
self.redirects.init_db()
+ # A more efficient SQLite implementation of :py:meth:`BaseCache.delete`
def delete(
self,
*keys: str,
expired: bool = False,
**kwargs,
):
- """A more efficient SQLite implementation of :py:meth:`BaseCache.delete`"""
if keys:
self.responses.bulk_delete(keys)
if expired:
@@ -109,19 +111,31 @@ class SQLiteCache(BaseCache):
')'
)
- def filter( # type: ignore
- self, valid: bool = True, expired: bool = True, **kwargs
- ) -> Iterator[CachedResponse]:
- """A more efficient implementation of :py:meth:`BaseCache.filter`, in the case where we want
- to get **only** expired responses
+ def count(self, expired: bool = True) -> int:
+ """Count number of responses, optionally excluding expired
+
+ Args:
+ expired: Set to ``False`` to count only unexpired responses
"""
- if expired and not valid and not kwargs:
- return self.responses.sorted(expired=True)
+ return self.responses.count(expired=expired)
+
+ # A more efficient implementation of :py:meth:`BaseCache.filter` to make use of indexes
+ def filter(
+ self,
+ valid: bool = True,
+ expired: bool = True,
+ invalid: bool = False,
+ older_than: ExpirationTime = None,
+ ) -> Iterator[CachedResponse]:
+ if valid and not invalid:
+ return self.responses.sorted(expired=expired)
else:
- return super().filter(valid, expired, **kwargs)
+ return super().filter(
+ valid=valid, expired=expired, invalid=invalid, older_than=older_than
+ )
+ # A more efficient implementation of :py:meth:`BaseCache.recreate_keys
def recreate_keys(self):
- """A more efficient implementation of :py:meth:`BaseCache.recreate_keys`"""
with self.responses.bulk_commit():
super().recreate_keys()
@@ -162,6 +176,8 @@ class SQLiteDict(BaseStorage):
self._local_context = threading.local()
self._lock = threading.RLock()
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
+ if use_memory:
+ self.connection_kwargs['uri'] = True
self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory)
self.fast_save = fast_save
self.table_name = table_name
@@ -172,9 +188,9 @@ class SQLiteDict(BaseStorage):
"""Initialize the database, if it hasn't already been"""
self.close()
with self._lock, self.connection() as con:
- # Add new column to tables created before 0.10
+ # Add new column to tables created before 1.0
try:
- con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT')
+ con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER')
except sqlite3.OperationalError:
pass
@@ -247,10 +263,8 @@ class SQLiteDict(BaseStorage):
def __setitem__(self, key, value):
# If available, set expiration as a timestamp in unix format
- expires = value.expires_unix if getattr(value, 'expires_unix', None) else None
+ expires = getattr(value, 'expires_unix', None)
value = self.serialize(value)
- if isinstance(value, bytes):
- value = sqlite3.Binary(value)
with self.connection(commit=True) as con:
con.execute(
f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)',
@@ -263,8 +277,7 @@ class SQLiteDict(BaseStorage):
yield row[0]
def __len__(self):
- with self.connection() as con:
- return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0]
+ return self.count()
def bulk_delete(self, keys=None, values=None):
"""Delete multiple items from the cache, without raising errors for any missing items.
@@ -288,6 +301,22 @@ class SQLiteDict(BaseStorage):
self.init_db()
self.vacuum()
+ def count(self, expired: bool = True) -> int:
+ """Count number of responses, optionally excluding expired"""
+ filter_expr = ''
+ params: Tuple = ()
+ if not expired:
+ filter_expr = 'WHERE expires is null or expires > ?'
+ params = (time(),)
+ query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}'
+
+ with self.connection() as con:
+ return con.execute(query, params).fetchone()[0]
+
+ def serialize(self, value: VT):
+ value = super().serialize(value)
+ return sqlite3.Binary(value) if isinstance(value, bytes) else value
+
def size(self) -> int:
"""Return the size of the database, in bytes. For an in-memory database, this will be an
estimate based on page size.
@@ -325,11 +354,19 @@ class SQLiteDict(BaseStorage):
with self.connection(commit=True) as con:
for row in con.execute(
- f'SELECT value FROM {self.table_name} {filter_expr}'
+ f'SELECT key, value FROM {self.table_name} {filter_expr}'
f' ORDER BY {key} {direction} {limit_expr}',
params,
):
- yield self.deserialize(row[0])
+ result = self.deserialize(row[1])
+ # Set cache key, if it's a response object
+ try:
+ result.cache_key = row[0]
+ except AttributeError:
+ pass
+ # Omit any results that can't be deserialized
+ if result:
+ yield result
def vacuum(self):
with self.connection(commit=True) as con:
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 4a55504..a7d2412 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -16,7 +16,7 @@ import requests
from requests import PreparedRequest, Session
from requests_cache import ALL_METHODS, CachedResponse, CachedSession
-from requests_cache.backends.base import BaseCache
+from requests_cache.backends import BaseCache
from tests.conftest import (
CACHE_NAME,
ETAG,
@@ -99,7 +99,7 @@ class BaseCacheTest:
# Patch storage class to track number of times getitem is called, without changing behavior
with patch.object(
- storage_class, '__getitem__', side_effect=storage_class.__getitem__
+ storage_class, '__getitem__', side_effect=lambda k: CachedResponse()
) as getitem:
session.get(httpbin('get'))
assert getitem.call_count == 1
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 580948b..06b17cb 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -1,4 +1,5 @@
import os
+import pickle
from datetime import datetime, timedelta
from os.path import join
from tempfile import NamedTemporaryFile, gettempdir
@@ -8,9 +9,9 @@ from unittest.mock import patch
import pytest
from platformdirs import user_cache_dir
-from requests_cache.backends.base import BaseCache
-from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict
-from requests_cache.models.response import CachedResponse
+from requests_cache.backends import BaseCache, SQLiteCache, SQLiteDict
+from requests_cache.backends.sqlite import MEMORY_URI
+from requests_cache.models import CachedResponse
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest
@@ -60,7 +61,7 @@ class TestSQLiteDict(BaseStorageTest):
assert len(cache) == 0
def test_use_memory__uri(self):
- self.init_cache(':memory:').db_path == ':memory:'
+ assert self.init_cache(':memory:').db_path == ':memory:'
def test_non_dir_parent_exists(self):
"""Expect a custom error message if a parent path already exists but isn't a directory"""
@@ -219,6 +220,30 @@ class TestSQLiteDict(BaseStorageTest):
assert prev_item is None or prev_item.expires < item.expires
assert item.status_code % 2 == 0
+ def test_sorted__error(self):
+ """sorted() should handle deserialization errors and not return invalid responses"""
+
+ class BadSerializer:
+ def loads(self, value):
+ response = pickle.loads(value)
+ if response.cache_key == 'key_42':
+ raise pickle.PickleError()
+ return response
+
+ def dumps(self, value):
+ return pickle.dumps(value)
+
+ cache = self.init_cache(serializer=BadSerializer())
+
+ for i in range(100):
+ response = CachedResponse(status_code=i)
+ response.cache_key = f'key_{i}'
+ cache[f'key_{i}'] = response
+
+ # Items should only include unexpired (even numbered) items, and still be in sorted order
+ items = list(cache.sorted())
+ assert len(items) == 99
+
@pytest.mark.parametrize(
'db_path, use_temp',
[
@@ -272,12 +297,26 @@ class TestSQLiteCache(BaseCacheTest):
session = self.init_session()
assert session.cache.db_path == session.cache.responses.db_path
+ def test_count(self):
+ """count() should work the same as len(), but with the option to exclude expired responses"""
+ session = self.init_session()
+ now = datetime.utcnow()
+ session.cache.responses['key_1'] = CachedResponse(expires=now + timedelta(1))
+ session.cache.responses['key_2'] = CachedResponse(expires=now - timedelta(1))
+
+ assert session.cache.count() == 2
+ assert session.cache.count(expired=False) == 1
+
@patch.object(SQLiteDict, 'sorted')
- def test_filter__expired_only(self, mock_sorted):
- """Filtering by expired only should use a more efficient SQL query"""
+ def test_filter__expired(self, mock_sorted):
+ """Filtering by expired should use a more efficient SQL query"""
session = self.init_session()
- session.cache.filter(valid=False, expired=True)
- mock_sorted.assert_called_once_with(expired=True)
+
+ session.cache.filter()
+ mock_sorted.assert_called_with(expired=True)
+
+ session.cache.filter(expired=False)
+ mock_sorted.assert_called_with(expired=False)
def test_sorted(self):
"""Test wrapper method for SQLiteDict.sorted(), with all arguments combined"""
@@ -300,4 +339,5 @@ class TestSQLiteCache(BaseCacheTest):
prev_item = None
for i, item in enumerate(items):
assert prev_item is None or prev_item.expires < item.expires
+ assert item.cache_key
assert not item.is_expired
diff --git a/tests/unit/test_base_cache.py b/tests/unit/test_base_cache.py
index e0b6b0e..b57ed09 100644
--- a/tests/unit/test_base_cache.py
+++ b/tests/unit/test_base_cache.py
@@ -1,4 +1,5 @@
"""BaseCache tests that use mocked responses only"""
+import pickle
from datetime import datetime, timedelta
from logging import getLogger
from pickle import PickleError
@@ -8,9 +9,10 @@ from unittest.mock import patch
import pytest
from requests import Request
-from requests_cache.backends import BaseCache, SQLiteDict
+from requests_cache.backends import BaseCache, SQLiteCache, SQLiteDict
from requests_cache.cache_keys import create_key
from requests_cache.models import CachedRequest, CachedResponse
+from requests_cache.session import CachedSession
from tests.conftest import (
MOCKED_URL,
MOCKED_URL_ETAG,
@@ -18,6 +20,7 @@ from tests.conftest import (
MOCKED_URL_JSON,
MOCKED_URL_REDIRECT,
ignore_deprecation,
+ mount_mock_adapter,
patch_normalize_url,
)
@@ -109,19 +112,28 @@ def test_delete__expired__per_request(mock_session):
assert len(mock_session.cache.responses) == 1
-def test_delete__invalid(mock_session):
+def test_delete__invalid(tempfile_path):
+ class BadSerialzier:
+ def dumps(self, value):
+ return pickle.dumps(value)
+
+ def loads(self, value):
+ response = pickle.loads(value)
+ if response.url.endswith('/json'):
+ raise PickleError
+ return response
+
+ mock_session = CachedSession(
+ cache_name=tempfile_path, backend='sqlite', serializer=BadSerialzier()
+ )
+ mock_session = mount_mock_adapter(mock_session)
+
# Start with two cached responses, one of which will raise an error
response_1 = mock_session.get(MOCKED_URL)
response_2 = mock_session.get(MOCKED_URL_JSON)
- def error_on_key(key):
- if key == response_2.cache_key:
- raise PickleError
- return CachedResponse.from_response(response_1)
-
# Use the generic BaseCache implementation, not the SQLite-specific one
- with patch.object(SQLiteDict, '__getitem__', side_effect=error_on_key):
- BaseCache.delete(mock_session.cache, expired=True, invalid=True)
+ BaseCache.delete(mock_session.cache, expired=True, invalid=True)
assert len(mock_session.cache.responses) == 1
assert mock_session.get(MOCKED_URL).from_cache is True
@@ -220,7 +232,7 @@ def test_recreate_keys__same_key_fn(mock_session):
def test_reset_expiration__extend_expiration(mock_session):
# Start with an expired response
- mock_session.settings.expire_after = datetime.utcnow() - timedelta(seconds=0.01)
+ mock_session.settings.expire_after = datetime.utcnow() - timedelta(seconds=1)
mock_session.get(MOCKED_URL)
# Set expiration in the future
@@ -236,7 +248,7 @@ def test_reset_expiration__shorten_expiration(mock_session):
mock_session.get(MOCKED_URL)
# Set expiration in the past
- mock_session.cache.reset_expiration(datetime.utcnow() - timedelta(seconds=0.01))
+ mock_session.cache.reset_expiration(datetime.utcnow() - timedelta(seconds=1))
response = mock_session.get(MOCKED_URL)
assert response.is_expired is False and response.from_cache is False
@@ -277,8 +289,8 @@ def test_urls(mock_normalize_url, mock_session):
def test_urls__error(mock_session):
responses = [mock_session.get(url) for url in [MOCKED_URL, MOCKED_URL_JSON, MOCKED_URL_HTTPS]]
- responses[2] = AttributeError
- with patch.object(SQLiteDict, '__getitem__', side_effect=responses):
+ responses[2] = None
+ with patch.object(SQLiteDict, 'deserialize', side_effect=responses):
expected_urls = [MOCKED_URL_JSON, MOCKED_URL]
assert mock_session.cache.urls() == expected_urls
@@ -346,6 +358,7 @@ def test_keys(mock_session):
response_keys = set(mock_session.cache.responses.keys())
redirect_keys = set(mock_session.cache.redirects.keys())
assert set(mock_session.cache.keys()) == response_keys | redirect_keys
+ assert len(list(mock_session.cache.keys(check_expiry=True))) == 5
def test_remove_expired_responses(mock_session):
@@ -385,16 +398,14 @@ def test_values(mock_session):
assert all([isinstance(response, CachedResponse) for response in responses])
-@pytest.mark.parametrize('check_expiry, expected_count', [(True, 1), (False, 2)])
-def test_values__with_invalid_responses(check_expiry, expected_count, mock_session):
- """values() should always exclude invalid responses, and optionally exclude expired responses"""
+def test_values__with_invalid_responses(mock_session):
responses = [mock_session.get(url) for url in [MOCKED_URL, MOCKED_URL_JSON, MOCKED_URL_HTTPS]]
- responses[1] = AttributeError
+ responses[1] = None
responses[2] = CachedResponse(expires=YESTERDAY, url='test')
- with ignore_deprecation(), patch.object(SQLiteDict, '__getitem__', side_effect=responses):
- values = mock_session.cache.values(check_expiry=check_expiry)
- assert len(list(values)) == expected_count
+ with ignore_deprecation(), patch.object(SQLiteCache, 'filter', side_effect=responses):
+ values = mock_session.cache.values(check_expiry=True)
+ assert len(list(values)) == 1
# The invalid response should be skipped, but remain in the cache for now
assert len(mock_session.cache.responses.keys()) == 3
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index 60e8bb4..e7495cc 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -17,7 +17,7 @@ from requests.structures import CaseInsensitiveDict
from requests_cache import ALL_METHODS, CachedSession
from requests_cache._utils import get_placeholder_class
-from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict
+from requests_cache.backends import BACKEND_CLASSES, BaseCache
from requests_cache.backends.base import DESERIALIZE_ERRORS
from requests_cache.policy.expiration import DO_NOT_CACHE, EXPIRE_IMMEDIATELY, NEVER_EXPIRE
from tests.conftest import (
@@ -351,7 +351,8 @@ def test_include_get_headers():
def test_cache_error(exception_cls, mock_session):
"""If there is an error while fetching a cached response, a new one should be fetched"""
mock_session.get(MOCKED_URL)
- with patch.object(SQLiteDict, '__getitem__', side_effect=exception_cls):
+
+ with patch.object(mock_session.cache.responses.serializer, 'loads', side_effect=exception_cls):
assert mock_session.get(MOCKED_URL).from_cache is False
@@ -440,7 +441,7 @@ def test_unpickle_errors(mock_session):
"""If there is an error during deserialization, the request should be made again"""
assert mock_session.get(MOCKED_URL_JSON).from_cache is False
- with patch.object(SQLiteDict, '__getitem__', side_effect=PickleError):
+ with patch.object(mock_session.cache.responses.serializer, 'loads', side_effect=PickleError):
resp = mock_session.get(MOCKED_URL_JSON)
assert resp.from_cache is False
assert resp.json()['message'] == 'mock json response'