diff options
author | Mark Florisson <markflorisson88@gmail.com> | 2011-09-19 22:56:09 +0100 |
---|---|---|
committer | Mark Florisson <markflorisson88@gmail.com> | 2011-09-19 23:00:26 +0100 |
commit | 86234dee59e2222373e317f38a23e46f8d488474 (patch) | |
tree | c2e83aed7d68b4f053ffc23113b737f63c6afb48 | |
parent | c54e21df64fa5df4ee872b28cec948465d4fefbf (diff) | |
download | cython-86234dee59e2222373e317f38a23e46f8d488474.tar.gz |
Fix num_thread for prange() without parallel() + more error checks0.15.1
-rw-r--r-- | Cython/Compiler/Nodes.py | 33 | ||||
-rw-r--r-- | tests/errors/e_invalid_num_threads.pyx | 26 | ||||
-rw-r--r-- | tests/run/sequential_parallel.pyx | 12 |
3 files changed, 58 insertions, 13 deletions
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 69b5b6ef4..638aeea2b 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -5915,7 +5915,11 @@ class ParallelStatNode(StatNode, ParallelNode): self.body.analyse_declarations(env) if self.kwargs: - self.kwargs = self.kwargs.compile_time_value(env) + try: + self.kwargs = self.kwargs.compile_time_value(env) + except Exception, e: + error(self.kwargs.pos, "Only compile-time values may be " + "supplied as keyword arguments") else: self.kwargs = {} @@ -5929,6 +5933,17 @@ class ParallelStatNode(StatNode, ParallelNode): self.body.analyse_expressions(env) self.analyse_sharing_attributes(env) + if self.num_threads is not None: + if self.parent and self.parent.num_threads is not None: + error(self.pos, + "num_threads already declared in outer section") + elif not isinstance(self.num_threads, (int, long)): + error(self.pos, + "Invalid value for num_threads argument, expected an int") + elif self.num_threads <= 0: + error(self.pos, + "argument to num_threads must be greater than 0") + def analyse_sharing_attributes(self, env): """ Analyse the privates for this block and set them in self.privates. @@ -6068,11 +6083,8 @@ class ParallelStatNode(StatNode, ParallelNode): Write self.num_threads if set as the num_threads OpenMP directive """ if self.num_threads is not None: - if isinstance(self.num_threads, (int, long)): - code.put(" num_threads(%d)" % (self.num_threads,)) - else: - error(self.pos, "Invalid value for num_threads argument, " - "expected an int") + code.put(" num_threads(%d)" % (self.num_threads,)) + def declare_closure_privates(self, code): """ @@ -6727,11 +6739,11 @@ class ParallelRangeNode(ParallelStatNode): if not self.is_parallel: code.put("#pragma omp for") self.privatization_insertion_point = code.insertion_point() - # reduction_codepoint = self.parent.privatization_insertion_point + reduction_codepoint = self.parent.privatization_insertion_point else: code.put("#pragma omp parallel") self.privatization_insertion_point = code.insertion_point() - # reduction_codepoint = self.privatization_insertion_point + reduction_codepoint = self.privatization_insertion_point code.putln("") code.putln("#endif /* _OPENMP */") @@ -6743,11 +6755,6 @@ class ParallelRangeNode(ParallelStatNode): code.putln("#ifdef _OPENMP") code.put("#pragma omp for") - # Nested parallelism is not supported, so we can put reductions on the - # for and not on the parallel (but would be valid, but gcc45 bugs on - # the former) - reduction_codepoint = code - for entry, (op, lastprivate) in self.privates.iteritems(): # Don't declare the index variable as a reduction if op and op in "+*-&^|" and entry != self.target.entry: diff --git a/tests/errors/e_invalid_num_threads.pyx b/tests/errors/e_invalid_num_threads.pyx new file mode 100644 index 000000000..641ea4f17 --- /dev/null +++ b/tests/errors/e_invalid_num_threads.pyx @@ -0,0 +1,26 @@ +# mode: error + +from cython.parallel cimport parallel, prange + +cdef int i + +# valid +with nogil, parallel(num_threads=None): + pass + +# invalid +with nogil, parallel(num_threads=0): + pass + +with nogil, parallel(num_threads=i): + pass + +with nogil, parallel(num_threads=2): + for i in prange(10, num_threads=2): + pass + +_ERRORS = u""" +e_invalid_num_threads.pyx:12:20: argument to num_threads must be greater than 0 +e_invalid_num_threads.pyx:15:20: Invalid value for num_threads argument, expected an int +e_invalid_num_threads.pyx:19:19: num_threads already declared in outer section +""" diff --git a/tests/run/sequential_parallel.pyx b/tests/run/sequential_parallel.pyx index 3ce50697c..68dce2a7c 100644 --- a/tests/run/sequential_parallel.pyx +++ b/tests/run/sequential_parallel.pyx @@ -720,3 +720,15 @@ def test_nogil_cdef_except_clause(): for i in prange(10, nogil=True): nogil_cdef_except_clause() nogil_cdef_except_star() + +def test_num_threads_compile(): + cdef int i + for i in prange(10, nogil=True, num_threads=2): + pass + + with nogil, cython.parallel.parallel(num_threads=2): + pass + + with nogil, cython.parallel.parallel(): + for i in prange(10, num_threads=2): + pass |