summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2022-12-13 18:37:04 -0600
committerJordan Cook <jordan.cook.git@proton.me>2022-12-30 14:46:02 -0600
commit94a4449cdc05e1a1460a50f8d696f8032223558b (patch)
tree5442039d47d91723068776c1395996dc25a7037c
parente8ea4ca58845afeda358c31c15dcdb28b2b62896 (diff)
downloadrequests-cache-94a4449cdc05e1a1460a50f8d696f8032223558b.tar.gz
Make CachedResponse.cache_key available from all cache access methods
-rw-r--r--HISTORY.md4
-rw-r--r--requests_cache/backends/base.py15
-rw-r--r--requests_cache/backends/dynamodb.py2
-rw-r--r--requests_cache/backends/filesystem.py2
-rw-r--r--requests_cache/backends/gridfs.py2
-rw-r--r--requests_cache/backends/mongodb.py2
-rw-r--r--requests_cache/backends/redis.py12
-rw-r--r--requests_cache/backends/sqlite.py9
-rw-r--r--tests/integration/base_storage_test.py11
9 files changed, 39 insertions, 20 deletions
diff --git a/HISTORY.md b/HISTORY.md
index 7737f56..202a398 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -81,7 +81,7 @@
(e.g., to preserve cache data after an update that changes request matching behavior)
* Update `BaseCache.urls` into a method that takes optional filter params, and returns sorted unique URLs
-ℹ️ **Type hints:**
+ℹ️ **Response attributes and type hints:**
* Add `OriginalResponse` type, which adds type hints to `requests.Response` objects for extra attributes added by requests-cache:
* `cache_key`
* `created_at`
@@ -91,6 +91,8 @@
* `revalidated`
* `OriginalResponse.cache_key` and `expires` will be populated for any new response that was written to the cache
* Add request wrapper methods with return type hints for all HTTP methods (`CachedSession.get()`, `head()`, etc.)
+* Set `CachedResponse.cache_key` attribute for responses read from lower-level storage methods
+ (`items()`, `values()`, etc.)
🧩 **Compatibility fixes:**
* **PyInstaller:** Fix potential `AttributeError` due to undetected imports when requests-cache is bundled in a PyInstaller package
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index faa338a..40651cd 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -68,7 +68,6 @@ class BaseCache:
response = self.responses.get(key)
if response is None: # Note: bool(requests.Response) is False if status > 400
response = self.responses[self.redirects[key]]
- response.cache_key = key
return response
except (AttributeError, KeyError):
return default
@@ -406,7 +405,7 @@ class BaseStorage(MutableMapping[KT, VT], ABC):
assert hasattr(self.serializer, 'dumps')
return self.serializer.dumps(value) if self.serializer else value
- def deserialize(self, value: VT):
+ def deserialize(self, key, value: VT):
"""Deserialize a value, if a serializer is available.
If deserialization fails (usually due to a value saved in an older requests-cache version),
@@ -418,7 +417,13 @@ class BaseStorage(MutableMapping[KT, VT], ABC):
assert hasattr(self.serializer, 'loads')
try:
- return self.serializer.loads(value)
+ obj = self.serializer.loads(value)
+ # Set cache key, if it's a response object
+ try:
+ obj.cache_key = key
+ except AttributeError:
+ pass
+ return obj
except DESERIALIZE_ERRORS as e:
logger.error(f'Unable to deserialize response: {str(e)}')
logger.debug(e, exc_info=True)
@@ -450,4 +455,8 @@ class DictStorage(UserDict, BaseStorage):
item = super().__getitem__(key)
if getattr(item, 'raw', None):
item.raw.reset()
+ try:
+ item.cache_key = key
+ except AttributeError:
+ pass
return item
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index 721e256..875c1f7 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -146,7 +146,7 @@ class DynamoDbDict(BaseStorage):
# With a custom serializer, the value may be a Binary object
raw_value = result['Item']['value']
value = raw_value.value if isinstance(raw_value, Binary) else raw_value
- return self.deserialize(value)
+ return self.deserialize(key, value)
def __setitem__(self, key, value):
item = {**self._composite_key(key), 'value': self.serialize(value)}
diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py
index cc57597..b68e2ee 100644
--- a/requests_cache/backends/filesystem.py
+++ b/requests_cache/backends/filesystem.py
@@ -102,7 +102,7 @@ class FileDict(BaseStorage):
mode = 'rb' if self.is_binary else 'r'
with self._try_io():
with self._path(key).open(mode) as f:
- return self.deserialize(f.read())
+ return self.deserialize(key, f.read())
def __delitem__(self, key):
with self._try_io():
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index 6380578..9f04a8a 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -67,7 +67,7 @@ class GridFSDict(BaseStorage):
result = self.fs.find_one({'_id': key})
if result is None:
raise KeyError
- return self.deserialize(result.read())
+ return self.deserialize(key, result.read())
except CorruptGridFile as e:
logger.warning(e, exc_info=True)
raise KeyError
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index e15e331..90f8feb 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -119,7 +119,7 @@ class MongoDict(BaseStorage):
if result is None:
raise KeyError
value = result['data'] if 'data' in result else result
- return self.deserialize(value)
+ return self.deserialize(key, value)
def __setitem__(self, key, value):
"""If ``value`` is already a dict, its values will be stored under top-level keys.
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index fef488c..95d1ed8 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -90,7 +90,7 @@ class RedisDict(BaseStorage):
result = self.connection.get(self._bkey(key))
if result is None:
raise KeyError
- return self.deserialize(result)
+ return self.deserialize(key, result)
def __setitem__(self, key, item):
"""Save an item to the cache, optionally with TTL"""
@@ -132,7 +132,8 @@ class RedisDict(BaseStorage):
return [(k, self[k]) for k in self.keys()]
def values(self):
- return [self.deserialize(v) for v in self.connection.mget(*self._bkeys(self.keys()))]
+ for _, v in self.items():
+ yield v
class RedisHashDict(BaseStorage):
@@ -162,7 +163,7 @@ class RedisHashDict(BaseStorage):
result = self.connection.hget(self._hash_key, encode(key))
if result is None:
raise KeyError
- return self.deserialize(result)
+ return self.deserialize(key, result)
def __setitem__(self, key, item):
self.connection.hset(self._hash_key, encode(key), self.serialize(item))
@@ -191,10 +192,11 @@ class RedisHashDict(BaseStorage):
def items(self):
"""Get all ``(key, value)`` pairs in the hash"""
return [
- (decode(k), self.deserialize(v))
+ (decode(k), self.deserialize(decode(k), v))
for k, v in self.connection.hgetall(self._hash_key).items()
]
def values(self):
"""Get all values in the hash"""
- return [self.deserialize(v) for v in self.connection.hvals(self._hash_key)]
+ for _, v in self.items():
+ yield v
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 7a99a39..02d9b70 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -259,7 +259,7 @@ class SQLiteDict(BaseStorage):
if not row:
raise KeyError
- return self.deserialize(row[0])
+ return self.deserialize(key, row[0])
def __setitem__(self, key, value):
# If available, set expiration as a timestamp in unix format
@@ -362,12 +362,7 @@ class SQLiteDict(BaseStorage):
f' ORDER BY {key} {direction} {limit_expr}',
params,
):
- result = self.deserialize(row[1])
- # Set cache key, if it's a response object
- try:
- result.cache_key = row[0]
- except AttributeError:
- pass
+ result = self.deserialize(row[0], row[1])
# Omit any results that can't be deserialized
if result:
yield result
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index 4e0f217..4070492 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -6,6 +6,7 @@ import pytest
from attrs import define, field
from requests_cache.backends import BaseStorage
+from requests_cache.models import CachedResponse
from tests.conftest import CACHE_NAME
@@ -62,6 +63,16 @@ class BaseStorageTest:
assert list(cache.items()) == [(f'key_{i}', f'value_{i}')]
assert dict(cache) == {f'key_{i}': f'value_{i}'}
+ def test_cache_key(self):
+ """The cache_key attribute should be available on responses returned from all
+ mapping/collection methods
+ """
+ cache = self.init_cache()
+ cache['key'] = CachedResponse()
+ assert cache['key'].cache_key == 'key'
+ assert list(cache.values())[0].cache_key == 'key'
+ assert list(cache.items())[0][1].cache_key == 'key'
+
def test_del(self):
"""Some more tests to ensure ``delitem`` deletes only the expected items"""
cache = self.init_cache()