summaryrefslogtreecommitdiff
path: root/tests/integration/test_mongodb.py
blob: 1dfae8746b3e66c7a5610ceab667eff1886ae65f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from logging import getLogger
from time import sleep
from unittest.mock import patch

import pytest
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

from requests_cache.backends import GridFSCache, GridFSDict, MongoCache, MongoDict
from requests_cache.policy import NEVER_EXPIRE
from tests.conftest import N_ITERATIONS, fail_if_no_connection, httpbin
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import BaseStorageTest

try:
    from pymongo.errors import ServerSelectionTimeoutError
except ImportError:
    pass

logger = getLogger(__name__)


@pytest.fixture(scope='module', autouse=True)
@fail_if_no_connection(connect_timeout=2)
def ensure_connection():
    """Fail all tests in this module if MongoDB is not running"""
    from pymongo import MongoClient

    client = MongoClient(serverSelectionTimeoutMS=2000)
    client.server_info()


class TestMongoDict(BaseStorageTest):
    storage_class = MongoDict

    def test_connection_kwargs(self):
        """A spot check to make sure optional connection kwargs gets passed to connection"""
        # MongoClient prevents direct access to private members like __init_kwargs;
        # need to test indirectly using its repr
        cache = MongoDict(
            'test',
            host='mongodb://0.0.0.0',
            port=2222,
            tz_aware=True,
            connect=False,
            invalid_kwarg='???',
        )
        assert "host=['0.0.0.0:2222']" in repr(cache.connection)
        assert "tz_aware=True" in repr(cache.connection)


class TestMongoCache(BaseCacheTest):
    backend_class = MongoCache

    def test_ttl(self):
        session = self.init_session()
        session.cache.set_ttl(1)

        session.get(httpbin('get'))
        response = session.get(httpbin('get'))
        assert response.from_cache is True

        # Wait for removal background process to run
        # Unfortunately there doesn't seem to be a way to manually trigger it
        for i in range(70):
            if response.cache_key not in session.cache.responses:
                logger.debug(f'Removed {response.cache_key} after {i} seconds')
                break
            sleep(1)

        assert response.cache_key not in session.cache.responses

    def test_ttl__overwrite(self):
        session = self.init_session()
        session.cache.set_ttl(60)

        # Should have no effect
        session.cache.set_ttl(360)
        assert session.cache.get_ttl() == 60

        # Should create new index
        session.cache.set_ttl(360, overwrite=True)
        assert session.cache.get_ttl() == 360

        # Should drop index
        session.cache.set_ttl(None, overwrite=True)
        assert session.cache.get_ttl() is None

        # Should attempt to drop non-existent index and ignore error
        session.cache.set_ttl(NEVER_EXPIRE, overwrite=True)
        assert session.cache.get_ttl() is None

    @retry(
        retry=retry_if_exception_type(ServerSelectionTimeoutError),
        reraise=True,
        stop=stop_after_attempt(5),
        wait=wait_fixed(5),
    )
    @pytest.mark.parametrize('executor_class', [ThreadPoolExecutor, ProcessPoolExecutor])
    @pytest.mark.parametrize('iteration', range(N_ITERATIONS))
    def test_concurrency(self, iteration, executor_class):
        """On GitHub runners, sometimes the MongoDB container is not ready yet by the time this,
        runs, so some retries are added here.
        """
        super().test_concurrency(iteration, executor_class)


class TestGridFSDict(BaseStorageTest):
    storage_class = GridFSDict
    picklable = True
    num_instances = 1  # Only test a single collecton instead of multiple

    def test_connection_kwargs(self):
        """A spot check to make sure optional connection kwargs gets passed to connection"""
        cache = GridFSDict(
            'test',
            host='mongodb://0.0.0.0',
            port=2222,
            tz_aware=True,
            connect=False,
            invalid_kwarg='???',
        )
        assert "host=['0.0.0.0:2222']" in repr(cache.connection)
        assert "tz_aware=True" in repr(cache.connection)

    def test_corrupt_file(self):
        """A corrupted file should be handled and raise a KeyError instead"""
        from gridfs import GridFS
        from gridfs.errors import CorruptGridFile

        cache = self.init_cache()
        cache['key'] = 'value'
        with pytest.raises(KeyError), patch.object(GridFS, 'find_one', side_effect=CorruptGridFile):
            cache['key']

    def test_file_exists(self):
        from gridfs import GridFS
        from gridfs.errors import FileExists

        cache = self.init_cache()

        # This write should just quiety fail
        with patch.object(GridFS, 'put', side_effect=FileExists):
            cache['key'] = 'value_1'

        assert 'key' not in cache


class TestGridFSCache(BaseCacheTest):
    backend_class = GridFSCache