diff options
-rw-r--r-- | swiftclient/multithreading.py | 10 | ||||
-rw-r--r-- | swiftclient/service.py | 122 | ||||
-rwxr-xr-x | swiftclient/shell.py | 37 | ||||
-rw-r--r-- | tests/unit/test_service.py | 4 | ||||
-rw-r--r-- | tests/unit/test_shell.py | 352 | ||||
-rw-r--r-- | tests/unit/utils.py | 82 |
6 files changed, 490 insertions, 117 deletions
diff --git a/swiftclient/multithreading.py b/swiftclient/multithreading.py index d8eb5f6..7ae82fa 100644 --- a/swiftclient/multithreading.py +++ b/swiftclient/multithreading.py @@ -107,10 +107,16 @@ class OutputManager(object): item = item.encode('utf8') print(item, file=stream) - def _print_error(self, item): - self.error_count += 1 + def _print_error(self, item, count=1): + self.error_count += count return self._print(item, stream=self.error_stream) + def warning(self, msg, *fmt_args): + # print to error stream but do not increment error count + if fmt_args: + msg = msg % fmt_args + self.error_print_pool.submit(self._print_error, msg, count=0) + class MultiThreadingManager(object): """ diff --git a/swiftclient/service.py b/swiftclient/service.py index b4ed675..d7a5795 100644 --- a/swiftclient/service.py +++ b/swiftclient/service.py @@ -109,41 +109,45 @@ def process_options(options): 'region_name': options['os_region_name'], } -_default_global_options = { - "snet": False, - "verbose": 1, - "debug": False, - "info": False, - "auth": environ.get('ST_AUTH'), - "auth_version": environ.get('ST_AUTH_VERSION', '1.0'), - "user": environ.get('ST_USER'), - "key": environ.get('ST_KEY'), - "retries": 5, - "os_username": environ.get('OS_USERNAME'), - "os_user_id": environ.get('OS_USER_ID'), - "os_user_domain_name": environ.get('OS_USER_DOMAIN_NAME'), - "os_user_domain_id": environ.get('OS_USER_DOMAIN_ID'), - "os_password": environ.get('OS_PASSWORD'), - "os_tenant_id": environ.get('OS_TENANT_ID'), - "os_tenant_name": environ.get('OS_TENANT_NAME'), - "os_project_name": environ.get('OS_PROJECT_NAME'), - "os_project_id": environ.get('OS_PROJECT_ID'), - "os_project_domain_name": environ.get('OS_PROJECT_DOMAIN_NAME'), - "os_project_domain_id": environ.get('OS_PROJECT_DOMAIN_ID'), - "os_auth_url": environ.get('OS_AUTH_URL'), - "os_auth_token": environ.get('OS_AUTH_TOKEN'), - "os_storage_url": environ.get('OS_STORAGE_URL'), - "os_region_name": environ.get('OS_REGION_NAME'), - "os_service_type": environ.get('OS_SERVICE_TYPE'), - "os_endpoint_type": environ.get('OS_ENDPOINT_TYPE'), - "os_cacert": environ.get('OS_CACERT'), - "insecure": config_true_value(environ.get('SWIFTCLIENT_INSECURE')), - "ssl_compression": False, - 'segment_threads': 10, - 'object_dd_threads': 10, - 'object_uu_threads': 10, - 'container_threads': 10 -} + +def _build_default_global_options(): + return { + "snet": False, + "verbose": 1, + "debug": False, + "info": False, + "auth": environ.get('ST_AUTH'), + "auth_version": environ.get('ST_AUTH_VERSION', '1.0'), + "user": environ.get('ST_USER'), + "key": environ.get('ST_KEY'), + "retries": 5, + "os_username": environ.get('OS_USERNAME'), + "os_user_id": environ.get('OS_USER_ID'), + "os_user_domain_name": environ.get('OS_USER_DOMAIN_NAME'), + "os_user_domain_id": environ.get('OS_USER_DOMAIN_ID'), + "os_password": environ.get('OS_PASSWORD'), + "os_tenant_id": environ.get('OS_TENANT_ID'), + "os_tenant_name": environ.get('OS_TENANT_NAME'), + "os_project_name": environ.get('OS_PROJECT_NAME'), + "os_project_id": environ.get('OS_PROJECT_ID'), + "os_project_domain_name": environ.get('OS_PROJECT_DOMAIN_NAME'), + "os_project_domain_id": environ.get('OS_PROJECT_DOMAIN_ID'), + "os_auth_url": environ.get('OS_AUTH_URL'), + "os_auth_token": environ.get('OS_AUTH_TOKEN'), + "os_storage_url": environ.get('OS_STORAGE_URL'), + "os_region_name": environ.get('OS_REGION_NAME'), + "os_service_type": environ.get('OS_SERVICE_TYPE'), + "os_endpoint_type": environ.get('OS_ENDPOINT_TYPE'), + "os_cacert": environ.get('OS_CACERT'), + "insecure": config_true_value(environ.get('SWIFTCLIENT_INSECURE')), + "ssl_compression": False, + 'segment_threads': 10, + 'object_dd_threads': 10, + 'object_uu_threads': 10, + 'container_threads': 10 + } + +_default_global_options = _build_default_global_options() _default_local_options = { 'sync_to': None, @@ -1177,11 +1181,6 @@ class SwiftService(object): except ValueError: raise SwiftError('Segment size should be an integer value') - # Does the account exist? - account_stat = self.stat(options=options) - if not account_stat["success"]: - raise account_stat["error"] - # Try to create the container, just in case it doesn't exist. If this # fails, it might just be because the user doesn't have container PUT # permissions, so we'll ignore any error. If there's really a problem, @@ -1208,28 +1207,29 @@ class SwiftService(object): seg_container = container + '_segments' if options['segment_container']: seg_container = options['segment_container'] - if not policy_header: - # Since no storage policy was specified on the command line, - # rather than just letting swift pick the default storage - # policy, we'll try to create the segments container with the - # same as the upload container - create_containers = [ - self.thread_manager.container_pool.submit( - self._create_container_job, seg_container, - policy_source=container - ) - ] - else: - create_containers = [ - self.thread_manager.container_pool.submit( - self._create_container_job, seg_container, - headers=policy_header - ) - ] + if seg_container != container: + if not policy_header: + # Since no storage policy was specified on the command + # line, rather than just letting swift pick the default + # storage policy, we'll try to create the segments + # container with the same policy as the upload container + create_containers = [ + self.thread_manager.container_pool.submit( + self._create_container_job, seg_container, + policy_source=container + ) + ] + else: + create_containers = [ + self.thread_manager.container_pool.submit( + self._create_container_job, seg_container, + headers=policy_header + ) + ] - for r in interruptable_as_completed(create_containers): - res = r.result() - yield res + for r in interruptable_as_completed(create_containers): + res = r.result() + yield res # We maintain a results queue here and a separate thread to monitor # the futures because we want to get results back from potential diff --git a/swiftclient/shell.py b/swiftclient/shell.py index a747625..ce779b3 100755 --- a/swiftclient/shell.py +++ b/swiftclient/shell.py @@ -818,16 +818,14 @@ def st_upload(parser, args, output_manager): ) else: error = r['error'] - if isinstance(error, SwiftError): - output_manager.error("%s" % error) - elif isinstance(error, ClientException): - if r['action'] == "create_container": - if 'X-Storage-Policy' in r['headers']: - output_manager.error( - 'Error trying to create container %s with ' - 'Storage Policy %s', container, - r['headers']['X-Storage-Policy'].strip() - ) + if 'action' in r and r['action'] == "create_container": + # it is not an error to be unable to create the + # container so print a warning and carry on + if isinstance(error, ClientException): + if (r['headers'] and + 'X-Storage-Policy' in r['headers']): + msg = ' with Storage Policy %s' % \ + r['headers']['X-Storage-Policy'].strip() else: msg = ' '.join(str(x) for x in ( error.http_status, error.http_reason) @@ -836,20 +834,15 @@ def st_upload(parser, args, output_manager): if msg: msg += ': ' msg += error.http_response_content[:60] - output_manager.error( - 'Error trying to create container %r: %s', - container, msg - ) + msg = ': %s' % msg else: - output_manager.error("%s" % error) + msg = ': %s' % error + output_manager.warning( + 'Warning: failed to create container ' + '%r%s', container, msg + ) else: - if r['action'] == "create_container": - output_manager.error( - 'Error trying to create container %r: %s', - container, error - ) - else: - output_manager.error("%s" % error) + output_manager.error("%s" % error) except SwiftError as e: output_manager.error("%s" % e) diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 9b69a31..0a0af89 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -385,11 +385,13 @@ class TestSwiftError(testtools.TestCase): self.assertEqual(str(se), '5 container:con object:obj segment:seg') -@mock.patch.dict(os.environ, clean_os_environ) class TestServiceUtils(testtools.TestCase): def setUp(self): super(TestServiceUtils, self).setUp() + with mock.patch.dict(swiftclient.service.environ, clean_os_environ): + swiftclient.service._default_global_options = \ + swiftclient.service._build_default_global_options() self.opts = swiftclient.service._default_global_options.copy() def test_process_options_defaults(self): diff --git a/tests/unit/test_shell.py b/tests/unit/test_shell.py index ccf0fe9..34473bf 100644 --- a/tests/unit/test_shell.py +++ b/tests/unit/test_shell.py @@ -12,7 +12,9 @@ # implied. # See the License for the specific language governing permissions and # limitations under the License. +from genericpath import getmtime +import hashlib import mock import os import tempfile @@ -83,6 +85,21 @@ def _make_env(opts, os_opts): return env +def _make_cmd(cmd, opts, os_opts, use_env=False, flags=None, cmd_args=None): + flags = flags or [] + if use_env: + # set up fake environment variables and make a minimal command line + env = _make_env(opts, os_opts) + args = _make_args(cmd, {}, {}, separator='-', flags=flags, + cmd_args=cmd_args) + else: + # set up empty environment and make full command line + env = {} + args = _make_args(cmd, opts, os_opts, separator='-', flags=flags, + cmd_args=cmd_args) + return args, env + + @mock.patch.dict(os.environ, mocked_os_environ) class TestShell(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -439,6 +456,28 @@ class TestShell(unittest.TestCase): response_dict={}) @mock.patch('swiftclient.service.Connection') + def test_upload_segments_to_same_container(self, connection): + # Upload in segments to same container + connection.return_value.head_object.return_value = { + 'content-length': '0'} + connection.return_value.attempts = 0 + argv = ["", "upload", "container", self.tmpfile, "-S", "10", + "-C", "container"] + with open(self.tmpfile, "wb") as fh: + fh.write(b'12345678901234567890') + swiftclient.shell.main(argv) + connection.return_value.put_container.assert_called_once_with( + 'container', {}, response_dict={}) + connection.return_value.put_object.assert_called_with( + 'container', + self.tmpfile.lstrip('/'), + '', + content_length=0, + headers={'x-object-manifest': mock.ANY, + 'x-object-meta-mtime': mock.ANY}, + response_dict={}) + + @mock.patch('swiftclient.service.Connection') def test_delete_account(self, connection): connection.return_value.get_account.side_effect = [ [None, [{'name': 'container'}]], @@ -742,10 +781,11 @@ class TestSubcommandHelp(unittest.TestCase): self.assertEqual(out.strip('\n'), expected) -class TestParsing(unittest.TestCase): - - def setUp(self): - super(TestParsing, self).setUp() +class TestBase(unittest.TestCase): + """ + Provide some common methods to subclasses + """ + def _remove_swift_env_vars(self): self._environ_vars = {} keys = list(os.environ.keys()) for k in keys: @@ -753,9 +793,20 @@ class TestParsing(unittest.TestCase): or k.startswith('OS_')): self._environ_vars[k] = os.environ.pop(k) - def tearDown(self): + def _replace_swift_env_vars(self): os.environ.update(self._environ_vars) + +class TestParsing(TestBase): + + def setUp(self): + super(TestParsing, self).setUp() + self._remove_swift_env_vars() + + def tearDown(self): + self._replace_swift_env_vars() + super(TestParsing, self).tearDown() + def _make_fake_command(self, result): def fake_command(parser, args, thread_manager): result[0], result[1] = swiftclient.shell.parse_args(parser, args) @@ -1389,3 +1440,294 @@ class TestAuth(MockHttpTest): 'x-auth-token': token + '_new', }), ]) + + +class TestCrossAccountObjectAccess(TestBase, MockHttpTest): + """ + Tests to verify use of --os-storage-url will actually + result in the object request being sent despite account + read/write access and container write access being denied. + """ + def setUp(self): + super(TestCrossAccountObjectAccess, self).setUp() + self._remove_swift_env_vars() + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file.file.write(b'01234567890123456789') + temp_file.file.flush() + self.obj = temp_file.name + self.url = 'http://alternate.com:8080/v1' + + # account tests will attempt to access + self.account = 'AUTH_alice' + + # keystone returns endpoint for another account + fake_ks = FakeKeystone(endpoint='http://example.com:8080/v1/AUTH_bob', + token='bob_token') + self.fake_ks_import = _make_fake_import_keystone_client(fake_ks) + + self.cont = 'c1' + self.cont_path = '/v1/%s/%s' % (self.account, self.cont) + self.obj_path = '%s%s' % (self.cont_path, self.obj) + + self.os_opts = {'username': 'bob', + 'password': 'password', + 'project-name': 'proj_bob', + 'auth-url': 'http://example.com:5000/v3', + 'storage-url': '%s/%s' % (self.url, self.account)} + self.opts = {'auth-version': '3'} + + def tearDown(self): + try: + os.remove(self.obj) + except OSError: + pass + self._replace_swift_env_vars() + super(TestCrossAccountObjectAccess, self).tearDown() + + def _make_cmd(self, cmd, cmd_args=None): + return _make_cmd(cmd, self.opts, self.os_opts, cmd_args=cmd_args) + + def _fake_cross_account_auth(self, read_ok, write_ok): + def on_request(method, path, *args, **kwargs): + """ + Modify response code to 200 if cross account permissions match. + """ + status = 403 + if (path.startswith('/v1/%s/%s' % (self.account, self.cont)) + and read_ok and method in ('GET', 'HEAD')): + status = 200 + elif (path.startswith('/v1/%s/%s%s' + % (self.account, self.cont, self.obj)) + and write_ok and method in ('PUT', 'POST', 'DELETE')): + status = 200 + return status + return on_request + + def test_upload_with_read_write_access(self): + req_handler = self._fake_cross_account_auth(True, True) + fake_conn = self.fake_http_connection(403, 403, + on_request=req_handler) + + args, env = self._make_cmd('upload', cmd_args=[self.cont, self.obj, + '--leave-segments']) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + except SystemExit as e: + self.fail('Unexpected SystemExit: %s' % e) + + self.assertRequests([('PUT', self.cont_path), + ('PUT', self.obj_path)]) + self.assertEqual(self.obj, out.strip()) + expected_err = 'Warning: failed to create container %r: 403 Fake' \ + % self.cont + self.assertEqual(expected_err, out.err.strip()) + + def test_upload_with_write_only_access(self): + req_handler = self._fake_cross_account_auth(False, True) + fake_conn = self.fake_http_connection(403, 403, + on_request=req_handler) + + args, env = self._make_cmd('upload', cmd_args=[self.cont, self.obj, + '--leave-segments']) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + except SystemExit as e: + self.fail('Unexpected SystemExit: %s' % e) + + self.assertRequests([('PUT', self.cont_path), + ('PUT', self.obj_path)]) + self.assertEqual(self.obj, out.strip()) + expected_err = 'Warning: failed to create container %r: 403 Fake' \ + % self.cont + self.assertEqual(expected_err, out.err.strip()) + + def test_segment_upload_with_write_only_access(self): + req_handler = self._fake_cross_account_auth(False, True) + fake_conn = self.fake_http_connection(403, 403, 403, 403, + on_request=req_handler) + + args, env = self._make_cmd('upload', + cmd_args=[self.cont, self.obj, + '--leave-segments', + '--segment-size=10', + '--segment-container=%s' + % self.cont]) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + except SystemExit as e: + self.fail('Unexpected SystemExit: %s' % e) + + segment_time = getmtime(self.obj) + segment_path_0 = '%s/%f/20/10/00000000' % (self.obj_path, segment_time) + segment_path_1 = '%s/%f/20/10/00000001' % (self.obj_path, segment_time) + # Note that the order of segment PUTs cannot be asserted, so test for + # existence in request log individually + self.assert_request(('PUT', self.cont_path)) + self.assert_request(('PUT', segment_path_0)) + self.assert_request(('PUT', segment_path_1)) + self.assert_request(('PUT', self.obj_path)) + self.assertTrue(self.obj in out.out) + expected_err = 'Warning: failed to create container %r: 403 Fake' \ + % self.cont + self.assertEqual(expected_err, out.err.strip()) + + def test_upload_with_no_access(self): + fake_conn = self.fake_http_connection(403, 403) + + args, env = self._make_cmd('upload', cmd_args=[self.cont, self.obj, + '--leave-segments']) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + self.fail('Expected SystemExit') + except SystemExit: + pass + + self.assertRequests([('PUT', self.cont_path), + ('PUT', self.obj_path)]) + expected_err = 'Object PUT failed: http://1.2.3.4%s 403 Fake' \ + % self.obj_path + self.assertTrue(expected_err in out.err) + self.assertEqual('', out) + + def test_download_with_read_write_access(self): + req_handler = self._fake_cross_account_auth(True, True) + empty_str_etag = 'd41d8cd98f00b204e9800998ecf8427e' + fake_conn = self.fake_http_connection(403, on_request=req_handler, + etags=[empty_str_etag]) + + args, env = self._make_cmd('download', cmd_args=[self.cont, + self.obj.lstrip('/'), + '--no-download']) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + except SystemExit as e: + self.fail('Unexpected SystemExit: %s' % e) + + self.assertRequests([('GET', self.obj_path)]) + self.assertTrue(out.out.startswith(self.obj.lstrip('/'))) + self.assertEqual('', out.err) + + def test_download_with_read_only_access(self): + req_handler = self._fake_cross_account_auth(True, False) + empty_str_etag = 'd41d8cd98f00b204e9800998ecf8427e' + fake_conn = self.fake_http_connection(403, on_request=req_handler, + etags=[empty_str_etag]) + + args, env = self._make_cmd('download', cmd_args=[self.cont, + self.obj.lstrip('/'), + '--no-download']) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + except SystemExit as e: + self.fail('Unexpected SystemExit: %s' % e) + + self.assertRequests([('GET', self.obj_path)]) + self.assertTrue(out.out.startswith(self.obj.lstrip('/'))) + self.assertEqual('', out.err) + + def test_download_with_no_access(self): + fake_conn = self.fake_http_connection(403) + args, env = self._make_cmd('download', cmd_args=[self.cont, + self.obj.lstrip('/'), + '--no-download']) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + self.fail('Expected SystemExit') + except SystemExit: + pass + + self.assertRequests([('GET', self.obj_path)]) + path = '%s%s' % (self.cont, self.obj) + expected_err = 'Error downloading object %r' % path + self.assertTrue(out.err.startswith(expected_err)) + self.assertEqual('', out) + + def test_list_with_read_access(self): + req_handler = self._fake_cross_account_auth(True, False) + resp_body = '{}' + m = hashlib.md5() + m.update(resp_body.encode()) + etag = m.hexdigest() + fake_conn = self.fake_http_connection(403, on_request=req_handler, + etags=[etag], + body=resp_body) + + args, env = self._make_cmd('download', cmd_args=[self.cont]) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + except SystemExit as e: + self.fail('Unexpected SystemExit: %s' % e) + + self.assertRequests([('GET', '%s?format=json' % self.cont_path)]) + self.assertEqual('', out) + self.assertEqual('', out.err) + + def test_list_with_no_access(self): + fake_conn = self.fake_http_connection(403) + + args, env = self._make_cmd('download', cmd_args=[self.cont]) + with mock.patch('swiftclient.client._import_keystone_client', + self.fake_ks_import): + with mock.patch('swiftclient.client.http_connection', fake_conn): + with mock.patch.dict(os.environ, env): + with CaptureOutput() as out: + try: + swiftclient.shell.main(args) + self.fail('Expected SystemExit') + except SystemExit: + pass + + self.assertRequests([('GET', '%s?format=json' % self.cont_path)]) + self.assertEqual('', out) + self.assertTrue(out.err.startswith('Container GET failed:')) + + +class TestCrossAccountObjectAccessUsingEnv(TestCrossAccountObjectAccess): + """ + Repeat super-class tests using environment variables rather than command + line to set options. + """ + + def _make_cmd(self, cmd, cmd_args=None): + return _make_cmd(cmd, self.opts, self.os_opts, cmd_args=cmd_args, + use_env=True) diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 873f4c6..2467ca6 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -216,6 +216,7 @@ class MockHttpTest(testtools.TestCase): storage_url = kwargs.get('storage_url') auth_token = kwargs.get('auth_token') exc = kwargs.get('exc') + on_request = kwargs.get('on_request') def wrapper(url, proxy=None, cacert=None, insecure=False, ssl_compression=True): @@ -245,6 +246,9 @@ class MockHttpTest(testtools.TestCase): conn.resp.has_been_read = True return _orig_read(*args, **kwargs) conn.resp.read = read + if on_request: + status = on_request(method, url, *args, **kwargs) + conn.resp.status = status if auth_token: headers = args[1] self.assertTrue('X-Auth-Token' in headers) @@ -258,7 +262,12 @@ class MockHttpTest(testtools.TestCase): if exc: raise exc return conn.resp + + def putrequest(path, data=None, headers=None, **kwargs): + request('PUT', path, data, headers, **kwargs) + conn.request = request + conn.putrequest = putrequest def getresponse(): return conn.resp @@ -288,6 +297,34 @@ class MockHttpTest(testtools.TestCase): orig_assertEqual = unittest.TestCase.assertEqual + def assert_request_equal(self, expected, real_request): + method, path = expected[:2] + if urlparse(path).scheme: + match_path = real_request['full_path'] + else: + match_path = real_request['path'] + self.assertEqual((method, path), (real_request['method'], + match_path)) + if len(expected) > 2: + body = expected[2] + real_request['expected'] = body + err_msg = 'Body mismatch for %(method)s %(path)s, ' \ + 'expected %(expected)r, and got %(body)r' % real_request + self.orig_assertEqual(body, real_request['body'], err_msg) + + if len(expected) > 3: + headers = expected[3] + for key, value in headers.items(): + real_request['key'] = key + real_request['expected_value'] = value + real_request['value'] = real_request['headers'].get(key) + err_msg = ( + 'Header mismatch on %(key)r, ' + 'expected %(expected_value)r and got %(value)r ' + 'for %(method)s %(path)s %(headers)r' % real_request) + self.orig_assertEqual(value, real_request['value'], + err_msg) + def assertRequests(self, expected_requests): """ Make sure some requests were made like you expected, provide a list of @@ -295,33 +332,26 @@ class MockHttpTest(testtools.TestCase): """ real_requests = self.iter_request_log() for expected in expected_requests: - method, path = expected[:2] real_request = next(real_requests) - if urlparse(path).scheme: - match_path = real_request['full_path'] - else: - match_path = real_request['path'] - self.assertEqual((method, path), (real_request['method'], - match_path)) - if len(expected) > 2: - body = expected[2] - real_request['expected'] = body - err_msg = 'Body mismatch for %(method)s %(path)s, ' \ - 'expected %(expected)r, and got %(body)r' % real_request - self.orig_assertEqual(body, real_request['body'], err_msg) - - if len(expected) > 3: - headers = expected[3] - for key, value in headers.items(): - real_request['key'] = key - real_request['expected_value'] = value - real_request['value'] = real_request['headers'].get(key) - err_msg = ( - 'Header mismatch on %(key)r, ' - 'expected %(expected_value)r and got %(value)r ' - 'for %(method)s %(path)s %(headers)r' % real_request) - self.orig_assertEqual(value, real_request['value'], - err_msg) + self.assert_request_equal(expected, real_request) + + def assert_request(self, expected_request): + """ + Make sure a request was made as expected. Provide the + expected request in the form of [(method, path), ...] + """ + real_requests = self.iter_request_log() + for real_request in real_requests: + try: + self.assert_request_equal(expected_request, real_request) + break + except AssertionError: + pass + else: + raise AssertionError( + "Expected request %s not found in actual requests %s" + % (expected_request, self.request_log) + ) def validateMockedRequestsConsumed(self): if not self.fake_connect: |