summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <JWCook@users.noreply.github.com>2021-04-19 17:42:00 -0500
committerGitHub <noreply@github.com>2021-04-19 17:42:00 -0500
commited0e286d2b5702b04ee7dd62d2c85138d4c79652 (patch)
tree36627a6f9c46657bac58d28d2c17ff311c1d679c
parent714814d0c23271fb357e9ee094e8754ef9818029 (diff)
parentf0916900e11778f9ff0a33b148edbf2f51f6f600 (diff)
downloadrequests-cache-ed0e286d2b5702b04ee7dd62d2c85138d4c79652.tar.gz
Merge pull request #237 from JWCook/backend-kwargs
Allow passing any valid backend connection kwargs via BaseCache
-rw-r--r--docs/conf.py8
-rw-r--r--requests_cache/backends/__init__.py10
-rw-r--r--requests_cache/backends/base.py11
-rw-r--r--requests_cache/backends/dynamodb.py38
-rw-r--r--requests_cache/backends/gridfs.py45
-rw-r--r--requests_cache/backends/mongo.py52
-rw-r--r--requests_cache/backends/redis.py21
-rw-r--r--requests_cache/backends/sqlite.py29
-rw-r--r--requests_cache/session.py6
-rw-r--r--tests/integration/test_dynamodb.py8
-rw-r--r--tests/integration/test_gridfs.py17
-rw-r--r--tests/integration/test_mongodb.py16
-rw-r--r--tests/integration/test_redis.py11
-rw-r--r--tests/integration/test_sqlite.py6
14 files changed, 169 insertions, 109 deletions
diff --git a/docs/conf.py b/docs/conf.py
index 6d92cad..68eeb16 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -26,6 +26,7 @@ templates_path = ['_templates']
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosectionlabel',
+ 'sphinx.ext.extlinks',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
@@ -40,10 +41,17 @@ exclude_patterns = ['_build', 'modules/requests_cache.rst']
# Enable automatic links to other projects' Sphinx docs
intersphinx_mapping = {
+ 'boto3': ('https://boto3.amazonaws.com/v1/documentation/api/latest/', None),
+ 'botocore': ('http://botocore.readthedocs.io/en/latest/', None),
+ 'pymongo': ('https://pymongo.readthedocs.io/en/stable/', None),
'python': ('https://docs.python.org/3', None),
+ 'redis': ('https://redis-py.readthedocs.io/en/stable/', None),
'requests': ('https://docs.python-requests.org/en/master/', None),
'urllib3': ('https://urllib3.readthedocs.io/en/latest/', None),
}
+extlinks = {
+ 'boto3': ('https://boto3.amazonaws.com/v1/documentation/api/latest/reference/%s', None),
+}
# Enable Google-style docstrings
napoleon_google_docstring = True
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py
index 78c935a..f64f051 100644
--- a/requests_cache/backends/__init__.py
+++ b/requests_cache/backends/__init__.py
@@ -1,7 +1,8 @@
"""Classes and functions for cache persistence"""
# flake8: noqa: F401
+from inspect import signature
from logging import getLogger
-from typing import Type, Union
+from typing import Callable, Dict, Iterable, Type, Union
from .base import BaseCache, BaseStorage
@@ -42,6 +43,13 @@ def get_placeholder_backend(original_exception: Exception = None) -> Type[BaseCa
return PlaceholderBackend
+def get_valid_kwargs(func: Callable, kwargs: Dict, extras: Iterable[str] = None) -> Dict:
+ """Get the subset of non-None ``kwargs`` that are valid params for ``func``"""
+ params = list(signature(func).parameters)
+ params.extend(extras or [])
+ return {k: v for k, v in kwargs.items() if k in params and v is not None}
+
+
# Import all backend classes for which dependencies are installed
try:
from .dynamodb import DynamoDbCache, DynamoDbDict
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index f0abb70..e912879 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -2,7 +2,7 @@ import pickle
import warnings
from abc import ABC
from collections.abc import MutableMapping
-from logging import getLogger
+from logging import DEBUG, WARNING, getLogger
from typing import Iterable, List, Tuple, Union
import requests
@@ -177,14 +177,9 @@ class BaseStorage(MutableMapping, ABC):
self._serializer = serializer or self._get_serializer(secret_key, salt)
logger.debug(f'Initializing {type(self).__name__} with serializer: {self._serializer}')
- # Show a warning instead of an exception if there are remaining unused kwargs
- kwargs.pop('include_get_headers', None)
- kwargs.pop('ignored_params', None)
- if kwargs:
- logger.warning(f'Unrecognized keyword arguments: {kwargs}')
if not secret_key:
- warn_func = logger.debug if suppress_warnings else warnings.warn
- warn_func('Using a secret key to sign cached items is recommended for this backend')
+ level = DEBUG if suppress_warnings else WARNING
+ logger.log(level, 'Using a secret key is recommended for this backend')
def serialize(self, item: ResponseOrKey) -> bytes:
"""Serialize a URL or response into bytes"""
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index 0f939d2..cad3ae6 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -1,7 +1,8 @@
import boto3
+from boto3.resources.base import ServiceResource
from botocore.exceptions import ClientError
-from .base import BaseCache, BaseStorage
+from . import BaseCache, BaseStorage, get_valid_kwargs
class DynamoDbCache(BaseCache):
@@ -10,15 +11,15 @@ class DynamoDbCache(BaseCache):
Args:
table_name: DynamoDb table name
namespace: Name of DynamoDb hash map
- connection: DynamoDb Resource object (``boto3.resource('dynamodb')``) to use instead of
- creating a new one
+ connection: :boto3:`DynamoDb Resource <services/dynamodb.html#DynamoDB.ServiceResource>`
+ object to use instead of creating a new one
+ kwargs: Additional keyword arguments for :py:meth:`~boto3.session.Session.resource`
"""
- def __init__(self, table_name='http_cache', **kwargs):
+ def __init__(self, table_name: str = 'http_cache', connection: ServiceResource = None, **kwargs):
super().__init__(**kwargs)
- self.responses = DynamoDbDict(table_name, namespace='responses', **kwargs)
- kwargs['connection'] = self.responses.connection
- self.redirects = DynamoDbDict(table_name, namespace='redirects', **kwargs)
+ self.responses = DynamoDbDict(table_name, 'responses', connection=connection, **kwargs)
+ self.redirects = DynamoDbDict(table_name, 'redirects', connection=self.responses.connection, **kwargs)
class DynamoDbDict(BaseStorage):
@@ -32,9 +33,9 @@ class DynamoDbDict(BaseStorage):
Args:
table_name: DynamoDb table name
namespace: Name of DynamoDb hash map
- connection: DynamoDb Resource object (``boto3.resource('dynamodb')``) to use instead of
- creating a new one
- endpoint_url: Alternative URL of dynamodb server.
+ connection: :boto3:`DynamoDb Resource <services/dynamodb.html#DynamoDB.ServiceResource>`
+ object to use instead of creating a new one
+ kwargs: Additional keyword arguments for :py:meth:`~boto3.session.Session.resource`
"""
def __init__(
@@ -42,27 +43,14 @@ class DynamoDbDict(BaseStorage):
table_name,
namespace='http_cache',
connection=None,
- endpoint_url=None,
- region_name='us-east-1',
- aws_access_key_id=None,
- aws_secret_access_key=None,
read_capacity_units=1,
write_capacity_units=1,
**kwargs,
):
super().__init__(**kwargs)
+ connection_kwargs = get_valid_kwargs(boto3.Session, kwargs, extras=['endpoint_url'])
+ self.connection = connection or boto3.resource('dynamodb', **connection_kwargs)
self._self_key = namespace
- if connection is not None:
- self.connection = connection
- else:
- # TODO: Use inspection to get any valid resource arguments from **kwargs
- self.connection = boto3.resource(
- 'dynamodb',
- endpoint_url=endpoint_url,
- region_name=region_name,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- )
try:
self.connection.create_table(
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index 5c551a2..73f258e 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -1,6 +1,7 @@
from gridfs import GridFS
from pymongo import MongoClient
+from . import get_valid_kwargs
from .base import BaseCache, BaseStorage
from .mongo import MongoDict
@@ -9,19 +10,21 @@ class GridFSCache(BaseCache):
"""GridFS cache backend.
Use this backend to store documents greater than 16MB.
- Usage:
- requests_cache.install_cache(backend='gridfs')
+ Example:
- Or:
- from pymongo import MongoClient
- requests_cache.install_cache(backend='gridfs', connection=MongoClient('another-host.local'))
+ >>> requests_cache.install_cache(backend='gridfs')
+
+ Or:
+ >>> from pymongo import MongoClient
+ >>> requests_cache.install_cache(backend='gridfs', connection=MongoClient('alternate-host'))
+
+ Args:
+ db_name: Database name
+ connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one
+ kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- def __init__(self, db_name, **kwargs):
- """
- :param db_name: database name
- :param connection: (optional) ``pymongo.Connection``
- """
+ def __init__(self, db_name: str, **kwargs):
super().__init__(**kwargs)
self.responses = GridFSPickleDict(db_name, **kwargs)
kwargs['connection'] = self.responses.connection
@@ -29,21 +32,19 @@ class GridFSCache(BaseCache):
class GridFSPickleDict(BaseStorage):
- """A dictionary-like interface for a GridFS collection"""
+ """A dictionary-like interface for a GridFS database
+
+ Args:
+ db_name: Database name
+ collection_name: Ignored; GridFS internally uses collections 'fs.files' and 'fs.chunks'
+ connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one
+ kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
+ """
def __init__(self, db_name, collection_name=None, connection=None, **kwargs):
- """
- :param db_name: database name (be careful with production databases)
- :param connection: ``pymongo.Connection`` instance. If it's ``None``
- (default) new connection with default options will
- be created
- """
super().__init__(**kwargs)
- if connection is not None:
- self.connection = connection
- else:
- self.connection = MongoClient()
-
+ connection_kwargs = get_valid_kwargs(MongoClient, kwargs)
+ self.connection = connection or MongoClient(**connection_kwargs)
self.db = self.connection[db_name]
self.fs = GridFS(self.db)
diff --git a/requests_cache/backends/mongo.py b/requests_cache/backends/mongo.py
index 0c9bc61..b479a47 100644
--- a/requests_cache/backends/mongo.py
+++ b/requests_cache/backends/mongo.py
@@ -1,40 +1,44 @@
from pymongo import MongoClient
-from .base import BaseCache, BaseStorage
+from . import BaseCache, BaseStorage, get_valid_kwargs
class MongoCache(BaseCache):
- """MongoDB cache backend"""
+ """MongoDB cache backend
- def __init__(self, db_name='http_cache', **kwargs):
- """
- :param db_name: database name (default: ``'requests-cache'``)
- :param connection: (optional) ``pymongo.Connection``
- """
+ Args:
+ db_name: Database name
+ connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one
+ kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
+ """
+
+ def __init__(self, db_name: str = 'http_cache', connection: MongoClient = None, **kwargs):
+ """"""
super().__init__(**kwargs)
- self.responses = MongoPickleDict(db_name, collection_name='responses', **kwargs)
- kwargs['connection'] = self.responses.connection
- self.redirects = MongoDict(db_name, collection_name='redirects', **kwargs)
+ self.responses = MongoPickleDict(db_name, 'responses', connection=connection, **kwargs)
+ self.redirects = MongoDict(
+ db_name,
+ collection_name='redirects',
+ connection=self.responses.connection,
+ **kwargs,
+ )
class MongoDict(BaseStorage):
- """A dictionary-like interface for a MongoDB collection"""
+ """A dictionary-like interface for a MongoDB collection
+
+ Args:
+ db_name: Database name
+ collection_name: Collection name
+ connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one
+ kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
+ """
def __init__(self, db_name, collection_name='http_cache', connection=None, **kwargs):
- """
- :param db_name: database name (be careful with production databases)
- :param collection_name: collection name (default: mongo_dict_data)
- :param connection: ``pymongo.Connection`` instance. If it's ``None``
- (default) new connection with default options will
- be created
- """
super().__init__(**kwargs)
- if connection is not None:
- self.connection = connection
- else:
- self.connection = MongoClient()
- self.db = self.connection[db_name]
- self.collection = self.db[collection_name]
+ connection_kwargs = get_valid_kwargs(MongoClient, kwargs)
+ self.connection = connection or MongoClient(**connection_kwargs)
+ self.collection = self.connection[db_name][collection_name]
def __getitem__(self, key):
result = self.collection.find_one({'_id': key})
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index cacb7c2..33236d5 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -1,6 +1,6 @@
-from redis import Redis
+from redis import Redis, StrictRedis
-from .base import BaseCache, BaseStorage
+from . import BaseCache, BaseStorage, get_valid_kwargs
class RedisCache(BaseCache):
@@ -8,14 +8,14 @@ class RedisCache(BaseCache):
Args:
namespace: redis namespace (default: ``'requests-cache'``)
- connection: (optional) Redis connection instance to use instead of creating a new one
+ connection: Redis connection instance to use instead of creating a new one
+ kwargs: Additional keyword arguments for :py:class:`redis.client.Redis`
"""
- def __init__(self, namespace='http_cache', **kwargs):
+ def __init__(self, namespace='http_cache', connection: Redis = None, **kwargs):
super().__init__(**kwargs)
- self.responses = RedisDict(namespace, collection_name='responses', **kwargs)
- kwargs['connection'] = self.responses.connection
- self.redirects = RedisDict(namespace, collection_name='redirects', **kwargs)
+ self.responses = RedisDict(namespace, 'responses', connection=connection, **kwargs)
+ self.redirects = RedisDict(namespace, 'redirects', connection=self.responses.connection, **kwargs)
class RedisDict(BaseStorage):
@@ -29,14 +29,13 @@ class RedisDict(BaseStorage):
namespace: Redis namespace
collection_name: Name of the Redis hash map
connection: (optional) Redis connection instance to use instead of creating a new one
+ kwargs: Additional keyword arguments for :py:class:`redis.client.Redis`
"""
def __init__(self, namespace, collection_name='http_cache', connection=None, **kwargs):
super().__init__(**kwargs)
- if connection is not None:
- self.connection = connection
- else:
- self.connection = Redis()
+ connection_kwargs = get_valid_kwargs(Redis, kwargs)
+ self.connection = connection or StrictRedis(**connection_kwargs)
self._self_key = ':'.join([namespace, collection_name])
def __getitem__(self, key):
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index 45c3f79..e80497c 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -5,9 +5,9 @@ from logging import getLogger
from os import makedirs
from os.path import abspath, basename, dirname, expanduser
from pathlib import Path
-from typing import Union
+from typing import Type, Union
-from .base import BaseCache, BaseStorage
+from . import BaseCache, BaseStorage, get_valid_kwargs
logger = getLogger(__name__)
@@ -15,18 +15,16 @@ logger = getLogger(__name__)
class DbCache(BaseCache):
"""SQLite cache backend.
- Reading is fast, saving is a bit slower. It can store big amount of data with low memory usage.
Args:
db_path: Database file path (expands user paths and creates parent dirs)
fast_save: Speedup cache saving up to 50 times but with possibility of data loss.
See :py:class:`.DbDict` for more info
- timeout: Timeout for acquiring a database lock
+ kwargs: Additional keyword arguments for :py:func:`sqlite3.connect`
"""
def __init__(self, db_path: Union[Path, str] = 'http_cache', fast_save: bool = False, **kwargs):
super().__init__(**kwargs)
- kwargs.setdefault('suppress_warnings', True)
db_path = _get_db_path(db_path)
self.responses = DbPickleDict(db_path, table_name='responses', fast_save=fast_save, **kwargs)
self.redirects = DbDict(db_path, table_name='redirects', **kwargs)
@@ -58,12 +56,13 @@ class DbDict(BaseStorage):
timeout: Timeout for acquiring a database lock
"""
- def __init__(self, db_path, table_name='http_cache', fast_save=False, timeout=5.0, **kwargs):
+ def __init__(self, db_path, table_name='http_cache', fast_save=False, **kwargs):
+ kwargs.setdefault('suppress_warnings', True)
super().__init__(**kwargs)
+ self.connection_kwargs = get_valid_kwargs(sqlite_template, kwargs)
self.db_path = db_path
self.fast_save = fast_save
self.table_name = table_name
- self.timeout = timeout
self._bulk_commit = False
self._can_commit = True
@@ -78,10 +77,10 @@ class DbDict(BaseStorage):
with self._lock:
if self._bulk_commit:
if self._pending_connection is None:
- self._pending_connection = sqlite3.connect(self.db_path, timeout=self.timeout)
+ self._pending_connection = sqlite3.connect(self.db_path, **self.connection_kwargs)
con = self._pending_connection
else:
- con = sqlite3.connect(self.db_path, timeout=self.timeout)
+ con = sqlite3.connect(self.db_path, **self.connection_kwargs)
try:
if self.fast_save:
con.execute("PRAGMA synchronous = 0;")
@@ -188,3 +187,15 @@ def _get_db_path(db_path):
# Make sure parent dirs exist
makedirs(dirname(db_path), exist_ok=True)
return db_path
+
+
+def sqlite_template(
+ timeout: float = 5.0,
+ detect_types: int = 0,
+ isolation_level: str = None,
+ check_same_thread: bool = True,
+ factory: Type = None,
+ cached_statements: int = 100,
+ uri: bool = False,
+):
+ """Template function to get an accurate signature for the builtin :py:func:`sqlite3.connect`"""
diff --git a/requests_cache/session.py b/requests_cache/session.py
index 031a481..65484ca 100644
--- a/requests_cache/session.py
+++ b/requests_cache/session.py
@@ -9,7 +9,7 @@ from requests import PreparedRequest
from requests import Session as OriginalSession
from requests.hooks import dispatch_hook
-from .backends import BACKEND_KWARGS, BackendSpecifier, init_backend
+from .backends import BackendSpecifier, get_valid_kwargs, init_backend
from .cache_keys import normalize_dict
from .response import AnyResponse, ExpirationTime, set_response_defaults
@@ -47,8 +47,8 @@ class CacheMixin:
self._disabled = False
self._lock = RLock()
- # Remove any requests-cache-specific kwargs before passing along to superclass
- session_kwargs = {k: v for k, v in kwargs.items() if k not in BACKEND_KWARGS}
+ # If the superclass is custom Session, pass along valid kwargs (if any)
+ session_kwargs = get_valid_kwargs(super().__init__, kwargs)
super().__init__(**session_kwargs)
def request(
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index a79118b..13e40dd 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -1,5 +1,6 @@
import pytest
import unittest
+from unittest.mock import patch
from requests_cache.backends import DynamoDbDict
from tests.conftest import fail_if_no_connection
@@ -39,3 +40,10 @@ class DynamoDbTestCase(BaseStorageTestCase, unittest.TestCase):
picklable=True,
**kwargs,
)
+
+
+@patch('requests_cache.backends.dynamodb.boto3.resource')
+def test_connection_kwargs(mock_resource):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ DynamoDbDict('test', region_name='us-east-2', invalid_kwarg='???')
+ mock_resource.assert_called_with('dynamodb', region_name='us-east-2')
diff --git a/tests/integration/test_gridfs.py b/tests/integration/test_gridfs.py
index 128356b..aae9948 100644
--- a/tests/integration/test_gridfs.py
+++ b/tests/integration/test_gridfs.py
@@ -1,7 +1,10 @@
import pytest
import unittest
+from unittest.mock import patch
-from requests_cache.backends import GridFSPickleDict
+from pymongo import MongoClient
+
+from requests_cache.backends import GridFSPickleDict, get_valid_kwargs
from tests.conftest import fail_if_no_connection
from tests.integration.test_backends import BaseStorageTestCase
@@ -29,3 +32,15 @@ class GridFSPickleDictTestCase(BaseStorageTestCase, unittest.TestCase):
with pytest.raises(KeyError):
d1[4]
+
+
+@patch('requests_cache.backends.gridfs.GridFS')
+@patch('requests_cache.backends.gridfs.MongoClient')
+@patch(
+ 'requests_cache.backends.gridfs.get_valid_kwargs',
+ side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs),
+)
+def test_connection_kwargs(mock_get_valid_kwargs, mock_client, mock_gridfs):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ GridFSPickleDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
+ mock_client.assert_called_with(host='http://0.0.0.0', port=1234)
diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py
index c781b52..05e61a2 100644
--- a/tests/integration/test_mongodb.py
+++ b/tests/integration/test_mongodb.py
@@ -1,7 +1,10 @@
import pytest
import unittest
+from unittest.mock import patch
-from requests_cache.backends import MongoDict, MongoPickleDict
+from pymongo import MongoClient
+
+from requests_cache.backends import MongoDict, MongoPickleDict, get_valid_kwargs
from tests.conftest import fail_if_no_connection
from tests.integration.test_backends import BaseStorageTestCase
@@ -24,3 +27,14 @@ class MongoDictTestCase(BaseStorageTestCase, unittest.TestCase):
class MongoPickleDictTestCase(BaseStorageTestCase, unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, storage_class=MongoPickleDict, picklable=True, **kwargs)
+
+
+@patch('requests_cache.backends.mongo.MongoClient')
+@patch(
+ 'requests_cache.backends.mongo.get_valid_kwargs',
+ side_effect=lambda cls, kwargs: get_valid_kwargs(MongoClient, kwargs),
+)
+def test_connection_kwargs(mock_get_valid_kwargs, mock_client):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ MongoDict('test', host='http://0.0.0.0', port=1234, invalid_kwarg='???')
+ mock_client.assert_called_with(host='http://0.0.0.0', port=1234)
diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py
index 4bf5fd8..5edf559 100644
--- a/tests/integration/test_redis.py
+++ b/tests/integration/test_redis.py
@@ -1,7 +1,8 @@
import pytest
import unittest
+from unittest.mock import patch
-from requests_cache.backends.redis import RedisDict
+from requests_cache.backends.redis import RedisCache, RedisDict
from tests.conftest import fail_if_no_connection
from tests.integration.test_backends import BaseStorageTestCase
@@ -18,3 +19,11 @@ def ensure_connection():
class RedisTestCase(BaseStorageTestCase, unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, storage_class=RedisDict, picklable=True, **kwargs)
+
+
+# @patch.object(Redis, '__init__', Redis.__init__)
+@patch('requests_cache.backends.redis.StrictRedis')
+def test_connection_kwargs(mock_redis):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ RedisCache('test', username='user', password='pass', invalid_kwarg='???')
+ mock_redis.assert_called_with(username='user', password='pass')
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index bc883bf..18ddefd 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -89,7 +89,7 @@ class DbPickleDictTestCase(SQLiteTestCase, unittest.TestCase):
@patch('requests_cache.backends.sqlite.sqlite3')
-def test_timeout(mock_sqlite):
- """Just make sure the optional 'timeout' param gets passed to sqlite3.connect"""
- DbDict('test', timeout=0.5)
+def test_connection_kwargs(mock_sqlite):
+ """A spot check to make sure optional connection kwargs gets passed to connection"""
+ DbDict('test', timeout=0.5, invalid_kwarg='???')
mock_sqlite.connect.assert_called_with('test', timeout=0.5)