summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorgfyoung <gfyoung17@gmail.com>2016-02-19 05:14:49 +0000
committergfyoung <gfyoung17@gmail.com>2016-02-19 06:16:11 +0000
commit4f91544af2a74bc2dcecfce5db05b015f26721e4 (patch)
tree79744f6d06eda3e6a7203d94e657cfd433af2df7 /numpy/random
parent711b2a9ccfb54c09cdc483ee511035bda60f3f0b (diff)
downloadnumpy-4f91544af2a74bc2dcecfce5db05b015f26721e4.tar.gz
BUG: Make randint backwards compatible with pandas
The 'pandas' library expects Python integers to be returned, so this commit changes the API so that the default is 'np.int' which converts to native Python integer types when a singleton is being generated with this function. Closes gh-7284.
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/mtrand/mtrand.pyx12
-rw-r--r--numpy/random/tests/test_random.py21
2 files changed, 24 insertions, 9 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index 35bbcc505..b05d86747 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -1159,7 +1159,7 @@ cdef class RandomState:
"""
return disc0_array(self.internal_state, rk_long, size, self.lock)
- def randint(self, low, high=None, size=None, dtype='l'):
+ def randint(self, low, high=None, size=None, dtype=int):
"""
randint(low, high=None, size=None, dtype='l')
@@ -1186,7 +1186,7 @@ cdef class RandomState:
Desired dtype of the result. All dtypes are determined by their
name, i.e., 'int64', 'int', etc, so byteorder is not available
and a specific precision may have different C types depending
- on the platform. The default value is 'l' (C long).
+ on the platform. The default value is 'np.int'.
.. versionadded:: 1.11.0
@@ -1234,7 +1234,13 @@ cdef class RandomState:
raise ValueError("low >= high")
with self.lock:
- return randfunc(low, high - 1, size, self.state_address)
+ ret = randfunc(low, high - 1, size, self.state_address)
+
+ if size is None:
+ if dtype in (np.bool, np.int, np.long):
+ return dtype(ret)
+
+ return ret
def bytes(self, npy_intp length):
"""
diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py
index fac287b3f..a07fa52f5 100644
--- a/numpy/random/tests/test_random.py
+++ b/numpy/random/tests/test_random.py
@@ -137,7 +137,7 @@ class TestRandint(TestCase):
rfunc = np.random.randint
# valid integer/boolean types
- itype = [np.bool, np.int8, np.uint8, np.int16, np.uint16,
+ itype = [np.bool_, np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64]
def test_unsupported_type(self):
@@ -145,8 +145,8 @@ class TestRandint(TestCase):
def test_bounds_checking(self):
for dt in self.itype:
- lbnd = 0 if dt is np.bool else np.iinfo(dt).min
- ubnd = 2 if dt is np.bool else np.iinfo(dt).max + 1
+ lbnd = 0 if dt is np.bool_ else np.iinfo(dt).min
+ ubnd = 2 if dt is np.bool_ else np.iinfo(dt).max + 1
assert_raises(ValueError, self.rfunc, lbnd - 1, ubnd, dtype=dt)
assert_raises(ValueError, self.rfunc, lbnd, ubnd + 1, dtype=dt)
assert_raises(ValueError, self.rfunc, ubnd, lbnd, dtype=dt)
@@ -154,8 +154,8 @@ class TestRandint(TestCase):
def test_rng_zero_and_extremes(self):
for dt in self.itype:
- lbnd = 0 if dt is np.bool else np.iinfo(dt).min
- ubnd = 2 if dt is np.bool else np.iinfo(dt).max + 1
+ lbnd = 0 if dt is np.bool_ else np.iinfo(dt).min
+ ubnd = 2 if dt is np.bool_ else np.iinfo(dt).max + 1
tgt = ubnd - 1
assert_equal(self.rfunc(tgt, tgt + 1, size=1000, dtype=dt), tgt)
tgt = lbnd
@@ -211,11 +211,20 @@ class TestRandint(TestCase):
def test_respect_dtype_singleton(self):
# See gh-7203
for dt in self.itype:
+ lbnd = 0 if dt is np.bool_ else np.iinfo(dt).min
+ ubnd = 2 if dt is np.bool_ else np.iinfo(dt).max + 1
+
+ sample = self.rfunc(lbnd, ubnd, dtype=dt)
+ self.assertEqual(sample.dtype, np.dtype(dt))
+
+ for dt in (np.bool, np.int, np.long):
lbnd = 0 if dt is np.bool else np.iinfo(dt).min
ubnd = 2 if dt is np.bool else np.iinfo(dt).max + 1
+ # gh-7284: Ensure that we get Python data types
sample = self.rfunc(lbnd, ubnd, dtype=dt)
- self.assertEqual(sample.dtype, np.dtype(dt))
+ self.assertFalse(hasattr(sample, 'dtype'))
+ self.assertEqual(type(sample), dt)
class TestRandomDist(TestCase):