diff options
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r-- | requests_cache/backends/sqlite.py | 79 |
1 files changed, 58 insertions, 21 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 502c0dd..f936a77 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,7 +17,9 @@ from typing import Collection, Iterator, List, Tuple, Type, Union from platformdirs import user_cache_dir +from requests_cache.backends.base import VT from requests_cache.models.response import CachedResponse +from requests_cache.policy import ExpirationTime from .._utils import chunkify, get_valid_kwargs from . import BaseCache, BaseStorage @@ -55,8 +57,8 @@ class SQLiteCache(BaseCache): return self.responses.db_path def clear(self): - """Clear the cache. If this fails due to a corrupted cache or other I/O error, this will - attempt to delete the cache file and re-initialize. + """Delete all items from the cache. If this fails due to a corrupted cache or other I/O + error, this will attempt to delete the cache file and re-initialize. """ try: super().clear() @@ -67,13 +69,13 @@ class SQLiteCache(BaseCache): self.responses.init_db() self.redirects.init_db() + # A more efficient SQLite implementation of :py:meth:`BaseCache.delete` def delete( self, *keys: str, expired: bool = False, **kwargs, ): - """A more efficient SQLite implementation of :py:meth:`BaseCache.delete`""" if keys: self.responses.bulk_delete(keys) if expired: @@ -109,19 +111,31 @@ class SQLiteCache(BaseCache): ')' ) - def filter( # type: ignore - self, valid: bool = True, expired: bool = True, **kwargs - ) -> Iterator[CachedResponse]: - """A more efficient implementation of :py:meth:`BaseCache.filter`, in the case where we want - to get **only** expired responses + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired + + Args: + expired: Set to ``False`` to count only unexpired responses """ - if expired and not valid and not kwargs: - return self.responses.sorted(expired=True) + return self.responses.count(expired=expired) + + # A more efficient implementation of :py:meth:`BaseCache.filter` to make use of indexes + def filter( + self, + valid: bool = True, + expired: bool = True, + invalid: bool = False, + older_than: ExpirationTime = None, + ) -> Iterator[CachedResponse]: + if valid and not invalid: + return self.responses.sorted(expired=expired) else: - return super().filter(valid, expired, **kwargs) + return super().filter( + valid=valid, expired=expired, invalid=invalid, older_than=older_than + ) + # A more efficient implementation of :py:meth:`BaseCache.recreate_keys def recreate_keys(self): - """A more efficient implementation of :py:meth:`BaseCache.recreate_keys`""" with self.responses.bulk_commit(): super().recreate_keys() @@ -162,6 +176,8 @@ class SQLiteDict(BaseStorage): self._local_context = threading.local() self._lock = threading.RLock() self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) + if use_memory: + self.connection_kwargs['uri'] = True self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory) self.fast_save = fast_save self.table_name = table_name @@ -172,9 +188,9 @@ class SQLiteDict(BaseStorage): """Initialize the database, if it hasn't already been""" self.close() with self._lock, self.connection() as con: - # Add new column to tables created before 0.10 + # Add new column to tables created before 1.0 try: - con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT') + con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER') except sqlite3.OperationalError: pass @@ -247,10 +263,8 @@ class SQLiteDict(BaseStorage): def __setitem__(self, key, value): # If available, set expiration as a timestamp in unix format - expires = value.expires_unix if getattr(value, 'expires_unix', None) else None + expires = getattr(value, 'expires_unix', None) value = self.serialize(value) - if isinstance(value, bytes): - value = sqlite3.Binary(value) with self.connection(commit=True) as con: con.execute( f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)', @@ -263,8 +277,7 @@ class SQLiteDict(BaseStorage): yield row[0] def __len__(self): - with self.connection() as con: - return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0] + return self.count() def bulk_delete(self, keys=None, values=None): """Delete multiple items from the cache, without raising errors for any missing items. @@ -288,6 +301,22 @@ class SQLiteDict(BaseStorage): self.init_db() self.vacuum() + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired""" + filter_expr = '' + params: Tuple = () + if not expired: + filter_expr = 'WHERE expires is null or expires > ?' + params = (time(),) + query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}' + + with self.connection() as con: + return con.execute(query, params).fetchone()[0] + + def serialize(self, value: VT): + value = super().serialize(value) + return sqlite3.Binary(value) if isinstance(value, bytes) else value + def size(self) -> int: """Return the size of the database, in bytes. For an in-memory database, this will be an estimate based on page size. @@ -325,11 +354,19 @@ class SQLiteDict(BaseStorage): with self.connection(commit=True) as con: for row in con.execute( - f'SELECT value FROM {self.table_name} {filter_expr}' + f'SELECT key, value FROM {self.table_name} {filter_expr}' f' ORDER BY {key} {direction} {limit_expr}', params, ): - yield self.deserialize(row[0]) + result = self.deserialize(row[1]) + # Set cache key, if it's a response object + try: + result.cache_key = row[0] + except AttributeError: + pass + # Omit any results that can't be deserialized + if result: + yield result def vacuum(self): with self.connection(commit=True) as con: |