summaryrefslogtreecommitdiff
path: root/docs/examples/tutorial/parallelization/median.py
blob: 535a2b1366ff05c05639e8179562fd2c9d7ef35f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# distutils: language = c++

from cython.parallel import parallel, prange
from cython.cimports.libc.stdlib import malloc, free
from cython.cimports.libcpp.algorithm import nth_element
import cython
from cython.operator import dereference

import numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
def median_along_axis0(x: cython.double[:,:]):
    out: cython.double[::1] = np.empty(x.shape[1])
    i: cython.Py_ssize_t
    j: cython.Py_ssize_t
    scratch: cython.pointer(cython.double)
    median_it: cython.pointer(cython.double)
    with cython.nogil, parallel():
        # allocate scratch space per loop
        scratch = cython.cast(
            cython.pointer(cython.double),
            malloc(cython.sizeof(cython.double)*x.shape[0]))
        try:
            for i in prange(x.shape[1]):
                # copy row into scratch space
                for j in range(x.shape[0]):
                    scratch[j] = x[j, i]
                median_it = scratch + x.shape[0]//2
                nth_element(scratch, median_it, scratch + x.shape[0])
                # for the sake of a simple example, don't handle even lengths...
                out[i] = dereference(median_it)
        finally:
            free(scratch)
    return np.asarray(out)