summaryrefslogtreecommitdiff
path: root/numpy/random/mtrand/mtrand.pyx
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/mtrand/mtrand.pyx')
-rw-r--r--numpy/random/mtrand/mtrand.pyx50
1 files changed, 33 insertions, 17 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index f115dccdd..2d4d904bd 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -124,6 +124,7 @@ cdef extern from "initarray.h":
import_array()
import numpy as np
+import operator
cdef object cont0_array(rk_state *state, rk_cont0 func, object size):
cdef double *array_data
@@ -994,21 +995,24 @@ cdef class RandomState:
"""
# Format and Verify input
- if isinstance(a, int):
- if a > 0:
- pop_size = a #population size
- else:
+ a = np.array(a, copy=False)
+ if a.ndim == 0:
+ try:
+ # __index__ must return an integer by python rules.
+ pop_size = operator.index(a.item())
+ except TypeError:
+ raise ValueError("a must be 1-dimensional or an integer")
+ if pop_size <= 0:
raise ValueError("a must be greater than 0")
+ elif a.ndim != 1:
+ raise ValueError("a must be 1-dimensional")
else:
- a = np.array(a, ndmin=1, copy=0)
- if a.ndim != 1:
- raise ValueError("a must be 1-dimensional")
- pop_size = a.size
+ pop_size = a.shape[0]
if pop_size is 0:
raise ValueError("a must be non-empty")
if None != p:
- p = np.array(p, dtype=np.double, ndmin=1, copy=0)
+ p = np.array(p, dtype=np.double, ndmin=1, copy=False)
if p.ndim != 1:
raise ValueError("p must be 1-dimensional")
if p.size != pop_size:
@@ -1019,7 +1023,10 @@ cdef class RandomState:
raise ValueError("probabilities do not sum to 1")
shape = size
- size = 1 if shape is None else np.prod(shape, dtype=np.intp)
+ if shape is not None:
+ size = np.prod(shape, dtype=np.intp)
+ else:
+ size = 1
# Actual sampling
if replace:
@@ -1060,18 +1067,27 @@ cdef class RandomState:
idx = self.permutation(pop_size)[:size]
if shape is not None:
idx.shape = shape
+
if shape is None and isinstance(idx, np.ndarray):
# In most cases a scalar will have been made an array
idx = idx.item(0)
+
#Use samples as indices for a if a is array-like
- if isinstance(a, int):
+ if a.ndim == 0:
return idx
- res = a[idx]
- # Note when introducing an axis argument a copy should be ensured.
- if res.ndim == 0 and shape is not None:
- # the result here is not a scalar but an array.
- return np.array(res)
- return res
+
+ if shape is not None and idx.ndim == 0:
+ # If size == () then the user requested a 0-d array as opposed to
+ # a scalar object when size is None. However a[idx] is always a
+ # scalar and not an array. So this makes sure the result is an
+ # array, taking into account that np.array(item) may not work
+ # for object arrays.
+ res = np.empty((), dtype=a.dtype)
+ res[()] = a[idx]
+ return res
+
+ return a[idx]
+
def uniform(self, low=0.0, high=1.0, size=None):
"""