summaryrefslogtreecommitdiff
path: root/tests/unit/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/utils.py')
-rw-r--r--tests/unit/utils.py82
1 files changed, 56 insertions, 26 deletions
diff --git a/tests/unit/utils.py b/tests/unit/utils.py
index 9d8aacc..201a8a8 100644
--- a/tests/unit/utils.py
+++ b/tests/unit/utils.py
@@ -216,6 +216,7 @@ class MockHttpTest(testtools.TestCase):
storage_url = kwargs.get('storage_url')
auth_token = kwargs.get('auth_token')
exc = kwargs.get('exc')
+ on_request = kwargs.get('on_request')
def wrapper(url, proxy=None, cacert=None, insecure=False,
ssl_compression=True):
@@ -245,6 +246,9 @@ class MockHttpTest(testtools.TestCase):
conn.resp.has_been_read = True
return _orig_read(*args, **kwargs)
conn.resp.read = read
+ if on_request:
+ status = on_request(method, url, *args, **kwargs)
+ conn.resp.status = status
if auth_token:
headers = args[1]
self.assertTrue('X-Auth-Token' in headers)
@@ -258,7 +262,12 @@ class MockHttpTest(testtools.TestCase):
if exc:
raise exc
return conn.resp
+
+ def putrequest(path, data=None, headers=None, **kwargs):
+ request('PUT', path, data, headers, **kwargs)
+
conn.request = request
+ conn.putrequest = putrequest
def getresponse():
return conn.resp
@@ -288,6 +297,34 @@ class MockHttpTest(testtools.TestCase):
orig_assertEqual = unittest.TestCase.assertEqual
+ def assert_request_equal(self, expected, real_request):
+ method, path = expected[:2]
+ if urlparse(path).scheme:
+ match_path = real_request['full_path']
+ else:
+ match_path = real_request['path']
+ self.assertEqual((method, path), (real_request['method'],
+ match_path))
+ if len(expected) > 2:
+ body = expected[2]
+ real_request['expected'] = body
+ err_msg = 'Body mismatch for %(method)s %(path)s, ' \
+ 'expected %(expected)r, and got %(body)r' % real_request
+ self.orig_assertEqual(body, real_request['body'], err_msg)
+
+ if len(expected) > 3:
+ headers = expected[3]
+ for key, value in headers.items():
+ real_request['key'] = key
+ real_request['expected_value'] = value
+ real_request['value'] = real_request['headers'].get(key)
+ err_msg = (
+ 'Header mismatch on %(key)r, '
+ 'expected %(expected_value)r and got %(value)r '
+ 'for %(method)s %(path)s %(headers)r' % real_request)
+ self.orig_assertEqual(value, real_request['value'],
+ err_msg)
+
def assertRequests(self, expected_requests):
"""
Make sure some requests were made like you expected, provide a list of
@@ -295,33 +332,26 @@ class MockHttpTest(testtools.TestCase):
"""
real_requests = self.iter_request_log()
for expected in expected_requests:
- method, path = expected[:2]
real_request = next(real_requests)
- if urlparse(path).scheme:
- match_path = real_request['full_path']
- else:
- match_path = real_request['path']
- self.assertEqual((method, path), (real_request['method'],
- match_path))
- if len(expected) > 2:
- body = expected[2]
- real_request['expected'] = body
- err_msg = 'Body mismatch for %(method)s %(path)s, ' \
- 'expected %(expected)r, and got %(body)r' % real_request
- self.orig_assertEqual(body, real_request['body'], err_msg)
-
- if len(expected) > 3:
- headers = expected[3]
- for key, value in headers.items():
- real_request['key'] = key
- real_request['expected_value'] = value
- real_request['value'] = real_request['headers'].get(key)
- err_msg = (
- 'Header mismatch on %(key)r, '
- 'expected %(expected_value)r and got %(value)r '
- 'for %(method)s %(path)s %(headers)r' % real_request)
- self.orig_assertEqual(value, real_request['value'],
- err_msg)
+ self.assert_request_equal(expected, real_request)
+
+ def assert_request(self, expected_request):
+ """
+ Make sure a request was made as expected. Provide the
+ expected request in the form of [(method, path), ...]
+ """
+ real_requests = self.iter_request_log()
+ for real_request in real_requests:
+ try:
+ self.assert_request_equal(expected_request, real_request)
+ break
+ except AssertionError:
+ pass
+ else:
+ raise AssertionError(
+ "Expected request %s not found in actual requests %s"
+ % (expected_request, self.request_log)
+ )
def validateMockedRequestsConsumed(self):
if not self.fake_connect: