summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2023-02-28 19:05:55 -0600
committerJordan Cook <jordan.cook.git@proton.me>2023-03-01 15:22:18 -0600
commitc83e4e523008337a6f8f638e65ca217177d8ed5c (patch)
tree40a8b5a6c0773afe4e08bec93743a9650f555dd5 /tests
parent0f330b54e15966ff74582cfa7d794f6b844d324c (diff)
downloadrequests-cache-c83e4e523008337a6f8f638e65ca217177d8ed5c.tar.gz
Share SQLite connection objects among threads and use lock for write operations instead of using thread-local connections
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py2
-rw-r--r--tests/integration/base_cache_test.py30
-rw-r--r--tests/integration/base_storage_test.py14
-rw-r--r--tests/integration/test_sqlite.py2
4 files changed, 37 insertions, 11 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
index c91d6db..c53d866 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -43,7 +43,7 @@ logger = getLogger(__name__)
# Allow running longer stress tests with an environment variable
STRESS_TEST_MULTIPLIER = int(os.getenv('STRESS_TEST_MULTIPLIER', '1'))
-N_WORKERS = 2 * STRESS_TEST_MULTIPLIER
+N_WORKERS = 5 * STRESS_TEST_MULTIPLIER
N_ITERATIONS = 4 * STRESS_TEST_MULTIPLIER
N_REQUESTS_PER_ITERATION = 10 + 10 * STRESS_TEST_MULTIPLIER
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index cdc3a98..d37bfa7 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -13,7 +13,7 @@ from urllib.parse import parse_qs, urlparse
import pytest
import requests
-from requests import PreparedRequest, Session
+from requests import ConnectionError, PreparedRequest, Session
from requests_cache import ALL_METHODS, CachedResponse, CachedSession
from requests_cache.backends import BaseCache
@@ -36,6 +36,7 @@ from tests.conftest import (
logger = getLogger(__name__)
+
VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]
@@ -355,10 +356,14 @@ class BaseCacheTest:
"""
start = time()
url = httpbin('anything')
- self.init_session(clear=True)
- session_factory = partial(self.init_session, clear=False)
- request_func = partial(_send_request, session_factory, url)
+ # For multithreading, we can share a session object, but we can't for multiprocessing
+ session = self.init_session(clear=True, expire_after=1)
+ if executor_class is ProcessPoolExecutor:
+ session = None
+ session_factory = partial(self.init_session, clear=False, expire_after=1)
+
+ request_func = partial(_send_request, session, session_factory, url)
with executor_class(max_workers=N_WORKERS) as executor:
_ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))
@@ -373,7 +378,7 @@ class BaseCacheTest:
)
-def _send_request(session_factory, url, _=None):
+def _send_request(session, session_factory, url, _=None):
"""Concurrent request function for stress tests. Defined in module scope so it can be serialized
to multiple processes.
"""
@@ -381,6 +386,15 @@ def _send_request(session_factory, url, _=None):
n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4)
i = randint(1, n_unique_responses)
- session = session_factory()
- sleep(0.1)
- return session.get(url, params={f'key_{i}': f'value_{i}'})
+ # Threads can share a session object, but processes will create their own session because it
+ # can't be serialized
+ if session is None:
+ session = session_factory()
+
+ sleep(0.01)
+ try:
+ return session.get(url, params={f'key_{i}': f'value_{i}'})
+ # Sometimes the local http server is the bottleneck here; just retry once
+ except ConnectionError:
+ sleep(0.1)
+ return session.get(url, params={f'key_{i}': f'value_{i}'})
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index ce881e8..4870c02 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -1,4 +1,5 @@
"""Common tests to run for all backends (BaseStorage subclasses)"""
+from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Dict, Type
@@ -7,7 +8,7 @@ from attrs import define, field
from requests_cache.backends import BaseStorage
from requests_cache.models import CachedResponse
-from tests.conftest import CACHE_NAME
+from tests.conftest import CACHE_NAME, N_ITERATIONS, N_REQUESTS_PER_ITERATION, N_WORKERS
class BaseStorageTest:
@@ -158,6 +159,17 @@ class BaseStorageTest:
for i in range(10):
assert f'key_{i}' in str(cache)
+ def test_concurrency(self):
+ """Test a large number of concurrent write operations for each backend"""
+ cache = self.init_cache()
+
+ def write(i):
+ cache[f'key_{i}'] = f'value_{i}'
+
+ n_iterations = N_ITERATIONS * N_REQUESTS_PER_ITERATION * 10
+ with ThreadPoolExecutor(max_workers=N_WORKERS * 2) as executor:
+ _ = list(executor.map(write, range(n_iterations)))
+
@define
class BasicDataclass:
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index af94610..04e4b61 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -32,7 +32,7 @@ class TestSQLiteDict(BaseStorageTest):
def test_connection_kwargs(self, mock_sqlite):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
cache = self.storage_class('test', use_temp=True, timeout=0.5, invalid_kwarg='???')
- mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5)
+ mock_sqlite.connect.assert_called_with(cache.db_path, timeout=0.5, check_same_thread=False)
def test_use_cache_dir(self):
relative_path = self.storage_class(CACHE_NAME).db_path