summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook.git@proton.me>2023-01-13 16:08:35 -0600
committerJordan Cook <jordan.cook.git@proton.me>2023-03-01 15:22:07 -0600
commit048b772fa751f860ebe73e268a8719c9bc8bc35c (patch)
treede8960f9c5d3436eacbbe27e891b81d641e8eac4
parentddacec4f0e3c1a9f16ec4aad44803a0b14224e12 (diff)
downloadrequests-cache-048b772fa751f860ebe73e268a8719c9bc8bc35c.tar.gz
Change DynamoDB table to use cache key as partition key
-rw-r--r--requests_cache/backends/dynamodb.py83
-rw-r--r--requests_cache/backends/mongodb.py2
-rw-r--r--tests/integration/base_storage_test.py4
-rw-r--r--tests/integration/test_dynamodb.py14
4 files changed, 49 insertions, 54 deletions
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index 41021c3..f5d9838 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -4,16 +4,18 @@
:classes-only:
:nosignatures:
"""
-from typing import Dict, Iterable, Optional
+from typing import Iterable, Optional
import boto3
from boto3.dynamodb.types import Binary
from boto3.resources.base import ServiceResource
from botocore.exceptions import ClientError
+from requests_cache.backends.base import VT
+
from .._utils import get_valid_kwargs
from ..serializers import SerializerType, dynamodb_document_serializer
-from . import BaseCache, BaseStorage
+from . import BaseCache, BaseStorage, DictStorage
class DynamoDbCache(BaseCache):
@@ -22,7 +24,6 @@ class DynamoDbCache(BaseCache):
Args:
table_name: DynamoDB table name
- namespace: Name of DynamoDB hash map
connection: :boto3:`DynamoDB Resource <services/dynamodb.html#DynamoDB.ServiceResource>`
object to use instead of creating a new one
ttl: Use DynamoDB TTL to automatically remove expired items
@@ -32,6 +33,7 @@ class DynamoDbCache(BaseCache):
def __init__(
self,
table_name: str = 'http_cache',
+ *,
ttl: bool = True,
connection: Optional[ServiceResource] = None,
decode_content: bool = True,
@@ -42,20 +44,13 @@ class DynamoDbCache(BaseCache):
skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
self.responses = DynamoDbDict(
table_name,
- namespace='responses',
ttl=ttl,
connection=connection,
decode_content=decode_content,
**skwargs,
)
- self.redirects = DynamoDbDict(
- table_name,
- namespace='redirects',
- ttl=False,
- connection=self.responses.connection,
- serialzier=None,
- **kwargs,
- )
+ # Redirects will be only stored in memory and not persisted
+ self.redirects: BaseStorage[str, str] = DictStorage()
class DynamoDbDict(BaseStorage):
@@ -63,7 +58,6 @@ class DynamoDbDict(BaseStorage):
Args:
table_name: DynamoDB table name
- namespace: Name of DynamoDB hash map
connection: :boto3:`DynamoDB Resource <services/dynamodb.html#DynamoDB.ServiceResource>`
object to use instead of creating a new one
ttl: Use DynamoDB TTL to automatically remove expired items
@@ -73,7 +67,6 @@ class DynamoDbDict(BaseStorage):
def __init__(
self,
table_name: str,
- namespace: str,
ttl: bool = True,
connection: Optional[ServiceResource] = None,
serializer: Optional[SerializerType] = dynamodb_document_serializer,
@@ -84,7 +77,6 @@ class DynamoDbDict(BaseStorage):
boto3.Session.__init__, kwargs, extras=['endpoint_url']
)
self.connection = connection or boto3.resource('dynamodb', **connection_kwargs)
- self.namespace = namespace
self.table_name = table_name
self.ttl = ttl
@@ -98,13 +90,11 @@ class DynamoDbDict(BaseStorage):
try:
self.connection.create_table(
AttributeDefinitions=[
- {'AttributeName': 'namespace', 'AttributeType': 'S'},
{'AttributeName': 'key', 'AttributeType': 'S'},
],
TableName=self.table_name,
KeySchema=[
- {'AttributeName': 'namespace', 'KeyType': 'HASH'},
- {'AttributeName': 'key', 'KeyType': 'RANGE'},
+ {'AttributeName': 'key', 'KeyType': 'HASH'},
],
BillingMode='PAY_PER_REQUEST',
)
@@ -126,31 +116,14 @@ class DynamoDbDict(BaseStorage):
if e.response['Error']['Code'] != 'ValidationException':
raise
- def _composite_key(self, key: str) -> Dict[str, str]:
- return {'namespace': self.namespace, 'key': str(key)}
-
- def _scan(self):
- expression_attribute_values = {':Namespace': self.namespace}
- expression_attribute_names = {'#N': 'namespace'}
- key_condition_expression = '#N = :Namespace'
- return self._table.query(
- ExpressionAttributeValues=expression_attribute_values,
- ExpressionAttributeNames=expression_attribute_names,
- KeyConditionExpression=key_condition_expression,
- )
-
def __getitem__(self, key):
- result = self._table.get_item(Key=self._composite_key(key))
+ result = self._table.get_item(Key={'key': key})
if 'Item' not in result:
raise KeyError
-
- # With a custom serializer, the value may be a Binary object
- raw_value = result['Item']['value']
- value = raw_value.value if isinstance(raw_value, Binary) else raw_value
- return self.deserialize(key, value)
+ return self.deserialize(key, result['Item']['value'])
def __setitem__(self, key, value):
- item = {**self._composite_key(key), 'value': self.serialize(value)}
+ item = {'key': key, 'value': self.serialize(value)}
# If enabled, set TTL value as a timestamp in unix format
if self.ttl and getattr(value, 'expires_unix', None):
@@ -159,28 +132,42 @@ class DynamoDbDict(BaseStorage):
self._table.put_item(Item=item)
def __delitem__(self, key):
- response = self._table.delete_item(Key=self._composite_key(key), ReturnValues='ALL_OLD')
+ response = self._table.delete_item(Key={'key': key}, ReturnValues='ALL_OLD')
if 'Attributes' not in response:
raise KeyError
def __iter__(self):
- response = self._scan()
- for item in response['Items']:
+ # Alias 'key' attribute since it's a reserved keyword
+ results = self._table.scan(
+ ProjectionExpression='#k',
+ ExpressionAttributeNames={'#k': 'key'},
+ )
+ for item in results['Items']:
yield item['key']
def __len__(self):
- return self._table.query(
- Select='COUNT',
- ExpressionAttributeNames={'#N': 'namespace'},
- ExpressionAttributeValues={':Namespace': self.namespace},
- KeyConditionExpression='#N = :Namespace',
- )['Count']
+ """Get the number of items in the table.
+
+ **Note:** This is an estimate, and is updated every 6 hours. A full table scan will use up
+ your provisioned throughput, so it's not recommended.
+ """
+ return self._table.item_count
def bulk_delete(self, keys: Iterable[str]):
"""Delete multiple keys from the cache. Does not raise errors for missing keys."""
with self._table.batch_writer() as batch:
for key in keys:
- batch.delete_item(Key=self._composite_key(key))
+ batch.delete_item(Key={'key': key})
def clear(self):
self.bulk_delete((k for k in self))
+
+ def deserialize(self, key, value: VT):
+ """Handle Binary objects from a custom serializer"""
+ serialized_value = value.value if isinstance(value, Binary) else value
+ return super().deserialize(key, serialized_value)
+
+ # TODO: Support pagination
+ def values(self):
+ for item in self._table.scan()['Items']:
+ yield self.deserialize(item['key'], item['value'])
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index 07671e7..76b33be 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -50,7 +50,7 @@ class MongoCache(BaseCache):
db_name,
collection_name='redirects',
connection=self.responses.connection,
- serialzier=None,
+ serializer=None,
**kwargs,
)
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index a96ffd9..ce881e8 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -135,8 +135,8 @@ class BaseStorageTest:
cache_2 = self.init_cache(connection=getattr(cache_1, 'connection', None))
for i in range(5):
- cache_1[i] = f'value_{i}'
- cache_2[i] = f'value_{i}'
+ cache_1[f'key_{i}'] = f'value_{i}'
+ cache_2[f'key_{i}'] = f'value_{i}'
assert len(cache_1) == len(cache_2) == 5
cache_1.clear()
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index 42e22eb..af54e67 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest
from requests_cache.backends import DynamoDbCache, DynamoDbDict
-from tests.conftest import fail_if_no_connection
+from tests.conftest import CACHE_NAME, fail_if_no_connection
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import BaseStorageTest
@@ -31,10 +31,18 @@ class TestDynamoDbDict(BaseStorageTest):
storage_class = DynamoDbDict
init_kwargs = AWS_OPTIONS
+ def init_cache(self, cache_name=CACHE_NAME, index=0, clear=True, **kwargs):
+ """For tests that use multiple tables, make index part of the table name"""
+ kwargs = {**self.init_kwargs, **kwargs}
+ cache = self.storage_class(f'{cache_name}_{index}', **kwargs)
+ if clear:
+ cache.clear()
+ return cache
+
@patch('requests_cache.backends.dynamodb.boto3.resource')
def test_connection_kwargs(self, mock_resource):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
- DynamoDbDict('test_table', 'namespace', region_name='us-east-2', invalid_kwarg='???')
+ DynamoDbDict('test_table', region_name='us-east-2', invalid_kwarg='???')
mock_resource.assert_called_with('dynamodb', region_name='us-east-2')
def test_create_table_error(self):
@@ -69,7 +77,7 @@ class TestDynamoDbDict(BaseStorageTest):
# 'ttl' is a reserved word, so to retrieve it we need to alias it
item = cache._table.get_item(
- Key=cache._composite_key('key'),
+ Key={'key': 'key'},
ProjectionExpression='#t',
ExpressionAttributeNames={'#t': 'ttl'},
)