From c83e4e523008337a6f8f638e65ca217177d8ed5c Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Tue, 28 Feb 2023 19:05:55 -0600 Subject: Share SQLite connection objects among threads and use lock for write operations instead of using thread-local connections --- HISTORY.md | 2 ++ requests_cache/backends/sqlite.py | 46 +++++++++++++++++++--------------- tests/conftest.py | 2 +- tests/integration/base_cache_test.py | 30 ++++++++++++++++------ tests/integration/base_storage_test.py | 14 ++++++++++- tests/integration/test_sqlite.py | 2 +- 6 files changed, 65 insertions(+), 31 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 6ffdfbe..4c294d2 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -55,6 +55,7 @@ * Add `ttl_offset` argument to add a delay between cache expiration and deletion * **SQLite**: * Improve performance for removing expired responses with `delete()` + * Improve performance (slightly) with a large number of threads and high request rate * Add `count()` method to count responses, with option to exclude expired responses (performs a fast indexed count instead of slower in-memory filtering) * Add `size()` method to get estimated size of the database (including in-memory databases) * Add `sorted()` method with sorting and other query options @@ -110,6 +111,7 @@ 🪲 **Bugfixes:** * Fix usage of memory backend with `install_cache()` +* Fix an uncommon `OperationalError: database is locked` in SQLite backend * Fix issue on Windows with occasional missing `CachedResponse.created_at` timestamp * Add `CachedRequest.path_url` property for compatibility with `RequestEncodingMixin` * Fix potential `AttributeError` due to undetected imports when requests-cache is bundled in a PyInstaller package diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 0a892ec..2b03d4a 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -100,7 +100,7 @@ class SQLiteCache(BaseCache): def _delete_expired(self): """A more efficient implementation deleting expired responses in SQL""" - with self.responses._lock, self.responses.connection(commit=True) as con: + with self.responses.connection(commit=True) as con: con.execute( f'DELETE FROM {self.responses.table_name} WHERE expires <= ?', (round(time()),) ) @@ -181,9 +181,10 @@ class SQLiteDict(BaseStorage): ): super().__init__(serializer=serializer, **kwargs) self._can_commit = True - self._local_context = threading.local() + self._connection: Optional[sqlite3.Connection] = None self._lock = threading.RLock() self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) + self.connection_kwargs.setdefault('check_same_thread', False) if use_memory: self.connection_kwargs['uri'] = True self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory) @@ -195,7 +196,7 @@ class SQLiteDict(BaseStorage): def init_db(self): """Initialize the database, if it hasn't already been""" self.close() - with self._lock, self.connection() as con: + with self.connection(commit=True) as con: # Add new column to tables created before 1.0 try: con.execute(f'ALTER TABLE {self.table_name} ADD COLUMN expires INTEGER') @@ -214,26 +215,31 @@ class SQLiteDict(BaseStorage): @contextmanager def connection(self, commit=False) -> Iterator[sqlite3.Connection]: """Get a thread-local database connection""" - if not getattr(self._local_context, 'con', None): + if not self._connection: logger.debug(f'Opening connection to {self.db_path}:{self.table_name}') - self._local_context.con = sqlite3.connect(self.db_path, **self.connection_kwargs) + self._connection = sqlite3.connect(self.db_path, **self.connection_kwargs) if self.fast_save: - self._local_context.con.execute('PRAGMA synchronous = 0;') + self._connection.execute('PRAGMA synchronous = 0;') if self.wal: - self._local_context.con.execute('PRAGMA journal_mode = wal') - yield self._local_context.con + self._connection.execute('PRAGMA journal_mode = wal') + + # Any write operations need to be run in serial + if commit and self._can_commit: + self._lock.acquire() + yield self._connection if commit and self._can_commit: - self._local_context.con.commit() + self._connection.commit() + self._lock.release() def close(self): """Close any active connections""" - if getattr(self._local_context, 'con', None): - self._local_context.con.close() - self._local_context.con = None + if self._connection: + self._connection.close() + self._connection = None @contextmanager def bulk_commit(self): - """Context manager used to speed up insertion of a large number of records + """Insert a large number of records within a single transaction Example: @@ -243,13 +249,13 @@ class SQLiteDict(BaseStorage): ... d1[i] = i * 2 """ - self._can_commit = False - try: - yield - if hasattr(self._local_context, 'con'): - self._local_context.con.commit() - finally: - self._can_commit = True + with self._lock: + self._can_commit = False + try: + yield + self._connection.commit() + finally: + self._can_commit = True def __del__(self): self.close() diff --git a/tests/conftest.py b/tests/conftest.py index c91d6db..c53d866 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,7 +43,7 @@ logger = getLogger(__name__) # Allow running longer stress tests with an environment variable STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1')) -N_WORKERS = 2 * STRESS_TEST_MULTIPLIER +N_WORKERS = 5 * STRESS_TEST_MULTIPLIER N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER N_REQUESTS_PER_ITERATION = 10 + 10 * STRESS_TEST_MULTIPLIER diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index cdc3a98..d37bfa7 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse import pytest import requests -from requests import PreparedRequest, Session +from requests import ConnectionError, PreparedRequest, Session from requests_cache import ALL_METHODS, CachedResponse, CachedSession from requests_cache.backends import BaseCache @@ -36,6 +36,7 @@ from tests.conftest import ( logger = getLogger(__name__) + VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}] @@ -355,10 +356,14 @@ class BaseCacheTest: """ start = time() url = httpbin('anything') - self.init_session(clear=True) - session_factory = partial(self.init_session, clear=False) - request_func = partial(_send_request, session_factory, url) + # For multithreading, we can share a session object, but we can't for multiprocessing + session = self.init_session(clear=True, expire_after=1) + if executor_class is ProcessPoolExecutor: + session = None + session_factory = partial(self.init_session, clear=False, expire_after=1) + + request_func = partial(_send_request, session, session_factory, url) with executor_class(max_workers=N_WORKERS) as executor: _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION))) @@ -373,7 +378,7 @@ class BaseCacheTest: ) -def _send_request(session_factory, url, _=None): +def _send_request(session, session_factory, url, _=None): """Concurrent request function for stress tests. Defined in module scope so it can be serialized to multiple processes. """ @@ -381,6 +386,15 @@ def _send_request(session_factory, url, _=None): n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4) i = randint(1, n_unique_responses) - session = session_factory() - sleep(0.1) - return session.get(url, params={f'key_{i}': f'value_{i}'}) + # Threads can share a session object, but processes will create their own session because it + # can't be serialized + if session is None: + session = session_factory() + + sleep(0.01) + try: + return session.get(url, params={f'key_{i}': f'value_{i}'}) + # Sometimes the local http server is the bottleneck here; just retry once + except ConnectionError: + sleep(0.1) + return session.get(url, params={f'key_{i}': f'value_{i}'}) diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py index ce881e8..4870c02 100644 --- a/tests/integration/base_storage_test.py +++ b/tests/integration/base_storage_test.py @@ -1,4 +1,5 @@ """Common tests to run for all backends (BaseStorage subclasses)""" +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Dict, Type @@ -7,7 +8,7 @@ from attrs import define, field from requests_cache.backends import BaseStorage from requests_cache.models import CachedResponse -from tests.conftest import CACHE_NAME +from tests.conftest import CACHE_NAME, N_ITERATIONS, N_REQUESTS_PER_ITERATION, N_WORKERS class BaseStorageTest: @@ -158,6 +159,17 @@ class BaseStorageTest: for i in range(10): assert f'key_{i}' in str(cache) + def test_concurrency(self): + """Test a large number of concurrent write operations for each backend""" + cache = self.init_cache() + + def write(i): + cache[f'key_{i}'] = f'value_{i}' + + n_iterations = N_ITERATIONS * N_REQUESTS_PER_ITERATION * 10 + with ThreadPoolExecutor(max_workers=N_WORKERS * 2) as executor: + _ = list(executor.map(write, range(n_iterations))) + @define class BasicDataclass: diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index af94610..04e4b61 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -32,7 +32,7 @@ class TestSQLiteDict(BaseStorageTest): def test_connection_kwargs(self, mock_sqlite): """A spot check to make sure optional connection kwargs gets passed to connection""" cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???') - mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5) + mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5, check_same_thread=False) def test_use_cache_dir(self): relative_path = self.storage_class(CACHE_NAME).db_path -- cgit v1.2.1