diff options
Diffstat (limited to 'tests/unit/utils.py')
-rw-r--r-- | tests/unit/utils.py | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/tests/unit/utils.py b/tests/unit/utils.py index c149abf..3cbb160 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -12,11 +12,16 @@ # implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import sys from requests import RequestException from time import sleep import testtools +import mock +import six from six.moves import reload_module from swiftclient import client as c +from swiftclient import shell as s def fake_get_auth_keystone(os_options, exc=None, **kwargs): @@ -213,3 +218,70 @@ class MockHttpTest(testtools.TestCase): def tearDown(self): super(MockHttpTest, self).tearDown() reload_module(c) + + +class CaptureStream(object): + + def __init__(self, stream): + self.stream = stream + self._capture = six.StringIO() + self.streams = [self.stream, self._capture] + + def write(self, *args, **kwargs): + for stream in self.streams: + stream.write(*args, **kwargs) + + def writelines(self, *args, **kwargs): + for stream in self.streams: + stream.writelines(*args, **kwargs) + + def getvalue(self): + return self._capture.getvalue() + + +class CaptureOutput(object): + + def __init__(self): + self._out = CaptureStream(sys.stdout) + self._err = CaptureStream(sys.stderr) + + WrappedOutputManager = functools.partial(s.OutputManager, + print_stream=self._out, + error_stream=self._err) + self.patchers = [ + mock.patch('swiftclient.shell.OutputManager', + WrappedOutputManager), + mock.patch('sys.stdout', self._out), + mock.patch('sys.stderr', self._err), + ] + + def __enter__(self): + for patcher in self.patchers: + patcher.start() + return self + + def __exit__(self, *args, **kwargs): + for patcher in self.patchers: + patcher.stop() + + @property + def out(self): + return self._out.getvalue() + + @property + def err(self): + return self._err.getvalue() + + # act like the string captured by stdout + + def __str__(self): + return self.out + + def __len__(self): + return len(self.out) + + def __eq__(self, other): + return self.out == other + + def __getattr__(self, name): + return getattr(self.out, name) |