diff options
author | gfyoung <gfyoung@mit.edu> | 2016-01-12 02:14:38 +0000 |
---|---|---|
committer | gfyoung <gfyoung@mit.edu> | 2016-01-20 11:47:36 +0000 |
commit | 0b150b8cfafa4d8c75e47e6e9c8b3b23d7c0a2b6 (patch) | |
tree | 9daeac2517930231defc7651b2055d24b5683fd2 /numpy/random | |
parent | a38942a87f9b76251e0950ba330f96a8d76c6d36 (diff) | |
download | numpy-0b150b8cfafa4d8c75e47e6e9c8b3b23d7c0a2b6.tar.gz |
MAINT: Simplified mtrand.pyx helpers
Refactored methods that broadcast arguments
together by finding additional common ground
between code in the if...else branches that
involved a size parameter being passed in.
Diffstat (limited to 'numpy/random')
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 178 |
1 files changed, 74 insertions, 104 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index c8738cf6f..2f315c5d3 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -250,28 +250,23 @@ cdef object cont2_array(rk_state *state, rk_cont2 func, object size, cdef broadcast multi if size is None: - multi = <broadcast> PyArray_MultiIterNew(2, <void *>oa, <void *>ob) - array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_DOUBLE) - array_data = <double *>PyArray_DATA(array) - with lock, nogil: - for i from 0 <= i < multi.size: - oa_data = <double *>PyArray_MultiIter_DATA(multi, 0) - ob_data = <double *>PyArray_MultiIter_DATA(multi, 1) - array_data[i] = func(state, oa_data[0], ob_data[0]) - PyArray_MultiIter_NEXT(multi) + multi = <broadcast>np.broadcast(oa, ob) + array = <ndarray>np.empty(multi.shape, dtype=np.float64) else: - array = <ndarray>np.empty(size, np.float64) - array_data = <double *>PyArray_DATA(array) - multi = <broadcast>PyArray_MultiIterNew(3, <void*>array, <void *>oa, <void *>ob) - if (multi.size != PyArray_SIZE(array)): + array = <ndarray>np.empty(size, dtype=np.float64) + multi = <broadcast>np.broadcast(oa, ob, array) + if multi.shape != array.shape: raise ValueError("size is not compatible with inputs") - with lock, nogil: - for i from 0 <= i < multi.size: - oa_data = <double *>PyArray_MultiIter_DATA(multi, 1) - ob_data = <double *>PyArray_MultiIter_DATA(multi, 2) - array_data[i] = func(state, oa_data[0], ob_data[0]) - PyArray_MultiIter_NEXTi(multi, 1) - PyArray_MultiIter_NEXTi(multi, 2) + + array_data = <double *>PyArray_DATA(array) + + with lock, nogil: + for i in range(multi.size): + oa_data = <double *>PyArray_MultiIter_DATA(multi, 0) + ob_data = <double *>PyArray_MultiIter_DATA(multi, 1) + array_data[i] = func(state, oa_data[0], ob_data[0]) + PyArray_MultiIter_NEXT(multi) + return array cdef object cont3_array_sc(rk_state *state, rk_cont3 func, object size, double a, @@ -305,30 +300,24 @@ cdef object cont3_array(rk_state *state, rk_cont3 func, object size, cdef broadcast multi if size is None: - multi = <broadcast> PyArray_MultiIterNew(3, <void *>oa, <void *>ob, <void *>oc) - array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_DOUBLE) - array_data = <double *>PyArray_DATA(array) - with lock, nogil: - for i from 0 <= i < multi.size: - oa_data = <double *>PyArray_MultiIter_DATA(multi, 0) - ob_data = <double *>PyArray_MultiIter_DATA(multi, 1) - oc_data = <double *>PyArray_MultiIter_DATA(multi, 2) - array_data[i] = func(state, oa_data[0], ob_data[0], oc_data[0]) - PyArray_MultiIter_NEXT(multi) + multi = <broadcast>np.broadcast(oa, ob, oc) + array = <ndarray>np.empty(multi.shape, dtype=np.float64) else: - array = <ndarray>np.empty(size, np.float64) - array_data = <double *>PyArray_DATA(array) - multi = <broadcast>PyArray_MultiIterNew(4, <void*>array, <void *>oa, - <void *>ob, <void *>oc) - if (multi.size != PyArray_SIZE(array)): + array = <ndarray>np.empty(size, dtype=np.float64) + multi = <broadcast>np.broadcast(oa, ob, oc, array) + if multi.shape != array.shape: raise ValueError("size is not compatible with inputs") - with lock, nogil: - for i from 0 <= i < multi.size: - oa_data = <double *>PyArray_MultiIter_DATA(multi, 1) - ob_data = <double *>PyArray_MultiIter_DATA(multi, 2) - oc_data = <double *>PyArray_MultiIter_DATA(multi, 3) - array_data[i] = func(state, oa_data[0], ob_data[0], oc_data[0]) - PyArray_MultiIter_NEXT(multi) + + array_data = <double *>PyArray_DATA(array) + + with lock, nogil: + for i in range(multi.size): + oa_data = <double *>PyArray_MultiIter_DATA(multi, 0) + ob_data = <double *>PyArray_MultiIter_DATA(multi, 1) + oc_data = <double *>PyArray_MultiIter_DATA(multi, 2) + array_data[i] = func(state, oa_data[0], ob_data[0], oc_data[0]) + PyArray_MultiIter_NEXT(multi) + return array cdef object disc0_array(rk_state *state, rk_disc0 func, object size, object lock): @@ -376,28 +365,22 @@ cdef object discnp_array(rk_state *state, rk_discnp func, object size, cdef broadcast multi if size is None: - multi = <broadcast> PyArray_MultiIterNew(2, <void *>on, <void *>op) - array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_LONG) - array_data = <long *>PyArray_DATA(array) - with lock, nogil: - for i from 0 <= i < multi.size: - on_data = <long *>PyArray_MultiIter_DATA(multi, 0) - op_data = <double *>PyArray_MultiIter_DATA(multi, 1) - array_data[i] = func(state, on_data[0], op_data[0]) - PyArray_MultiIter_NEXT(multi) + multi = <broadcast>np.broadcast(on, op) + array = <ndarray>np.empty(multi.shape, dtype=int) else: - array = <ndarray>np.empty(size, int) - array_data = <long *>PyArray_DATA(array) - multi = <broadcast>PyArray_MultiIterNew(3, <void*>array, <void *>on, <void *>op) - if (multi.size != PyArray_SIZE(array)): + array = <ndarray>np.empty(size, dtype=int) + multi = <broadcast>np.broadcast(on, op, array) + if multi.shape != array.shape: raise ValueError("size is not compatible with inputs") - with lock, nogil: - for i from 0 <= i < multi.size: - on_data = <long *>PyArray_MultiIter_DATA(multi, 1) - op_data = <double *>PyArray_MultiIter_DATA(multi, 2) - array_data[i] = func(state, on_data[0], op_data[0]) - PyArray_MultiIter_NEXTi(multi, 1) - PyArray_MultiIter_NEXTi(multi, 2) + + array_data = <long *>PyArray_DATA(array) + + with lock, nogil: + for i in range(multi.size): + on_data = <long *>PyArray_MultiIter_DATA(multi, 0) + op_data = <double *>PyArray_MultiIter_DATA(multi, 1) + array_data[i] = func(state, on_data[0], op_data[0]) + PyArray_MultiIter_NEXT(multi) return array @@ -429,28 +412,22 @@ cdef object discdd_array(rk_state *state, rk_discdd func, object size, cdef broadcast multi if size is None: - multi = <broadcast> PyArray_MultiIterNew(2, <void *>on, <void *>op) - array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_LONG) - array_data = <long *>PyArray_DATA(array) - with lock, nogil: - for i from 0 <= i < multi.size: - on_data = <double *>PyArray_MultiIter_DATA(multi, 0) - op_data = <double *>PyArray_MultiIter_DATA(multi, 1) - array_data[i] = func(state, on_data[0], op_data[0]) - PyArray_MultiIter_NEXT(multi) + multi = <broadcast>np.broadcast(on, op) + array = <ndarray>np.empty(multi.shape, dtype=int) else: - array = <ndarray>np.empty(size, int) - array_data = <long *>PyArray_DATA(array) - multi = <broadcast>PyArray_MultiIterNew(3, <void*>array, <void *>on, <void *>op) - if (multi.size != PyArray_SIZE(array)): + array = <ndarray>np.empty(size, dtype=int) + multi = <broadcast>np.broadcast(on, op, array) + if multi.shape != array.shape: raise ValueError("size is not compatible with inputs") - with lock, nogil: - for i from 0 <= i < multi.size: - on_data = <double *>PyArray_MultiIter_DATA(multi, 1) - op_data = <double *>PyArray_MultiIter_DATA(multi, 2) - array_data[i] = func(state, on_data[0], op_data[0]) - PyArray_MultiIter_NEXTi(multi, 1) - PyArray_MultiIter_NEXTi(multi, 2) + + array_data = <long *>PyArray_DATA(array) + + with lock, nogil: + for i in range(multi.size): + on_data = <double *>PyArray_MultiIter_DATA(multi, 0) + op_data = <double *>PyArray_MultiIter_DATA(multi, 1) + array_data[i] = func(state, on_data[0], op_data[0]) + PyArray_MultiIter_NEXT(multi) return array @@ -483,30 +460,23 @@ cdef object discnmN_array(rk_state *state, rk_discnmN func, object size, cdef broadcast multi if size is None: - multi = <broadcast> PyArray_MultiIterNew(3, <void *>on, <void *>om, <void *>oN) - array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_LONG) - array_data = <long *>PyArray_DATA(array) - with lock, nogil: - for i from 0 <= i < multi.size: - on_data = <long *>PyArray_MultiIter_DATA(multi, 0) - om_data = <long *>PyArray_MultiIter_DATA(multi, 1) - oN_data = <long *>PyArray_MultiIter_DATA(multi, 2) - array_data[i] = func(state, on_data[0], om_data[0], oN_data[0]) - PyArray_MultiIter_NEXT(multi) + multi = <broadcast>np.broadcast(on, om, oN) + array = <ndarray>np.empty(multi.shape, dtype=int) else: - array = <ndarray>np.empty(size, int) - array_data = <long *>PyArray_DATA(array) - multi = <broadcast>PyArray_MultiIterNew(4, <void*>array, <void *>on, <void *>om, - <void *>oN) - if (multi.size != PyArray_SIZE(array)): + array = <ndarray>np.empty(size, dtype=int) + multi = <broadcast>np.broadcast(on, om, oN, array) + if multi.shape != array.shape: raise ValueError("size is not compatible with inputs") - with lock, nogil: - for i from 0 <= i < multi.size: - on_data = <long *>PyArray_MultiIter_DATA(multi, 1) - om_data = <long *>PyArray_MultiIter_DATA(multi, 2) - oN_data = <long *>PyArray_MultiIter_DATA(multi, 3) - array_data[i] = func(state, on_data[0], om_data[0], oN_data[0]) - PyArray_MultiIter_NEXT(multi) + + array_data = <long *>PyArray_DATA(array) + + with lock, nogil: + for i in range(multi.size): + on_data = <long *>PyArray_MultiIter_DATA(multi, 0) + om_data = <long *>PyArray_MultiIter_DATA(multi, 1) + oN_data = <long *>PyArray_MultiIter_DATA(multi, 2) + array_data[i] = func(state, on_data[0], om_data[0], oN_data[0]) + PyArray_MultiIter_NEXT(multi) return array |