summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorgfyoung <gfyoung@mit.edu>2016-01-12 02:14:38 +0000
committergfyoung <gfyoung@mit.edu>2016-01-20 11:47:36 +0000
commit0b150b8cfafa4d8c75e47e6e9c8b3b23d7c0a2b6 (patch)
tree9daeac2517930231defc7651b2055d24b5683fd2 /numpy/random
parenta38942a87f9b76251e0950ba330f96a8d76c6d36 (diff)
downloadnumpy-0b150b8cfafa4d8c75e47e6e9c8b3b23d7c0a2b6.tar.gz
MAINT: Simplified mtrand.pyx helpers
Refactored methods that broadcast arguments together by finding additional common ground between code in the if...else branches that involved a size parameter being passed in.
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/mtrand/mtrand.pyx178
1 files changed, 74 insertions, 104 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index c8738cf6f..2f315c5d3 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -250,28 +250,23 @@ cdef object cont2_array(rk_state *state, rk_cont2 func, object size,
cdef broadcast multi
if size is None:
- multi = <broadcast> PyArray_MultiIterNew(2, <void *>oa, <void *>ob)
- array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_DOUBLE)
- array_data = <double *>PyArray_DATA(array)
- with lock, nogil:
- for i from 0 <= i < multi.size:
- oa_data = <double *>PyArray_MultiIter_DATA(multi, 0)
- ob_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- array_data[i] = func(state, oa_data[0], ob_data[0])
- PyArray_MultiIter_NEXT(multi)
+ multi = <broadcast>np.broadcast(oa, ob)
+ array = <ndarray>np.empty(multi.shape, dtype=np.float64)
else:
- array = <ndarray>np.empty(size, np.float64)
- array_data = <double *>PyArray_DATA(array)
- multi = <broadcast>PyArray_MultiIterNew(3, <void*>array, <void *>oa, <void *>ob)
- if (multi.size != PyArray_SIZE(array)):
+ array = <ndarray>np.empty(size, dtype=np.float64)
+ multi = <broadcast>np.broadcast(oa, ob, array)
+ if multi.shape != array.shape:
raise ValueError("size is not compatible with inputs")
- with lock, nogil:
- for i from 0 <= i < multi.size:
- oa_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- ob_data = <double *>PyArray_MultiIter_DATA(multi, 2)
- array_data[i] = func(state, oa_data[0], ob_data[0])
- PyArray_MultiIter_NEXTi(multi, 1)
- PyArray_MultiIter_NEXTi(multi, 2)
+
+ array_data = <double *>PyArray_DATA(array)
+
+ with lock, nogil:
+ for i in range(multi.size):
+ oa_data = <double *>PyArray_MultiIter_DATA(multi, 0)
+ ob_data = <double *>PyArray_MultiIter_DATA(multi, 1)
+ array_data[i] = func(state, oa_data[0], ob_data[0])
+ PyArray_MultiIter_NEXT(multi)
+
return array
cdef object cont3_array_sc(rk_state *state, rk_cont3 func, object size, double a,
@@ -305,30 +300,24 @@ cdef object cont3_array(rk_state *state, rk_cont3 func, object size,
cdef broadcast multi
if size is None:
- multi = <broadcast> PyArray_MultiIterNew(3, <void *>oa, <void *>ob, <void *>oc)
- array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_DOUBLE)
- array_data = <double *>PyArray_DATA(array)
- with lock, nogil:
- for i from 0 <= i < multi.size:
- oa_data = <double *>PyArray_MultiIter_DATA(multi, 0)
- ob_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- oc_data = <double *>PyArray_MultiIter_DATA(multi, 2)
- array_data[i] = func(state, oa_data[0], ob_data[0], oc_data[0])
- PyArray_MultiIter_NEXT(multi)
+ multi = <broadcast>np.broadcast(oa, ob, oc)
+ array = <ndarray>np.empty(multi.shape, dtype=np.float64)
else:
- array = <ndarray>np.empty(size, np.float64)
- array_data = <double *>PyArray_DATA(array)
- multi = <broadcast>PyArray_MultiIterNew(4, <void*>array, <void *>oa,
- <void *>ob, <void *>oc)
- if (multi.size != PyArray_SIZE(array)):
+ array = <ndarray>np.empty(size, dtype=np.float64)
+ multi = <broadcast>np.broadcast(oa, ob, oc, array)
+ if multi.shape != array.shape:
raise ValueError("size is not compatible with inputs")
- with lock, nogil:
- for i from 0 <= i < multi.size:
- oa_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- ob_data = <double *>PyArray_MultiIter_DATA(multi, 2)
- oc_data = <double *>PyArray_MultiIter_DATA(multi, 3)
- array_data[i] = func(state, oa_data[0], ob_data[0], oc_data[0])
- PyArray_MultiIter_NEXT(multi)
+
+ array_data = <double *>PyArray_DATA(array)
+
+ with lock, nogil:
+ for i in range(multi.size):
+ oa_data = <double *>PyArray_MultiIter_DATA(multi, 0)
+ ob_data = <double *>PyArray_MultiIter_DATA(multi, 1)
+ oc_data = <double *>PyArray_MultiIter_DATA(multi, 2)
+ array_data[i] = func(state, oa_data[0], ob_data[0], oc_data[0])
+ PyArray_MultiIter_NEXT(multi)
+
return array
cdef object disc0_array(rk_state *state, rk_disc0 func, object size, object lock):
@@ -376,28 +365,22 @@ cdef object discnp_array(rk_state *state, rk_discnp func, object size,
cdef broadcast multi
if size is None:
- multi = <broadcast> PyArray_MultiIterNew(2, <void *>on, <void *>op)
- array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_LONG)
- array_data = <long *>PyArray_DATA(array)
- with lock, nogil:
- for i from 0 <= i < multi.size:
- on_data = <long *>PyArray_MultiIter_DATA(multi, 0)
- op_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- array_data[i] = func(state, on_data[0], op_data[0])
- PyArray_MultiIter_NEXT(multi)
+ multi = <broadcast>np.broadcast(on, op)
+ array = <ndarray>np.empty(multi.shape, dtype=int)
else:
- array = <ndarray>np.empty(size, int)
- array_data = <long *>PyArray_DATA(array)
- multi = <broadcast>PyArray_MultiIterNew(3, <void*>array, <void *>on, <void *>op)
- if (multi.size != PyArray_SIZE(array)):
+ array = <ndarray>np.empty(size, dtype=int)
+ multi = <broadcast>np.broadcast(on, op, array)
+ if multi.shape != array.shape:
raise ValueError("size is not compatible with inputs")
- with lock, nogil:
- for i from 0 <= i < multi.size:
- on_data = <long *>PyArray_MultiIter_DATA(multi, 1)
- op_data = <double *>PyArray_MultiIter_DATA(multi, 2)
- array_data[i] = func(state, on_data[0], op_data[0])
- PyArray_MultiIter_NEXTi(multi, 1)
- PyArray_MultiIter_NEXTi(multi, 2)
+
+ array_data = <long *>PyArray_DATA(array)
+
+ with lock, nogil:
+ for i in range(multi.size):
+ on_data = <long *>PyArray_MultiIter_DATA(multi, 0)
+ op_data = <double *>PyArray_MultiIter_DATA(multi, 1)
+ array_data[i] = func(state, on_data[0], op_data[0])
+ PyArray_MultiIter_NEXT(multi)
return array
@@ -429,28 +412,22 @@ cdef object discdd_array(rk_state *state, rk_discdd func, object size,
cdef broadcast multi
if size is None:
- multi = <broadcast> PyArray_MultiIterNew(2, <void *>on, <void *>op)
- array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_LONG)
- array_data = <long *>PyArray_DATA(array)
- with lock, nogil:
- for i from 0 <= i < multi.size:
- on_data = <double *>PyArray_MultiIter_DATA(multi, 0)
- op_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- array_data[i] = func(state, on_data[0], op_data[0])
- PyArray_MultiIter_NEXT(multi)
+ multi = <broadcast>np.broadcast(on, op)
+ array = <ndarray>np.empty(multi.shape, dtype=int)
else:
- array = <ndarray>np.empty(size, int)
- array_data = <long *>PyArray_DATA(array)
- multi = <broadcast>PyArray_MultiIterNew(3, <void*>array, <void *>on, <void *>op)
- if (multi.size != PyArray_SIZE(array)):
+ array = <ndarray>np.empty(size, dtype=int)
+ multi = <broadcast>np.broadcast(on, op, array)
+ if multi.shape != array.shape:
raise ValueError("size is not compatible with inputs")
- with lock, nogil:
- for i from 0 <= i < multi.size:
- on_data = <double *>PyArray_MultiIter_DATA(multi, 1)
- op_data = <double *>PyArray_MultiIter_DATA(multi, 2)
- array_data[i] = func(state, on_data[0], op_data[0])
- PyArray_MultiIter_NEXTi(multi, 1)
- PyArray_MultiIter_NEXTi(multi, 2)
+
+ array_data = <long *>PyArray_DATA(array)
+
+ with lock, nogil:
+ for i in range(multi.size):
+ on_data = <double *>PyArray_MultiIter_DATA(multi, 0)
+ op_data = <double *>PyArray_MultiIter_DATA(multi, 1)
+ array_data[i] = func(state, on_data[0], op_data[0])
+ PyArray_MultiIter_NEXT(multi)
return array
@@ -483,30 +460,23 @@ cdef object discnmN_array(rk_state *state, rk_discnmN func, object size,
cdef broadcast multi
if size is None:
- multi = <broadcast> PyArray_MultiIterNew(3, <void *>on, <void *>om, <void *>oN)
- array = <ndarray> PyArray_SimpleNew(multi.nd, multi.dimensions, NPY_LONG)
- array_data = <long *>PyArray_DATA(array)
- with lock, nogil:
- for i from 0 <= i < multi.size:
- on_data = <long *>PyArray_MultiIter_DATA(multi, 0)
- om_data = <long *>PyArray_MultiIter_DATA(multi, 1)
- oN_data = <long *>PyArray_MultiIter_DATA(multi, 2)
- array_data[i] = func(state, on_data[0], om_data[0], oN_data[0])
- PyArray_MultiIter_NEXT(multi)
+ multi = <broadcast>np.broadcast(on, om, oN)
+ array = <ndarray>np.empty(multi.shape, dtype=int)
else:
- array = <ndarray>np.empty(size, int)
- array_data = <long *>PyArray_DATA(array)
- multi = <broadcast>PyArray_MultiIterNew(4, <void*>array, <void *>on, <void *>om,
- <void *>oN)
- if (multi.size != PyArray_SIZE(array)):
+ array = <ndarray>np.empty(size, dtype=int)
+ multi = <broadcast>np.broadcast(on, om, oN, array)
+ if multi.shape != array.shape:
raise ValueError("size is not compatible with inputs")
- with lock, nogil:
- for i from 0 <= i < multi.size:
- on_data = <long *>PyArray_MultiIter_DATA(multi, 1)
- om_data = <long *>PyArray_MultiIter_DATA(multi, 2)
- oN_data = <long *>PyArray_MultiIter_DATA(multi, 3)
- array_data[i] = func(state, on_data[0], om_data[0], oN_data[0])
- PyArray_MultiIter_NEXT(multi)
+
+ array_data = <long *>PyArray_DATA(array)
+
+ with lock, nogil:
+ for i in range(multi.size):
+ on_data = <long *>PyArray_MultiIter_DATA(multi, 0)
+ om_data = <long *>PyArray_MultiIter_DATA(multi, 1)
+ oN_data = <long *>PyArray_MultiIter_DATA(multi, 2)
+ array_data[i] = func(state, on_data[0], om_data[0], oN_data[0])
+ PyArray_MultiIter_NEXT(multi)
return array