summaryrefslogtreecommitdiff
path: root/test/unit/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/unit/utils.py')
-rw-r--r--test/unit/utils.py582
1 files changed, 582 insertions, 0 deletions
diff --git a/test/unit/utils.py b/test/unit/utils.py
new file mode 100644
index 0000000..025a234
--- /dev/null
+++ b/test/unit/utils.py
@@ -0,0 +1,582 @@
+# Copyright (c) 2010-2012 OpenStack, LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import functools
+import sys
+from requests import RequestException
+from requests.structures import CaseInsensitiveDict
+from time import sleep
+import unittest
+import mock
+import six
+import os
+from six.moves import reload_module
+from six.moves.urllib.parse import urlparse, ParseResult
+from swiftclient import client as c
+from swiftclient import shell as s
+from swiftclient.utils import EMPTY_ETAG
+
+
+def fake_get_auth_keystone(expected_os_options=None, exc=None,
+ storage_url='http://url/', token='token',
+ **kwargs):
+ def fake_get_auth_keystone(auth_url,
+ user,
+ key,
+ actual_os_options, **actual_kwargs):
+ if exc:
+ raise exc('test')
+ # TODO: some way to require auth_url, user and key?
+ if expected_os_options:
+ for key, value in actual_os_options.items():
+ if value and value != expected_os_options.get(key):
+ return "", None
+ if 'required_kwargs' in kwargs:
+ for k, v in kwargs['required_kwargs'].items():
+ if v != actual_kwargs.get(k):
+ return "", None
+
+ if auth_url.startswith("https") and \
+ auth_url.endswith("invalid-certificate") and \
+ not actual_kwargs['insecure']:
+ from swiftclient import client as c
+ raise c.ClientException("invalid-certificate")
+ if auth_url.startswith("https") and \
+ auth_url.endswith("self-signed-certificate") and \
+ not actual_kwargs['insecure'] and \
+ actual_kwargs['cacert'] is None:
+ from swiftclient import client as c
+ raise c.ClientException("unverified-certificate")
+ if auth_url.startswith("https") and \
+ auth_url.endswith("client-certificate") and \
+ not (actual_kwargs['cert'] and actual_kwargs['cert_key']):
+ from swiftclient import client as c
+ raise c.ClientException("noclient-certificate")
+
+ return storage_url, token
+ return fake_get_auth_keystone
+
+
+class StubResponse(object):
+ """
+ Placeholder structure for use with fake_http_connect's code_iter to modify
+ response attributes (status, body, headers) on a per-request basis.
+ """
+
+ def __init__(self, status=200, body='', headers=None):
+ self.status = status
+ self.body = body
+ self.headers = headers or {}
+
+ def __repr__(self):
+ return '%s(%r, %r, %r)' % (self.__class__.__name__, self.status,
+ self.body, self.headers)
+
+
+def fake_http_connect(*code_iter, **kwargs):
+ """
+ Generate a callable which yields a series of stubbed responses. Because
+ swiftclient will reuse an HTTP connection across pipelined requests it is
+ not always the case that this fake is used strictly for mocking an HTTP
+ connection, but rather each HTTP response (i.e. each call to requests
+ get_response).
+ """
+
+ class FakeConn(object):
+
+ def __init__(self, status, etag=None, body='', timestamp='1',
+ headers=None):
+ self.status_code = self.status = status
+ self.reason = 'Fake'
+ self.scheme = 'http'
+ self.host = '1.2.3.4'
+ self.port = '1234'
+ self.sent = 0
+ self.received = 0
+ self.etag = etag
+ self.content = self.body = body
+ self.timestamp = timestamp
+ self.headers = headers or {}
+ self.request = None
+
+ def getresponse(self):
+ if kwargs.get('raise_exc'):
+ raise Exception('test')
+ return self
+
+ def getheaders(self):
+ if self.headers:
+ return self.headers.items()
+ headers = {'content-length': str(len(self.body)),
+ 'content-type': 'x-application/test',
+ 'x-timestamp': self.timestamp,
+ 'last-modified': self.timestamp,
+ 'x-object-meta-test': 'testing',
+ 'etag':
+ self.etag or '"%s"' % EMPTY_ETAG,
+ 'x-works': 'yes',
+ 'x-account-container-count': '12345'}
+ if not self.timestamp:
+ del headers['x-timestamp']
+ try:
+ if next(container_ts_iter) is False:
+ headers['x-container-timestamp'] = '1'
+ except StopIteration:
+ pass
+ if 'slow' in kwargs:
+ headers['content-length'] = '4'
+ if 'headers' in kwargs:
+ headers.update(kwargs['headers'])
+ if 'auth_v1' in kwargs:
+ headers.update(
+ {'x-storage-url': 'storageURL',
+ 'x-auth-token': 'someauthtoken'})
+ return headers.items()
+
+ def read(self, amt=None):
+ if 'slow' in kwargs:
+ if self.sent < 4:
+ self.sent += 1
+ sleep(0.1)
+ return ' '
+ rv = self.body[:amt]
+ if amt is not None:
+ self.body = self.body[amt:]
+ else:
+ self.body = ''
+ return rv
+
+ def send(self, amt=None):
+ if 'slow' in kwargs:
+ if self.received < 4:
+ self.received += 1
+ sleep(0.1)
+
+ def getheader(self, name, default=None):
+ return dict(self.getheaders()).get(name.lower(), default)
+
+ def close(self):
+ pass
+
+ timestamps_iter = iter(kwargs.get('timestamps') or ['1'] * len(code_iter))
+ etag_iter = iter(kwargs.get('etags') or [None] * len(code_iter))
+ x = kwargs.get('missing_container', [False] * len(code_iter))
+ if not isinstance(x, (tuple, list)):
+ x = [x] * len(code_iter)
+ container_ts_iter = iter(x)
+ code_iter = iter(code_iter)
+
+ def connect(*args, **ckwargs):
+ if 'give_content_type' in kwargs:
+ if len(args) >= 7 and 'Content-Type' in args[6]:
+ kwargs['give_content_type'](args[6]['Content-Type'])
+ else:
+ kwargs['give_content_type']('')
+ if 'give_connect' in kwargs:
+ kwargs['give_connect'](*args, **ckwargs)
+ status = next(code_iter)
+ if isinstance(status, StubResponse):
+ fake_conn = FakeConn(status.status, body=status.body,
+ headers=status.headers)
+ else:
+ etag = next(etag_iter)
+ timestamp = next(timestamps_iter)
+ fake_conn = FakeConn(status, etag, body=kwargs.get('body', ''),
+ timestamp=timestamp)
+ if fake_conn.status <= 0:
+ raise RequestException()
+ return fake_conn
+
+ connect.code_iter = code_iter
+ return connect
+
+
+class MockHttpTest(unittest.TestCase):
+
+ def setUp(self):
+ super(MockHttpTest, self).setUp()
+ self.fake_connect = None
+ self.request_log = []
+
+ # Capture output, since the test-runner stdout/stderr monkey-patching
+ # won't cover the references to sys.stdout/sys.stderr in
+ # swiftclient.multithreading
+ self.capture_output = CaptureOutput()
+ if 'SWIFTCLIENT_DEBUG' not in os.environ:
+ self.capture_output.__enter__()
+ self.addCleanup(self.capture_output.__exit__)
+
+ # since we're going to steal all stderr output globally; we should
+ # give the developer an escape hatch or risk scorn
+ def blowup_but_with_the_helpful(*args, **kwargs):
+ raise Exception(
+ "You tried to enter a debugger while stderr is "
+ "patched, you need to set SWIFTCLIENT_DEBUG=1 "
+ "and try again")
+ import pdb
+ pdb.set_trace = blowup_but_with_the_helpful
+
+ def fake_http_connection(*args, **kwargs):
+ self.validateMockedRequestsConsumed()
+ self.request_log = []
+ self.fake_connect = fake_http_connect(*args, **kwargs)
+ _orig_http_connection = c.http_connection
+ query_string = kwargs.get('query_string')
+ storage_url = kwargs.get('storage_url')
+ auth_token = kwargs.get('auth_token')
+ exc = kwargs.get('exc')
+ on_request = kwargs.get('on_request')
+
+ def wrapper(url, proxy=None, cacert=None, insecure=False,
+ cert=None, cert_key=None,
+ ssl_compression=True, timeout=None):
+ if storage_url:
+ self.assertEqual(storage_url, url)
+
+ parsed, _conn = _orig_http_connection(url, proxy=proxy)
+
+ class RequestsWrapper(object):
+ def close(self):
+ pass
+ conn = RequestsWrapper()
+
+ def request(method, path, *args, **kwargs):
+ try:
+ conn.resp = self.fake_connect()
+ except StopIteration:
+ self.fail('Unexpected %s request for %s' % (
+ method, path))
+ self.request_log.append((parsed, method, path, args,
+ kwargs, conn.resp))
+ conn.host = conn.resp.host
+ conn.resp.request = RequestsWrapper()
+ conn.resp.request.url = '%s://%s%s' % (
+ conn.resp.scheme, conn.resp.host, path)
+ conn.resp.has_been_read = False
+ _orig_read = conn.resp.read
+
+ def read(*args, **kwargs):
+ conn.resp.has_been_read = True
+ return _orig_read(*args, **kwargs)
+ conn.resp.read = read
+ if on_request:
+ status = on_request(method, path, *args, **kwargs)
+ conn.resp.status = status
+ if auth_token:
+ headers = args[1]
+ self.assertEqual(auth_token,
+ headers.get('X-Auth-Token'))
+ if query_string:
+ self.assertTrue(path.endswith('?' + query_string))
+ if path.endswith('invalid_cert') and not insecure:
+ from swiftclient import client as c
+ raise c.ClientException("invalid_certificate")
+ if exc:
+ raise exc
+ return conn.resp
+
+ def putrequest(path, data=None, headers=None, **kwargs):
+ request('PUT', path, data, headers, **kwargs)
+
+ conn.request = request
+ conn.putrequest = putrequest
+
+ def getresponse():
+ return conn.resp
+ conn.getresponse = getresponse
+
+ return parsed, conn
+ return wrapper
+ self.fake_http_connection = fake_http_connection
+
+ def iter_request_log(self):
+ for parsed, method, path, args, kwargs, resp in self.request_log:
+ parts = parsed._asdict()
+ parts['path'] = path
+ full_path = ParseResult(**parts).geturl()
+ args = list(args)
+ log = dict(zip(('body', 'headers'), args))
+ log.update({
+ 'method': method,
+ 'full_path': full_path,
+ 'parsed_path': urlparse(full_path),
+ 'path': path,
+ 'headers': CaseInsensitiveDict(log.get('headers')),
+ 'resp': resp,
+ 'status': resp.status,
+ })
+ yield log
+
+ orig_assertEqual = unittest.TestCase.assertEqual
+
+ def assert_request_equal(self, expected, real_request):
+ method, path = expected[:2]
+ if urlparse(path).scheme:
+ match_path = real_request['full_path']
+ else:
+ match_path = real_request['path']
+ self.assertEqual((method, path), (real_request['method'],
+ match_path))
+ if len(expected) > 2:
+ body = expected[2]
+ real_request['expected'] = body
+ err_msg = 'Body mismatch for %(method)s %(path)s, ' \
+ 'expected %(expected)r, and got %(body)r' % real_request
+ self.orig_assertEqual(body, real_request['body'], err_msg)
+
+ if len(expected) > 3:
+ headers = CaseInsensitiveDict(expected[3])
+ for key, value in headers.items():
+ real_request['key'] = key
+ real_request['expected_value'] = value
+ real_request['value'] = real_request['headers'].get(key)
+ err_msg = (
+ 'Header mismatch on %(key)r, '
+ 'expected %(expected_value)r and got %(value)r '
+ 'for %(method)s %(path)s %(headers)r' % real_request)
+ self.orig_assertEqual(value, real_request['value'],
+ err_msg)
+ real_request['extra_headers'] = dict(
+ (key, value) for key, value in real_request['headers'].items()
+ if key not in headers)
+ if real_request['extra_headers']:
+ self.fail('Received unexpected headers for %(method)s '
+ '%(path)s, got %(extra_headers)r' % real_request)
+
+ def assertRequests(self, expected_requests):
+ """
+ Make sure some requests were made like you expected, provide a list of
+ expected requests, typically in the form of [(method, path), ...]
+ or [(method, path, body, headers), ...]
+ """
+ real_requests = self.iter_request_log()
+ for expected in expected_requests:
+ real_request = next(real_requests)
+ self.assert_request_equal(expected, real_request)
+ try:
+ real_request = next(real_requests)
+ except StopIteration:
+ pass
+ else:
+ self.fail('At least one extra request received: %r' %
+ real_request)
+
+ def assert_request(self, expected_request):
+ """
+ Make sure a request was made as expected. Provide the
+ expected request in the form of [(method, path), ...]
+ """
+ real_requests = self.iter_request_log()
+ for real_request in real_requests:
+ try:
+ self.assert_request_equal(expected_request, real_request)
+ break
+ except AssertionError:
+ pass
+ else:
+ raise AssertionError(
+ "Expected request %s not found in actual requests %s"
+ % (expected_request, self.request_log)
+ )
+
+ def validateMockedRequestsConsumed(self):
+ if not self.fake_connect:
+ return
+ unused_responses = list(self.fake_connect.code_iter)
+ if unused_responses:
+ self.fail('Unused responses %r' % (unused_responses,))
+
+ def tearDown(self):
+ self.validateMockedRequestsConsumed()
+ super(MockHttpTest, self).tearDown()
+ # TODO: this nuke from orbit clean up seems to be encouraging
+ # un-hygienic mocking on the swiftclient.client module; which may lead
+ # to some unfortunate test order dependency bugs by way of the broken
+ # window theory if any other modules are similarly patched
+ reload_module(c)
+
+
+class CaptureStreamPrinter(object):
+ """
+ CaptureStreamPrinter is used for testing unicode writing for PY3. Anything
+ written here is encoded as utf-8 and written to the parent CaptureStream
+ """
+ def __init__(self, captured_stream):
+ self._captured_stream = captured_stream
+
+ def write(self, data):
+ # No encoding, just convert the raw bytes into a str for testing
+ # The below call also validates that we have a byte string.
+ self._captured_stream.write(
+ data if isinstance(data, six.binary_type) else data.encode('utf8'))
+
+
+class CaptureStream(object):
+
+ def __init__(self, stream):
+ self.stream = stream
+ self._buffer = six.BytesIO()
+ self._capture = CaptureStreamPrinter(self._buffer)
+ self.streams = [self._capture]
+
+ @property
+ def buffer(self):
+ if six.PY3:
+ return self._buffer
+ else:
+ raise AttributeError(
+ 'Output stream has no attribute "buffer" in Python2')
+
+ def flush(self):
+ pass
+
+ def write(self, *args, **kwargs):
+ for stream in self.streams:
+ stream.write(*args, **kwargs)
+
+ def writelines(self, *args, **kwargs):
+ for stream in self.streams:
+ stream.writelines(*args, **kwargs)
+
+ def getvalue(self):
+ return self._buffer.getvalue()
+
+ def clear(self):
+ self._buffer.truncate(0)
+ self._buffer.seek(0)
+
+
+class CaptureOutput(object):
+
+ def __init__(self, suppress_systemexit=False):
+ self._out = CaptureStream(sys.stdout)
+ self._err = CaptureStream(sys.stderr)
+ self.patchers = []
+
+ WrappedOutputManager = functools.partial(s.OutputManager,
+ print_stream=self._out,
+ error_stream=self._err)
+
+ if suppress_systemexit:
+ self.patchers += [
+ mock.patch('swiftclient.shell.OutputManager.get_error_count',
+ return_value=0)
+ ]
+
+ self.patchers += [
+ mock.patch('swiftclient.shell.OutputManager',
+ WrappedOutputManager),
+ mock.patch('sys.stdout', self._out),
+ mock.patch('sys.stderr', self._err),
+ ]
+
+ def __enter__(self):
+ for patcher in self.patchers:
+ patcher.start()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ for patcher in self.patchers:
+ patcher.stop()
+
+ @property
+ def out(self):
+ return self._out.getvalue().decode('utf8')
+
+ @property
+ def err(self):
+ return self._err.getvalue().decode('utf8')
+
+ def clear(self):
+ self._out.clear()
+ self._err.clear()
+
+ # act like the string captured by stdout
+
+ def __str__(self):
+ return self.out
+
+ def __len__(self):
+ return len(self.out)
+
+ def __eq__(self, other):
+ return self.out == other
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getattr__(self, name):
+ return getattr(self.out, name)
+
+
+class FakeKeystone(object):
+ '''
+ Fake keystone client module. Returns given endpoint url and auth token.
+ '''
+ def __init__(self, endpoint, token):
+ self.calls = []
+ self.auth_version = None
+ self.endpoint = endpoint
+ self.token = token
+
+ class _Client(object):
+ def __init__(self, endpoint, auth_token, **kwargs):
+ self.auth_token = auth_token
+ self.endpoint = endpoint
+ self.service_catalog = self.ServiceCatalog(endpoint)
+
+ class ServiceCatalog(object):
+ def __init__(self, endpoint):
+ self.calls = []
+ self.endpoint_url = endpoint
+
+ def url_for(self, **kwargs):
+ self.calls.append(kwargs)
+ return self.endpoint_url
+
+ def Client(self, **kwargs):
+ self.calls.append(kwargs)
+ self.client = self._Client(
+ endpoint=self.endpoint, auth_token=self.token, **kwargs)
+ return self.client
+
+ class Unauthorized(Exception):
+ pass
+
+ class AuthorizationFailure(Exception):
+ pass
+
+ class EndpointNotFound(Exception):
+ pass
+
+
+class FakeStream(object):
+ def __init__(self, size):
+ self.bytes_read = 0
+ self.size = size
+
+ def read(self, size=-1):
+ if self.bytes_read == self.size:
+ return b''
+
+ if size == -1 or size + self.bytes_read > self.size:
+ remaining = self.size - self.bytes_read
+ self.bytes_read = self.size
+ return b'A' * remaining
+
+ self.bytes_read += size
+ return b'A' * size
+
+ def __len__(self):
+ return self.size