summaryrefslogtreecommitdiff
path: root/numpy/random/tests/test_regression.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/random/tests/test_regression.py')
-rw-r--r--numpy/random/tests/test_regression.py30
1 files changed, 18 insertions, 12 deletions
diff --git a/numpy/random/tests/test_regression.py b/numpy/random/tests/test_regression.py
index 70858b049..9f7455fe5 100644
--- a/numpy/random/tests/test_regression.py
+++ b/numpy/random/tests/test_regression.py
@@ -1,7 +1,7 @@
from __future__ import division, absolute_import, print_function
-from numpy.testing import TestCase, run_module_suite, assert_,\
- assert_array_equal
+from numpy.testing import (TestCase, run_module_suite, assert_,
+ assert_array_equal)
from numpy import random
from numpy.compat import long
import numpy as np
@@ -10,21 +10,19 @@ import numpy as np
class TestRegression(TestCase):
def test_VonMises_range(self):
- """Make sure generated random variables are in [-pi, pi].
-
- Regression test for ticket #986.
- """
+ # Make sure generated random variables are in [-pi, pi].
+ # Regression test for ticket #986.
for mu in np.linspace(-7., 7., 5):
- r = random.mtrand.vonmises(mu,1,50)
+ r = random.mtrand.vonmises(mu, 1, 50)
assert_(np.all(r > -np.pi) and np.all(r <= np.pi))
- def test_hypergeometric_range(self) :
- """Test for ticket #921"""
+ def test_hypergeometric_range(self):
+ # Test for ticket #921
assert_(np.all(np.random.hypergeometric(3, 18, 11, size=10) < 4))
assert_(np.all(np.random.hypergeometric(18, 3, 11, size=10) > 0))
- def test_logseries_convergence(self) :
- """Test for ticket #923"""
+ def test_logseries_convergence(self):
+ # Test for ticket #923
N = 1000
np.random.seed(0)
rvsn = np.random.logseries(0.8, size=N)
@@ -56,7 +54,7 @@ class TestRegression(TestCase):
raise AssertionError
def test_shuffle_mixed_dimension(self):
- """Test for trac ticket #2074"""
+ # Test for trac ticket #2074
for t in [[1, 2, 3, None],
[(1, 1), (2, 2), (3, 3), None],
[1, (2, 2), (3, 3), None],
@@ -76,5 +74,13 @@ class TestRegression(TestCase):
# If m.state is not honored, the result will change
assert_array_equal(m.choice(10, size=10, p=np.ones(10.)/10), res)
+ def test_multivariate_normal_size_types(self):
+ # Test for multivariate_normal issue with 'size' argument.
+ # Check that the multivariate_normal size argument can be a
+ # numpy integer.
+ np.random.multivariate_normal([0], [[0]], size=1)
+ np.random.multivariate_normal([0], [[0]], size=np.int_(1))
+ np.random.multivariate_normal([0], [[0]], size=np.int64(1))
+
if __name__ == "__main__":
run_module_suite()