# Copyright (c) 2010-2012 OpenStack, LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import sys from requests import RequestException from requests.structures import CaseInsensitiveDict from time import sleep import unittest import testtools import mock import six from six.moves import reload_module from six.moves.urllib.parse import urlparse, ParseResult from swiftclient import client as c from swiftclient import shell as s def fake_get_auth_keystone(expected_os_options=None, exc=None, storage_url='http://url/', token='token', **kwargs): def fake_get_auth_keystone(auth_url, user, key, actual_os_options, **actual_kwargs): if exc: raise exc('test') # TODO: some way to require auth_url, user and key? if expected_os_options and actual_os_options != expected_os_options: return "", None if 'required_kwargs' in kwargs: for k, v in kwargs['required_kwargs'].items(): if v != actual_kwargs.get(k): return "", None if auth_url.startswith("https") and \ auth_url.endswith("invalid-certificate") and \ not actual_kwargs['insecure']: from swiftclient import client as c raise c.ClientException("invalid-certificate") if auth_url.startswith("https") and \ auth_url.endswith("self-signed-certificate") and \ not actual_kwargs['insecure'] and \ actual_kwargs['cacert'] is None: from swiftclient import client as c raise c.ClientException("unverified-certificate") return storage_url, token return fake_get_auth_keystone class StubResponse(object): """ Placeholder structure for use with fake_http_connect's code_iter to modify response attributes (status, body, headers) on a per-request basis. """ def __init__(self, status=200, body='', headers=None): self.status = status self.body = body self.headers = headers or {} def fake_http_connect(*code_iter, **kwargs): """ Generate a callable which yields a series of stubbed responses. Because swiftclient will reuse an HTTP connection across pipelined requests it is not always the case that this fake is used strictly for mocking an HTTP connection, but rather each HTTP response (i.e. each call to requests get_response). """ class FakeConn(object): def __init__(self, status, etag=None, body='', timestamp='1', headers=None): self.status = status self.reason = 'Fake' self.host = '1.2.3.4' self.port = '1234' self.sent = 0 self.received = 0 self.etag = etag self.body = body self.timestamp = timestamp self._is_closed = True self.headers = headers or {} def connect(self): self._is_closed = False def close(self): self._is_closed = True def isclosed(self): return self._is_closed def getresponse(self): if kwargs.get('raise_exc'): raise Exception('test') return self def getexpect(self): if self.status == -2: raise RequestException() if self.status == -3: return FakeConn(507) return FakeConn(100) def getheaders(self): if self.headers: return self.headers.items() headers = {'content-length': len(self.body), 'content-type': 'x-application/test', 'x-timestamp': self.timestamp, 'last-modified': self.timestamp, 'x-object-meta-test': 'testing', 'etag': self.etag or '"d41d8cd98f00b204e9800998ecf8427e"', 'x-works': 'yes', 'x-account-container-count': 12345} if not self.timestamp: del headers['x-timestamp'] try: if next(container_ts_iter) is False: headers['x-container-timestamp'] = '1' except StopIteration: pass if 'slow' in kwargs: headers['content-length'] = '4' if 'headers' in kwargs: headers.update(kwargs['headers']) if 'auth_v1' in kwargs: headers.update( {'x-storage-url': 'storageURL', 'x-auth-token': 'someauthtoken'}) return headers.items() def read(self, amt=None): if 'slow' in kwargs: if self.sent < 4: self.sent += 1 sleep(0.1) return ' ' rv = self.body[:amt] self.body = self.body[amt:] return rv def send(self, amt=None): if 'slow' in kwargs: if self.received < 4: self.received += 1 sleep(0.1) def getheader(self, name, default=None): return dict(self.getheaders()).get(name.lower(), default) timestamps_iter = iter(kwargs.get('timestamps') or ['1'] * len(code_iter)) etag_iter = iter(kwargs.get('etags') or [None] * len(code_iter)) x = kwargs.get('missing_container', [False] * len(code_iter)) if not isinstance(x, (tuple, list)): x = [x] * len(code_iter) container_ts_iter = iter(x) code_iter = iter(code_iter) def connect(*args, **ckwargs): if 'give_content_type' in kwargs: if len(args) >= 7 and 'Content-Type' in args[6]: kwargs['give_content_type'](args[6]['Content-Type']) else: kwargs['give_content_type']('') if 'give_connect' in kwargs: kwargs['give_connect'](*args, **ckwargs) status = next(code_iter) if isinstance(status, StubResponse): fake_conn = FakeConn(status.status, body=status.body, headers=status.headers) else: etag = next(etag_iter) timestamp = next(timestamps_iter) fake_conn = FakeConn(status, etag, body=kwargs.get('body', ''), timestamp=timestamp) if fake_conn.status <= 0: raise RequestException() fake_conn.connect() return fake_conn connect.code_iter = code_iter return connect class MockHttpTest(testtools.TestCase): def setUp(self): super(MockHttpTest, self).setUp() self.fake_connect = None self.request_log = [] def fake_http_connection(*args, **kwargs): self.validateMockedRequestsConsumed() self.request_log = [] self.fake_connect = fake_http_connect(*args, **kwargs) _orig_http_connection = c.http_connection query_string = kwargs.get('query_string') storage_url = kwargs.get('storage_url') auth_token = kwargs.get('auth_token') exc = kwargs.get('exc') on_request = kwargs.get('on_request') def wrapper(url, proxy=None, cacert=None, insecure=False, ssl_compression=True): if storage_url: self.assertEqual(storage_url, url) parsed, _conn = _orig_http_connection(url, proxy=proxy) class RequestsWrapper(object): pass conn = RequestsWrapper() def request(method, url, *args, **kwargs): try: conn.resp = self.fake_connect() except StopIteration: self.fail('Unexpected %s request for %s' % ( method, url)) self.request_log.append((parsed, method, url, args, kwargs, conn.resp)) conn.host = conn.resp.host conn.isclosed = conn.resp.isclosed conn.resp.has_been_read = False _orig_read = conn.resp.read def read(*args, **kwargs): conn.resp.has_been_read = True return _orig_read(*args, **kwargs) conn.resp.read = read if on_request: status = on_request(method, url, *args, **kwargs) conn.resp.status = status if auth_token: headers = args[1] self.assertTrue('X-Auth-Token' in headers) actual_token = headers.get('X-Auth-Token') self.assertEqual(auth_token, actual_token) if query_string: self.assertTrue(url.endswith('?' + query_string)) if url.endswith('invalid_cert') and not insecure: from swiftclient import client as c raise c.ClientException("invalid_certificate") if exc: raise exc return conn.resp def putrequest(path, data=None, headers=None, **kwargs): request('PUT', path, data, headers, **kwargs) conn.request = request conn.putrequest = putrequest def getresponse(): return conn.resp conn.getresponse = getresponse return parsed, conn return wrapper self.fake_http_connection = fake_http_connection def iter_request_log(self): for parsed, method, path, args, kwargs, resp in self.request_log: parts = parsed._asdict() parts['path'] = path full_path = ParseResult(**parts).geturl() args = list(args) log = dict(zip(('body', 'headers'), args)) log.update({ 'method': method, 'full_path': full_path, 'parsed_path': urlparse(full_path), 'path': path, 'headers': CaseInsensitiveDict(log.get('headers')), 'resp': resp, 'status': resp.status, }) yield log orig_assertEqual = unittest.TestCase.assertEqual def assert_request_equal(self, expected, real_request): method, path = expected[:2] if urlparse(path).scheme: match_path = real_request['full_path'] else: match_path = real_request['path'] self.assertEqual((method, path), (real_request['method'], match_path)) if len(expected) > 2: body = expected[2] real_request['expected'] = body err_msg = 'Body mismatch for %(method)s %(path)s, ' \ 'expected %(expected)r, and got %(body)r' % real_request self.orig_assertEqual(body, real_request['body'], err_msg) if len(expected) > 3: headers = expected[3] for key, value in headers.items(): real_request['key'] = key real_request['expected_value'] = value real_request['value'] = real_request['headers'].get(key) err_msg = ( 'Header mismatch on %(key)r, ' 'expected %(expected_value)r and got %(value)r ' 'for %(method)s %(path)s %(headers)r' % real_request) self.orig_assertEqual(value, real_request['value'], err_msg) def assertRequests(self, expected_requests): """ Make sure some requests were made like you expected, provide a list of expected requests, typically in the form of [(method, path), ...] """ real_requests = self.iter_request_log() for expected in expected_requests: real_request = next(real_requests) self.assert_request_equal(expected, real_request) def assert_request(self, expected_request): """ Make sure a request was made as expected. Provide the expected request in the form of [(method, path), ...] """ real_requests = self.iter_request_log() for real_request in real_requests: try: self.assert_request_equal(expected_request, real_request) break except AssertionError: pass else: raise AssertionError( "Expected request %s not found in actual requests %s" % (expected_request, self.request_log) ) def validateMockedRequestsConsumed(self): if not self.fake_connect: return unused_responses = list(self.fake_connect.code_iter) if unused_responses: self.fail('Unused responses %r' % (unused_responses,)) def tearDown(self): self.validateMockedRequestsConsumed() super(MockHttpTest, self).tearDown() # TODO: this nuke from orbit clean up seems to be encouraging # un-hygienic mocking on the swiftclient.client module; which may lead # to some unfortunate test order dependency bugs by way of the broken # window theory if any other modules are similarly patched reload_module(c) class CaptureStreamBuffer(object): """ CaptureStreamBuffer is used for testing raw byte writing for PY3. Anything written here is decoded as utf-8 and written to the parent CaptureStream """ def __init__(self, captured_stream): self._captured_stream = captured_stream def write(self, bytes_data): # No encoding, just convert the raw bytes into a str for testing # The below call also validates that we have a byte string. self._captured_stream.write( ''.join(map(chr, bytes_data)) ) class CaptureStream(object): def __init__(self, stream): self.stream = stream self._capture = six.StringIO() self._buffer = CaptureStreamBuffer(self) self.streams = [self.stream, self._capture] @property def buffer(self): if six.PY3: return self._buffer else: raise AttributeError( 'Output stream has no attribute "buffer" in Python2') def flush(self): pass def write(self, *args, **kwargs): for stream in self.streams: stream.write(*args, **kwargs) def writelines(self, *args, **kwargs): for stream in self.streams: stream.writelines(*args, **kwargs) def getvalue(self): return self._capture.getvalue() def clear(self): self._capture.truncate(0) self._capture.seek(0) class CaptureOutput(object): def __init__(self, suppress_systemexit=False): self._out = CaptureStream(sys.stdout) self._err = CaptureStream(sys.stderr) self.patchers = [] WrappedOutputManager = functools.partial(s.OutputManager, print_stream=self._out, error_stream=self._err) if suppress_systemexit: self.patchers += [ mock.patch('swiftclient.shell.OutputManager.get_error_count', return_value=0) ] self.patchers += [ mock.patch('swiftclient.shell.OutputManager', WrappedOutputManager), mock.patch('sys.stdout', self._out), mock.patch('sys.stderr', self._err), ] def __enter__(self): for patcher in self.patchers: patcher.start() return self def __exit__(self, *args, **kwargs): for patcher in self.patchers: patcher.stop() @property def out(self): return self._out.getvalue() @property def err(self): return self._err.getvalue() def clear(self): self._out.clear() self._err.clear() # act like the string captured by stdout def __str__(self): return self.out def __len__(self): return len(self.out) def __eq__(self, other): return self.out == other def __getattr__(self, name): return getattr(self.out, name)