summaryrefslogtreecommitdiff
path: root/requests_cache/backends/gridfs.py
diff options
context:
space:
mode:
Diffstat (limited to 'requests_cache/backends/gridfs.py')
-rw-r--r--requests_cache/backends/gridfs.py27
1 files changed, 22 insertions, 5 deletions
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index 9f04a8a..aadb7e5 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -6,12 +6,14 @@
"""
from logging import getLogger
from threading import RLock
+from typing import Optional
from gridfs import GridFS
from gridfs.errors import CorruptGridFile, FileExists
from pymongo import MongoClient
from .._utils import get_valid_kwargs
+from ..serializers import SerializerType, pickle_serializer
from .base import BaseCache, BaseStorage
from .mongodb import MongoDict
@@ -27,14 +29,22 @@ class GridFSCache(BaseCache):
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- def __init__(self, db_name: str, decode_content: bool = False, **kwargs):
+ def __init__(
+ self,
+ db_name: str,
+ decode_content: bool = False,
+ serializer: Optional[SerializerType] = None,
+ **kwargs
+ ):
super().__init__(cache_name=db_name, **kwargs)
- self.responses = GridFSDict(db_name, decode_content=decode_content, **kwargs)
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
+
+ self.responses = GridFSDict(db_name, decode_content=decode_content, **skwargs)
self.redirects = MongoDict(
db_name,
collection_name='redirects',
connection=self.responses.connection,
- no_serializer=True,
+ serializer=None,
**kwargs
)
@@ -53,8 +63,15 @@ class GridFSDict(BaseStorage):
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- def __init__(self, db_name, collection_name=None, connection=None, **kwargs):
- super().__init__(**kwargs)
+ def __init__(
+ self,
+ db_name,
+ collection_name=None,
+ connection=None,
+ serializer: Optional[SerializerType] = pickle_serializer,
+ **kwargs
+ ):
+ super().__init__(serializer=serializer, **kwargs)
connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs)
self.connection = connection or MongoClient(**connection_kwargs)
self.db = self.connection[db_name]