diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2021-08-29 21:43:16 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-08-29 22:02:13 -0500 |
commit | 3e3ae7d2279eeee44b3a1cb1e989bc12b3a28b70 (patch) | |
tree | 214293cf8c428560a4dd83040279144af8506a70 /requests_cache/backends/sqlite.py | |
parent | f707ddeeee5c29881c0de96092a01e78a905a401 (diff) | |
download | requests-cache-3e3ae7d2279eeee44b3a1cb1e989bc12b3a28b70.tar.gz |
Use pathlib.Path objects for all file paths in Filesystem and SQLite backends
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r-- | requests_cache/backends/sqlite.py | 39 |
1 files changed, 18 insertions, 21 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 2ae5e40..b6cf157 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -69,8 +69,8 @@ import sqlite3 import threading from contextlib import contextmanager from logging import getLogger -from os import makedirs, unlink -from os.path import abspath, basename, dirname, expanduser, isabs, isfile, join +from os import unlink +from os.path import isfile from pathlib import Path from tempfile import gettempdir from typing import Collection, Iterable, Iterator, List, Tuple, Type, Union @@ -81,6 +81,7 @@ from . import BaseCache, BaseStorage, get_valid_kwargs MEMORY_URI = 'file::memory:?cache=shared' SQLITE_MAX_VARIABLE_NUMBER = 999 +AnyPath = Union[Path, str] logger = getLogger(__name__) @@ -98,12 +99,12 @@ class SQLiteCache(BaseCache): kwargs: Additional keyword arguments for :py:func:`sqlite3.connect` """ - def __init__(self, db_path: Union[Path, str] = 'http_cache', **kwargs): + def __init__(self, db_path: AnyPath = 'http_cache', **kwargs): super().__init__(**kwargs) self.responses = SQLitePickleDict(db_path, table_name='responses', **kwargs) self.redirects = SQLiteDict(db_path, table_name='redirects', **kwargs) - def db_path(self) -> str: + def db_path(self) -> AnyPath: return self.responses.db_path def bulk_delete(self, keys): @@ -144,9 +145,7 @@ class SQLiteDict(BaseStorage): ): super().__init__(**kwargs) self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs) - self.db_path = _get_sqlite_cache_path( - db_path, use_cache_dir=use_cache_dir, use_temp=use_temp, use_memory=use_memory - ) + self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory) self.fast_save = fast_save self.table_name = table_name @@ -286,8 +285,8 @@ def _format_sequence(values: Collection) -> Tuple[str, List]: def _get_sqlite_cache_path( - db_path: Union[Path, str], use_cache_dir: bool, use_temp: bool, use_memory: bool = False -) -> str: + db_path: AnyPath, use_cache_dir: bool, use_temp: bool, use_memory: bool = False +) -> AnyPath: """Get a resolved path for a SQLite database file (or memory URI(""" # Use an in-memory database, if specified db_path = str(db_path) @@ -297,27 +296,25 @@ def _get_sqlite_cache_path( return db_path # Add file extension if not specified - if '.' not in basename(db_path): + if not Path(db_path).suffix: db_path += '.sqlite' return get_cache_path(db_path, use_cache_dir, use_temp) -def get_cache_path( - db_path: Union[Path, str], use_cache_dir: bool = False, use_temp: bool = False -) -> str: +def get_cache_path(db_path: AnyPath, use_cache_dir: bool = False, use_temp: bool = False) -> Path: """Get a resolved cache path""" - db_path = str(db_path) + db_path = Path(db_path) # Save to platform-specific temp or user cache directory, if specified - if use_cache_dir and not isabs(db_path): - db_path = join(user_cache_dir(), db_path) - elif use_temp and not isabs(db_path): - db_path = join(gettempdir(), db_path) + if use_cache_dir and not db_path.is_absolute(): + db_path = Path(user_cache_dir()) / db_path + elif use_temp and not db_path.is_absolute(): + db_path = Path(gettempdir()) / db_path # Expand relative and user paths (~/*), and make sure parent dirs exist - db_path = abspath(expanduser(db_path)) - makedirs(dirname(db_path), exist_ok=True) - return str(db_path) + db_path = db_path.absolute().expanduser() + db_path.parent.mkdir(parents=True, exist_ok=True) + return db_path def sqlite_template( |