summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Florisson <markflorisson88@gmail.com>2011-09-19 22:56:09 +0100
committerMark Florisson <markflorisson88@gmail.com>2011-09-19 23:00:26 +0100
commit86234dee59e2222373e317f38a23e46f8d488474 (patch)
treec2e83aed7d68b4f053ffc23113b737f63c6afb48
parentc54e21df64fa5df4ee872b28cec948465d4fefbf (diff)
downloadcython-86234dee59e2222373e317f38a23e46f8d488474.tar.gz
Fix num_thread for prange() without parallel() + more error checks0.15.1
-rw-r--r--Cython/Compiler/Nodes.py33
-rw-r--r--tests/errors/e_invalid_num_threads.pyx26
-rw-r--r--tests/run/sequential_parallel.pyx12
3 files changed, 58 insertions, 13 deletions
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index 69b5b6ef4..638aeea2b 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -5915,7 +5915,11 @@ class ParallelStatNode(StatNode, ParallelNode):
self.body.analyse_declarations(env)
if self.kwargs:
- self.kwargs = self.kwargs.compile_time_value(env)
+ try:
+ self.kwargs = self.kwargs.compile_time_value(env)
+ except Exception, e:
+ error(self.kwargs.pos, "Only compile-time values may be "
+ "supplied as keyword arguments")
else:
self.kwargs = {}
@@ -5929,6 +5933,17 @@ class ParallelStatNode(StatNode, ParallelNode):
self.body.analyse_expressions(env)
self.analyse_sharing_attributes(env)
+ if self.num_threads is not None:
+ if self.parent and self.parent.num_threads is not None:
+ error(self.pos,
+ "num_threads already declared in outer section")
+ elif not isinstance(self.num_threads, (int, long)):
+ error(self.pos,
+ "Invalid value for num_threads argument, expected an int")
+ elif self.num_threads <= 0:
+ error(self.pos,
+ "argument to num_threads must be greater than 0")
+
def analyse_sharing_attributes(self, env):
"""
Analyse the privates for this block and set them in self.privates.
@@ -6068,11 +6083,8 @@ class ParallelStatNode(StatNode, ParallelNode):
Write self.num_threads if set as the num_threads OpenMP directive
"""
if self.num_threads is not None:
- if isinstance(self.num_threads, (int, long)):
- code.put(" num_threads(%d)" % (self.num_threads,))
- else:
- error(self.pos, "Invalid value for num_threads argument, "
- "expected an int")
+ code.put(" num_threads(%d)" % (self.num_threads,))
+
def declare_closure_privates(self, code):
"""
@@ -6727,11 +6739,11 @@ class ParallelRangeNode(ParallelStatNode):
if not self.is_parallel:
code.put("#pragma omp for")
self.privatization_insertion_point = code.insertion_point()
- # reduction_codepoint = self.parent.privatization_insertion_point
+ reduction_codepoint = self.parent.privatization_insertion_point
else:
code.put("#pragma omp parallel")
self.privatization_insertion_point = code.insertion_point()
- # reduction_codepoint = self.privatization_insertion_point
+ reduction_codepoint = self.privatization_insertion_point
code.putln("")
code.putln("#endif /* _OPENMP */")
@@ -6743,11 +6755,6 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("#ifdef _OPENMP")
code.put("#pragma omp for")
- # Nested parallelism is not supported, so we can put reductions on the
- # for and not on the parallel (but would be valid, but gcc45 bugs on
- # the former)
- reduction_codepoint = code
-
for entry, (op, lastprivate) in self.privates.iteritems():
# Don't declare the index variable as a reduction
if op and op in "+*-&^|" and entry != self.target.entry:
diff --git a/tests/errors/e_invalid_num_threads.pyx b/tests/errors/e_invalid_num_threads.pyx
new file mode 100644
index 000000000..641ea4f17
--- /dev/null
+++ b/tests/errors/e_invalid_num_threads.pyx
@@ -0,0 +1,26 @@
+# mode: error
+
+from cython.parallel cimport parallel, prange
+
+cdef int i
+
+# valid
+with nogil, parallel(num_threads=None):
+ pass
+
+# invalid
+with nogil, parallel(num_threads=0):
+ pass
+
+with nogil, parallel(num_threads=i):
+ pass
+
+with nogil, parallel(num_threads=2):
+ for i in prange(10, num_threads=2):
+ pass
+
+_ERRORS = u"""
+e_invalid_num_threads.pyx:12:20: argument to num_threads must be greater than 0
+e_invalid_num_threads.pyx:15:20: Invalid value for num_threads argument, expected an int
+e_invalid_num_threads.pyx:19:19: num_threads already declared in outer section
+"""
diff --git a/tests/run/sequential_parallel.pyx b/tests/run/sequential_parallel.pyx
index 3ce50697c..68dce2a7c 100644
--- a/tests/run/sequential_parallel.pyx
+++ b/tests/run/sequential_parallel.pyx
@@ -720,3 +720,15 @@ def test_nogil_cdef_except_clause():
for i in prange(10, nogil=True):
nogil_cdef_except_clause()
nogil_cdef_except_star()
+
+def test_num_threads_compile():
+ cdef int i
+ for i in prange(10, nogil=True, num_threads=2):
+ pass
+
+ with nogil, cython.parallel.parallel(num_threads=2):
+ pass
+
+ with nogil, cython.parallel.parallel():
+ for i in prange(10, num_threads=2):
+ pass