From c83e4e523008337a6f8f638e65ca217177d8ed5c Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Tue, 28 Feb 2023 19:05:55 -0600 Subject: Share SQLite connection objects among threads and use lock for write operations instead of using thread-local connections --- tests/conftest.py | 2 +- tests/integration/base_cache_test.py | 30 ++++++++++++++++++++++-------- tests/integration/base_storage_test.py | 14 +++++++++++++- tests/integration/test_sqlite.py | 2 +- 4 files changed, 37 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/conftest.py b/tests/conftest.py index c91d6db..c53d866 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,7 +43,7 @@ logger = getLogger(__name__) # Allow running longer stress tests with an environment variable STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1')) -N_WORKERS = 2 * STRESS_TEST_MULTIPLIER +N_WORKERS = 5 * STRESS_TEST_MULTIPLIER N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER N_REQUESTS_PER_ITERATION = 10 + 10 * STRESS_TEST_MULTIPLIER diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index cdc3a98..d37bfa7 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse import pytest import requests -from requests import PreparedRequest, Session +from requests import ConnectionError, PreparedRequest, Session from requests_cache import ALL_METHODS, CachedResponse, CachedSession from requests_cache.backends import BaseCache @@ -36,6 +36,7 @@ from tests.conftest import ( logger = getLogger(__name__) + VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}] @@ -355,10 +356,14 @@ class BaseCacheTest: """ start = time() url = httpbin('anything') - self.init_session(clear=True) - session_factory = partial(self.init_session, clear=False) - request_func = partial(_send_request, session_factory, url) + # For multithreading, we can share a session object, but we can't for multiprocessing + session = self.init_session(clear=True, expire_after=1) + if executor_class is ProcessPoolExecutor: + session = None + session_factory = partial(self.init_session, clear=False, expire_after=1) + + request_func = partial(_send_request, session, session_factory, url) with executor_class(max_workers=N_WORKERS) as executor: _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION))) @@ -373,7 +378,7 @@ class BaseCacheTest: ) -def _send_request(session_factory, url, _=None): +def _send_request(session, session_factory, url, _=None): """Concurrent request function for stress tests. Defined in module scope so it can be serialized to multiple processes. """ @@ -381,6 +386,15 @@ def _send_request(session_factory, url, _=None): n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4) i = randint(1, n_unique_responses) - session = session_factory() - sleep(0.1) - return session.get(url, params={f'key_{i}': f'value_{i}'}) + # Threads can share a session object, but processes will create their own session because it + # can't be serialized + if session is None: + session = session_factory() + + sleep(0.01) + try: + return session.get(url, params={f'key_{i}': f'value_{i}'}) + # Sometimes the local http server is the bottleneck here; just retry once + except ConnectionError: + sleep(0.1) + return session.get(url, params={f'key_{i}': f'value_{i}'}) diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py index ce881e8..4870c02 100644 --- a/tests/integration/base_storage_test.py +++ b/tests/integration/base_storage_test.py @@ -1,4 +1,5 @@ """Common tests to run for all backends (BaseStorage subclasses)""" +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Dict, Type @@ -7,7 +8,7 @@ from attrs import define, field from requests_cache.backends import BaseStorage from requests_cache.models import CachedResponse -from tests.conftest import CACHE_NAME +from tests.conftest import CACHE_NAME, N_ITERATIONS, N_REQUESTS_PER_ITERATION, N_WORKERS class BaseStorageTest: @@ -158,6 +159,17 @@ class BaseStorageTest: for i in range(10): assert f'key_{i}' in str(cache) + def test_concurrency(self): + """Test a large number of concurrent write operations for each backend""" + cache = self.init_cache() + + def write(i): + cache[f'key_{i}'] = f'value_{i}' + + n_iterations = N_ITERATIONS * N_REQUESTS_PER_ITERATION * 10 + with ThreadPoolExecutor(max_workers=N_WORKERS * 2) as executor: + _ = list(executor.map(write, range(n_iterations))) + @define class BasicDataclass: diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index af94610..04e4b61 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -32,7 +32,7 @@ class TestSQLiteDict(BaseStorageTest): def test_connection_kwargs(self, mock_sqlite): """A spot check to make sure optional connection kwargs gets passed to connection""" cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???') - mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5) + mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5, check_same_thread=False) def test_use_cache_dir(self): relative_path = self.storage_class(CACHE_NAME).db_path -- cgit v1.2.1