summaryrefslogtreecommitdiff
path: root/requests_cache/backends/sqlite.py
diff options
context:
space:
mode:
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r--requests_cache/backends/sqlite.py79
1 files changed, 58 insertions, 21 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 502c0dd..f936a77 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -17,7 +17,9 @@ from typing import Collection, Iterator, List, Tuple, Type, Union
from platformdirs import user_cache_dir
+from requests_cache.backends.base import VT
from requests_cache.models.response import CachedResponse
+from requests_cache.policy import ExpirationTime
from .._utils import chunkify, get_valid_kwargs
from . import BaseCache, BaseStorage
@@ -55,8 +57,8 @@ class SQLiteCache(BaseCache):
return self.responses.db_path
def clear(self):
- """Clear the cache. If this fails due to a corrupted cache or other I/O error, this will
- attempt to delete the cache file and re-initialize.
+ """Delete all items from the cache. If this fails due to a corrupted cache or other I/O
+ error, this will attempt to delete the cache file and re-initialize.
"""
try:
super().clear()
@@ -67,13 +69,13 @@ class SQLiteCache(BaseCache):
self.responses.init_db()
self.redirects.init_db()
+ # A more efficient SQLite implementation of :py:meth:`BaseCache.delete`
def delete(
self,
*keys: str,
expired: bool = False,
**kwargs,
):
- """A more efficient SQLite implementation of :py:meth:`BaseCache.delete`"""
if keys:
self.responses.bulk_delete(keys)
if expired:
@@ -109,19 +111,31 @@ class SQLiteCache(BaseCache):
')'
)
- def filter( # type: ignore
- self, valid: bool = True, expired: bool = True, **kwargs
- ) -> Iterator[CachedResponse]:
- """A more efficient implementation of :py:meth:`BaseCache.filter`, in the case where we want
- to get **only** expired responses
+ def count(self, expired: bool = True) -> int:
+ """Count number of responses, optionally excluding expired
+
+ Args:
+ expired: Set to ``False`` to count only unexpired responses
"""
- if expired and not valid and not kwargs:
- return self.responses.sorted(expired=True)
+ return self.responses.count(expired=expired)
+
+ # A more efficient implementation of :py:meth:`BaseCache.filter` to make use of indexes
+ def filter(
+ self,
+ valid: bool = True,
+ expired: bool = True,
+ invalid: bool = False,
+ older_than: ExpirationTime = None,
+ ) -> Iterator[CachedResponse]:
+ if valid and not invalid:
+ return self.responses.sorted(expired=expired)
else:
- return super().filter(valid, expired, **kwargs)
+ return super().filter(
+ valid=valid, expired=expired, invalid=invalid, older_than=older_than
+ )
+ # A more efficient implementation of :py:meth:`BaseCache.recreate_keys
def recreate_keys(self):
- """A more efficient implementation of :py:meth:`BaseCache.recreate_keys`"""
with self.responses.bulk_commit():
super().recreate_keys()
@@ -162,6 +176,8 @@ class SQLiteDict(BaseStorage):
self._local_context = threading.local()
self._lock = threading.RLock()
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
+ if use_memory:
+ self.connection_kwargs['uri'] = True
self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory)
self.fast_save = fast_save
self.table_name = table_name
@@ -172,9 +188,9 @@ class SQLiteDict(BaseStorage):
"""Initialize the database, if it hasn't already been"""
self.close()
with self._lock, self.connection() as con:
- # Add new column to tables created before 0.10
+ # Add new column to tables created before 1.0
try:
- con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT')
+ con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER')
except sqlite3.OperationalError:
pass
@@ -247,10 +263,8 @@ class SQLiteDict(BaseStorage):
def __setitem__(self, key, value):
# If available, set expiration as a timestamp in unix format
- expires = value.expires_unix if getattr(value, 'expires_unix', None) else None
+ expires = getattr(value, 'expires_unix', None)
value = self.serialize(value)
- if isinstance(value, bytes):
- value = sqlite3.Binary(value)
with self.connection(commit=True) as con:
con.execute(
f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)',
@@ -263,8 +277,7 @@ class SQLiteDict(BaseStorage):
yield row[0]
def __len__(self):
- with self.connection() as con:
- return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0]
+ return self.count()
def bulk_delete(self, keys=None, values=None):
"""Delete multiple items from the cache, without raising errors for any missing items.
@@ -288,6 +301,22 @@ class SQLiteDict(BaseStorage):
self.init_db()
self.vacuum()
+ def count(self, expired: bool = True) -> int:
+ """Count number of responses, optionally excluding expired"""
+ filter_expr = ''
+ params: Tuple = ()
+ if not expired:
+ filter_expr = 'WHERE expires is null or expires > ?'
+ params = (time(),)
+ query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}'
+
+ with self.connection() as con:
+ return con.execute(query, params).fetchone()[0]
+
+ def serialize(self, value: VT):
+ value = super().serialize(value)
+ return sqlite3.Binary(value) if isinstance(value, bytes) else value
+
def size(self) -> int:
"""Return the size of the database, in bytes. For an in-memory database, this will be an
estimate based on page size.
@@ -325,11 +354,19 @@ class SQLiteDict(BaseStorage):
with self.connection(commit=True) as con:
for row in con.execute(
- f'SELECT value FROM {self.table_name} {filter_expr}'
+ f'SELECT key, value FROM {self.table_name} {filter_expr}'
f' ORDER BY {key} {direction} {limit_expr}',
params,
):
- yield self.deserialize(row[0])
+ result = self.deserialize(row[1])
+ # Set cache key, if it's a response object
+ try:
+ result.cache_key = row[0]
+ except AttributeError:
+ pass
+ # Omit any results that can't be deserialized
+ if result:
+ yield result
def vacuum(self):
with self.connection(commit=True) as con: