diff options
Diffstat (limited to 'requests_cache/backends/gridfs.py')
-rw-r--r-- | requests_cache/backends/gridfs.py | 27 |
1 files changed, 22 insertions, 5 deletions
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index 9f04a8a..aadb7e5 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -6,12 +6,14 @@ """ from logging import getLogger from threading import RLock +from typing import Optional from gridfs import GridFS from gridfs.errors import CorruptGridFile, FileExists from pymongo import MongoClient from .._utils import get_valid_kwargs +from ..serializers import SerializerType, pickle_serializer from .base import BaseCache, BaseStorage from .mongodb import MongoDict @@ -27,14 +29,22 @@ class GridFSCache(BaseCache): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - def __init__(self, db_name: str, decode_content: bool = False, **kwargs): + def __init__( + self, + db_name: str, + decode_content: bool = False, + serializer: Optional[SerializerType] = None, + **kwargs + ): super().__init__(cache_name=db_name, **kwargs) - self.responses = GridFSDict(db_name, decode_content=decode_content, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs + + self.responses = GridFSDict(db_name, decode_content=decode_content, **skwargs) self.redirects = MongoDict( db_name, collection_name='redirects', connection=self.responses.connection, - no_serializer=True, + serializer=None, **kwargs ) @@ -53,8 +63,15 @@ class GridFSDict(BaseStorage): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - def __init__(self, db_name, collection_name=None, connection=None, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + db_name, + collection_name=None, + connection=None, + serializer: Optional[SerializerType] = pickle_serializer, + **kwargs + ): + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs) self.connection = connection or MongoClient(**connection_kwargs) self.db = self.connection[db_name] |