summaryrefslogtreecommitdiff
path: root/tests/integration/base_cache_test.py
blob: 44d45eb7133a200e18a8f5ba4bd789109bfb475a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
"""Common tests to run for all backends (BaseCache subclasses)"""
import json
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import datetime
from functools import partial
from io import BytesIO
from logging import getLogger
from random import randint
from time import sleep, time
from typing import Dict, Type
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs, urlparse

import pytest
import requests
from requests import ConnectionError, PreparedRequest, Session

from requests_cache import ALL_METHODS, CachedResponse, CachedSession
from requests_cache.backends import BaseCache
from tests.conftest import (
    CACHE_NAME,
    ETAG,
    EXPIRED_DT,
    HTTPBIN_FORMATS,
    HTTPBIN_METHODS,
    HTTPDATE_STR,
    LAST_MODIFIED,
    MOCKED_URL_JSON,
    MOCKED_URL_JSON_LIST,
    N_ITERATIONS,
    N_REQUESTS_PER_ITERATION,
    N_WORKERS,
    USE_PYTEST_HTTPBIN,
    assert_delta_approx_equal,
    httpbin,
    mount_mock_adapter,
    skip_pypy,
)

logger = getLogger(__name__)


VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]


class BaseCacheTest:
    """Base class for testing cache backend classes"""

    backend_class: Type[BaseCache] = None
    init_kwargs: Dict = {}

    def init_session(self, cache_name=CACHE_NAME, clear=True, **kwargs) -> CachedSession:
        kwargs = {**self.init_kwargs, **kwargs}
        kwargs.setdefault('allowable_methods', ALL_METHODS)
        backend = self.backend_class(cache_name, **kwargs)
        if clear:
            backend.clear()

        return CachedSession(backend=backend, **kwargs)

    @classmethod
    def teardown_class(cls):
        session = cls().init_session(clear=True)
        session.close()

    @pytest.mark.parametrize('method', HTTPBIN_METHODS)
    @pytest.mark.parametrize('field', ['params', 'data', 'json'])
    def test_all_methods(self, field, method, serializer=None):
        """Test all relevant combinations of methods X data fields.
        Requests with different request params, data, or json should be cached under different keys.

        Note: Serializer combinations are only tested for Filesystem backend.
        """
        url = httpbin(method.lower())
        session = self.init_session(serializer=serializer)
        for params in [{'param_1': 1}, {'param_1': 2}, {'param_2': 2}]:
            assert session.request(method, url, **{field: params}).from_cache is False
            assert session.request(method, url, **{field: params}).from_cache is True

    @pytest.mark.parametrize('response_format', HTTPBIN_FORMATS)
    def test_all_response_formats(self, response_format, serializer=None):
        """Test all relevant combinations of response formats X serializers"""
        session = self.init_session(serializer=serializer)
        # Workaround for this issue: https://github.com/kevin1024/pytest-httpbin/issues/60
        if response_format == 'json' and USE_PYTEST_HTTPBIN:
            session.settings.allowable_codes = (200, 404)

        r1 = session.get(httpbin(response_format))
        r2 = session.get(httpbin(response_format))
        assert r1.from_cache is False
        assert r2.from_cache is True

        # For JSON responses, variations like whitespace won't be preserved
        if r1.text.startswith('{'):
            assert r1.json() == r2.json()
        else:
            assert r1.content == r2.content

    def test_response_no_duplicate_read(self):
        """Ensure that response data is read only once per request, whether it's cached or not"""
        session = self.init_session()
        storage_class = type(session.cache.responses)

        # Patch storage class to track number of times getitem is called, without changing behavior
        with patch.object(
            storage_class, '__getitem__', side_effect=lambda k: CachedResponse()
        ) as getitem:
            session.get(httpbin('get'))
            assert getitem.call_count == 1

            session.get(httpbin('get'))
            assert getitem.call_count == 2

    @pytest.mark.parametrize('n_redirects', range(1, 5))
    @pytest.mark.parametrize('endpoint', ['redirect', 'absolute-redirect', 'relative-redirect'])
    def test_redirect_history(self, endpoint, n_redirects):
        """Test redirect caching (in separate `redirects` cache) with all types of redirect
        endpoints, using different numbers of consecutive redirects
        """
        session = self.init_session()
        session.get(httpbin(f'{endpoint}/{n_redirects}'))
        r2 = session.get(httpbin('get'))

        assert r2.from_cache is True
        assert len(session.cache.redirects) == n_redirects

    @pytest.mark.parametrize('endpoint', ['redirect', 'absolute-redirect', 'relative-redirect'])
    def test_redirect_responses(self, endpoint):
        """Test redirect caching (in main `responses` cache) with all types of redirect endpoints"""
        session = self.init_session(allowable_codes=(200, 302))
        r1 = session.head(httpbin(f'{endpoint}/2'))
        r2 = session.head(httpbin(f'{endpoint}/2'))

        assert r2.from_cache is True
        assert len(session.cache.redirects) == 0
        assert isinstance(r1.next, PreparedRequest) and r1.next.url.endswith('redirect/1')
        assert isinstance(r2.next, PreparedRequest) and r2.next.url.endswith('redirect/1')

    def test_cookies(self):
        session = self.init_session()

        def get_json(url):
            return json.loads(session.get(url).text)

        response_1 = get_json(httpbin('cookies/set/test1/test2'))
        with session.cache_disabled():
            assert get_json(httpbin('cookies')) == response_1

        # From cache
        response_2 = get_json(httpbin('cookies'))
        assert response_2 == get_json(httpbin('cookies'))

        # Not from cache
        with session.cache_disabled():
            response_3 = get_json(httpbin('cookies/set/test3/test4'))
            assert response_3 == get_json(httpbin('cookies'))

    @pytest.mark.parametrize(
        'cache_control, request_headers, expected_expiration',
        [
            (True, {}, 60),
            (True, {'Cache-Control': 'max-age=360'}, 60),
            (True, {'Cache-Control': 'no-store'}, None),
            (True, {'Expires': HTTPDATE_STR, 'Cache-Control': 'max-age=360'}, 60),
            (False, {}, None),
            (False, {'Cache-Control': 'max-age=360'}, 360),
            (False, {'Cache-Control': 'no-store'}, None),
            (False, {'Expires': HTTPDATE_STR, 'Cache-Control': 'max-age=360'}, 360),
        ],
    )
    def test_cache_control_expiration(self, cache_control, request_headers, expected_expiration):
        """Test cache headers for both requests and responses. The `/cache/{seconds}` endpoint returns
        Cache-Control headers, which should be used unless request headers are sent.
        No headers should be used if `cache_control=False`.
        """
        session = self.init_session(cache_control=cache_control)
        session.get(httpbin('cache/60'), headers=request_headers)
        response = session.get(httpbin('cache/60'), headers=request_headers)

        if expected_expiration is None:
            assert response.expires is None
        else:
            assert_delta_approx_equal(datetime.utcnow(), response.expires, expected_expiration)

    @pytest.mark.parametrize(
        'cached_response_headers, expected_from_cache',
        [
            ({}, False),
            ({'ETag': ETAG}, True),
            ({'Last-Modified': LAST_MODIFIED}, True),
            ({'ETag': ETAG, 'Last-Modified': LAST_MODIFIED}, True),
        ],
    )
    def test_conditional_request(self, cached_response_headers, expected_from_cache):
        """Test behavior of ETag and Last-Modified headers and 304 responses.

        When a cached response contains one of these headers, corresponding request headers should
        be added. The `/cache` endpoint returns a 304 if one of these request headers is present.
        When this happens, the previously cached response should be returned.
        """
        response = requests.get(httpbin('cache'))
        response.headers = cached_response_headers

        session = self.init_session(cache_control=True)
        session.cache.save_response(response, expires=EXPIRED_DT)

        response = session.get(httpbin('cache'))
        assert response.from_cache == expected_from_cache

    @pytest.mark.parametrize('validator_headers', VALIDATOR_HEADERS)
    @pytest.mark.parametrize(
        'cache_headers',
        [
            {'Cache-Control': 'max-age=0'},
            {'Cache-Control': 'max-age=0,must-revalidate'},
            {'Cache-Control': 'no-cache'},
            {'Expires': '0'},
        ],
    )
    def test_conditional_request__response_headers(self, cache_headers, validator_headers):
        """Test response headers that can initiate revalidation before a cached response expires"""
        url = httpbin('response-headers')
        response_headers = {**cache_headers, **validator_headers}
        session = self.init_session(cache_control=True)

        # This endpoint returns request params as response headers
        session.get(url, params=response_headers)

        # It doesn't respond to conditional requests, but let's just pretend it does
        with patch.object(Session, 'send', return_value=MagicMock(status_code=304)):
            response = session.get(url, params=response_headers)

        assert response.from_cache is True

    @pytest.mark.parametrize('validator_headers', VALIDATOR_HEADERS)
    @pytest.mark.parametrize('cache_headers', [{'Cache-Control': 'max-age=0'}])
    def test_conditional_request__refreshes_expire_date(self, cache_headers, validator_headers):
        """Test that revalidation attempt with 304 responses causes stale entry to become fresh again considering
        Cache-Control header of the 304 response."""
        url = httpbin('response-headers')
        first_response_headers = {**cache_headers, **validator_headers}
        session = self.init_session(cache_control=True)

        # This endpoint returns request params as response headers
        session.get(url, params=first_response_headers)

        # Add different Response Header to mocked return value of the session.send() function.
        updated_response_headers = {**first_response_headers, 'Cache-Control': 'max-age=60'}
        with patch.object(
            Session,
            'send',
            return_value=MagicMock(status_code=304, headers=updated_response_headers),
        ):
            response = session.get(url, params=first_response_headers)
        assert response.from_cache is True
        assert response.is_expired is False

        # Make sure an immediate subsequent request will be served from the cache for another 60 seconds
        with patch.object(Session, 'send', side_effect=AssertionError):
            response = session.get(url, params=first_response_headers)
        assert response.from_cache is True
        assert response.is_expired is False

    def test_decode_gzip_response(self):
        """Test that gzip-compressed responses read decompressed content with decode_content=True"""
        session = self.init_session()
        response = session.get(httpbin('gzip'))
        assert b'gzipped' in response.content
        response.raw._fp = BytesIO(response.content)

        cached_response = CachedResponse.from_response(response)
        assert b'gzipped' in cached_response.content
        assert b'gzipped' in cached_response.raw.read(amt=None, decode_content=True)

    @pytest.mark.parametrize('decode_content', [True, False])
    @pytest.mark.parametrize('url', [MOCKED_URL_JSON, MOCKED_URL_JSON_LIST])
    def test_decode_json_response(self, decode_content, url):
        """Test that JSON responses (with both dict and list root) are correctly returned from the
        cache, regardless of `decode_content` setting"""
        session = self.init_session(decode_content=decode_content)
        session = mount_mock_adapter(session)

        r1 = session.get(url)
        r2 = session.get(url)
        assert r1.json() == r2.json()

    def test_multipart_upload(self):
        session = self.init_session()
        session.post(httpbin('post'), files={'file1': BytesIO(b'10' * 1024)})
        for _ in range(5):
            assert session.post(httpbin('post'), files={'file1': BytesIO(b'10' * 1024)}).from_cache

    def test_delete__expired(self):
        session = self.init_session(expire_after=1)

        # Populate the cache with several responses that should expire immediately
        for response_format in HTTPBIN_FORMATS:
            session.get(httpbin(response_format))
        session.get(httpbin('redirect/1'))
        sleep(1)

        # Cache a response and some redirects, which should be the only non-expired cache items
        session.get(httpbin('get'), expire_after=-1)
        session.get(httpbin('redirect/3'), expire_after=-1)
        assert len(session.cache.redirects.keys()) == 4
        session.cache.delete(expired=True)

        assert len(session.cache.responses.keys()) == 2
        assert len(session.cache.redirects.keys()) == 3
        assert not session.cache.contains(url=httpbin('redirect/1'))
        assert not any(session.cache.contains(url=httpbin(f)) for f in HTTPBIN_FORMATS)

    @pytest.mark.parametrize('method', HTTPBIN_METHODS)
    def test_filter_request_headers(self, method):
        url = httpbin(method.lower())
        session = self.init_session(ignored_parameters=['Authorization'])
        response = session.request(method, url, headers={"Authorization": "<Secret Key>"})
        assert response.from_cache is False
        response = session.request(method, url, headers={"Authorization": "<Secret Key>"})
        assert response.from_cache is True
        assert response.request.headers.get('Authorization') == 'REDACTED'

    @pytest.mark.parametrize('method', HTTPBIN_METHODS)
    def test_filter_request_query_parameters(self, method):
        url = httpbin(method.lower())
        session = self.init_session(ignored_parameters=['api_key'])
        response = session.request(method, url, params={"api_key": "<Secret Key>"})
        assert response.from_cache is False
        response = session.request(method, url, params={"api_key": "<Secret Key>"})
        assert response.from_cache is True
        query = urlparse(response.request.url).query
        query_dict = parse_qs(query)
        assert query_dict['api_key'] == ['REDACTED']

    @skip_pypy
    @pytest.mark.parametrize('post_type', ['data', 'json'])
    def test_filter_request_post_data(self, post_type):
        method = 'POST'
        url = httpbin(method.lower())
        body = {"api_key": "<Secret Key>"}
        headers = {}
        if post_type == 'data':
            body = json.dumps(body)
            headers = {'Content-Type': 'application/json'}
        session = self.init_session(ignored_parameters=['api_key'])

        response = session.request(method, url, headers=headers, **{post_type: body})
        response = session.request(method, url, headers=headers, **{post_type: body})
        assert response.from_cache is True

        parsed_body = json.loads(response.request.body)
        assert parsed_body['api_key'] == 'REDACTED'

    @pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
    @pytest.mark.parametrize('iteration', range(N_ITERATIONS))
    def test_concurrency(self, iteration, executor_class):
        """Run multithreaded and multiprocess stress tests for each backend.
        The number of workers (thread/processes), iterations, and requests per iteration can be
        increased via the `STRESS_TEST_MULTIPLIER` environment variable.
        """
        start = time()
        url = httpbin('anything')

        # For multithreading, we can share a session object, but we can't for multiprocessing
        session = self.init_session(clear=True, expire_after=1)
        if executor_class is ProcessPoolExecutor:
            session = None
        session_factory = partial(self.init_session, clear=False, expire_after=1)

        request_func = partial(_send_request, session, session_factory, url)
        with executor_class(max_workers=N_WORKERS) as executor:
            _ = list(executor.map(request_func, range(N_REQUESTS_PER_ITERATION)))

        # Some logging for debug purposes
        elapsed = time() - start
        average = (elapsed * 1000) / (N_ITERATIONS * N_WORKERS)
        worker_type = 'threads' if executor_class is ThreadPoolExecutor else 'processes'
        logger.info(
            f'{self.backend_class.__name__}: Ran {N_REQUESTS_PER_ITERATION} requests with '
            f'{N_WORKERS} {worker_type} in {elapsed} s\n'
            f'Average time per request: {average} ms'
        )


def _send_request(session, session_factory, url, _=None):
    """Concurrent request function for stress tests. Defined in module scope so it can be serialized
    to multiple processes.
    """
    # Use fewer unique requests/cached responses than total iterations, so we get some cache hits
    n_unique_responses = int(N_REQUESTS_PER_ITERATION / 4)
    i = randint(1, n_unique_responses)

    # Threads can share a session object, but processes will create their own session because it
    # can't be serialized
    if session is None:
        session = session_factory()

    sleep(0.01)
    try:
        return session.get(url, params={f'key_{i}': f'value_{i}'})
    # Sometimes the local http server is the bottleneck here; just retry once
    except ConnectionError:
        sleep(0.1)
        return session.get(url, params={f'key_{i}': f'value_{i}'})