summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/deploy.yml4
-rw-r--r--noxfile.py2
-rw-r--r--tests/integration/base_cache_test.py38
3 files changed, 27 insertions, 17 deletions
diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml
index 1ff4976..686241b 100644
--- a/.github/workflows/deploy.yml
+++ b/.github/workflows/deploy.yml
@@ -74,8 +74,8 @@ jobs:
run: |
source $VENV
pytest -x ${{ env.XDIST_ARGS }} tests/unit
- pytest -x ${{ env.XDIST_ARGS }} tests/integration -k 'not multithreaded'
- STRESS_TEST_MULTIPLIER=5 pytest tests/integration -k 'multithreaded'
+ pytest -x ${{ env.XDIST_ARGS }} tests/integration -k 'not concurrency'
+ STRESS_TEST_MULTIPLIER=5 pytest tests/integration -k 'concurrency'
# Deploy stable builds on tags only, and pre-release builds from manual trigger ("workflow_dispatch")
release:
diff --git a/noxfile.py b/noxfile.py
index 7773dcf..3b72f50 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -57,7 +57,7 @@ def coverage(session):
def stress_test(session):
"""Run concurrency tests with a higher stress test multiplier"""
multiplier = session.posargs[0] if session.posargs else 5
- cmd = f'pytest {INTEGRATION_TESTS} -rs -k multithreaded'
+ cmd = f'pytest {INTEGRATION_TESTS} -rs -k concurrency'
session.run(*cmd.split(' '), env={'STRESS_TEST_MULTIPLIER': str(multiplier)})
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 48f1516..750c31f 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -1,7 +1,8 @@
"""Common tests to run for all backends (BaseCache subclasses)"""
import json
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import datetime
+from functools import partial
from io import BytesIO
from logging import getLogger
from random import randint
@@ -305,31 +306,28 @@ class BaseCacheTest:
assert not session.cache.has_url(httpbin('redirect/1'))
assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS])
+ @pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
@pytest.mark.parametrize('iteration', range(N_ITERATIONS))
- def test_multithreaded(self, iteration):
- """Run a multi-threaded stress test for each backend.
+ def test_concurrency(self, iteration, executor_class):
+ """Run multithreaded and multiprocess stress tests for each backend.
The number of workers (thread/processes), iterations, and requests per iteration can be
increased via the `STRESS_TEST_MULTIPLIER` environment variable.
"""
- session = self.init_session()
start = time()
url = httpbin('anything')
- # Use fewer unique requests/cached responses than total iterations, so we get some cache hits
- n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4)
-
- def send_request(_):
- i = randint(1, n_unique_responses)
- return session.get(url, params={f'key_{i}': f'value_{i}'})
-
- with ThreadPoolExecutor(max_workers=N_WORKERS) as executor:
- _ = list(executor.map(send_request, range(N_REQUESTS_PER_ITERATION)))
+ session_factory = partial(self.init_session, clear=False)
+ request_func = partial(_send_request, session_factory, url)
+ with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
+ _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))
+ # Some logging for debug purposes
elapsed = time() - start
average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS)
+ worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes'
logger.info(
f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with '
- f'{N_WORKERS} threads in {elapsed} s\n'
+ f'{N_WORKERS} {worker_type} in {elapsed} s\n'
f'Average time per request: {average} ms'
)
@@ -370,3 +368,15 @@ class BaseCacheTest:
elif post_type == 'json':
body = json.loads(response.request.body)
assert "api_key" not in body
+
+
+def _send_request(session_factory, url, _=None):
+ """Concurrent request function for stress tests. Defined in module scope so it can be serialized
+ to multiple processes.
+ """
+ # Use fewer unique requests/cached responses than total iterations, so we get some cache hits
+ n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4)
+ i = randint(1, n_unique_responses)
+
+ session = session_factory()
+ return session.get(url, params={f'key_{i}': f'value_{i}'})