diff options
Diffstat (limited to 'requests_cache/backends/sqlite.py')
-rw-r--r-- | requests_cache/backends/sqlite.py | 24 |
1 files changed, 16 insertions, 8 deletions
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 02d9b70..0a892ec 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,12 +17,12 @@ from typing import Collection, Iterator, List, Optional, Tuple, Type, Union from platformdirs import user_cache_dir -from requests_cache.backends.base import VT -from requests_cache.models.response import CachedResponse -from requests_cache.policy import ExpirationTime - from .._utils import chunkify, get_valid_kwargs +from ..models.response import CachedResponse +from ..policy import ExpirationTime +from ..serializers import SerializerType, pickle_serializer from . import BaseCache, BaseStorage +from .base import VT MEMORY_URI = 'file::memory:?cache=shared' SQLITE_MAX_VARIABLE_NUMBER = 999 @@ -45,11 +45,18 @@ class SQLiteCache(BaseCache): kwargs: Additional keyword arguments for :py:func:`sqlite3.connect` """ - def __init__(self, db_path: AnyPath = 'http_cache', **kwargs): + def __init__( + self, + db_path: AnyPath = 'http_cache', + serializer: Optional[SerializerType] = None, + **kwargs, + ): super().__init__(cache_name=str(db_path), **kwargs) - self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **kwargs) + # Only override serializer if a non-None value is specified + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs + self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **skwargs) self.redirects: SQLiteDict = SQLiteDict( - db_path, table_name='redirects', no_serializer=True, **kwargs + db_path, table_name='redirects', serializer=None, **kwargs ) @property @@ -165,13 +172,14 @@ class SQLiteDict(BaseStorage): db_path, table_name='http_cache', fast_save=False, + serializer: Optional[SerializerType] = pickle_serializer, use_cache_dir: bool = False, use_memory: bool = False, use_temp: bool = False, wal: bool = False, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) self._can_commit = True self._local_context = threading.local() self._lock = threading.RLock() |