From 81929b18b11c72dbf7d95b7358b9cb1e0b395c2c Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Fri, 28 Oct 2022 12:46:19 -0500 Subject: Add SQLite method to count unexpired responses in SQL --- HISTORY.md | 1 + requests_cache/backends/sqlite.py | 34 ++++++++++++++++++++++++++++------ tests/integration/test_sqlite.py | 10 ++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f602e7a..040c0f1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -51,6 +51,7 @@ * Add `ttl_offset` argument to add a delay between cache expiration and deletion * **SQLite**: * Improve performance for removing expired responses with `delete()` + * Add `count()` method to count responses, with option to exclude expired responses (performs a fast indexed count instead of slower in-memory filtering) * Add `size()` method to get estimated size of the database (including in-memory databases) * Add `sorted()` method with sorting and other query options * Add `wal` parameter to enable write-ahead logging diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 87a09e0..907e027 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,6 +17,7 @@ from typing import Collection, Iterator, List, Tuple, Type, Union from platformdirs import user_cache_dir +from requests_cache.backends.base import VT from requests_cache.models.response import CachedResponse from .._utils import chunkify, get_valid_kwargs @@ -109,6 +110,14 @@ class SQLiteCache(BaseCache): ')' ) + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired + + Args: + expired: Set to ``False`` to count only unexpired responses + """ + return self.responses.count(expired=expired) + def filter( # type: ignore self, valid: bool = True, expired: bool = True, **kwargs ) -> Iterator[CachedResponse]: @@ -176,7 +185,7 @@ class SQLiteDict(BaseStorage): with self._lock, self.connection() as con: # Add new column to tables created before 1.0 try: - con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires TEXT') + con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER') except sqlite3.OperationalError: pass @@ -249,10 +258,8 @@ class SQLiteDict(BaseStorage): def __setitem__(self, key, value): # If available, set expiration as a timestamp in unix format - expires = value.expires_unix if getattr(value, 'expires_unix', None) else None + expires = getattr(value, 'expires_unix', None) value = self.serialize(value) - if isinstance(value, bytes): - value = sqlite3.Binary(value) with self.connection(commit=True) as con: con.execute( f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)', @@ -265,8 +272,7 @@ class SQLiteDict(BaseStorage): yield row[0] def __len__(self): - with self.connection() as con: - return con.execute(f'SELECT COUNT(key) FROM {self.table_name}').fetchone()[0] + return self.count() def bulk_delete(self, keys=None, values=None): """Delete multiple items from the cache, without raising errors for any missing items. @@ -290,6 +296,22 @@ class SQLiteDict(BaseStorage): self.init_db() self.vacuum() + def count(self, expired: bool = True) -> int: + """Count number of responses, optionally excluding expired""" + filter_expr = '' + params: Tuple = () + if not expired: + filter_expr = 'WHERE expires is null or expires > ?' + params = (time(),) + query = f'SELECT COUNT(key) FROM {self.table_name} {filter_expr}' + + with self.connection() as con: + return con.execute(query, params).fetchone()[0] + + def serialize(self, value: VT): + value = super().serialize(value) + return sqlite3.Binary(value) if isinstance(value, bytes) else value + def size(self) -> int: """Return the size of the database, in bytes. For an in-memory database, this will be an estimate based on page size. diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index 7e31634..9aeaa3f 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -297,6 +297,16 @@ class TestSQLiteCache(BaseCacheTest): session = self.init_session() assert session.cache.db_path == session.cache.responses.db_path + def test_count(self): + """count() should work the same as len(), but with the option to exclude expired responses""" + session = self.init_session() + now = datetime.utcnow() + session.cache.responses['key_1'] = CachedResponse(expires=now + timedelta(1)) + session.cache.responses['key_2'] = CachedResponse(expires=now - timedelta(1)) + + assert session.cache.count() == 2 + assert session.cache.count(expired=False) == 1 + @patch.object(SQLiteDict, 'sorted') def test_filter__expired_only(self, mock_sorted): """Filtering by expired only should use a more efficient SQL query""" -- cgit v1.2.1