From 9c4f76744a764f8898e5735d838f929c277821c0 Mon Sep 17 00:00:00 2001 From: Jordan Cook Date: Sun, 1 Jan 2023 12:40:47 -0600 Subject: Set default serializers for each backend using param defaults instead of 'default_serializer' class attributes --- requests_cache/backends/base.py | 34 ++++++---------------------------- requests_cache/backends/dynamodb.py | 13 +++++++------ requests_cache/backends/filesystem.py | 11 ++++++----- requests_cache/backends/gridfs.py | 27 ++++++++++++++++++++++----- requests_cache/backends/mongodb.py | 13 +++++++------ requests_cache/backends/redis.py | 20 +++++++++----------- requests_cache/backends/sqlite.py | 24 ++++++++++++++++-------- requests_cache/serializers/__init__.py | 23 ++++++++++++++++++++++- 8 files changed, 95 insertions(+), 70 deletions(-) (limited to 'requests_cache') diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 40651cd..6193483 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -19,7 +19,7 @@ from requests import Request, Response from ..cache_keys import create_key, redact_response from ..models import AnyRequest, CachedResponse from ..policy import DEFAULT_CACHE_NAME, CacheSettings, ExpirationTime -from ..serializers import SERIALIZERS, SerializerPipeline, SerializerType, pickle_serializer +from ..serializers import SerializerType, init_serializer # Specific exceptions that may be raised during deserialization DESERIALIZE_ERRORS = (AttributeError, ImportError, PickleError, TypeError, ValueError) @@ -350,38 +350,16 @@ class BaseStorage(MutableMapping[KT, VT], ABC): for details. Args: - serializer: Custom serializer that provides ``loads`` and ``dumps`` methods - no_serializer: Explicitly disable serialization, and write values as-is; this is to avoid - ambiguity with ``serializer=None`` - decode_content: Decode JSON or text response body into a human-readable format + serializer: Custom serializer that provides ``loads`` and ``dumps`` methods. + If not provided, values will be written as-is. + decode_content: Decode response body JSON or text into a human-readable format kwargs: Additional backend-specific keyword arguments """ - # Default serializer to use for responses, if one isn't specified; may be overridden by subclass - default_serializer: SerializerType = pickle_serializer - def __init__( - self, - serializer: Optional[SerializerType] = None, - no_serializer: bool = False, - decode_content: bool = False, - **kwargs, + self, serializer: Optional[SerializerType] = None, decode_content: bool = False, **kwargs ): - self.serializer: Optional[SerializerPipeline] = None - if not no_serializer: - self._init_serializer(serializer or self.default_serializer, decode_content) - - def _init_serializer(self, serializer: SerializerType, decode_content: bool): - # Look up a serializer by name, if needed - if isinstance(serializer, str): - serializer = SERIALIZERS[serializer] - - # Wrap in a SerializerPipeline, if needed - if not isinstance(serializer, SerializerPipeline): - serializer = SerializerPipeline([serializer], name=str(serializer)) - serializer.set_decode_content(decode_content) - - self.serializer = serializer + self.serializer = init_serializer(serializer, decode_content) logger.debug(f'Initialized {type(self).__name__} with serializer: {self.serializer}') def bulk_delete(self, keys: Iterable[KT]): diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py index 875c1f7..41021c3 100644 --- a/requests_cache/backends/dynamodb.py +++ b/requests_cache/backends/dynamodb.py @@ -12,7 +12,7 @@ from boto3.resources.base import ServiceResource from botocore.exceptions import ClientError from .._utils import get_valid_kwargs -from ..serializers import dynamodb_document_serializer +from ..serializers import SerializerType, dynamodb_document_serializer from . import BaseCache, BaseStorage @@ -35,23 +35,25 @@ class DynamoDbCache(BaseCache): ttl: bool = True, connection: Optional[ServiceResource] = None, decode_content: bool = True, + serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=table_name, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses = DynamoDbDict( table_name, namespace='responses', ttl=ttl, connection=connection, decode_content=decode_content, - **kwargs, + **skwargs, ) self.redirects = DynamoDbDict( table_name, namespace='redirects', ttl=False, connection=self.responses.connection, - no_serializer=True, + serialzier=None, **kwargs, ) @@ -68,17 +70,16 @@ class DynamoDbDict(BaseStorage): kwargs: Additional keyword arguments for :py:meth:`~boto3.session.Session.resource` """ - default_serializer = dynamodb_document_serializer - def __init__( self, table_name: str, namespace: str, ttl: bool = True, connection: Optional[ServiceResource] = None, + serializer: Optional[SerializerType] = dynamodb_document_serializer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs( boto3.Session.__init__, kwargs, extras=['endpoint_url'] ) diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py index b68e2ee..ccda6c8 100644 --- a/requests_cache/backends/filesystem.py +++ b/requests_cache/backends/filesystem.py @@ -12,7 +12,7 @@ from shutil import rmtree from threading import RLock from typing import Iterator, Optional -from ..serializers import SERIALIZERS, json_serializer +from ..serializers import SERIALIZERS, SerializerType, json_serializer from . import BaseCache, BaseStorage from .sqlite import AnyPath, SQLiteDict, get_cache_path @@ -35,11 +35,13 @@ class FileCache(BaseCache): cache_name: AnyPath = 'http_cache', use_temp: bool = False, decode_content: bool = True, + serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=str(cache_name), **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses: FileDict = FileDict( - cache_name, use_temp=use_temp, decode_content=decode_content, **kwargs + cache_name, use_temp=use_temp, decode_content=decode_content, **skwargs ) self.redirects: SQLiteDict = SQLiteDict( self.cache_dir / 'redirects.sqlite', 'redirects', no_serializer=True, **kwargs @@ -68,17 +70,16 @@ class FileCache(BaseCache): class FileDict(BaseStorage): """A dictionary-like interface to files on the local filesystem""" - default_serializer = json_serializer - def __init__( self, cache_name: AnyPath, use_temp: bool = False, use_cache_dir: bool = False, extension: Optional[str] = None, + serializer: Optional[SerializerType] = json_serializer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) self.cache_dir = get_cache_path(cache_name, use_cache_dir=use_cache_dir, use_temp=use_temp) self.extension = _get_extension(extension, self.serializer) self.is_binary = getattr(self.serializer, 'is_binary', False) diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index 9f04a8a..aadb7e5 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py @@ -6,12 +6,14 @@ """ from logging import getLogger from threading import RLock +from typing import Optional from gridfs import GridFS from gridfs.errors import CorruptGridFile, FileExists from pymongo import MongoClient from .._utils import get_valid_kwargs +from ..serializers import SerializerType, pickle_serializer from .base import BaseCache, BaseStorage from .mongodb import MongoDict @@ -27,14 +29,22 @@ class GridFSCache(BaseCache): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - def __init__(self, db_name: str, decode_content: bool = False, **kwargs): + def __init__( + self, + db_name: str, + decode_content: bool = False, + serializer: Optional[SerializerType] = None, + **kwargs + ): super().__init__(cache_name=db_name, **kwargs) - self.responses = GridFSDict(db_name, decode_content=decode_content, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs + + self.responses = GridFSDict(db_name, decode_content=decode_content, **skwargs) self.redirects = MongoDict( db_name, collection_name='redirects', connection=self.responses.connection, - no_serializer=True, + serializer=None, **kwargs ) @@ -53,8 +63,15 @@ class GridFSDict(BaseStorage): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - def __init__(self, db_name, collection_name=None, connection=None, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + db_name, + collection_name=None, + connection=None, + serializer: Optional[SerializerType] = pickle_serializer, + **kwargs + ): + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs) self.connection = connection or MongoClient(**connection_kwargs) self.db = self.connection[db_name] diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py index 90f8feb..07671e7 100644 --- a/requests_cache/backends/mongodb.py +++ b/requests_cache/backends/mongodb.py @@ -13,7 +13,7 @@ from pymongo.errors import OperationFailure from .._utils import get_valid_kwargs from ..policy.expiration import NEVER_EXPIRE, get_expiration_seconds -from ..serializers import bson_document_serializer +from ..serializers import SerializerType, bson_document_serializer from . import BaseCache, BaseStorage logger = getLogger(__name__) @@ -34,21 +34,23 @@ class MongoCache(BaseCache): db_name: str = 'http_cache', connection: MongoClient = None, decode_content: bool = True, + serializer: Optional[SerializerType] = None, **kwargs, ): super().__init__(cache_name=db_name, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses: MongoDict = MongoDict( db_name, collection_name='responses', connection=connection, decode_content=decode_content, - **kwargs, + **skwargs, ) self.redirects: MongoDict = MongoDict( db_name, collection_name='redirects', connection=self.responses.connection, - no_serializer=True, + serialzier=None, **kwargs, ) @@ -77,16 +79,15 @@ class MongoDict(BaseStorage): kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient` """ - default_serializer = bson_document_serializer - def __init__( self, db_name: str, collection_name: str = 'http_cache', connection: Optional[MongoClient] = None, + serializer: Optional[SerializerType] = bson_document_serializer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs) self.connection = connection or MongoClient(**connection_kwargs) self.collection = self.connection[db_name][collection_name] diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py index 95d1ed8..59b2920 100644 --- a/requests_cache/backends/redis.py +++ b/requests_cache/backends/redis.py @@ -11,7 +11,7 @@ from redis import Redis, StrictRedis from .._utils import get_valid_kwargs from ..cache_keys import decode, encode -from ..serializers import utf8_encoder +from ..serializers import SerializerType, pickle_serializer, utf8_encoder from . import BaseCache, BaseStorage DEFAULT_TTL_OFFSET = 3600 @@ -33,21 +33,18 @@ class RedisCache(BaseCache): self, namespace='http_cache', connection: Optional[Redis] = None, + serializer: Optional[SerializerType] = None, ttl: bool = True, ttl_offset: int = DEFAULT_TTL_OFFSET, **kwargs, ): super().__init__(cache_name=namespace, **kwargs) + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs self.responses = RedisDict( - namespace, connection=connection, ttl=ttl, ttl_offset=ttl_offset, **kwargs + namespace, connection=connection, ttl=ttl, ttl_offset=ttl_offset, **skwargs ) - kwargs.pop('serializer', None) self.redirects = RedisHashDict( - namespace, - 'redirects', - connection=self.responses.connection, - serializer=utf8_encoder, # Only needs encoding to/decoding from bytes - **kwargs, + namespace, collection_name='redirects', connection=self.responses.connection, **kwargs ) @@ -64,12 +61,12 @@ class RedisDict(BaseStorage): namespace: str, collection_name: Optional[str] = None, connection=None, + serializer: Optional[SerializerType] = pickle_serializer, ttl: bool = True, ttl_offset: int = DEFAULT_TTL_OFFSET, **kwargs, ): - - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(Redis.__init__, kwargs) self.connection = connection or StrictRedis(**connection_kwargs) self.namespace = namespace @@ -149,9 +146,10 @@ class RedisHashDict(BaseStorage): namespace: str = 'http_cache', collection_name: Optional[str] = None, connection=None, + serializer: Optional[SerializerType] = utf8_encoder, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) connection_kwargs = get_valid_kwargs(Redis, kwargs) self.connection = connection or StrictRedis(**connection_kwargs) self._hash_key = f'{namespace}-{collection_name}' diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py index 02d9b70..0a892ec 100644 --- a/requests_cache/backends/sqlite.py +++ b/requests_cache/backends/sqlite.py @@ -17,12 +17,12 @@ from typing import Collection, Iterator, List, Optional, Tuple, Type, Union from platformdirs import user_cache_dir -from requests_cache.backends.base import VT -from requests_cache.models.response import CachedResponse -from requests_cache.policy import ExpirationTime - from .._utils import chunkify, get_valid_kwargs +from ..models.response import CachedResponse +from ..policy import ExpirationTime +from ..serializers import SerializerType, pickle_serializer from . import BaseCache, BaseStorage +from .base import VT MEMORY_URI = 'file::memory:?cache=shared' SQLITE_MAX_VARIABLE_NUMBER = 999 @@ -45,11 +45,18 @@ class SQLiteCache(BaseCache): kwargs: Additional keyword arguments for :py:func:`sqlite3.connect` """ - def __init__(self, db_path: AnyPath = 'http_cache', **kwargs): + def __init__( + self, + db_path: AnyPath = 'http_cache', + serializer: Optional[SerializerType] = None, + **kwargs, + ): super().__init__(cache_name=str(db_path), **kwargs) - self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **kwargs) + # Only override serializer if a non-None value is specified + skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs + self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **skwargs) self.redirects: SQLiteDict = SQLiteDict( - db_path, table_name='redirects', no_serializer=True, **kwargs + db_path, table_name='redirects', serializer=None, **kwargs ) @property @@ -165,13 +172,14 @@ class SQLiteDict(BaseStorage): db_path, table_name='http_cache', fast_save=False, + serializer: Optional[SerializerType] = pickle_serializer, use_cache_dir: bool = False, use_memory: bool = False, use_temp: bool = False, wal: bool = False, **kwargs, ): - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) self._can_commit = True self._local_context = threading.local() self._lock = threading.RLock() diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py index 085dfc2..c477787 100644 --- a/requests_cache/serializers/__init__.py +++ b/requests_cache/serializers/__init__.py @@ -19,7 +19,7 @@ For any optional libraries that aren't installed, the corresponding serializer w class that raises an ``ImportError`` at initialization time instead of at import time. """ # flake8: noqa: F401 -from typing import Union +from typing import Optional, Union from .cattrs import CattrStage from .pipeline import SerializerPipeline, Stage @@ -39,7 +39,9 @@ __all__ = [ 'SERIALIZERS', 'CattrStage', 'SerializerPipeline', + 'SerializerType', 'Stage', + 'init_serializer', 'bson_serializer', 'bson_document_serializer', 'dynamodb_document_serializer', @@ -59,3 +61,22 @@ SERIALIZERS = { } SerializerType = Union[str, SerializerPipeline, Stage] + + +def init_serializer( + serializer: Optional[SerializerType], decode_content: bool +) -> Optional[SerializerPipeline]: + """Intitialze a serializer by name or instance""" + if not serializer: + return None + + # Look up a serializer by name, if needed + if isinstance(serializer, str): + serializer = SERIALIZERS[serializer] + + # Wrap in a SerializerPipeline, if needed + if not isinstance(serializer, SerializerPipeline): + serializer = SerializerPipeline([serializer], name=str(serializer)) + serializer.set_decode_content(decode_content) + + return serializer -- cgit v1.2.1