diff options
-rw-r--r-- | .github/workflows/deploy.yml | 4 | ||||
-rw-r--r-- | noxfile.py | 2 | ||||
-rw-r--r-- | tests/integration/base_cache_test.py | 38 |
3 files changed, 27 insertions, 17 deletions
diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 1ff4976..686241b 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -74,8 +74,8 @@ jobs: run: | source $VENV pytest -x ${{ env.XDIST_ARGS }} tests/unit - pytest -x ${{ env.XDIST_ARGS }} tests/integration -k 'not multithreaded' - STRESS_TEST_MULTIPLIER=5 pytest tests/integration -k 'multithreaded' + pytest -x ${{ env.XDIST_ARGS }} tests/integration -k 'not concurrency' + STRESS_TEST_MULTIPLIER=5 pytest tests/integration -k 'concurrency' # Deploy stable builds on tags only, and pre-release builds from manual trigger ("workflow_dispatch") release: @@ -57,7 +57,7 @@ def coverage(session): def stress_test(session): """Run concurrency tests with a higher stress test multiplier""" multiplier = session.posargs[0] if session.posargs else 5 - cmd = f'pytest {INTEGRATION_TESTS} -rs -k multithreaded' + cmd = f'pytest {INTEGRATION_TESTS} -rs -k concurrency' session.run(*cmd.split(' '), env={'STRESS_TEST_MULTIPLIER': str(multiplier)}) diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 48f1516..750c31f 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -1,7 +1,8 @@ """Common tests to run for all backends (BaseCache subclasses)""" import json -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from datetime import datetime +from functools import partial from io import BytesIO from logging import getLogger from random import randint @@ -305,31 +306,28 @@ class BaseCacheTest: assert not session.cache.has_url(httpbin('redirect/1')) assert not any([session.cache.has_url(httpbin(f)) for f in HTTPBIN_FORMATS]) + @pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor]) @pytest.mark.parametrize('iteration', range(N_ITERATIONS)) - def test_multithreaded(self, iteration): - """Run a multi-threaded stress test for each backend. + def test_concurrency(self, iteration, executor_class): + """Run multithreaded and multiprocess stress tests for each backend. The number of workers (thread/processes), iterations, and requests per iteration can be increased via the `STRESS_TEST_MULTIPLIER` environment variable. """ - session = self.init_session() start = time() url = httpbin('anything') - # Use fewer unique requests/cached responses than total iterations, so we get some cache hits - n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4) - - def send_request(_): - i = randint(1, n_unique_responses) - return session.get(url, params={f'key_{i}': f'value_{i}'}) - - with ThreadPoolExecutor(max_workers=N_WORKERS) as executor: - _ = list(executor.map(send_request, range(N_REQUESTS_PER_ITERATION))) + session_factory = partial(self.init_session, clear=False) + request_func = partial(_send_request, session_factory, url) + with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: + _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION))) + # Some logging for debug purposes elapsed = time() - start average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS) + worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes' logger.info( f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with ' - f'{N_WORKERS} threads in {elapsed} s\n' + f'{N_WORKERS} {worker_type} in {elapsed} s\n' f'Average time per request: {average} ms' ) @@ -370,3 +368,15 @@ class BaseCacheTest: elif post_type == 'json': body = json.loads(response.request.body) assert "api_key" not in body + + +def _send_request(session_factory, url, _=None): + """Concurrent request function for stress tests. Defined in module scope so it can be serialized + to multiple processes. + """ + # Use fewer unique requests/cached responses than total iterations, so we get some cache hits + n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4) + i = randint(1, n_unique_responses) + + session = session_factory() + return session.get(url, params={f'key_{i}': f'value_{i}'}) |