summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2023-01-01 12:40:47 -0600
committerJordan Cook <jordan.cook.git@proton.me>2023-01-13 14:22:31 -0600
commit9c4f76744a764f8898e5735d838f929c277821c0 (patch)
tree911e03121980c25eafdd8dbd95a82a2308e85514
parent87ccbd2a03297121aee5f86d007f929f49ca7a1a (diff)
downloadrequests-cache-9c4f76744a764f8898e5735d838f929c277821c0.tar.gz
Set default serializers for each backend using param defaults instead of 'default_serializer' class attributes
-rw-r--r--requests_cache/backends/base.py34
-rw-r--r--requests_cache/backends/dynamodb.py13
-rw-r--r--requests_cache/backends/filesystem.py11
-rw-r--r--requests_cache/backends/gridfs.py27
-rw-r--r--requests_cache/backends/mongodb.py13
-rw-r--r--requests_cache/backends/redis.py20
-rw-r--r--requests_cache/backends/sqlite.py24
-rw-r--r--requests_cache/serializers/__init__.py23
-rw-r--r--tests/integration/base_cache_test.py1
-rw-r--r--tests/integration/base_storage_test.py9
-rw-r--r--tests/integration/test_dynamodb.py8
-rw-r--r--tests/integration/test_filesystem.py1
-rw-r--r--tests/integration/test_mongodb.py1
-rw-r--r--tests/integration/test_redis.py1
14 files changed, 102 insertions, 84 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 40651cd..6193483 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -19,7 +19,7 @@ from requests import Request, Response
from ..cache_keys import create_key, redact_response
from ..models import AnyRequest, CachedResponse
from ..policy import DEFAULT_CACHE_NAME, CacheSettings, ExpirationTime
-from ..serializers import SERIALIZERS, SerializerPipeline, SerializerType, pickle_serializer
+from ..serializers import SerializerType, init_serializer
# Specific exceptions that may be raised during deserialization
DESERIALIZE_ERRORS = (AttributeError, ImportError, PickleError, TypeError, ValueError)
@@ -350,38 +350,16 @@ class BaseStorage(MutableMapping[KT, VT], ABC):
for details.
Args:
- serializer: Custom serializer that provides ``loads`` and ``dumps`` methods
- no_serializer: Explicitly disable serialization, and write values as-is; this is to avoid
- ambiguity with ``serializer=None``
- decode_content: Decode JSON or text response body into a human-readable format
+ serializer: Custom serializer that provides ``loads`` and ``dumps`` methods.
+ If not provided, values will be written as-is.
+ decode_content: Decode response body JSON or text into a human-readable format
kwargs: Additional backend-specific keyword arguments
"""
- # Default serializer to use for responses, if one isn't specified; may be overridden by subclass
- default_serializer: SerializerType = pickle_serializer
-
def __init__(
- self,
- serializer: Optional[SerializerType] = None,
- no_serializer: bool = False,
- decode_content: bool = False,
- **kwargs,
+ self, serializer: Optional[SerializerType] = None, decode_content: bool = False, **kwargs
):
- self.serializer: Optional[SerializerPipeline] = None
- if not no_serializer:
- self._init_serializer(serializer or self.default_serializer, decode_content)
-
- def _init_serializer(self, serializer: SerializerType, decode_content: bool):
- # Look up a serializer by name, if needed
- if isinstance(serializer, str):
- serializer = SERIALIZERS[serializer]
-
- # Wrap in a SerializerPipeline, if needed
- if not isinstance(serializer, SerializerPipeline):
- serializer = SerializerPipeline([serializer], name=str(serializer))
- serializer.set_decode_content(decode_content)
-
- self.serializer = serializer
+ self.serializer = init_serializer(serializer, decode_content)
logger.debug(f'Initialized {type(self).__name__} with serializer: {self.serializer}')
def bulk_delete(self, keys: Iterable[KT]):
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index 875c1f7..41021c3 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -12,7 +12,7 @@ from boto3.resources.base import ServiceResource
from botocore.exceptions import ClientError
from .._utils import get_valid_kwargs
-from ..serializers import dynamodb_document_serializer
+from ..serializers import SerializerType, dynamodb_document_serializer
from . import BaseCache, BaseStorage
@@ -35,23 +35,25 @@ class DynamoDbCache(BaseCache):
ttl: bool = True,
connection: Optional[ServiceResource] = None,
decode_content: bool = True,
+ serializer: Optional[SerializerType] = None,
**kwargs,
):
super().__init__(cache_name=table_name, **kwargs)
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
self.responses = DynamoDbDict(
table_name,
namespace='responses',
ttl=ttl,
connection=connection,
decode_content=decode_content,
- **kwargs,
+ **skwargs,
)
self.redirects = DynamoDbDict(
table_name,
namespace='redirects',
ttl=False,
connection=self.responses.connection,
- no_serializer=True,
+ serialzier=None,
**kwargs,
)
@@ -68,17 +70,16 @@ class DynamoDbDict(BaseStorage):
kwargs: Additional keyword arguments for :py:meth:`~boto3.session.Session.resource`
"""
- default_serializer = dynamodb_document_serializer
-
def __init__(
self,
table_name: str,
namespace: str,
ttl: bool = True,
connection: Optional[ServiceResource] = None,
+ serializer: Optional[SerializerType] = dynamodb_document_serializer,
**kwargs,
):
- super().__init__(**kwargs)
+ super().__init__(serializer=serializer, **kwargs)
connection_kwargs = get_valid_kwargs(
boto3.Session.__init__, kwargs, extras=['endpoint_url']
)
diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py
index b68e2ee..ccda6c8 100644
--- a/requests_cache/backends/filesystem.py
+++ b/requests_cache/backends/filesystem.py
@@ -12,7 +12,7 @@ from shutil import rmtree
from threading import RLock
from typing import Iterator, Optional
-from ..serializers import SERIALIZERS, json_serializer
+from ..serializers import SERIALIZERS, SerializerType, json_serializer
from . import BaseCache, BaseStorage
from .sqlite import AnyPath, SQLiteDict, get_cache_path
@@ -35,11 +35,13 @@ class FileCache(BaseCache):
cache_name: AnyPath = 'http_cache',
use_temp: bool = False,
decode_content: bool = True,
+ serializer: Optional[SerializerType] = None,
**kwargs,
):
super().__init__(cache_name=str(cache_name), **kwargs)
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
self.responses: FileDict = FileDict(
- cache_name, use_temp=use_temp, decode_content=decode_content, **kwargs
+ cache_name, use_temp=use_temp, decode_content=decode_content, **skwargs
)
self.redirects: SQLiteDict = SQLiteDict(
self.cache_dir / 'redirects.sqlite', 'redirects', no_serializer=True, **kwargs
@@ -68,17 +70,16 @@ class FileCache(BaseCache):
class FileDict(BaseStorage):
"""A dictionary-like interface to files on the local filesystem"""
- default_serializer = json_serializer
-
def __init__(
self,
cache_name: AnyPath,
use_temp: bool = False,
use_cache_dir: bool = False,
extension: Optional[str] = None,
+ serializer: Optional[SerializerType] = json_serializer,
**kwargs,
):
- super().__init__(**kwargs)
+ super().__init__(serializer=serializer, **kwargs)
self.cache_dir = get_cache_path(cache_name, use_cache_dir=use_cache_dir, use_temp=use_temp)
self.extension = _get_extension(extension, self.serializer)
self.is_binary = getattr(self.serializer, 'is_binary', False)
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index 9f04a8a..aadb7e5 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -6,12 +6,14 @@
"""
from logging import getLogger
from threading import RLock
+from typing import Optional
from gridfs import GridFS
from gridfs.errors import CorruptGridFile, FileExists
from pymongo import MongoClient
from .._utils import get_valid_kwargs
+from ..serializers import SerializerType, pickle_serializer
from .base import BaseCache, BaseStorage
from .mongodb import MongoDict
@@ -27,14 +29,22 @@ class GridFSCache(BaseCache):
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- def __init__(self, db_name: str, decode_content: bool = False, **kwargs):
+ def __init__(
+ self,
+ db_name: str,
+ decode_content: bool = False,
+ serializer: Optional[SerializerType] = None,
+ **kwargs
+ ):
super().__init__(cache_name=db_name, **kwargs)
- self.responses = GridFSDict(db_name, decode_content=decode_content, **kwargs)
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
+
+ self.responses = GridFSDict(db_name, decode_content=decode_content, **skwargs)
self.redirects = MongoDict(
db_name,
collection_name='redirects',
connection=self.responses.connection,
- no_serializer=True,
+ serializer=None,
**kwargs
)
@@ -53,8 +63,15 @@ class GridFSDict(BaseStorage):
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- def __init__(self, db_name, collection_name=None, connection=None, **kwargs):
- super().__init__(**kwargs)
+ def __init__(
+ self,
+ db_name,
+ collection_name=None,
+ connection=None,
+ serializer: Optional[SerializerType] = pickle_serializer,
+ **kwargs
+ ):
+ super().__init__(serializer=serializer, **kwargs)
connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs)
self.connection = connection or MongoClient(**connection_kwargs)
self.db = self.connection[db_name]
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index 90f8feb..07671e7 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -13,7 +13,7 @@ from pymongo.errors import OperationFailure
from .._utils import get_valid_kwargs
from ..policy.expiration import NEVER_EXPIRE, get_expiration_seconds
-from ..serializers import bson_document_serializer
+from ..serializers import SerializerType, bson_document_serializer
from . import BaseCache, BaseStorage
logger = getLogger(__name__)
@@ -34,21 +34,23 @@ class MongoCache(BaseCache):
db_name: str = 'http_cache',
connection: MongoClient = None,
decode_content: bool = True,
+ serializer: Optional[SerializerType] = None,
**kwargs,
):
super().__init__(cache_name=db_name, **kwargs)
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
self.responses: MongoDict = MongoDict(
db_name,
collection_name='responses',
connection=connection,
decode_content=decode_content,
- **kwargs,
+ **skwargs,
)
self.redirects: MongoDict = MongoDict(
db_name,
collection_name='redirects',
connection=self.responses.connection,
- no_serializer=True,
+ serialzier=None,
**kwargs,
)
@@ -77,16 +79,15 @@ class MongoDict(BaseStorage):
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- default_serializer = bson_document_serializer
-
def __init__(
self,
db_name: str,
collection_name: str = 'http_cache',
connection: Optional[MongoClient] = None,
+ serializer: Optional[SerializerType] = bson_document_serializer,
**kwargs,
):
- super().__init__(**kwargs)
+ super().__init__(serializer=serializer, **kwargs)
connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs)
self.connection = connection or MongoClient(**connection_kwargs)
self.collection = self.connection[db_name][collection_name]
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index 95d1ed8..59b2920 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -11,7 +11,7 @@ from redis import Redis, StrictRedis
from .._utils import get_valid_kwargs
from ..cache_keys import decode, encode
-from ..serializers import utf8_encoder
+from ..serializers import SerializerType, pickle_serializer, utf8_encoder
from . import BaseCache, BaseStorage
DEFAULT_TTL_OFFSET = 3600
@@ -33,21 +33,18 @@ class RedisCache(BaseCache):
self,
namespace='http_cache',
connection: Optional[Redis] = None,
+ serializer: Optional[SerializerType] = None,
ttl: bool = True,
ttl_offset: int = DEFAULT_TTL_OFFSET,
**kwargs,
):
super().__init__(cache_name=namespace, **kwargs)
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
self.responses = RedisDict(
- namespace, connection=connection, ttl=ttl, ttl_offset=ttl_offset, **kwargs
+ namespace, connection=connection, ttl=ttl, ttl_offset=ttl_offset, **skwargs
)
- kwargs.pop('serializer', None)
self.redirects = RedisHashDict(
- namespace,
- 'redirects',
- connection=self.responses.connection,
- serializer=utf8_encoder, # Only needs encoding to/decoding from bytes
- **kwargs,
+ namespace, collection_name='redirects', connection=self.responses.connection, **kwargs
)
@@ -64,12 +61,12 @@ class RedisDict(BaseStorage):
namespace: str,
collection_name: Optional[str] = None,
connection=None,
+ serializer: Optional[SerializerType] = pickle_serializer,
ttl: bool = True,
ttl_offset: int = DEFAULT_TTL_OFFSET,
**kwargs,
):
-
- super().__init__(**kwargs)
+ super().__init__(serializer=serializer, **kwargs)
connection_kwargs = get_valid_kwargs(Redis.__init__, kwargs)
self.connection = connection or StrictRedis(**connection_kwargs)
self.namespace = namespace
@@ -149,9 +146,10 @@ class RedisHashDict(BaseStorage):
namespace: str = 'http_cache',
collection_name: Optional[str] = None,
connection=None,
+ serializer: Optional[SerializerType] = utf8_encoder,
**kwargs,
):
- super().__init__(**kwargs)
+ super().__init__(serializer=serializer, **kwargs)
connection_kwargs = get_valid_kwargs(Redis, kwargs)
self.connection = connection or StrictRedis(**connection_kwargs)
self._hash_key = f'{namespace}-{collection_name}'
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 02d9b70..0a892ec 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -17,12 +17,12 @@ from typing import Collection, Iterator, List, Optional, Tuple, Type, Union
from platformdirs import user_cache_dir
-from requests_cache.backends.base import VT
-from requests_cache.models.response import CachedResponse
-from requests_cache.policy import ExpirationTime
-
from .._utils import chunkify, get_valid_kwargs
+from ..models.response import CachedResponse
+from ..policy import ExpirationTime
+from ..serializers import SerializerType, pickle_serializer
from . import BaseCache, BaseStorage
+from .base import VT
MEMORY_URI = 'file::memory:?cache=shared'
SQLITE_MAX_VARIABLE_NUMBER = 999
@@ -45,11 +45,18 @@ class SQLiteCache(BaseCache):
kwargs: Additional keyword arguments for :py:func:`sqlite3.connect`
"""
- def __init__(self, db_path: AnyPath = 'http_cache', **kwargs):
+ def __init__(
+ self,
+ db_path: AnyPath = 'http_cache',
+ serializer: Optional[SerializerType] = None,
+ **kwargs,
+ ):
super().__init__(cache_name=str(db_path), **kwargs)
- self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **kwargs)
+ # Only override serializer if a non-None value is specified
+ skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
+ self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **skwargs)
self.redirects: SQLiteDict = SQLiteDict(
- db_path, table_name='redirects', no_serializer=True, **kwargs
+ db_path, table_name='redirects', serializer=None, **kwargs
)
@property
@@ -165,13 +172,14 @@ class SQLiteDict(BaseStorage):
db_path,
table_name='http_cache',
fast_save=False,
+ serializer: Optional[SerializerType] = pickle_serializer,
use_cache_dir: bool = False,
use_memory: bool = False,
use_temp: bool = False,
wal: bool = False,
**kwargs,
):
- super().__init__(**kwargs)
+ super().__init__(serializer=serializer, **kwargs)
self._can_commit = True
self._local_context = threading.local()
self._lock = threading.RLock()
diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py
index 085dfc2..c477787 100644
--- a/requests_cache/serializers/__init__.py
+++ b/requests_cache/serializers/__init__.py
@@ -19,7 +19,7 @@ For any optional libraries that aren't installed, the corresponding serializer w
class that raises an ``ImportError`` at initialization time instead of at import time.
"""
# flake8: noqa: F401
-from typing import Union
+from typing import Optional, Union
from .cattrs import CattrStage
from .pipeline import SerializerPipeline, Stage
@@ -39,7 +39,9 @@ __all__ = [
'SERIALIZERS',
'CattrStage',
'SerializerPipeline',
+ 'SerializerType',
'Stage',
+ 'init_serializer',
'bson_serializer',
'bson_document_serializer',
'dynamodb_document_serializer',
@@ -59,3 +61,22 @@ SERIALIZERS = {
}
SerializerType = Union[str, SerializerPipeline, Stage]
+
+
+def init_serializer(
+ serializer: Optional[SerializerType], decode_content: bool
+) -> Optional[SerializerPipeline]:
+ """Intitialze a serializer by name or instance"""
+ if not serializer:
+ return None
+
+ # Look up a serializer by name, if needed
+ if isinstance(serializer, str):
+ serializer = SERIALIZERS[serializer]
+
+ # Wrap in a SerializerPipeline, if needed
+ if not isinstance(serializer, SerializerPipeline):
+ serializer = SerializerPipeline([serializer], name=str(serializer))
+ serializer.set_decode_content(decode_content)
+
+ return serializer
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 385e8fe..cdc3a98 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -48,7 +48,6 @@ class BaseCacheTest:
def init_session(self, cache_name=CACHE_NAME, clear=True, **kwargs) -> CachedSession:
kwargs = {**self.init_kwargs, **kwargs}
kwargs.setdefault('allowable_methods', ALL_METHODS)
- kwargs.setdefault('serializer', 'pickle')
backend = self.backend_class(cache_name, **kwargs)
if clear:
backend.clear()
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index 4070492..a96ffd9 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -19,7 +19,6 @@ class BaseStorageTest:
def init_cache(self, cache_name=CACHE_NAME, index=0, clear=True, **kwargs):
kwargs = {**self.init_kwargs, **kwargs}
- kwargs.setdefault('serializer', 'pickle')
cache = self.storage_class(cache_name, f'table_{index}', **kwargs)
if clear:
cache.clear()
@@ -136,8 +135,8 @@ class BaseStorageTest:
cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None))
for i in range(5):
- cache_1[i] = i
- cache_2[i] = i
+ cache_1[i] = f'value_{i}'
+ cache_2[i] = f'value_{i}'
assert len(cache_1) == len(cache_2) == 5
cache_1.clear()
@@ -147,8 +146,8 @@ class BaseStorageTest:
def test_same_settings(self):
cache_1 = self.init_cache()
cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None))
- cache_1['key_1'] = 1
- cache_2['key_2'] = 2
+ cache_1['key_1'] = 'value_1'
+ cache_2['key_2'] = 'value_2'
assert cache_1 == cache_2
def test_str(self):
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index 561c979..42e22eb 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -15,10 +15,6 @@ AWS_OPTIONS = {
'aws_access_key_id': 'placeholder',
'aws_secret_access_key': 'placeholder',
}
-DYNAMODB_OPTIONS = {
- **AWS_OPTIONS,
- 'serializer': None, # Use class default serializer
-}
@pytest.fixture(scope='module', autouse=True)
@@ -33,7 +29,7 @@ def ensure_connection():
class TestDynamoDbDict(BaseStorageTest):
storage_class = DynamoDbDict
- init_kwargs = DYNAMODB_OPTIONS
+ init_kwargs = AWS_OPTIONS
@patch('requests_cache.backends.dynamodb.boto3.resource')
def test_connection_kwargs(self, mock_resource):
@@ -87,4 +83,4 @@ class TestDynamoDbDict(BaseStorageTest):
class TestDynamoDbCache(BaseCacheTest):
backend_class = DynamoDbCache
- init_kwargs = DYNAMODB_OPTIONS
+ init_kwargs = AWS_OPTIONS
diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py
index dfe9414..99625c4 100644
--- a/tests/integration/test_filesystem.py
+++ b/tests/integration/test_filesystem.py
@@ -39,7 +39,6 @@ class TestFileDict(BaseStorageTest):
rmtree(CACHE_NAME, ignore_errors=True)
def init_cache(self, index=0, clear=True, **kwargs):
- kwargs.setdefault('serializer', 'pickle')
cache = FileDict(f'{CACHE_NAME}_{index}', use_temp=True, **kwargs)
if clear:
cache.clear()
diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py
index 4a3a114..1dfae87 100644
--- a/tests/integration/test_mongodb.py
+++ b/tests/integration/test_mongodb.py
@@ -51,7 +51,6 @@ class TestMongoDict(BaseStorageTest):
class TestMongoCache(BaseCacheTest):
backend_class = MongoCache
- init_kwargs = {'serializer': None} # Use class default serializer instead of pickle
def test_ttl(self):
session = self.init_session()
diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py
index 4c74d68..2a34899 100644
--- a/tests/integration/test_redis.py
+++ b/tests/integration/test_redis.py
@@ -33,6 +33,7 @@ class TestRedisHashDict(TestRedisDict):
storage_class = RedisHashDict
num_instances: int = 10 # Supports multiple instances, since this stores items under hash keys
picklable = True
+ init_kwargs = {'serializer': 'pickle'}
class TestRedisCache(BaseCacheTest):