summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2022-10-28 12:46:19 -0500
committerJordan Cook <jordan.cook.git@proton.me>2022-10-28 17:38:28 -0500
commit81929b18b11c72dbf7d95b7358b9cb1e0b395c2c (patch)
tree2fe4e51cffe2135285b35aaed8d3c1688a048c0a
parent03a97079839a5683b202e3ea3f66e8ce54eedab4 (diff)
downloadrequests-cache-81929b18b11c72dbf7d95b7358b9cb1e0b395c2c.tar.gz
Add SQLite method to count unexpired responses in SQL
-rw-r--r--HISTORY.md1
-rw-r--r--requests_cache/backends/sqlite.py34
-rw-r--r--tests/integration/test_sqlite.py10
3 files changed, 39 insertions, 6 deletions
diff --git a/HISTORY.md b/HISTORY.md
index f602e7a..040c0f1 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -51,6 +51,7 @@
* Add `ttl_offset` argument to add a delay between cache expiration and deletion
* **SQLite**:
* Improve performance for removing expired responses with `delete()`
+ * Add `count()` method to count responses, with option to exclude expired responses (performs a fast indexed count instead of slower in-memory filtering)
* Add `size()` method to get estimated size of the database (including in-memory databases)
* Add `sorted()` method with sorting and other query options
* Add `wal` parameter to enable write-ahead logging
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 87a09e0..907e027 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -17,6 +17,7 @@ from typing import Collection, Iterator, List, Tuple, Type, Union
from platformdirs import user_cache_dir
+from requests_cache.backends.base import VT
from requests_cache.models.response import CachedResponse
from .._utils import chunkify, get_valid_kwargs
@@ -109,6 +110,14 @@ class SQLiteCache(BaseCache):
')'
)
+ def count(self, expired: bool = True) -> int:
+ """Count number of responses, optionally excluding expired
+
+ Args:
+ expired: Set to ``False`` to count only unexpired responses
+ """
+ return self.responses.count(expired=expired)
+
def filter( # type: ignore
self, valid: bool = True, expired: bool = True, **kwargs
) -> Iterator[CachedResponse]:
@@ -176,7 +185,7 @@ class SQLiteDict(BaseStorage):
with self._lock, self.connection() as con:
# Add new column to tables created before 1.0
try:
- con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT')
+ con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER')
except sqlite3.OperationalError:
pass
@@ -249,10 +258,8 @@ class SQLiteDict(BaseStorage):
def __setitem__(self, key, value):
# If available, set expiration as a timestamp in unix format
- expires = value.expires_unix if getattr(value, 'expires_unix', None) else None
+ expires = getattr(value, 'expires_unix', None)
value = self.serialize(value)
- if isinstance(value, bytes):
- value = sqlite3.Binary(value)
with self.connection(commit=True) as con:
con.execute(
f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)',
@@ -265,8 +272,7 @@ class SQLiteDict(BaseStorage):
yield row[0]
def __len__(self):
- with self.connection() as con:
- return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0]
+ return self.count()
def bulk_delete(self, keys=None, values=None):
"""Delete multiple items from the cache, without raising errors for any missing items.
@@ -290,6 +296,22 @@ class SQLiteDict(BaseStorage):
self.init_db()
self.vacuum()
+ def count(self, expired: bool = True) -> int:
+ """Count number of responses, optionally excluding expired"""
+ filter_expr = ''
+ params: Tuple = ()
+ if not expired:
+ filter_expr = 'WHERE expires is null or expires > ?'
+ params = (time(),)
+ query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}'
+
+ with self.connection() as con:
+ return con.execute(query, params).fetchone()[0]
+
+ def serialize(self, value: VT):
+ value = super().serialize(value)
+ return sqlite3.Binary(value) if isinstance(value, bytes) else value
+
def size(self) -> int:
"""Return the size of the database, in bytes. For an in-memory database, this will be an
estimate based on page size.
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 7e31634..9aeaa3f 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -297,6 +297,16 @@ class TestSQLiteCache(BaseCacheTest):
session = self.init_session()
assert session.cache.db_path == session.cache.responses.db_path
+ def test_count(self):
+ """count() should work the same as len(), but with the option to exclude expired responses"""
+ session = self.init_session()
+ now = datetime.utcnow()
+ session.cache.responses['key_1'] = CachedResponse(expires=now + timedelta(1))
+ session.cache.responses['key_2'] = CachedResponse(expires=now - timedelta(1))
+
+ assert session.cache.count() == 2
+ assert session.cache.count(expired=False) == 1
+
@patch.object(SQLiteDict, 'sorted')
def test_filter__expired_only(self, mock_sorted):
"""Filtering by expired only should use a more efficient SQL query"""