summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/mtrand/mtrand.pyx24
-rw-r--r--numpy/random/mtrand/randint_helpers.pxi.in6
-rw-r--r--numpy/random/tests/test_random.py9
3 files changed, 24 insertions, 15 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 16c58cfce..ec759fdfb 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -22,8 +22,8 @@
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
include "Python.pxi"
-include "randint_helpers.pxi"
include "numpy.pxd"
+include "randint_helpers.pxi"
include "cpython/pycapsule.pxd"
from libc cimport string
@@ -988,9 +988,9 @@ cdef class RandomState:
raise ValueError("low is out of bounds for %s" % dtype)
if ihigh > highbnd:
raise ValueError("high is out of bounds for %s" % dtype)
- if ilow >= ihigh:
- raise ValueError("low >= high")
-
+ if ilow >= ihigh and np.prod(size) != 0:
+ raise ValueError("Range cannot be empty (low >= high) unless no samples are taken")
+
with self.lock:
ret = randfunc(ilow, ihigh - 1, size, self.state_address)
@@ -1114,15 +1114,15 @@ cdef class RandomState:
# __index__ must return an integer by python rules.
pop_size = operator.index(a.item())
except TypeError:
- raise ValueError("a must be 1-dimensional or an integer")
- if pop_size <= 0:
- raise ValueError("a must be greater than 0")
+ raise ValueError("'a' must be 1-dimensional or an integer")
+ if pop_size <= 0 and np.prod(size) != 0:
+ raise ValueError("'a' must be greater than 0 unless no samples are taken")
elif a.ndim != 1:
- raise ValueError("a must be 1-dimensional")
+ raise ValueError("'a' must be 1-dimensional")
else:
pop_size = a.shape[0]
- if pop_size is 0:
- raise ValueError("a must be non-empty")
+ if pop_size is 0 and np.prod(size) != 0:
+ raise ValueError("'a' cannot be empty unless no samples are taken")
if p is not None:
d = len(p)
@@ -1136,9 +1136,9 @@ cdef class RandomState:
pix = <double*>PyArray_DATA(p)
if p.ndim != 1:
- raise ValueError("p must be 1-dimensional")
+ raise ValueError("'p' must be 1-dimensional")
if p.size != pop_size:
- raise ValueError("a and p must have same size")
+ raise ValueError("'a' and 'p' must have same size")
if np.logical_or.reduce(p < 0):
raise ValueError("probabilities are not non-negative")
if abs(kahan_sum(pix, d) - 1.) > atol:
diff --git a/numpy/random/mtrand/randint_helpers.pxi.in b/numpy/random/mtrand/randint_helpers.pxi.in
index 4bd7cd356..894a25167 100644
--- a/numpy/random/mtrand/randint_helpers.pxi.in
+++ b/numpy/random/mtrand/randint_helpers.pxi.in
@@ -23,7 +23,7 @@ def get_dispatch(dtypes):
{{for npy_dt, npy_udt, np_dt in get_dispatch(dtypes)}}
-def _rand_{{npy_dt}}(low, high, size, rngstate):
+def _rand_{{npy_dt}}(npy_{{npy_dt}} low, npy_{{npy_dt}} high, size, rngstate):
"""
_rand_{{npy_dt}}(low, high, size, rngstate)
@@ -60,8 +60,8 @@ def _rand_{{npy_dt}}(low, high, size, rngstate):
cdef npy_intp cnt
cdef rk_state *state = <rk_state *>PyCapsule_GetPointer(rngstate, NULL)
- rng = <npy_{{npy_udt}}>(high - low)
- off = <npy_{{npy_udt}}>(<npy_{{npy_dt}}>low)
+ off = <npy_{{npy_udt}}>(low)
+ rng = <npy_{{npy_udt}}>(high) - <npy_{{npy_udt}}>(low)
if size is None:
rk_random_{{npy_udt}}(off, rng, 1, &buf, state)
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index 7b4f90839..2e0885024 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -440,6 +440,15 @@ class TestRandomDist(object):
assert_equal(np.random.choice(6, s, replace=False, p=p).shape, s)
assert_equal(np.random.choice(np.arange(6), s, replace=True).shape, s)
+ # Check zero-size
+ assert_equal(np.random.randint(0, 0, size=(3, 0, 4)).shape, (3, 0, 4))
+ assert_equal(np.random.randint(0, -10, size=0).shape, (0,))
+ assert_equal(np.random.randint(10, 10, size=0).shape, (0,))
+ assert_equal(np.random.choice(0, size=0).shape, (0,))
+ assert_equal(np.random.choice([], size=(0,)).shape, (0,))
+ assert_equal(np.random.choice(['a', 'b'], size=(3, 0, 4)).shape, (3, 0, 4))
+ assert_raises(ValueError, np.random.choice, [], 10)
+
def test_bytes(self):
np.random.seed(self.seed)
actual = np.random.bytes(10)