summaryrefslogtreecommitdiff
path: root/requests_cache/patcher.py
blob: aa5af76cc50d6e6cfe7446ea1f062d9763878425 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Utilities for patching ``requests``. See :ref:`patching` for general usage info.

.. warning:: These functions are not thread-safe. Use :py:class:`.CachedSession` if you want to use
    caching in a multi-threaded environment.

.. automodsumm:: requests_cache.patcher
   :functions-only:
   :nosignatures:
"""
from contextlib import contextmanager
from logging import getLogger
from typing import Optional, Type
from warnings import warn

import requests

from .backends import BackendSpecifier, BaseCache, init_backend
from .session import CachedSession, OriginalSession

logger = getLogger(__name__)


def install_cache(
    cache_name: str = 'http_cache',
    backend: Optional[BackendSpecifier] = None,
    session_factory: Type[OriginalSession] = CachedSession,
    **kwargs,
):
    """
    Install the cache for all ``requests`` functions by monkey-patching :py:class:`requests.Session`

    Example:

        >>> requests_cache.install_cache('demo_cache')

    Accepts all the same parameters as :py:class:`.CachedSession`. Additional parameters:

    Args:
        session_factory: Session class to use. It must inherit from either
            :py:class:`.CachedSession` or :py:class:`.CacheMixin`
    """
    backend = init_backend(cache_name, backend, **kwargs)

    class _ConfiguredCachedSession(session_factory):  # type: ignore  # See mypy issue #5865
        def __init__(self):
            super().__init__(cache_name=cache_name, backend=backend, **kwargs)

    _patch_session_factory(_ConfiguredCachedSession)


def uninstall_cache():
    """Disable the cache by restoring the original :py:class:`requests.Session`"""
    _patch_session_factory(OriginalSession)


@contextmanager
def disabled():
    """
    Context manager for temporarily disabling caching for all ``requests`` functions

    Example:

        >>> with requests_cache.disabled():
        ...     requests.get('https://httpbin.org/get')

    """
    previous = requests.Session
    uninstall_cache()
    try:
        yield
    finally:
        _patch_session_factory(previous)


@contextmanager
def enabled(*args, **kwargs):
    """
    Context manager for temporarily enabling caching for all ``requests`` functions

    Example:

        >>> with requests_cache.enabled('cache.db'):
        ...     requests.get('https://httpbin.org/get')

    Accepts the same arguments as :py:class:`.CachedSession` and :py:func:`.install_cache`.
    """
    install_cache(*args, **kwargs)
    try:
        yield
    finally:
        uninstall_cache()


def get_cache() -> Optional[BaseCache]:
    """Get the internal cache object from the currently installed ``CachedSession`` (if any)"""
    return getattr(requests.Session(), 'cache', None)


def is_installed() -> bool:
    """Indicate whether or not requests-cache is currently installed"""
    return isinstance(requests.Session(), CachedSession)


def clear():
    """Clear the currently installed cache (if any)"""
    if get_cache():
        get_cache().clear()


def delete(*args, **kwargs):
    """Remove responses from the cache according one or more conditions.
    See :py:meth:`.BaseCache.delete for usage details.
    """
    session = requests.Session()
    if isinstance(session, CachedSession):
        session.cache.delete(*args, **kwargs)


def remove_expired_responses():
    """Remove expired responses from the cache"""
    warn(
        'remove_expired_responses() is deprecated; please use delete() instead',
        DeprecationWarning,
        stacklevel=2,
    )
    delete(expired=True)


def _patch_session_factory(session_factory: Type[OriginalSession] = CachedSession):
    logger.debug(f'Patching requests.Session with class: {session_factory.__name__}')
    requests.Session = requests.sessions.Session = session_factory  # type: ignore