diff options
author | Jordan Cook <jordan.cook.git@proton.me> | 2023-01-01 12:40:47 -0600 |
---|---|---|
committer | Jordan Cook <jordan.cook.git@proton.me> | 2023-01-13 14:22:31 -0600 |
commit | 9c4f76744a764f8898e5735d838f929c277821c0 (patch) | |
tree | 911e03121980c25eafdd8dbd95a82a2308e85514 | |
parent | 87ccbd2a03297121aee5f86d007f929f49ca7a1a (diff) | |
download | requests-cache-9c4f76744a764f8898e5735d838f929c277821c0.tar.gz |
Set default serializers for each backend using param defaults instead of 'default_serializer' class attributes
-rw-r--r-- | requests_cache/backends/base.py | 34 | ||||
-rw-r--r-- | requests_cache/backends/dynamodb.py | 13 | ||||
-rw-r--r-- | requests_cache/backends/filesystem.py | 11 | ||||
-rw-r--r-- | requests_cache/backends/gridfs.py | 27 | ||||
-rw-r--r-- | requests_cache/backends/mongodb.py | 13 | ||||
-rw-r--r-- | requests_cache/backends/redis.py | 20 | ||||
-rw-r--r-- | requests_cache/backends/sqlite.py | 24 | ||||
-rw-r--r-- | requests_cache/serializers/__init__.py | 23 | ||||
-rw-r--r-- | tests/integration/base_cache_test.py | 1 | ||||
-rw-r--r-- | tests/integration/base_storage_test.py | 9 | ||||
-rw-r--r-- | tests/integration/test_dynamodb.py | 8 | ||||
-rw-r--r-- | tests/integration/test_filesystem.py | 1 | ||||
-rw-r--r-- | tests/integration/test_mongodb.py | 1 | ||||
-rw-r--r-- | tests/integration/test_redis.py | 1 |
14 files changed, 102 insertions, 84 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 40651cd..6193483 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -19,7 +19,7 @@ from requests import Request, Response from ..cache_keys import create_key, redact_response from ..models import AnyRequest, CachedResponse from ..policy import DEFAULT_CACHE_NAME, CacheSettings, ExpirationTime -from ..serializers import SERIALIZERS, SerializerPipeline, SerializerType, pickle_serializer +from ..serializers import SerializerType, init_serializer # Specific exceptions that may be raised during deserialization DESERIALIZE_ERRORS = (AttributeError, ImportError, PickleError, TypeError, ValueError) @@ -350,38 +350,16 @@ class BaseStorage(MutableMapping[KT, VT], ABC): for details. Args: - serializer: Custom serializer that provides ``loads`` and ``dumps`` methods - no_serializer: Explicitly disable serialization, and write values as-is; this is to avoid - ambiguity with ``serializer=None`` - decode_content: Decode JSON or text response body into a human-readable format + serializer: Custom serializer that provides ``loads`` and ``dumps`` methods. + If not provided, values will be written as-is. + decode_content: Decode response body JSON or text into a human-readable format kwargs: Additional backend-specific keyword arguments """ - # Default serializer to use for responses, if one isn't specified; may be overridden by subclass - default_serializer: SerializerType = pickle_serializer - def __init__( - self, - serializer: Optional[SerializerType] = None, - no_serializer: bool = False, - decode_content: bool = False, - **kwargs, + self, serializer: Optional[SerializerType] = None, decode_content: bool = False, **kwargs ): - self.serializer: Optional[SerializerPipeline] = None - if not no_serializer: - self._init_serializer(serializer or self.default_serializer, decode_content) - - def _init_serializer(self, serializer: SerializerType, decode_content: bool): - # Look up a serializer by name, if needed - if isinstance(serializer, str): - serializer = SERIALIZERS[serializer] - - # Wrap in a SerializerPipeline, if needed - if not isinstance(serializer, SerializerPipeline): - serializer = SerializerPipeline([serializer], name=str(serializer)) - serializer.set_decode_content(decode_content) - - self.serializer = serializer + self.serializer = init_serializer(serializer, decode_content) logger.debug(f'Initialized {type(self).__name__} with serializer: {self.serializer}') def bulk_delete(self, keys: Iterable[KT]): diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py index 875c1f7..41021c3 100644 --- a/requests_cache/backends/dynamodb.py +++ b/requests_cache/backends/dynamodb.py @@ -12,7 +12,7 @@ from boto3.resources.base import ServiceResource from botocore.exceptions import ClientError from .._utils import get_valid_kwargs -from ..serializers import dynamodb_document_serializer +from ..serializers import SerializerType, dynamodb_document_serializer from . import BaseCache, BaseStorage @@ -35,23 +35,25 @@ class DynamoDbCache(BaseCache): ttl: bool = True, connection: Optional[ServiceResource] = None, decode_content: bool = True, + serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=table_name, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses = DynamoDbDict( table_name, namespace='responses', ttl=ttl, connection=connection, decode_content=decode_content, - **kwargs, + **skwargs, ) self.redirects = DynamoDbDict( table_name, namespace='redirects', ttl=False, connection=self.responses.connection, - no_serializer=True, + serialzier=None, **kwargs, ) @@ -68,17 +70,16 @@ class DynamoDbDict(BaseStorage): kwargs: Additional keyword arguments for :py:meth:`~boto3.session.Session.resource` """ - default_serializer = dynamodb_document_serializer - def __init__( self, table_name: str, namespace: str, ttl: bool = True, connection: Optional[ServiceResource] = None, + serializer: Optional[SerializerType] = dynamodb_document_serializer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs( boto3.Session.__init__, kwargs, extras=['endpoint_url'] ) diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py index b68e2ee..ccda6c8 100644 --- a/requests_cache/backends/filesystem.py +++ b/requests_cache/backends/filesystem.py @@ -12,7 +12,7 @@ from shutil import rmtree from threading import RLock from typing import Iterator, Optional -from ..serializers import SERIALIZERS, json_serializer +from ..serializers import SERIALIZERS, SerializerType, json_serializer from . import BaseCache, BaseStorage from .sqlite import AnyPath, SQLiteDict, get_cache_path @@ -35,11 +35,13 @@ class FileCache(BaseCache): cache_name: AnyPath = 'http_cache', use_temp: bool = False, decode_content: bool = True, + serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=str(cache_name), **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses: FileDict = FileDict( - cache_name, use_temp=use_temp, decode_content=decode_content, **kwargs + cache_name, use_temp=use_temp, decode_content=decode_content, **skwargs ) self.redirects: SQLiteDict = SQLiteDict( self.cache_dir / 'redirects.sqlite', 'redirects', no_serializer=True, **kwargs @@ -68,17 +70,16 @@ class FileCache(BaseCache): class FileDict(BaseStorage): """A dictionary-like interface to files on the local filesystem""" - default_serializer = json_serializer - def __init__( self, cache_name: AnyPath, use_temp: bool = False, use_cache_dir: bool = False, extension: Optional[str] = None, + serializer: Optional[SerializerType] = json_serializer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) self.cache_dir = get_cache_path(cache_name, use_cache_dir=use_cache_dir, use_temp=use_temp) self.extension = _get_extension(extension, self.serializer) self.is_binary = getattr(self.serializer, 'is_binary', False) diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index 9f04a8a..aadb7e5 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -6,12 +6,14 @@ """ from logging import getLogger from threading import RLock +from typing import Optional from gridfs import GridFS from gridfs.errors import CorruptGridFile, FileExists from pymongo import MongoClient from .._utils import get_valid_kwargs +from ..serializers import SerializerType, pickle_serializer from .base import BaseCache, BaseStorage from .mongodb import MongoDict @@ -27,14 +29,22 @@ class GridFSCache(BaseCache): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - def __init__(self, db_name: str, decode_content: bool = False, **kwargs): + def __init__( + self, + db_name: str, + decode_content: bool = False, + serializer: Optional[SerializerType] = None, + **kwargs + ): super().__init__(cache_name=db_name, **kwargs) - self.responses = GridFSDict(db_name, decode_content=decode_content, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs + + self.responses = GridFSDict(db_name, decode_content=decode_content, **skwargs) self.redirects = MongoDict( db_name, collection_name='redirects', connection=self.responses.connection, - no_serializer=True, + serializer=None, **kwargs ) @@ -53,8 +63,15 @@ class GridFSDict(BaseStorage): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - def __init__(self, db_name, collection_name=None, connection=None, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + db_name, + collection_name=None, + connection=None, + serializer: Optional[SerializerType] = pickle_serializer, + **kwargs + ): + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs) self.connection = connection or MongoClient(**connection_kwargs) self.db = self.connection[db_name] diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py index 90f8feb..07671e7 100644 --- a/requests_cache/backends/mongodb.py +++ b/requests_cache/backends/mongodb.py @@ -13,7 +13,7 @@ from pymongo.errors import OperationFailure from .._utils import get_valid_kwargs from ..policy.expiration import NEVER_EXPIRE, get_expiration_seconds -from ..serializers import bson_document_serializer +from ..serializers import SerializerType, bson_document_serializer from . import BaseCache, BaseStorage logger = getLogger(__name__) @@ -34,21 +34,23 @@ class MongoCache(BaseCache): db_name: str = 'http_cache', connection: MongoClient = None, decode_content: bool = True, + serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=db_name, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses: MongoDict = MongoDict( db_name, collection_name='responses', connection=connection, decode_content=decode_content, - **kwargs, + **skwargs, ) self.redirects: MongoDict = MongoDict( db_name, collection_name='redirects', connection=self.responses.connection, - no_serializer=True, + serialzier=None, **kwargs, ) @@ -77,16 +79,15 @@ class MongoDict(BaseStorage): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - default_serializer = bson_document_serializer - def __init__( self, db_name: str, collection_name: str = 'http_cache', connection: Optional[MongoClient] = None, + serializer: Optional[SerializerType] = bson_document_serializer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs) self.connection = connection or MongoClient(**connection_kwargs) self.collection = self.connection[db_name][collection_name] diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py index 95d1ed8..59b2920 100644 --- a/requests_cache/backends/redis.py +++ b/requests_cache/backends/redis.py @@ -11,7 +11,7 @@ from redis import Redis, StrictRedis from .._utils import get_valid_kwargs from ..cache_keys import decode, encode -from ..serializers import utf8_encoder +from ..serializers import SerializerType, pickle_serializer, utf8_encoder from . import BaseCache, BaseStorage DEFAULT_TTL_OFFSET = 3600 @@ -33,21 +33,18 @@ class RedisCache(BaseCache): self, namespace='http_cache', connection: Optional[Redis] = None, + serializer: Optional[SerializerType] = None, ttl: bool = True, ttl_offset: int = DEFAULT_TTL_OFFSET, **kwargs, ): super().__init__(cache_name=namespace, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses = RedisDict( - namespace, connection=connection, ttl=ttl, ttl_offset=ttl_offset, **kwargs + namespace, connection=connection, ttl=ttl, ttl_offset=ttl_offset, **skwargs ) - kwargs.pop('serializer', None) self.redirects = RedisHashDict( - namespace, - 'redirects', - connection=self.responses.connection, - serializer=utf8_encoder, # Only needs encoding to/decoding from bytes - **kwargs, + namespace, collection_name='redirects', connection=self.responses.connection, **kwargs ) @@ -64,12 +61,12 @@ class RedisDict(BaseStorage): namespace: str, collection_name: Optional[str] = None, connection=None, + serializer: Optional[SerializerType] = pickle_serializer, ttl: bool = True, ttl_offset: int = DEFAULT_TTL_OFFSET, **kwargs, ): - - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(Redis.__init__, kwargs) self.connection = connection or StrictRedis(**connection_kwargs) self.namespace = namespace @@ -149,9 +146,10 @@ class RedisHashDict(BaseStorage): namespace: str = 'http_cache', collection_name: Optional[str] = None, connection=None, + serializer: Optional[SerializerType] = utf8_encoder, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(Redis, kwargs) self.connection = connection or StrictRedis(**connection_kwargs) self._hash_key = f'{namespace}-{collection_name}' diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 02d9b70..0a892ec 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,12 +17,12 @@ from typing import Collection, Iterator, List, Optional, Tuple, Type, Union from platformdirs import user_cache_dir -from requests_cache.backends.base import VT -from requests_cache.models.response import CachedResponse -from requests_cache.policy import ExpirationTime - from .._utils import chunkify, get_valid_kwargs +from ..models.response import CachedResponse +from ..policy import ExpirationTime +from ..serializers import SerializerType, pickle_serializer from . import BaseCache, BaseStorage +from .base import VT MEMORY_URI = 'file::memory:?cache=shared' SQLITE_MAX_VARIABLE_NUMBER = 999 @@ -45,11 +45,18 @@ class SQLiteCache(BaseCache): kwargs: Additional keyword arguments for :py:func:`sqlite3.connect` """ - def __init__(self, db_path: AnyPath = 'http_cache', **kwargs): + def __init__( + self, + db_path: AnyPath = 'http_cache', + serializer: Optional[SerializerType] = None, + **kwargs, + ): super().__init__(cache_name=str(db_path), **kwargs) - self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **kwargs) + # Only override serializer if a non-None value is specified + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs + self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **skwargs) self.redirects: SQLiteDict = SQLiteDict( - db_path, table_name='redirects', no_serializer=True, **kwargs + db_path, table_name='redirects', serializer=None, **kwargs ) @property @@ -165,13 +172,14 @@ class SQLiteDict(BaseStorage): db_path, table_name='http_cache', fast_save=False, + serializer: Optional[SerializerType] = pickle_serializer, use_cache_dir: bool = False, use_memory: bool = False, use_temp: bool = False, wal: bool = False, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) self._can_commit = True self._local_context = threading.local() self._lock = threading.RLock() diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py index 085dfc2..c477787 100644 --- a/requests_cache/serializers/__init__.py +++ b/requests_cache/serializers/__init__.py @@ -19,7 +19,7 @@ For any optional libraries that aren't installed, the corresponding serializer w class that raises an ``ImportError`` at initialization time instead of at import time. """ # flake8: noqa: F401 -from typing import Union +from typing import Optional, Union from .cattrs import CattrStage from .pipeline import SerializerPipeline, Stage @@ -39,7 +39,9 @@ __all__ = [ 'SERIALIZERS', 'CattrStage', 'SerializerPipeline', + 'SerializerType', 'Stage', + 'init_serializer', 'bson_serializer', 'bson_document_serializer', 'dynamodb_document_serializer', @@ -59,3 +61,22 @@ SERIALIZERS = { } SerializerType = Union[str, SerializerPipeline, Stage] + + +def init_serializer( + serializer: Optional[SerializerType], decode_content: bool +) -> Optional[SerializerPipeline]: + """Intitialze a serializer by name or instance""" + if not serializer: + return None + + # Look up a serializer by name, if needed + if isinstance(serializer, str): + serializer = SERIALIZERS[serializer] + + # Wrap in a SerializerPipeline, if needed + if not isinstance(serializer, SerializerPipeline): + serializer = SerializerPipeline([serializer], name=str(serializer)) + serializer.set_decode_content(decode_content) + + return serializer diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 385e8fe..cdc3a98 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -48,7 +48,6 @@ class BaseCacheTest: def init_session(self, cache_name=CACHE_NAME, clear=True, **kwargs) -> CachedSession: kwargs = {**self.init_kwargs, **kwargs} kwargs.setdefault('allowable_methods', ALL_METHODS) - kwargs.setdefault('serializer', 'pickle') backend = self.backend_class(cache_name, **kwargs) if clear: backend.clear() diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py index 4070492..a96ffd9 100644 --- a/tests/integration/base_storage_test.py +++ b/tests/integration/base_storage_test.py @@ -19,7 +19,6 @@ class BaseStorageTest: def init_cache(self, cache_name=CACHE_NAME, index=0, clear=True, **kwargs): kwargs = {**self.init_kwargs, **kwargs} - kwargs.setdefault('serializer', 'pickle') cache = self.storage_class(cache_name, f'table_{index}', **kwargs) if clear: cache.clear() @@ -136,8 +135,8 @@ class BaseStorageTest: cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None)) for i in range(5): - cache_1[i] = i - cache_2[i] = i + cache_1[i] = f'value_{i}' + cache_2[i] = f'value_{i}' assert len(cache_1) == len(cache_2) == 5 cache_1.clear() @@ -147,8 +146,8 @@ class BaseStorageTest: def test_same_settings(self): cache_1 = self.init_cache() cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None)) - cache_1['key_1'] = 1 - cache_2['key_2'] = 2 + cache_1['key_1'] = 'value_1' + cache_2['key_2'] = 'value_2' assert cache_1 == cache_2 def test_str(self): diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py index 561c979..42e22eb 100644 --- a/tests/integration/test_dynamodb.py +++ b/tests/integration/test_dynamodb.py @@ -15,10 +15,6 @@ AWS_OPTIONS = { 'aws_access_key_id': 'placeholder', 'aws_secret_access_key': 'placeholder', } -DYNAMODB_OPTIONS = { - **AWS_OPTIONS, - 'serializer': None, # Use class default serializer -} @pytest.fixture(scope='module', autouse=True) @@ -33,7 +29,7 @@ def ensure_connection(): class TestDynamoDbDict(BaseStorageTest): storage_class = DynamoDbDict - init_kwargs = DYNAMODB_OPTIONS + init_kwargs = AWS_OPTIONS @patch('requests_cache.backends.dynamodb.boto3.resource') def test_connection_kwargs(self, mock_resource): @@ -87,4 +83,4 @@ class TestDynamoDbDict(BaseStorageTest): class TestDynamoDbCache(BaseCacheTest): backend_class = DynamoDbCache - init_kwargs = DYNAMODB_OPTIONS + init_kwargs = AWS_OPTIONS diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py index dfe9414..99625c4 100644 --- a/tests/integration/test_filesystem.py +++ b/tests/integration/test_filesystem.py @@ -39,7 +39,6 @@ class TestFileDict(BaseStorageTest): rmtree(CACHE_NAME, ignore_errors=True) def init_cache(self, index=0, clear=True, **kwargs): - kwargs.setdefault('serializer', 'pickle') cache = FileDict(f'{CACHE_NAME}_{index}', use_temp=True, **kwargs) if clear: cache.clear() diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py index 4a3a114..1dfae87 100644 --- a/tests/integration/test_mongodb.py +++ b/tests/integration/test_mongodb.py @@ -51,7 +51,6 @@ class TestMongoDict(BaseStorageTest): class TestMongoCache(BaseCacheTest): backend_class = MongoCache - init_kwargs = {'serializer': None} # Use class default serializer instead of pickle def test_ttl(self): session = self.init_session() diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py index 4c74d68..2a34899 100644 --- a/tests/integration/test_redis.py +++ b/tests/integration/test_redis.py @@ -33,6 +33,7 @@ class TestRedisHashDict(TestRedisDict): storage_class = RedisHashDict num_instances: int = 10 # Supports multiple instances, since this stores items under hash keys picklable = True + init_kwargs = {'serializer': 'pickle'} class TestRedisCache(BaseCacheTest): |