diff options
Diffstat (limited to 'Lib/test/lock_tests.py')
-rw-r--r-- | Lib/test/lock_tests.py | 320 |
1 files changed, 303 insertions, 17 deletions
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py index 04f7422c8d..b6d818e4d4 100644 --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -4,7 +4,7 @@ Various tests for synchronization primitives. import sys import time -from _thread import start_new_thread, get_ident +from _thread import start_new_thread, get_ident, TIMEOUT_MAX import threading import unittest @@ -62,6 +62,14 @@ class BaseTestCase(unittest.TestCase): support.threading_cleanup(*self._threads) support.reap_children() + def assertTimeout(self, actual, expected): + # The waiting and/or time.time() can be imprecise, which + # is why comparing to the expected value would sometimes fail + # (especially under Windows). + self.assertGreaterEqual(actual, expected * 0.6) + # Test nothing insane happened + self.assertLess(actual, expected * 10.0) + class BaseLockTests(BaseTestCase): """ @@ -143,6 +151,32 @@ class BaseLockTests(BaseTestCase): Bunch(f, 15).wait_for_finished() self.assertEqual(n, len(threading.enumerate())) + def test_timeout(self): + lock = self.locktype() + # Can't set timeout if not blocking + self.assertRaises(ValueError, lock.acquire, 0, 1) + # Invalid timeout values + self.assertRaises(ValueError, lock.acquire, timeout=-100) + self.assertRaises(OverflowError, lock.acquire, timeout=1e100) + self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1) + # TIMEOUT_MAX is ok + lock.acquire(timeout=TIMEOUT_MAX) + lock.release() + t1 = time.time() + self.assertTrue(lock.acquire(timeout=5)) + t2 = time.time() + # Just a sanity test that it didn't actually wait for the timeout. + self.assertLess(t2 - t1, 5) + results = [] + def f(): + t1 = time.time() + results.append(lock.acquire(timeout=0.5)) + t2 = time.time() + results.append(t2 - t1) + Bunch(f, 1).wait_for_finished() + self.assertFalse(results[0]) + self.assertTimeout(results[1], 0.5) + class LockTests(BaseLockTests): """ @@ -284,14 +318,14 @@ class EventTests(BaseTestCase): def f(): results1.append(evt.wait(0.0)) t1 = time.time() - r = evt.wait(0.2) + r = evt.wait(0.5) t2 = time.time() results2.append((r, t2 - t1)) Bunch(f, N).wait_for_finished() self.assertEqual(results1, [False] * N) for r, dt in results2: self.assertFalse(r) - self.assertTrue(dt >= 0.2, dt) + self.assertTimeout(dt, 0.5) # The event is set results1 = [] results2 = [] @@ -341,13 +375,13 @@ class ConditionTests(BaseTestCase): phase_num = 0 def f(): cond.acquire() - cond.wait() + result = cond.wait() cond.release() - results1.append(phase_num) + results1.append((result, phase_num)) cond.acquire() - cond.wait() + result = cond.wait() cond.release() - results2.append(phase_num) + results2.append((result, phase_num)) b = Bunch(f, N) b.wait_for_started() _wait() @@ -360,7 +394,7 @@ class ConditionTests(BaseTestCase): cond.release() while len(results1) < 3: _wait() - self.assertEqual(results1, [1] * 3) + self.assertEqual(results1, [(True, 1)] * 3) self.assertEqual(results2, []) # Notify 5 threads: they might be in their first or second wait cond.acquire() @@ -370,8 +404,8 @@ class ConditionTests(BaseTestCase): cond.release() while len(results1) + len(results2) < 8: _wait() - self.assertEqual(results1, [1] * 3 + [2] * 2) - self.assertEqual(results2, [2] * 3) + self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2) + self.assertEqual(results2, [(True, 2)] * 3) # Notify all threads: they are all in their second wait cond.acquire() cond.notify_all() @@ -380,8 +414,8 @@ class ConditionTests(BaseTestCase): cond.release() while len(results2) < 5: _wait() - self.assertEqual(results1, [1] * 3 + [2] * 2) - self.assertEqual(results2, [2] * 3 + [3] * 2) + self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2) + self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2) b.wait_for_finished() def test_notify(self): @@ -397,14 +431,60 @@ class ConditionTests(BaseTestCase): def f(): cond.acquire() t1 = time.time() - cond.wait(0.2) + result = cond.wait(0.5) t2 = time.time() cond.release() - results.append(t2 - t1) + results.append((t2 - t1, result)) Bunch(f, N).wait_for_finished() - self.assertEqual(len(results), 5) - for dt in results: - self.assertTrue(dt >= 0.2, dt) + self.assertEqual(len(results), N) + for dt, result in results: + self.assertTimeout(dt, 0.5) + # Note that conceptually (that"s the condition variable protocol) + # a wait() may succeed even if no one notifies us and before any + # timeout occurs. Spurious wakeups can occur. + # This makes it hard to verify the result value. + # In practice, this implementation has no spurious wakeups. + self.assertFalse(result) + + def test_waitfor(self): + cond = self.condtype() + state = 0 + def f(): + with cond: + result = cond.wait_for(lambda : state==4) + self.assertTrue(result) + self.assertEqual(state, 4) + b = Bunch(f, 1) + b.wait_for_started() + for i in range(5): + time.sleep(0.01) + with cond: + state += 1 + cond.notify() + b.wait_for_finished() + + def test_waitfor_timeout(self): + cond = self.condtype() + state = 0 + success = [] + def f(): + with cond: + dt = time.time() + result = cond.wait_for(lambda : state==4, timeout=0.1) + dt = time.time() - dt + self.assertFalse(result) + self.assertTimeout(dt, 0.1) + success.append(None) + b = Bunch(f, 1) + b.wait_for_started() + # Only increment 3 times, so state == 4 is never reached. + for i in range(3): + time.sleep(0.01) + with cond: + state += 1 + cond.notify() + b.wait_for_finished() + self.assertEqual(len(success), 1) class BaseSemaphoreTests(BaseTestCase): @@ -487,6 +567,19 @@ class BaseSemaphoreTests(BaseTestCase): # ordered. self.assertEqual(sorted(results), [False] * 7 + [True] * 3 ) + def test_acquire_timeout(self): + sem = self.semtype(2) + self.assertRaises(ValueError, sem.acquire, False, timeout=1.0) + self.assertTrue(sem.acquire(timeout=0.005)) + self.assertTrue(sem.acquire(timeout=0.005)) + self.assertFalse(sem.acquire(timeout=0.005)) + sem.release() + self.assertTrue(sem.acquire(timeout=0.005)) + t = time.time() + self.assertFalse(sem.acquire(timeout=0.5)) + dt = time.time() - t + self.assertTimeout(dt, 0.5) + def test_default_value(self): # The default initial value is 1. sem = self.semtype() @@ -544,3 +637,196 @@ class BoundedSemaphoreTests(BaseSemaphoreTests): sem.acquire() sem.release() self.assertRaises(ValueError, sem.release) + + +class BarrierTests(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 2.0 + + def setUp(self): + self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout) + def tearDown(self): + self.barrier.abort() + + def run_threads(self, f): + b = Bunch(f, self.N-1) + f() + b.wait_for_finished() + + def multipass(self, results, n): + m = self.barrier.parties + self.assertEqual(m, self.N) + for i in range(n): + results[0].append(True) + self.assertEqual(len(results[1]), i * m) + self.barrier.wait() + results[1].append(True) + self.assertEqual(len(results[0]), (i + 1) * m) + self.barrier.wait() + self.assertEqual(self.barrier.n_waiting, 0) + self.assertFalse(self.barrier.broken) + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [[],[]] + def f(): + self.multipass(results, passes) + self.run_threads(f) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + results = [] + def f(): + r = self.barrier.wait() + results.append(r) + + self.run_threads(f) + self.assertEqual(sum(results), sum(range(self.N))) + + def test_action(self): + """ + Test the 'action' callback + """ + results = [] + def action(): + results.append(True) + barrier = self.barriertype(self.N, action) + def f(): + barrier.wait() + self.assertEqual(len(results), 1) + + self.run_threads(f) + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = [] + results2 = [] + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = [] + results2 = [] + results3 = [] + def f(): + i = self.barrier.wait() + if i == self.N//2: + # Wait until the other threads are all in the barrier. + while self.barrier.n_waiting < self.N-1: + time.sleep(0.001) + self.barrier.reset() + else: + try: + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = [] + results2 = [] + results3 = [] + barrier2 = self.barriertype(self.N) + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == self.N//2: + self.barrier.reset() + barrier2.wait() + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + def test_timeout(self): + """ + Test wait(timeout) + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is late! + time.sleep(1.0) + # Default timeout is 2.0, so this is shorter. + self.assertRaises(threading.BrokenBarrierError, + self.barrier.wait, 0.5) + self.run_threads(f) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + #create a barrier with a low default timeout + barrier = self.barriertype(self.N, timeout=0.1) + def f(): + i = barrier.wait() + if i == self.N // 2: + # One thread is later than the default timeout of 0.1s. + time.sleep(1.0) + self.assertRaises(threading.BrokenBarrierError, barrier.wait) + self.run_threads(f) + + def test_single_thread(self): + b = self.barriertype(1) + b.wait() + b.wait() |