1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
# tag: run
# tag: openmp
cimport cython.parallel
from cython.parallel import prange, threadid
cimport openmp
from libc.stdlib cimport malloc, free
openmp.omp_set_nested(1)
cdef int forward(int x) nogil:
return x
def test_parallel():
"""
>>> test_parallel()
"""
cdef int maxthreads = openmp.omp_get_max_threads()
cdef int *buf = <int *> malloc(sizeof(int) * maxthreads)
if buf == NULL:
raise MemoryError
with nogil, cython.parallel.parallel():
buf[threadid()] = threadid()
# Recognise threadid() also when it's used in a function argument.
# See https://github.com/cython/cython/issues/3594
buf[forward(cython.parallel.threadid())] = forward(threadid())
for i in range(maxthreads):
assert buf[i] == i
free(buf)
cdef int get_num_threads() noexcept with gil:
print "get_num_threads called"
return 3
def test_num_threads():
"""
>>> test_num_threads()
1
get_num_threads called
3
get_num_threads called
3
"""
cdef int dyn = openmp.omp_get_dynamic()
cdef int num_threads
cdef int *p = &num_threads
openmp.omp_set_dynamic(0)
with nogil, cython.parallel.parallel(num_threads=1):
p[0] = openmp.omp_get_num_threads()
print num_threads
with nogil, cython.parallel.parallel(num_threads=get_num_threads()):
p[0] = openmp.omp_get_num_threads()
print num_threads
cdef int i
num_threads = 0xbad
for i in prange(1, nogil=True, num_threads=get_num_threads()):
p[0] = openmp.omp_get_num_threads()
break
openmp.omp_set_dynamic(dyn)
return num_threads
'''
def test_parallel_catch():
"""
>>> test_parallel_catch()
True
"""
cdef int i, j, num_threads
exceptions = []
for i in prange(100, nogil=True, num_threads=4):
num_threads = openmp.omp_get_num_threads()
with gil:
try:
for j in prange(100, nogil=True):
if i + j > 60:
with gil:
raise Exception("try and catch me if you can!")
except Exception, e:
exceptions.append(e)
break
print len(exceptions) == num_threads
assert len(exceptions) == num_threads, (len(exceptions), num_threads)
'''
OPENMP_PARALLEL = True
include "sequential_parallel.pyx"
|