diff options
Diffstat (limited to 'numpy/random/mtrand/mtrand.pyx')
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 50 |
1 files changed, 33 insertions, 17 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index f115dccdd..2d4d904bd 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -124,6 +124,7 @@ cdef extern from "initarray.h": import_array() import numpy as np +import operator cdef object cont0_array(rk_state *state, rk_cont0 func, object size): cdef double *array_data @@ -994,21 +995,24 @@ cdef class RandomState: """ # Format and Verify input - if isinstance(a, int): - if a > 0: - pop_size = a #population size - else: + a = np.array(a, copy=False) + if a.ndim == 0: + try: + # __index__ must return an integer by python rules. + pop_size = operator.index(a.item()) + except TypeError: + raise ValueError("a must be 1-dimensional or an integer") + if pop_size <= 0: raise ValueError("a must be greater than 0") + elif a.ndim != 1: + raise ValueError("a must be 1-dimensional") else: - a = np.array(a, ndmin=1, copy=0) - if a.ndim != 1: - raise ValueError("a must be 1-dimensional") - pop_size = a.size + pop_size = a.shape[0] if pop_size is 0: raise ValueError("a must be non-empty") if None != p: - p = np.array(p, dtype=np.double, ndmin=1, copy=0) + p = np.array(p, dtype=np.double, ndmin=1, copy=False) if p.ndim != 1: raise ValueError("p must be 1-dimensional") if p.size != pop_size: @@ -1019,7 +1023,10 @@ cdef class RandomState: raise ValueError("probabilities do not sum to 1") shape = size - size = 1 if shape is None else np.prod(shape, dtype=np.intp) + if shape is not None: + size = np.prod(shape, dtype=np.intp) + else: + size = 1 # Actual sampling if replace: @@ -1060,18 +1067,27 @@ cdef class RandomState: idx = self.permutation(pop_size)[:size] if shape is not None: idx.shape = shape + if shape is None and isinstance(idx, np.ndarray): # In most cases a scalar will have been made an array idx = idx.item(0) + #Use samples as indices for a if a is array-like - if isinstance(a, int): + if a.ndim == 0: return idx - res = a[idx] - # Note when introducing an axis argument a copy should be ensured. - if res.ndim == 0 and shape is not None: - # the result here is not a scalar but an array. - return np.array(res) - return res + + if shape is not None and idx.ndim == 0: + # If size == () then the user requested a 0-d array as opposed to + # a scalar object when size is None. However a[idx] is always a + # scalar and not an array. So this makes sure the result is an + # array, taking into account that np.array(item) may not work + # for object arrays. + res = np.empty((), dtype=a.dtype) + res[()] = a[idx] + return res + + return a[idx] + def uniform(self, low=0.0, high=1.0, size=None): """ |