summaryrefslogtreecommitdiff
path: root/requests_cache/backends/mongodb.py
blob: 07671e7821af58ebf07956e6a69f302a156528c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""MongoDB cache backend. For usage details, see :ref:`Backends: MongoDB <mongodb>`.

.. automodsumm:: requests_cache.backends.mongodb
   :classes-only:
   :nosignatures:
"""
from datetime import timedelta
from logging import getLogger
from typing import Iterable, Mapping, Optional, Union

from pymongo import MongoClient
from pymongo.errors import OperationFailure

from .._utils import get_valid_kwargs
from ..policy.expiration import NEVER_EXPIRE, get_expiration_seconds
from ..serializers import SerializerType, bson_document_serializer
from . import BaseCache, BaseStorage

logger = getLogger(__name__)


class MongoCache(BaseCache):
    """MongoDB cache backend.
    By default, responses are only partially serialized into a MongoDB-compatible document format.

    Args:
        db_name: Database name
        connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one
        kwargs: Additional keyword arguments for :py:class:`pymongo.mongo_client.MongoClient`
    """

    def __init__(
        self,
        db_name: str = 'http_cache',
        connection: MongoClient = None,
        decode_content: bool = True,
        serializer: Optional[SerializerType] = None,
        **kwargs,
    ):
        super().__init__(cache_name=db_name, **kwargs)
        skwargs = {'serializer': serializer, **kwargs} if serializer else kwargs
        self.responses: MongoDict = MongoDict(
            db_name,
            collection_name='responses',
            connection=connection,
            decode_content=decode_content,
            **skwargs,
        )
        self.redirects: MongoDict = MongoDict(
            db_name,
            collection_name='redirects',
            connection=self.responses.connection,
            serialzier=None,
            **kwargs,
        )

    def get_ttl(self) -> Optional[int]:
        """Get the currently defined TTL value in seconds, if any"""
        return self.responses.get_ttl()

    def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False):
        """Create or update a TTL index. Notes:

        * This will have no effect if TTL is already set
        * To overwrite an existing TTL index, use ``overwrite=True``
        * This may take some time to complete, depending on the size of your cache
        * Use ``ttl=None, overwrite=True`` to remove the TTL index
        """
        self.responses.set_ttl(ttl, overwrite=overwrite)


class MongoDict(BaseStorage):
    """A dictionary-like interface for a MongoDB collection

    Args:
        db_name: Database name
        collection_name: Collection name
        connection: :py:class:`pymongo.MongoClient` object to reuse instead of creating a new one
        kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
    """

    def __init__(
        self,
        db_name: str,
        collection_name: str = 'http_cache',
        connection: Optional[MongoClient] = None,
        serializer: Optional[SerializerType] = bson_document_serializer,
        **kwargs,
    ):
        super().__init__(serializer=serializer, **kwargs)
        connection_kwargs = get_valid_kwargs(MongoClient.__init__, kwargs)
        self.connection = connection or MongoClient(**connection_kwargs)
        self.collection = self.connection[db_name][collection_name]

    def get_ttl(self) -> Optional[int]:
        """Get the currently defined TTL value in seconds, if any"""
        idx_info = self.collection.index_information().get('ttl_idx', {})
        return idx_info.get('expireAfterSeconds')

    def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False):
        """Create or update a TTL index, and ignore and log any errors due to dropping a nonexistent
        index or attempting to overwrite without ```overwrite=True``.
        """
        try:
            self._set_ttl(get_expiration_seconds(ttl), overwrite=overwrite)
        except OperationFailure:
            logger.warning('Failed to update TTL index', exc_info=True)

    def _set_ttl(self, ttl: int, overwrite: bool = False):
        if overwrite:
            self.collection.drop_index('ttl_idx')
            logger.info('Dropped TTL index')

        if ttl and ttl != NEVER_EXPIRE:
            logger.info(f'Creating TTL index for {ttl} seconds')
            self.collection.create_index('created_at', name='ttl_idx', expireAfterSeconds=ttl)

    def __getitem__(self, key):
        result = self.collection.find_one({'_id': key})
        if result is None:
            raise KeyError
        value = result['data'] if 'data' in result else result
        return self.deserialize(key, value)

    def __setitem__(self, key, value):
        """If ``value`` is already a dict, its values will be stored under top-level keys.
        Otherwise, it will be stored under a 'data' key.
        """
        value = self.serialize(value)
        if not isinstance(value, Mapping):
            value = {'data': value}
        self.collection.replace_one({'_id': key}, value, upsert=True)

    def __delitem__(self, key):
        result = self.collection.find_one_and_delete({'_id': key}, {'_id': True})
        if result is None:
            raise KeyError

    def __len__(self):
        return self.collection.estimated_document_count()

    def __iter__(self):
        for d in self.collection.find({}, {'_id': True}):
            yield d['_id']

    def bulk_delete(self, keys: Iterable[str]):
        """Delete multiple keys from the cache. Does not raise errors for missing keys."""
        self.collection.delete_many({'_id': {'$in': list(keys)}})

    def clear(self):
        self.collection.drop()

    def close(self):
        self.connection.close()