summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2015-05-10 10:47:07 -0400
committerCharles Harris <charlesr.harris@gmail.com>2015-05-10 10:47:07 -0400
commitc65fc9e36b287461040615a452e0268bbca9f3ca (patch)
tree45e038f121684df8f41797df25fdaca8c415fb63 /numpy
parent59de9f290b9180e316d319d6b70f65554cd2ef4e (diff)
parentc7f121abb24e19c34246c8ad834dfdd5bd5ce8b2 (diff)
downloadnumpy-c65fc9e36b287461040615a452e0268bbca9f3ca.tar.gz
Merge pull request #5858 from jaimefrio/random_beta
BUG: random.beta with small parameters
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/mtrand/distributions.c15
-rw-r--r--numpy/random/tests/test_regression.py8
2 files changed, 22 insertions, 1 deletions
diff --git a/numpy/random/mtrand/distributions.c b/numpy/random/mtrand/distributions.c
index ff936fdd8..f5ee6d8c1 100644
--- a/numpy/random/mtrand/distributions.c
+++ b/numpy/random/mtrand/distributions.c
@@ -199,7 +199,20 @@ double rk_beta(rk_state *state, double a, double b)
if ((X + Y) <= 1.0)
{
- return X / (X + Y);
+ if (X +Y > 0)
+ {
+ return X / (X + Y);
+ }
+ else
+ {
+ double logX = log(U) / a;
+ double logY = log(V) / b;
+ double logM = logX > logY ? logX : logY;
+ logX -= logM;
+ logY -= logM;
+
+ return exp(logX - log(exp(logX) + exp(logY)));
+ }
}
}
}
diff --git a/numpy/random/tests/test_regression.py b/numpy/random/tests/test_regression.py
index 1a5854e82..52be0d4d5 100644
--- a/numpy/random/tests/test_regression.py
+++ b/numpy/random/tests/test_regression.py
@@ -93,5 +93,13 @@ class TestRegression(TestCase):
np.random.multivariate_normal([0], [[0]], size=np.int_(1))
np.random.multivariate_normal([0], [[0]], size=np.int64(1))
+ def test_beta_small_parameters(self):
+ # Test that beta with small a and b parameters does not produce
+ # NaNs due to roundoff errors causing 0 / 0, gh-5851
+ np.random.seed(1234567890)
+ x = np.random.beta(0.0001, 0.0001, size=100)
+ assert_(not np.any(np.isnan(x)), 'Nans in np.random.beta')
+
+
if __name__ == "__main__":
run_module_suite()