summaryrefslogtreecommitdiff
path: root/examples/generate_test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/generate_test_db.py')
-rwxr-xr-x[-rw-r--r--]examples/generate_test_db.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/examples/generate_test_db.py b/examples/generate_test_db.py
index e3ce60f..6137fff 100644..100755
--- a/examples/generate_test_db.py
+++ b/examples/generate_test_db.py
@@ -3,18 +3,21 @@
This is useful, for example, for reproducing issues that only occur with large caches.
"""
import logging
+from datetime import datetime, timedelta
from os import urandom
from os.path import getsize
from random import random
from time import perf_counter as time
+import requests
from rich.progress import Progress
from requests_cache import ALL_METHODS, CachedSession, CachedResponse
-from requests_cache.response import format_file_size
-from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, httpbin
+from requests_cache.models.response import format_file_size
+from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS
+BASE_RESPONSE = requests.get('https://httpbin.org/get')
CACHE_NAME = 'rubbish_bin.sqlite'
HTTPBIN_EXTRA_ENDPOINTS = [
'anything',
@@ -41,37 +44,45 @@ class InvalidResponse(CachedResponse):
def populate_cache(progress, task):
session = CachedSession(CACHE_NAME, backend='sqlite', allowable_methods=ALL_METHODS)
+ n_previous_responses = len(session.cache.responses)
# Cache a variety of different response formats, which may result in different behavior
- urls = [('GET', httpbin(endpoint)) for endpoint in HTTPBIN_FORMATS + HTTPBIN_EXTRA_ENDPOINTS]
- urls += [(method, httpbin(method.lower())) for method in HTTPBIN_METHODS]
+ urls = [
+ ('GET', 'https://httpbin.org/{endpoint}')
+ for endpoint in HTTPBIN_FORMATS + HTTPBIN_EXTRA_ENDPOINTS
+ ]
+ urls += [(method, 'https://httpbin.org/{method.lower()}') for method in HTTPBIN_METHODS]
for method, url in urls:
session.request(method, url)
progress.update(task, advance=1)
# Cache a large number of responses with randomized response content, which will expire at random times
- response = session.get(httpbin('get'))
with session.cache.responses.bulk_commit():
for i in range(N_RESPONSES):
- new_response = CachedResponse(response)
- n_bytes = int(random() * MAX_RESPONSE_SIZE)
- new_response._content = urandom(n_bytes)
-
- new_response.request.url += f'/response_{i}'
- expire_after = random() * MAX_EXPIRE_AFTER
- session.cache.save_response(new_response, expire_after=expire_after)
+ new_response = get_randomized_response(i + n_previous_responses)
+ expires = datetime.now() + timedelta(seconds=random() * MAX_EXPIRE_AFTER)
+ session.cache.save_response(new_response, expires=expires)
progress.update(task, advance=1)
# Add some invalid responses
with session.cache.responses.bulk_commit():
for i in range(N_INVALID_RESPONSES):
- new_response = InvalidResponse(response)
+ new_response = InvalidResponse.from_response(BASE_RESPONSE)
new_response.request.url += f'/invalid_response_{i}'
key = session.cache.create_key(new_response.request)
session.cache.responses[key] = new_response
progress.update(task, advance=1)
+def get_randomized_response(i=0):
+ """Get a response with randomized content"""
+ new_response = CachedResponse.from_response(BASE_RESPONSE)
+ n_bytes = int(random() * MAX_RESPONSE_SIZE)
+ new_response._content = urandom(n_bytes)
+ new_response.request.url += f'/response_{i}'
+ return new_response
+
+
def remove_expired_responses(expire_after=None):
logger.setLevel('DEBUG')
session = CachedSession(CACHE_NAME)