diff options
author | Roman Haritonov <reclosedev@gmail.com> | 2014-10-13 22:09:11 +0400 |
---|---|---|
committer | Roman Haritonov <reclosedev@gmail.com> | 2014-10-13 22:09:11 +0400 |
commit | bdd43c43351a3d1cb16483c3a4fb5589b33a2b03 (patch) | |
tree | 3e789052a9c814e64d599b57d8c65af3443538b0 | |
parent | 2eef606d4c9eadaa1020e0fe09db06fb8d2a1ba7 (diff) | |
download | requests-cache-bdd43c43351a3d1cb16483c3a4fb5589b33a2b03.tar.gz |
Normalize (sort) parameters passed as builtin dict #29
-rw-r--r-- | requests_cache/core.py | 25 | ||||
-rw-r--r-- | tests/test_cache.py | 29 |
2 files changed, 49 insertions, 5 deletions
diff --git a/requests_cache/core.py b/requests_cache/core.py index 477d16d..1c3ee6c 100644 --- a/requests_cache/core.py +++ b/requests_cache/core.py @@ -8,6 +8,7 @@ """ from contextlib import contextmanager from datetime import datetime, timedelta +from operator import itemgetter import requests from requests import Session as OriginalSession @@ -103,11 +104,15 @@ class CachedSession(OriginalSession): cookies=None, files=None, auth=None, timeout=None, allow_redirects=True, proxies=None, hooks=None, stream=None, verify=None, cert=None): - response = super(CachedSession, self).request(method, url, params, data, - headers, cookies, files, - auth, timeout, - allow_redirects, proxies, - hooks, stream, verify, cert) + response = super(CachedSession, self).request( + method, url, + _normalize_parameters(params), + _normalize_parameters(data), + headers, cookies, files, + auth, timeout, + allow_redirects, proxies, + hooks, stream, verify, cert + ) if self._is_cache_disabled: return response @@ -230,3 +235,13 @@ def clear(): def _patch_session_factory(session_factory=CachedSession): requests.Session = requests.sessions.Session = session_factory + + +def _normalize_parameters(params): + """ If builtin dict is passed as parameter, returns sorted list + of key-value pairs + """ + + if type(params) is dict: + return sorted(params.items(), key=itemgetter(0)) + return params diff --git a/tests/test_cache.py b/tests/test_cache.py index 6ddd2b0..68fd0f5 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -253,6 +253,35 @@ class CacheTestCase(unittest.TestCase): r = self.s.get(httpbin("get")) r.close() + def test_get_parameters_normalization(self): + url = httpbin("get") + params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} + + self.assertFalse(self.s.get(url, params=params).from_cache) + r = self.s.get(url, params=params) + self.assertTrue(r.from_cache) + self.assertEquals(r.json()["args"], params) + self.assertFalse(self.s.get(url, params={"a": "b"}).from_cache) + self.assertTrue(self.s.get(url, params=sorted(params.items())).from_cache) + + class UserSubclass(dict): + def items(self): + return sorted(super(UserSubclass, self).items(), reverse=True) + + custom_dict = UserSubclass(params) + self.assertFalse(self.s.get(url, params=custom_dict).from_cache) + self.assertTrue(self.s.get(url, params=custom_dict).from_cache) + + def test_post_parameters_normalization(self): + params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} + url = httpbin("post") + s = CachedSession(CACHE_NAME, CACHE_BACKEND, + allowable_methods=('GET', 'POST')) + self.assertFalse(s.post(url, data=params).from_cache) + self.assertTrue(s.post(url, data=params).from_cache) + self.assertTrue(s.post(url, data=sorted(params.items())).from_cache) + self.assertFalse(s.post(url, data=sorted(params.items(), reverse=True)).from_cache) + if __name__ == '__main__': unittest.main() |