diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2021-04-22 20:49:24 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-04-22 20:55:23 -0500 |
commit | 2698d5ebb81f7639ae1f2e60b80e7d4ab5d6eff4 (patch) | |
tree | 4d323683d00a3d83364ce9d9b1cf460cd3b0b247 | |
parent | ca49461153ee321b757fb171bb9c8f22e7840e75 (diff) | |
parent | 26f41c191c69a3d97e06cfa950b8c55867b7d9eb (diff) | |
download | requests-cache-2698d5ebb81f7639ae1f2e60b80e7d4ab5d6eff4.tar.gz |
Merge branch 'use-temp'
-rw-r--r-- | HISTORY.md | 3 | ||||
-rw-r--r-- | requests_cache/backends/filesystem.py | 3 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 26 | ||||
-rw-r--r-- | tests/integration/test_filesystem.py | 7 | ||||
-rw-r--r-- | tests/integration/test_sqlite.py | 9 |
5 files changed, 39 insertions, 9 deletions
@@ -1,8 +1,9 @@ # History ## 0.7.0 (2021-TBD) -* Add a filesystem cache +* Add a filesystem backend * Add option to manually cache response objects +* Add `use_temp` option to both SQLite and filesystem backends to store files in a temp directory * Use thread-local connections for SQLite backend * Fix `DynamoDbDict.__iter__` to return keys instead of values diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py index c9fc515..bf3d23c 100644 --- a/requests_cache/backends/filesystem.py +++ b/requests_cache/backends/filesystem.py @@ -81,8 +81,11 @@ class FileDict(BaseStorage): def _get_cache_dir(cache_dir: Union[Path, str], use_temp: bool) -> str: + # Save to a temp directory, if specified if use_temp and not isabs(cache_dir): cache_dir = join(gettempdir(), cache_dir, 'responses') + + # Expand relative and user paths (~/*), and make sure parent dirs exist cache_dir = abspath(expanduser(str(cache_dir))) makedirs(cache_dir, exist_ok=True) return cache_dir diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 9c98fcd..29ac083 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -3,8 +3,9 @@ import threading from contextlib import contextmanager from logging import getLogger from os import makedirs -from os.path import abspath, basename, dirname, expanduser +from os.path import abspath, basename, dirname, expanduser, isabs, join from pathlib import Path +from tempfile import gettempdir from typing import Type, Union from . import BaseCache, BaseStorage, get_valid_kwargs @@ -18,15 +19,19 @@ class DbCache(BaseCache): Args: db_path: Database file path (expands user paths and creates parent dirs) + use_temp: Store database in a temp directory (e.g., ``/tmp/http_cache.sqlite``). + Note: if ``db_path`` is an absolute path, this option will be ignored. fast_save: Speedup cache saving up to 50 times but with possibility of data loss. See :py:class:`.DbDict` for more info kwargs: Additional keyword arguments for :py:func:`sqlite3.connect` """ - def __init__(self, db_path: Union[Path, str] = 'http_cache', fast_save: bool = False, **kwargs): + def __init__( + self, db_path: Union[Path, str] = 'http_cache', use_temp: bool = False, fast_save: bool = False, **kwargs + ): super().__init__(**kwargs) - self.responses = DbPickleDict(db_path, table_name='responses', fast_save=fast_save, **kwargs) - self.redirects = DbDict(db_path, table_name='redirects', **kwargs) + self.responses = DbPickleDict(db_path, table_name='responses', use_temp=use_temp, fast_save=fast_save, **kwargs) + self.redirects = DbDict(db_path, table_name='redirects', use_temp=use_temp, **kwargs) def remove_expired_responses(self, *args, **kwargs): """Remove expired responses from the cache, with additional cleanup""" @@ -55,11 +60,11 @@ class DbDict(BaseStorage): timeout: Timeout for acquiring a database lock """ - def __init__(self, db_path, table_name='http_cache', fast_save=False, **kwargs): + def __init__(self, db_path, table_name='http_cache', fast_save=False, use_temp: bool = False, **kwargs): kwargs.setdefault('suppress_warnings', True) super().__init__(**kwargs) self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) - self.db_path = _get_db_path(db_path) + self.db_path = _get_db_path(db_path, use_temp) self.fast_save = fast_save self.table_name = table_name @@ -154,12 +159,17 @@ class DbPickleDict(DbDict): return self.deserialize(super().__getitem__(key)) -def _get_db_path(db_path): +def _get_db_path(db_path: Union[Path, str], use_temp: bool) -> str: """Get resolved path for database file""" - # Allow paths with user directories (~/*), and add file extension if not specified + # Save to a temp directory, if specified + if use_temp and not isabs(db_path): + db_path = join(gettempdir(), db_path) + + # Expand relative and user paths (~/*), and add file extension if not specified db_path = abspath(expanduser(str(db_path))) if '.' not in basename(db_path): db_path += '.sqlite' + # Make sure parent dirs exist makedirs(dirname(db_path), exist_ok=True) return db_path diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py index 7c329a4..b65ea8c 100644 --- a/tests/integration/test_filesystem.py +++ b/tests/integration/test_filesystem.py @@ -1,5 +1,6 @@ from os.path import isfile from shutil import rmtree +from tempfile import gettempdir from requests_cache.backends import FileCache, FileDict from tests.integration.base_cache_test import BaseCacheTest @@ -19,6 +20,12 @@ class TestFileDict(BaseStorageTest): cache.clear() return cache + def test_use_temp(self): + relative_path = self.storage_class(CACHE_NAME).cache_dir + temp_path = self.storage_class(CACHE_NAME, use_temp=True).cache_dir + assert not relative_path.startswith(gettempdir()) + assert temp_path.startswith(gettempdir()) + def test_paths(self): cache = self.storage_class(CACHE_NAME) for i in range(self.num_instances): diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py index fd463ac..6efba0b 100644 --- a/tests/integration/test_sqlite.py +++ b/tests/integration/test_sqlite.py @@ -1,4 +1,5 @@ import os +from tempfile import gettempdir from threading import Thread from unittest.mock import patch @@ -8,6 +9,8 @@ from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest class SQLiteTestCase(BaseStorageTest): + init_kwargs = {'use_temp': True} + @classmethod def teardown_class(cls): try: @@ -15,6 +18,12 @@ class SQLiteTestCase(BaseStorageTest): except Exception: pass + def test_use_temp(self): + relative_path = self.storage_class(CACHE_NAME).db_path + temp_path = self.storage_class(CACHE_NAME, use_temp=True).db_path + assert not relative_path.startswith(gettempdir()) + assert temp_path.startswith(gettempdir()) + def test_bulk_commit(self): cache = self.init_cache() with cache.bulk_commit(): |