summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2022-10-28 12:18:14 -0500
committerJordan Cook <jordan.cook.git@proton.me>2022-10-28 15:05:13 -0500
commit03a97079839a5683b202e3ea3f66e8ce54eedab4 (patch)
tree33540bb5443d9667b1b82df52e561e0f5437adb9
parenteae13531f0d2806068229d04621b6ec7e1fc406e (diff)
downloadrequests-cache-03a97079839a5683b202e3ea3f66e8ce54eedab4.tar.gz
Omit invalid responses and set response.cache_key in SQLiteCache.sorted()
-rw-r--r--requests_cache/backends/sqlite.py12
-rw-r--r--tests/integration/test_sqlite.py34
-rw-r--r--tests/unit/test_session.py2
3 files changed, 41 insertions, 7 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 87b581b..87a09e0 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -327,11 +327,19 @@ class SQLiteDict(BaseStorage):
with self.connection(commit=True) as con:
for row in con.execute(
- f'SELECT value FROM {self.table_name} {filter_expr}'
+ f'SELECT key, value FROM {self.table_name} {filter_expr}'
f' ORDER BY {key} {direction} {limit_expr}',
params,
):
- yield self.deserialize(row[0])
+ result = self.deserialize(row[1])
+ # Set cache key, if it's a response object
+ try:
+ result.cache_key = row[0]
+ except AttributeError:
+ pass
+ # Omit any results that can't be deserialized
+ if result:
+ yield result
def vacuum(self):
with self.connection(commit=True) as con:
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 580948b..7e31634 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -1,4 +1,5 @@
import os
+import pickle
from datetime import datetime, timedelta
from os.path import join
from tempfile import NamedTemporaryFile, gettempdir
@@ -8,9 +9,9 @@ from unittest.mock import patch
import pytest
from platformdirs import user_cache_dir
-from requests_cache.backends.base import BaseCache
-from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict
-from requests_cache.models.response import CachedResponse
+from requests_cache.backends import BaseCache, SQLiteCache, SQLiteDict
+from requests_cache.backends.sqlite import MEMORY_URI
+from requests_cache.models import CachedResponse
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest
@@ -60,7 +61,7 @@ class TestSQLiteDict(BaseStorageTest):
assert len(cache) == 0
def test_use_memory__uri(self):
- self.init_cache(':memory:').db_path == ':memory:'
+ assert self.init_cache(':memory:').db_path == ':memory:'
def test_non_dir_parent_exists(self):
"""Expect a custom error message if a parent path already exists but isn't a directory"""
@@ -219,6 +220,30 @@ class TestSQLiteDict(BaseStorageTest):
assert prev_item is None or prev_item.expires < item.expires
assert item.status_code % 2 == 0
+ def test_sorted__error(self):
+ """sorted() should handle deserialization errors and not return invalid responses"""
+
+ class BadSerializer:
+ def loads(self, value):
+ response = pickle.loads(value)
+ if response.cache_key == 'key_42':
+ raise pickle.PickleError()
+ return response
+
+ def dumps(self, value):
+ return pickle.dumps(value)
+
+ cache = self.init_cache(serializer=BadSerializer())
+
+ for i in range(100):
+ response = CachedResponse(status_code=i)
+ response.cache_key = f'key_{i}'
+ cache[f'key_{i}'] = response
+
+ # Items should only include unexpired (even numbered) items, and still be in sorted order
+ items = list(cache.sorted())
+ assert len(items) == 99
+
@pytest.mark.parametrize(
'db_path, use_temp',
[
@@ -300,4 +325,5 @@ class TestSQLiteCache(BaseCacheTest):
prev_item = None
for i, item in enumerate(items):
assert prev_item is None or prev_item.expires < item.expires
+ assert item.cache_key
assert not item.is_expired
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index 116b74f..e7495cc 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -17,7 +17,7 @@ from requests.structures import CaseInsensitiveDict
from requests_cache import ALL_METHODS, CachedSession
from requests_cache._utils import get_placeholder_class
-from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict
+from requests_cache.backends import BACKEND_CLASSES, BaseCache
from requests_cache.backends.base import DESERIALIZE_ERRORS
from requests_cache.policy.expiration import DO_NOT_CACHE, EXPIRE_IMMEDIATELY, NEVER_EXPIRE
from tests.conftest import (