summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Haritonov <reclosedev@gmail.com>2014-10-13 22:09:11 +0400
committerRoman Haritonov <reclosedev@gmail.com>2014-10-13 22:09:11 +0400
commitbdd43c43351a3d1cb16483c3a4fb5589b33a2b03 (patch)
tree3e789052a9c814e64d599b57d8c65af3443538b0
parent2eef606d4c9eadaa1020e0fe09db06fb8d2a1ba7 (diff)
downloadrequests-cache-bdd43c43351a3d1cb16483c3a4fb5589b33a2b03.tar.gz
Normalize (sort) parameters passed as builtin dict #29
-rw-r--r--requests_cache/core.py25
-rw-r--r--tests/test_cache.py29
2 files changed, 49 insertions, 5 deletions
diff --git a/requests_cache/core.py b/requests_cache/core.py
index 477d16d..1c3ee6c 100644
--- a/requests_cache/core.py
+++ b/requests_cache/core.py
@@ -8,6 +8,7 @@
"""
from contextlib import contextmanager
from datetime import datetime, timedelta
+from operator import itemgetter
import requests
from requests import Session as OriginalSession
@@ -103,11 +104,15 @@ class CachedSession(OriginalSession):
cookies=None, files=None, auth=None, timeout=None,
allow_redirects=True, proxies=None, hooks=None, stream=None,
verify=None, cert=None):
- response = super(CachedSession, self).request(method, url, params, data,
- headers, cookies, files,
- auth, timeout,
- allow_redirects, proxies,
- hooks, stream, verify, cert)
+ response = super(CachedSession, self).request(
+ method, url,
+ _normalize_parameters(params),
+ _normalize_parameters(data),
+ headers, cookies, files,
+ auth, timeout,
+ allow_redirects, proxies,
+ hooks, stream, verify, cert
+ )
if self._is_cache_disabled:
return response
@@ -230,3 +235,13 @@ def clear():
def _patch_session_factory(session_factory=CachedSession):
requests.Session = requests.sessions.Session = session_factory
+
+
+def _normalize_parameters(params):
+ """ If builtin dict is passed as parameter, returns sorted list
+ of key-value pairs
+ """
+
+ if type(params) is dict:
+ return sorted(params.items(), key=itemgetter(0))
+ return params
diff --git a/tests/test_cache.py b/tests/test_cache.py
index 6ddd2b0..68fd0f5 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -253,6 +253,35 @@ class CacheTestCase(unittest.TestCase):
r = self.s.get(httpbin("get"))
r.close()
+ def test_get_parameters_normalization(self):
+ url = httpbin("get")
+ params = {"a": "a", "b": ["1", "2", "3"], "c": "4"}
+
+ self.assertFalse(self.s.get(url, params=params).from_cache)
+ r = self.s.get(url, params=params)
+ self.assertTrue(r.from_cache)
+ self.assertEquals(r.json()["args"], params)
+ self.assertFalse(self.s.get(url, params={"a": "b"}).from_cache)
+ self.assertTrue(self.s.get(url, params=sorted(params.items())).from_cache)
+
+ class UserSubclass(dict):
+ def items(self):
+ return sorted(super(UserSubclass, self).items(), reverse=True)
+
+ custom_dict = UserSubclass(params)
+ self.assertFalse(self.s.get(url, params=custom_dict).from_cache)
+ self.assertTrue(self.s.get(url, params=custom_dict).from_cache)
+
+ def test_post_parameters_normalization(self):
+ params = {"a": "a", "b": ["1", "2", "3"], "c": "4"}
+ url = httpbin("post")
+ s = CachedSession(CACHE_NAME, CACHE_BACKEND,
+ allowable_methods=('GET', 'POST'))
+ self.assertFalse(s.post(url, data=params).from_cache)
+ self.assertTrue(s.post(url, data=params).from_cache)
+ self.assertTrue(s.post(url, data=sorted(params.items())).from_cache)
+ self.assertFalse(s.post(url, data=sorted(params.items(), reverse=True)).from_cache)
+
if __name__ == '__main__':
unittest.main()