summaryrefslogtreecommitdiff
path: root/tests/unit/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/utils.py')
-rw-r--r--tests/unit/utils.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/tests/unit/utils.py b/tests/unit/utils.py
index c149abf..3cbb160 100644
--- a/tests/unit/utils.py
+++ b/tests/unit/utils.py
@@ -12,11 +12,16 @@
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
+import sys
from requests import RequestException
from time import sleep
import testtools
+import mock
+import six
from six.moves import reload_module
from swiftclient import client as c
+from swiftclient import shell as s
def fake_get_auth_keystone(os_options, exc=None, **kwargs):
@@ -213,3 +218,70 @@ class MockHttpTest(testtools.TestCase):
def tearDown(self):
super(MockHttpTest, self).tearDown()
reload_module(c)
+
+
+class CaptureStream(object):
+
+ def __init__(self, stream):
+ self.stream = stream
+ self._capture = six.StringIO()
+ self.streams = [self.stream, self._capture]
+
+ def write(self, *args, **kwargs):
+ for stream in self.streams:
+ stream.write(*args, **kwargs)
+
+ def writelines(self, *args, **kwargs):
+ for stream in self.streams:
+ stream.writelines(*args, **kwargs)
+
+ def getvalue(self):
+ return self._capture.getvalue()
+
+
+class CaptureOutput(object):
+
+ def __init__(self):
+ self._out = CaptureStream(sys.stdout)
+ self._err = CaptureStream(sys.stderr)
+
+ WrappedOutputManager = functools.partial(s.OutputManager,
+ print_stream=self._out,
+ error_stream=self._err)
+ self.patchers = [
+ mock.patch('swiftclient.shell.OutputManager',
+ WrappedOutputManager),
+ mock.patch('sys.stdout', self._out),
+ mock.patch('sys.stderr', self._err),
+ ]
+
+ def __enter__(self):
+ for patcher in self.patchers:
+ patcher.start()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ for patcher in self.patchers:
+ patcher.stop()
+
+ @property
+ def out(self):
+ return self._out.getvalue()
+
+ @property
+ def err(self):
+ return self._err.getvalue()
+
+ # act like the string captured by stdout
+
+ def __str__(self):
+ return self.out
+
+ def __len__(self):
+ return len(self.out)
+
+ def __eq__(self, other):
+ return self.out == other
+
+ def __getattr__(self, name):
+ return getattr(self.out, name)