summaryrefslogtreecommitdiff
path: root/docs/examples/tutorial/parallelization/median.py
diff options
context:
space:
mode:
Diffstat (limited to 'docs/examples/tutorial/parallelization/median.py')
-rw-r--r--docs/examples/tutorial/parallelization/median.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/docs/examples/tutorial/parallelization/median.py b/docs/examples/tutorial/parallelization/median.py
new file mode 100644
index 000000000..535a2b136
--- /dev/null
+++ b/docs/examples/tutorial/parallelization/median.py
@@ -0,0 +1,35 @@
+# distutils: language = c++
+
+from cython.parallel import parallel, prange
+from cython.cimports.libc.stdlib import malloc, free
+from cython.cimports.libcpp.algorithm import nth_element
+import cython
+from cython.operator import dereference
+
+import numpy as np
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+def median_along_axis0(x: cython.double[:,:]):
+ out: cython.double[::1] = np.empty(x.shape[1])
+ i: cython.Py_ssize_t
+ j: cython.Py_ssize_t
+ scratch: cython.pointer(cython.double)
+ median_it: cython.pointer(cython.double)
+ with cython.nogil, parallel():
+ # allocate scratch space per loop
+ scratch = cython.cast(
+ cython.pointer(cython.double),
+ malloc(cython.sizeof(cython.double)*x.shape[0]))
+ try:
+ for i in prange(x.shape[1]):
+ # copy row into scratch space
+ for j in range(x.shape[0]):
+ scratch[j] = x[j, i]
+ median_it = scratch + x.shape[0]//2
+ nth_element(scratch, median_it, scratch + x.shape[0])
+ # for the sake of a simple example, don't handle even lengths...
+ out[i] = dereference(median_it)
+ finally:
+ free(scratch)
+ return np.asarray(out)