summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorPaulo Almeida <pcca24102002@gmail.com>2023-04-11 07:38:03 +0000
committerGitHub <noreply@github.com>2023-04-11 09:38:03 +0200
commitfe893743b83173b72648dc79d86eda905230d1e9 (patch)
tree1667f85fe190d1158fc7583ea96eb404a5db100e /numpy
parent34417a5d3470dc143ccd6935057b25b045306b49 (diff)
downloadnumpy-fe893743b83173b72648dc79d86eda905230d1e9.tar.gz
BUG: accept zeros on numpy.random dirichlet function (#23440)
Changed alpha value error to pass a null value. This way, dirichlet function (on the generator, not mtrand) won't raise a value exception at 0. Also added test.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/random/_generator.pyx6
-rw-r--r--numpy/random/tests/test_randomstate.py4
2 files changed, 7 insertions, 3 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index 1b19d00d9..a30d116c2 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -4327,7 +4327,7 @@ cdef class Generator:
Raises
------
ValueError
- If any value in ``alpha`` is less than or equal to zero
+ If any value in ``alpha`` is less than zero
Notes
-----
@@ -4406,8 +4406,8 @@ cdef class Generator:
alpha_arr = <np.ndarray>np.PyArray_FROMANY(
alpha, np.NPY_DOUBLE, 1, 1,
np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
- if np.any(np.less_equal(alpha_arr, 0)):
- raise ValueError('alpha <= 0')
+ if np.any(np.less(alpha_arr, 0)):
+ raise ValueError('alpha < 0')
alpha_data = <double*>np.PyArray_DATA(alpha_arr)
if size is None:
diff --git a/numpy/random/tests/test_randomstate.py b/numpy/random/tests/test_randomstate.py
index 8b911cb3a..3a2961098 100644
--- a/numpy/random/tests/test_randomstate.py
+++ b/numpy/random/tests/test_randomstate.py
@@ -812,6 +812,10 @@ class TestRandomDist:
alpha = np.array([5.4e-01, -1.0e-16])
assert_raises(ValueError, random.dirichlet, alpha)
+ def test_dirichlet_zero_alpha(self):
+ y = random.default_rng().dirichlet([5, 9, 0, 8])
+ assert_equal(y[2], 0)
+
def test_dirichlet_alpha_non_contiguous(self):
a = np.array([51.72840233779265162, -1.0, 39.74494232180943953])
alpha = a[::2]