summaryrefslogtreecommitdiff
path: root/Lib/test/lock_tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/lock_tests.py')
-rw-r--r--Lib/test/lock_tests.py320
1 files changed, 303 insertions, 17 deletions
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py
index 04f7422c8d..b6d818e4d4 100644
--- a/Lib/test/lock_tests.py
+++ b/Lib/test/lock_tests.py
@@ -4,7 +4,7 @@ Various tests for synchronization primitives.
import sys
import time
-from _thread import start_new_thread, get_ident
+from _thread import start_new_thread, get_ident, TIMEOUT_MAX
import threading
import unittest
@@ -62,6 +62,14 @@ class BaseTestCase(unittest.TestCase):
support.threading_cleanup(*self._threads)
support.reap_children()
+ def assertTimeout(self, actual, expected):
+ # The waiting and/or time.time() can be imprecise, which
+ # is why comparing to the expected value would sometimes fail
+ # (especially under Windows).
+ self.assertGreaterEqual(actual, expected * 0.6)
+ # Test nothing insane happened
+ self.assertLess(actual, expected * 10.0)
+
class BaseLockTests(BaseTestCase):
"""
@@ -143,6 +151,32 @@ class BaseLockTests(BaseTestCase):
Bunch(f, 15).wait_for_finished()
self.assertEqual(n, len(threading.enumerate()))
+ def test_timeout(self):
+ lock = self.locktype()
+ # Can't set timeout if not blocking
+ self.assertRaises(ValueError, lock.acquire, 0, 1)
+ # Invalid timeout values
+ self.assertRaises(ValueError, lock.acquire, timeout=-100)
+ self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
+ self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
+ # TIMEOUT_MAX is ok
+ lock.acquire(timeout=TIMEOUT_MAX)
+ lock.release()
+ t1 = time.time()
+ self.assertTrue(lock.acquire(timeout=5))
+ t2 = time.time()
+ # Just a sanity test that it didn't actually wait for the timeout.
+ self.assertLess(t2 - t1, 5)
+ results = []
+ def f():
+ t1 = time.time()
+ results.append(lock.acquire(timeout=0.5))
+ t2 = time.time()
+ results.append(t2 - t1)
+ Bunch(f, 1).wait_for_finished()
+ self.assertFalse(results[0])
+ self.assertTimeout(results[1], 0.5)
+
class LockTests(BaseLockTests):
"""
@@ -284,14 +318,14 @@ class EventTests(BaseTestCase):
def f():
results1.append(evt.wait(0.0))
t1 = time.time()
- r = evt.wait(0.2)
+ r = evt.wait(0.5)
t2 = time.time()
results2.append((r, t2 - t1))
Bunch(f, N).wait_for_finished()
self.assertEqual(results1, [False] * N)
for r, dt in results2:
self.assertFalse(r)
- self.assertTrue(dt >= 0.2, dt)
+ self.assertTimeout(dt, 0.5)
# The event is set
results1 = []
results2 = []
@@ -341,13 +375,13 @@ class ConditionTests(BaseTestCase):
phase_num = 0
def f():
cond.acquire()
- cond.wait()
+ result = cond.wait()
cond.release()
- results1.append(phase_num)
+ results1.append((result, phase_num))
cond.acquire()
- cond.wait()
+ result = cond.wait()
cond.release()
- results2.append(phase_num)
+ results2.append((result, phase_num))
b = Bunch(f, N)
b.wait_for_started()
_wait()
@@ -360,7 +394,7 @@ class ConditionTests(BaseTestCase):
cond.release()
while len(results1) < 3:
_wait()
- self.assertEqual(results1, [1] * 3)
+ self.assertEqual(results1, [(True, 1)] * 3)
self.assertEqual(results2, [])
# Notify 5 threads: they might be in their first or second wait
cond.acquire()
@@ -370,8 +404,8 @@ class ConditionTests(BaseTestCase):
cond.release()
while len(results1) + len(results2) < 8:
_wait()
- self.assertEqual(results1, [1] * 3 + [2] * 2)
- self.assertEqual(results2, [2] * 3)
+ self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
+ self.assertEqual(results2, [(True, 2)] * 3)
# Notify all threads: they are all in their second wait
cond.acquire()
cond.notify_all()
@@ -380,8 +414,8 @@ class ConditionTests(BaseTestCase):
cond.release()
while len(results2) < 5:
_wait()
- self.assertEqual(results1, [1] * 3 + [2] * 2)
- self.assertEqual(results2, [2] * 3 + [3] * 2)
+ self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2)
+ self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2)
b.wait_for_finished()
def test_notify(self):
@@ -397,14 +431,60 @@ class ConditionTests(BaseTestCase):
def f():
cond.acquire()
t1 = time.time()
- cond.wait(0.2)
+ result = cond.wait(0.5)
t2 = time.time()
cond.release()
- results.append(t2 - t1)
+ results.append((t2 - t1, result))
Bunch(f, N).wait_for_finished()
- self.assertEqual(len(results), 5)
- for dt in results:
- self.assertTrue(dt >= 0.2, dt)
+ self.assertEqual(len(results), N)
+ for dt, result in results:
+ self.assertTimeout(dt, 0.5)
+ # Note that conceptually (that"s the condition variable protocol)
+ # a wait() may succeed even if no one notifies us and before any
+ # timeout occurs. Spurious wakeups can occur.
+ # This makes it hard to verify the result value.
+ # In practice, this implementation has no spurious wakeups.
+ self.assertFalse(result)
+
+ def test_waitfor(self):
+ cond = self.condtype()
+ state = 0
+ def f():
+ with cond:
+ result = cond.wait_for(lambda : state==4)
+ self.assertTrue(result)
+ self.assertEqual(state, 4)
+ b = Bunch(f, 1)
+ b.wait_for_started()
+ for i in range(5):
+ time.sleep(0.01)
+ with cond:
+ state += 1
+ cond.notify()
+ b.wait_for_finished()
+
+ def test_waitfor_timeout(self):
+ cond = self.condtype()
+ state = 0
+ success = []
+ def f():
+ with cond:
+ dt = time.time()
+ result = cond.wait_for(lambda : state==4, timeout=0.1)
+ dt = time.time() - dt
+ self.assertFalse(result)
+ self.assertTimeout(dt, 0.1)
+ success.append(None)
+ b = Bunch(f, 1)
+ b.wait_for_started()
+ # Only increment 3 times, so state == 4 is never reached.
+ for i in range(3):
+ time.sleep(0.01)
+ with cond:
+ state += 1
+ cond.notify()
+ b.wait_for_finished()
+ self.assertEqual(len(success), 1)
class BaseSemaphoreTests(BaseTestCase):
@@ -487,6 +567,19 @@ class BaseSemaphoreTests(BaseTestCase):
# ordered.
self.assertEqual(sorted(results), [False] * 7 + [True] * 3 )
+ def test_acquire_timeout(self):
+ sem = self.semtype(2)
+ self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
+ self.assertTrue(sem.acquire(timeout=0.005))
+ self.assertTrue(sem.acquire(timeout=0.005))
+ self.assertFalse(sem.acquire(timeout=0.005))
+ sem.release()
+ self.assertTrue(sem.acquire(timeout=0.005))
+ t = time.time()
+ self.assertFalse(sem.acquire(timeout=0.5))
+ dt = time.time() - t
+ self.assertTimeout(dt, 0.5)
+
def test_default_value(self):
# The default initial value is 1.
sem = self.semtype()
@@ -544,3 +637,196 @@ class BoundedSemaphoreTests(BaseSemaphoreTests):
sem.acquire()
sem.release()
self.assertRaises(ValueError, sem.release)
+
+
+class BarrierTests(BaseTestCase):
+ """
+ Tests for Barrier objects.
+ """
+ N = 5
+ defaultTimeout = 2.0
+
+ def setUp(self):
+ self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
+ def tearDown(self):
+ self.barrier.abort()
+
+ def run_threads(self, f):
+ b = Bunch(f, self.N-1)
+ f()
+ b.wait_for_finished()
+
+ def multipass(self, results, n):
+ m = self.barrier.parties
+ self.assertEqual(m, self.N)
+ for i in range(n):
+ results[0].append(True)
+ self.assertEqual(len(results[1]), i * m)
+ self.barrier.wait()
+ results[1].append(True)
+ self.assertEqual(len(results[0]), (i + 1) * m)
+ self.barrier.wait()
+ self.assertEqual(self.barrier.n_waiting, 0)
+ self.assertFalse(self.barrier.broken)
+
+ def test_barrier(self, passes=1):
+ """
+ Test that a barrier is passed in lockstep
+ """
+ results = [[],[]]
+ def f():
+ self.multipass(results, passes)
+ self.run_threads(f)
+
+ def test_barrier_10(self):
+ """
+ Test that a barrier works for 10 consecutive runs
+ """
+ return self.test_barrier(10)
+
+ def test_wait_return(self):
+ """
+ test the return value from barrier.wait
+ """
+ results = []
+ def f():
+ r = self.barrier.wait()
+ results.append(r)
+
+ self.run_threads(f)
+ self.assertEqual(sum(results), sum(range(self.N)))
+
+ def test_action(self):
+ """
+ Test the 'action' callback
+ """
+ results = []
+ def action():
+ results.append(True)
+ barrier = self.barriertype(self.N, action)
+ def f():
+ barrier.wait()
+ self.assertEqual(len(results), 1)
+
+ self.run_threads(f)
+
+ def test_abort(self):
+ """
+ Test that an abort will put the barrier in a broken state
+ """
+ results1 = []
+ results2 = []
+ def f():
+ try:
+ i = self.barrier.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ self.barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ self.barrier.abort()
+ pass
+
+ self.run_threads(f)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(self.barrier.broken)
+
+ def test_reset(self):
+ """
+ Test that a 'reset' on a barrier frees the waiting threads
+ """
+ results1 = []
+ results2 = []
+ results3 = []
+ def f():
+ i = self.barrier.wait()
+ if i == self.N//2:
+ # Wait until the other threads are all in the barrier.
+ while self.barrier.n_waiting < self.N-1:
+ time.sleep(0.001)
+ self.barrier.reset()
+ else:
+ try:
+ self.barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ # Now, pass the barrier again
+ self.barrier.wait()
+ results3.append(True)
+
+ self.run_threads(f)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+
+ def test_abort_and_reset(self):
+ """
+ Test that a barrier can be reset after being broken.
+ """
+ results1 = []
+ results2 = []
+ results3 = []
+ barrier2 = self.barriertype(self.N)
+ def f():
+ try:
+ i = self.barrier.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ self.barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ self.barrier.abort()
+ pass
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ if barrier2.wait() == self.N//2:
+ self.barrier.reset()
+ barrier2.wait()
+ self.barrier.wait()
+ results3.append(True)
+
+ self.run_threads(f)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+ def test_timeout(self):
+ """
+ Test wait(timeout)
+ """
+ def f():
+ i = self.barrier.wait()
+ if i == self.N // 2:
+ # One thread is late!
+ time.sleep(1.0)
+ # Default timeout is 2.0, so this is shorter.
+ self.assertRaises(threading.BrokenBarrierError,
+ self.barrier.wait, 0.5)
+ self.run_threads(f)
+
+ def test_default_timeout(self):
+ """
+ Test the barrier's default timeout
+ """
+ #create a barrier with a low default timeout
+ barrier = self.barriertype(self.N, timeout=0.1)
+ def f():
+ i = barrier.wait()
+ if i == self.N // 2:
+ # One thread is later than the default timeout of 0.1s.
+ time.sleep(1.0)
+ self.assertRaises(threading.BrokenBarrierError, barrier.wait)
+ self.run_threads(f)
+
+ def test_single_thread(self):
+ b = self.barriertype(1)
+ b.wait()
+ b.wait()