diff options
Diffstat (limited to 'requests_cache/backends')
-rw-r--r-- | requests_cache/backends/base.py | 31 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 79 |
2 files changed, 79 insertions, 31 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index bc57ca1..efbd214 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -70,11 +70,7 @@ class BaseCache: response = self.responses[self.redirects[key]] response.cache_key = key return response - except KeyError: - return default - except DESERIALIZE_ERRORS as e: - logger.error(f'Unable to deserialize response {key}: {str(e)}') - logger.debug(e, exc_info=True) + except (AttributeError, KeyError): return default def save_response(self, response: Response, cache_key: str = None, expires: datetime = None): @@ -296,8 +292,11 @@ class BaseCache: DeprecationWarning, ) yield from self.redirects.keys() - for response in self.filter(expired=not check_expiry): - yield response.cache_key + if not check_expiry: + yield from self.responses.keys() + else: + for response in self.filter(expired=False): + yield response.cache_key def response_count(self, check_expiry: bool = False) -> int: warn( @@ -394,16 +393,28 @@ class BaseStorage(MutableMapping[KT, VT], ABC): """Close any open backend connections""" def serialize(self, value: VT): - """Serialize value, if a serializer is available""" + """Serialize a value, if a serializer is available""" if TYPE_CHECKING: assert hasattr(self.serializer, 'dumps') return self.serializer.dumps(value) if self.serializer else value def deserialize(self, value: VT): - """Deserialize value, if a serializer is available""" + """Deserialize a value, if a serializer is available. + + If deserialization fails (usually due to a value saved in an older requests-cache version), + ``None`` will be returned. + """ + if not self.serializer: + return value if TYPE_CHECKING: assert hasattr(self.serializer, 'loads') - return self.serializer.loads(value) if self.serializer else value + + try: + return self.serializer.loads(value) + except DESERIALIZE_ERRORS as e: + logger.error(f'Unable to deserialize response: {str(e)}') + logger.debug(e, exc_info=True) + return None def __str__(self): return str(list(self.keys())) diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 502c0dd..f936a77 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,7 +17,9 @@ from typing import Collection, Iterator, List, Tuple, Type, Union from platformdirs import user_cache_dir +from requests_cache.backends.base import VT from requests_cache.models.response import CachedResponse +from requests_cache.policy import ExpirationTime from .._utils import chunkify, get_valid_kwargs from . import BaseCache, BaseStorage @@ -55,8 +57,8 @@ class SQLiteCache(BaseCache): return self.responses.db_path def clear(self): - """Clear the cache. If this fails due to a corrupted cache or other I/O error, this will - attempt to delete the cache file and re-initialize. + """Delete all items from the cache. If this fails due to a corrupted cache or other I/O + error, this will attempt to delete the cache file and re-initialize. """ try: super().clear() @@ -67,13 +69,13 @@ class SQLiteCache(BaseCache): self.responses.init_db() self.redirects.init_db() + # A more efficient SQLite implementation of :py:meth:`BaseCache.delete` def delete( self, *keys: str, expired: bool = False, **kwargs, ): - """A more efficient SQLite implementation of :py:meth:`BaseCache.delete`""" if keys: self.responses.bulk_delete(keys) if expired: @@ -109,19 +111,31 @@ class SQLiteCache(BaseCache): ')' ) - def filter( # type: ignore - self, valid: bool = True, expired: bool = True, **kwargs - ) -> Iterator[CachedResponse]: - """A more efficient implementation of :py:meth:`BaseCache.filter`, in the case where we want - to get **only** expired responses + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired + + Args: + expired: Set to ``False`` to count only unexpired responses """ - if expired and not valid and not kwargs: - return self.responses.sorted(expired=True) + return self.responses.count(expired=expired) + + # A more efficient implementation of :py:meth:`BaseCache.filter` to make use of indexes + def filter( + self, + valid: bool = True, + expired: bool = True, + invalid: bool = False, + older_than: ExpirationTime = None, + ) -> Iterator[CachedResponse]: + if valid and not invalid: + return self.responses.sorted(expired=expired) else: - return super().filter(valid, expired, **kwargs) + return super().filter( + valid=valid, expired=expired, invalid=invalid, older_than=older_than + ) + # A more efficient implementation of :py:meth:`BaseCache.recreate_keys def recreate_keys(self): - """A more efficient implementation of :py:meth:`BaseCache.recreate_keys`""" with self.responses.bulk_commit(): super().recreate_keys() @@ -162,6 +176,8 @@ class SQLiteDict(BaseStorage): self._local_context = threading.local() self._lock = threading.RLock() self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) + if use_memory: + self.connection_kwargs['uri'] = True self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory) self.fast_save = fast_save self.table_name = table_name @@ -172,9 +188,9 @@ class SQLiteDict(BaseStorage): """Initialize the database, if it hasn't already been""" self.close() with self._lock, self.connection() as con: - # Add new column to tables created before 0.10 + # Add new column to tables created before 1.0 try: - con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT') + con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER') except sqlite3.OperationalError: pass @@ -247,10 +263,8 @@ class SQLiteDict(BaseStorage): def __setitem__(self, key, value): # If available, set expiration as a timestamp in unix format - expires = value.expires_unix if getattr(value, 'expires_unix', None) else None + expires = getattr(value, 'expires_unix', None) value = self.serialize(value) - if isinstance(value, bytes): - value = sqlite3.Binary(value) with self.connection(commit=True) as con: con.execute( f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)', @@ -263,8 +277,7 @@ class SQLiteDict(BaseStorage): yield row[0] def __len__(self): - with self.connection() as con: - return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0] + return self.count() def bulk_delete(self, keys=None, values=None): """Delete multiple items from the cache, without raising errors for any missing items. @@ -288,6 +301,22 @@ class SQLiteDict(BaseStorage): self.init_db() self.vacuum() + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired""" + filter_expr = '' + params: Tuple = () + if not expired: + filter_expr = 'WHERE expires is null or expires > ?' + params = (time(),) + query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}' + + with self.connection() as con: + return con.execute(query, params).fetchone()[0] + + def serialize(self, value: VT): + value = super().serialize(value) + return sqlite3.Binary(value) if isinstance(value, bytes) else value + def size(self) -> int: """Return the size of the database, in bytes. For an in-memory database, this will be an estimate based on page size. @@ -325,11 +354,19 @@ class SQLiteDict(BaseStorage): with self.connection(commit=True) as con: for row in con.execute( - f'SELECT value FROM {self.table_name} {filter_expr}' + f'SELECT key, value FROM {self.table_name} {filter_expr}' f' ORDER BY {key} {direction} {limit_expr}', params, ): - yield self.deserialize(row[0]) + result = self.deserialize(row[1]) + # Set cache key, if it's a response object + try: + result.cache_key = row[0] + except AttributeError: + pass + # Omit any results that can't be deserialized + if result: + yield result def vacuum(self): with self.connection(commit=True) as con: |