summaryrefslogtreecommitdiff
path: root/requests_cache/backends/sqlite.py
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-08-29 21:43:16 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-08-29 22:02:13 -0500
commit3e3ae7d2279eeee44b3a1cb1e989bc12b3a28b70 (patch)
tree214293cf8c428560a4dd83040279144af8506a70 /requests_cache/backends/sqlite.py
parentf707ddeeee5c29881c0de96092a01e78a905a401 (diff)
downloadrequests-cache-3e3ae7d2279eeee44b3a1cb1e989bc12b3a28b70.tar.gz
Use pathlib.Path objects for all file paths in Filesystem and SQLite backends
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r--requests_cache/backends/sqlite.py39
1 files changed, 18 insertions, 21 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 2ae5e40..b6cf157 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -69,8 +69,8 @@ import sqlite3
import threading
from contextlib import contextmanager
from logging import getLogger
-from os import makedirs, unlink
-from os.path import abspath, basename, dirname, expanduser, isabs, isfile, join
+from os import unlink
+from os.path import isfile
from pathlib import Path
from tempfile import gettempdir
from typing import Collection, Iterable, Iterator, List, Tuple, Type, Union
@@ -81,6 +81,7 @@ from . import BaseCache, BaseStorage, get_valid_kwargs
MEMORY_URI = 'file::memory:?cache=shared'
SQLITE_MAX_VARIABLE_NUMBER = 999
+AnyPath = Union[Path, str]
logger = getLogger(__name__)
@@ -98,12 +99,12 @@ class SQLiteCache(BaseCache):
kwargs: Additional keyword arguments for :py:func:`sqlite3.connect`
"""
- def __init__(self, db_path: Union[Path, str] = 'http_cache', **kwargs):
+ def __init__(self, db_path: AnyPath = 'http_cache', **kwargs):
super().__init__(**kwargs)
self.responses = SQLitePickleDict(db_path, table_name='responses', **kwargs)
self.redirects = SQLiteDict(db_path, table_name='redirects', **kwargs)
- def db_path(self) -> str:
+ def db_path(self) -> AnyPath:
return self.responses.db_path
def bulk_delete(self, keys):
@@ -144,9 +145,7 @@ class SQLiteDict(BaseStorage):
):
super().__init__(**kwargs)
self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
- self.db_path = _get_sqlite_cache_path(
- db_path, use_cache_dir=use_cache_dir, use_temp=use_temp, use_memory=use_memory
- )
+ self.db_path = _get_sqlite_cache_path(db_path, use_cache_dir, use_temp, use_memory)
self.fast_save = fast_save
self.table_name = table_name
@@ -286,8 +285,8 @@ def _format_sequence(values: Collection) -> Tuple[str, List]:
def _get_sqlite_cache_path(
- db_path: Union[Path, str], use_cache_dir: bool, use_temp: bool, use_memory: bool = False
-) -> str:
+ db_path: AnyPath, use_cache_dir: bool, use_temp: bool, use_memory: bool = False
+) -> AnyPath:
"""Get a resolved path for a SQLite database file (or memory URI("""
# Use an in-memory database, if specified
db_path = str(db_path)
@@ -297,27 +296,25 @@ def _get_sqlite_cache_path(
return db_path
# Add file extension if not specified
- if '.' not in basename(db_path):
+ if not Path(db_path).suffix:
db_path += '.sqlite'
return get_cache_path(db_path, use_cache_dir, use_temp)
-def get_cache_path(
- db_path: Union[Path, str], use_cache_dir: bool = False, use_temp: bool = False
-) -> str:
+def get_cache_path(db_path: AnyPath, use_cache_dir: bool = False, use_temp: bool = False) -> Path:
"""Get a resolved cache path"""
- db_path = str(db_path)
+ db_path = Path(db_path)
# Save to platform-specific temp or user cache directory, if specified
- if use_cache_dir and not isabs(db_path):
- db_path = join(user_cache_dir(), db_path)
- elif use_temp and not isabs(db_path):
- db_path = join(gettempdir(), db_path)
+ if use_cache_dir and not db_path.is_absolute():
+ db_path = Path(user_cache_dir()) / db_path
+ elif use_temp and not db_path.is_absolute():
+ db_path = Path(gettempdir()) / db_path
# Expand relative and user paths (~/*), and make sure parent dirs exist
- db_path = abspath(expanduser(db_path))
- makedirs(dirname(db_path), exist_ok=True)
- return str(db_path)
+ db_path = db_path.absolute().expanduser()
+ db_path.parent.mkdir(parents=True, exist_ok=True)
+ return db_path
def sqlite_template(