1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
|
from datetime import datetime, timedelta, timezone
from logging import getLogger
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr
from attr import define, field
from requests import PreparedRequest, Response
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict
from urllib3.response import HTTPHeaderDict
from ..cache_control import ExpirationTime, get_expiration_datetime
from . import CachedHTTPResponse, CachedRequest
DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S %Z' # Format used for __str__ only
HeaderList = List[Tuple[str, str]]
logger = getLogger(__name__)
@define(auto_attribs=False, slots=False)
class CachedResponse(Response):
"""A class that emulates :py:class:`requests.Response`, with some additional optimizations
for serialization.
"""
_content: bytes = field(default=None)
_next: Optional[CachedRequest] = field(default=None)
cache_key: Optional[str] = None # Not serialized; set by BaseCache.get_response()
cookies: RequestsCookieJar = field(factory=RequestsCookieJar)
created_at: datetime = field(factory=datetime.utcnow)
elapsed: timedelta = field(factory=timedelta)
encoding: str = field(default=None)
expires: Optional[datetime] = field(default=None)
headers: CaseInsensitiveDict = field(factory=CaseInsensitiveDict)
history: List['CachedResponse'] = field(factory=list) # type: ignore
raw: CachedHTTPResponse = field(factory=CachedHTTPResponse, repr=False)
reason: str = field(default=None)
request: CachedRequest = field(factory=CachedRequest) # type: ignore
status_code: int = field(default=0)
url: str = field(default=None)
def __attrs_post_init__(self):
"""Re-initialize raw response body after deserialization"""
if self.raw._body is None and self._content is not None:
self.raw.reset(self._content)
if not self.raw.headers:
self.raw.headers = HTTPHeaderDict(self.headers)
@classmethod
def from_response(
cls, original_response: Union[Response, 'CachedResponse'], expires: datetime = None, **kwargs
):
"""Create a CachedResponse based on an original Response or another CachedResponse object"""
if isinstance(original_response, CachedResponse):
return attr.evolve(original_response, expires=expires)
obj = cls(expires=expires, **kwargs)
# Copy basic attributes
for k in Response.__attrs__:
setattr(obj, k, getattr(original_response, k, None))
# Store request, raw response, and next response (if it's a redirect response)
obj.request = CachedRequest.from_request(original_response.request)
obj.raw = CachedHTTPResponse.from_response(original_response)
obj._next = (
CachedRequest.from_request(original_response.next) if original_response.next else None
)
# Store response body, which will have been read & decoded by requests.Response by now
obj._content = original_response.content
# Copy redirect history, if any; avoid recursion by not copying redirects of redirects
obj.history = []
if not obj.is_redirect:
for redirect in original_response.history:
obj.history.append(cls.from_response(redirect))
return obj
@property
def _content_consumed(self) -> bool:
"""For compatibility with requests.Response; will always be True for a cached response"""
return True
@_content_consumed.setter
def _content_consumed(self, value: bool):
pass
@property
def from_cache(self) -> bool:
return True
@property
def is_expired(self) -> bool:
"""Determine if this cached response is expired"""
return self.expires is not None and datetime.utcnow() >= self.expires
@property
def ttl(self) -> Optional[int]:
"""Get time to expiration in seconds"""
if self.expires is None or self.is_expired:
return None
delta = self.expires - datetime.utcnow()
return int(delta.total_seconds())
@property
def next(self) -> Optional[PreparedRequest]:
"""Returns a PreparedRequest for the next request in a redirect chain, if there is one."""
return self._next.prepare() if self._next else None
def revalidate(self, expire_after: ExpirationTime) -> bool:
"""Set a new expiration for this response, and determine if it is now expired"""
self.expires = get_expiration_datetime(expire_after)
return self.is_expired
@property
def size(self) -> int:
"""Get the size of the response body in bytes"""
return len(self.content) if self.content else 0
def __getstate__(self):
"""Override pickling behavior from ``requests.Response.__getstate__``"""
return self.__dict__
def __setstate__(self, state):
"""Override pickling behavior from ``requests.Response.__setstate__``"""
for name, value in state.items():
setattr(self, name, value)
def __str__(self):
return (
f'request: {self.request}, response: {self.status_code} '
f'({format_file_size(self.size)}), created: {format_datetime(self.created_at)}, '
f'expires: {format_datetime(self.expires)} ({"stale" if self.is_expired else "fresh"})'
)
def format_datetime(value: Optional[datetime]) -> str:
"""Get a formatted datetime string in the local time zone"""
if not value:
return "N/A"
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone().strftime(DATETIME_FORMAT)
def format_file_size(n_bytes: int) -> str:
"""Convert a file size in bytes into a human-readable format"""
filesize = float(n_bytes or 0)
def _format(unit):
return f'{int(filesize)} {unit}' if unit == 'bytes' else f'{filesize:.2f} {unit}'
for unit in ['bytes', 'KiB', 'MiB', 'GiB']:
if filesize < 1024 or unit == 'GiB':
return _format(unit)
filesize /= 1024
if TYPE_CHECKING:
return _format(unit)
def set_response_defaults(
response: Union[Response, CachedResponse], cache_key: str = None
) -> Union[Response, CachedResponse]:
"""Set some default CachedResponse values on a requests.Response object, so they can be
expected to always be present
"""
if not isinstance(response, CachedResponse):
response.cache_key = cache_key # type: ignore
response.created_at = None # type: ignore
response.expires = None # type: ignore
response.from_cache = False # type: ignore
response.is_expired = False # type: ignore
return response
|