blob: 535a2b1366ff05c05639e8179562fd2c9d7ef35f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
# distutils: language = c++
from cython.parallel import parallel, prange
from cython.cimports.libc.stdlib import malloc, free
from cython.cimports.libcpp.algorithm import nth_element
import cython
from cython.operator import dereference
import numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def median_along_axis0(x: cython.double[:,:]):
out: cython.double[::1] = np.empty(x.shape[1])
i: cython.Py_ssize_t
j: cython.Py_ssize_t
scratch: cython.pointer(cython.double)
median_it: cython.pointer(cython.double)
with cython.nogil, parallel():
# allocate scratch space per loop
scratch = cython.cast(
cython.pointer(cython.double),
malloc(cython.sizeof(cython.double)*x.shape[0]))
try:
for i in prange(x.shape[1]):
# copy row into scratch space
for j in range(x.shape[0]):
scratch[j] = x[j, i]
median_it = scratch + x.shape[0]//2
nth_element(scratch, median_it, scratch + x.shape[0])
# for the sake of a simple example, don't handle even lengths...
out[i] = dereference(median_it)
finally:
free(scratch)
return np.asarray(out)
|