diff options
Diffstat (limited to 'test/unit/test_multithreading.py')
-rw-r--r-- | test/unit/test_multithreading.py | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/test/unit/test_multithreading.py b/test/unit/test_multithreading.py new file mode 100644 index 0000000..8944d48 --- /dev/null +++ b/test/unit/test_multithreading.py @@ -0,0 +1,240 @@ +# Copyright (c) 2010-2013 OpenStack, LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest +import threading +import six + +from concurrent.futures import as_completed +from six.moves.queue import Queue, Empty +from time import sleep + +from swiftclient import multithreading as mt +from .utils import CaptureStream + + +class ThreadTestCase(unittest.TestCase): + def setUp(self): + super(ThreadTestCase, self).setUp() + self.got_items = Queue() + self.got_args_kwargs = Queue() + self.starting_thread_count = threading.active_count() + + def _func(self, conn, item, *args, **kwargs): + self.got_items.put((conn, item)) + self.got_args_kwargs.put((args, kwargs)) + + if item == 'sleep': + sleep(1) + if item == 'go boom': + raise Exception('I went boom!') + + return 'success' + + def _create_conn(self): + return "This is a connection" + + def _create_conn_fail(self): + raise Exception("This is a failed connection") + + def assertQueueContains(self, queue, expected_contents): + got_contents = [] + try: + while True: + got_contents.append(queue.get(timeout=0.1)) + except Empty: + pass + if isinstance(expected_contents, set): + got_contents = set(got_contents) + self.assertEqual(expected_contents, got_contents) + + +class TestConnectionThreadPoolExecutor(ThreadTestCase): + def setUp(self): + super(TestConnectionThreadPoolExecutor, self).setUp() + self.input_queue = Queue() + self.stored_results = [] + + def tearDown(self): + super(TestConnectionThreadPoolExecutor, self).tearDown() + + def test_submit_good_connection(self): + ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn, 1) + with ctpe as pool: + # Try submitting a job that should succeed + f = pool.submit(self._func, "succeed") + f.result() + self.assertQueueContains( + self.got_items, + [("This is a connection", "succeed")] + ) + + # Now a job that fails + try: + f = pool.submit(self._func, "go boom") + f.result() + except Exception as e: + self.assertEqual('I went boom!', str(e)) + else: + self.fail('I never went boom!') + + # Has the connection been returned to the pool? + f = pool.submit(self._func, "succeed") + f.result() + self.assertQueueContains( + self.got_items, + [ + ("This is a connection", "go boom"), + ("This is a connection", "succeed") + ] + ) + + def test_submit_bad_connection(self): + ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn_fail, 1) + with ctpe as pool: + # Now a connection that fails + try: + f = pool.submit(self._func, "succeed") + f.result() + except Exception as e: + self.assertEqual('This is a failed connection', str(e)) + else: + self.fail('The connection did not fail') + + # Make sure we don't lock up on failed connections + try: + f = pool.submit(self._func, "go boom") + f.result() + except Exception as e: + self.assertEqual('This is a failed connection', str(e)) + else: + self.fail('The connection did not fail') + + def test_lazy_connections(self): + ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn, 10) + with ctpe as pool: + # Submit multiple jobs sequentially - should only use 1 conn + f = pool.submit(self._func, "succeed") + f.result() + f = pool.submit(self._func, "succeed") + f.result() + f = pool.submit(self._func, "succeed") + f.result() + + expected_connections = [(0, "This is a connection")] + expected_connections.extend([(x, None) for x in range(1, 10)]) + + self.assertQueueContains( + pool._connections, expected_connections + ) + + ctpe = mt.ConnectionThreadPoolExecutor(self._create_conn, 10) + with ctpe as pool: + fs = [] + f1 = pool.submit(self._func, "sleep") + f2 = pool.submit(self._func, "sleep") + f3 = pool.submit(self._func, "sleep") + fs.extend([f1, f2, f3]) + + expected_connections = [ + (0, "This is a connection"), + (1, "This is a connection"), + (2, "This is a connection") + ] + expected_connections.extend([(x, None) for x in range(3, 10)]) + + for f in as_completed(fs): + f.result() + + self.assertQueueContains( + pool._connections, expected_connections + ) + + +class TestOutputManager(unittest.TestCase): + + def test_instantiation(self): + output_manager = mt.OutputManager() + + self.assertEqual(sys.stdout, output_manager.print_stream) + self.assertEqual(sys.stderr, output_manager.error_stream) + + def test_printers(self): + out_stream = CaptureStream(sys.stdout) + err_stream = CaptureStream(sys.stderr) + starting_thread_count = threading.active_count() + + with mt.OutputManager( + print_stream=out_stream, + error_stream=err_stream) as thread_manager: + + # Sanity-checking these gives power to the previous test which + # looked at the default values of thread_manager.print/error_stream + self.assertEqual(out_stream, thread_manager.print_stream) + self.assertEqual(err_stream, thread_manager.error_stream) + + # No printing has happened yet, so no new threads + self.assertEqual(starting_thread_count, + threading.active_count()) + + thread_manager.print_msg('one-argument') + thread_manager.print_msg('one %s, %d fish', 'fish', 88) + thread_manager.error('I have %d problems, but a %s is not one', + 99, u'\u062A\u062A') + thread_manager.print_msg('some\n%s\nover the %r', 'where', + u'\u062A\u062A') + thread_manager.error('one-error-argument') + thread_manager.error('Sometimes\n%.1f%% just\ndoes not\nwork!', + 3.14159) + thread_manager.print_raw( + u'some raw bytes: \u062A\u062A'.encode('utf-8')) + + thread_manager.print_items([ + ('key', 'value'), + ('object', u'O\u0308bject'), + ]) + + thread_manager.print_raw(b'\xffugly\xffraw') + + # Now we have a thread for error printing and a thread for + # normal print messages + self.assertEqual(starting_thread_count + 2, + threading.active_count()) + + # The threads should have been cleaned up + self.assertEqual(starting_thread_count, threading.active_count()) + + if six.PY3: + over_the = "over the '\u062a\u062a'\n" + else: + over_the = "over the u'\\u062a\\u062a'\n" + # We write to the CaptureStream so no decoding is performed + self.assertEqual(''.join([ + 'one-argument\n', + 'one fish, 88 fish\n', + 'some\n', 'where\n', + over_the, + u'some raw bytes: \u062a\u062a', + ' key: value\n', + u' object: O\u0308bject\n' + ]).encode('utf8') + b'\xffugly\xffraw', out_stream.getvalue()) + + self.assertEqual(''.join([ + u'I have 99 problems, but a \u062A\u062A is not one\n', + 'one-error-argument\n', + 'Sometimes\n', '3.1% just\n', 'does not\n', 'work!\n' + ]), err_stream.getvalue().decode('utf8')) + + self.assertEqual(3, thread_manager.error_count) |