summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorwarren <warren.weckesser@gmail.com>2022-06-16 12:19:49 -0400
committerwarren <warren.weckesser@gmail.com>2022-06-16 12:27:03 -0400
commit979d288077ba00b487213b683ec47d38f4a31648 (patch)
treee0df228e538021a8f5d4e1b67a7d18434bbb1c5f /numpy/random
parentc8b5124fb614be196303ee622bc4936944652e34 (diff)
downloadnumpy-979d288077ba00b487213b683ec47d38f4a31648.tar.gz
MAINT: random: Update to disallowing complex inputs to multivariate_normal.
* Disallow both mean and cov from being complex. * Raise a TypeError instead of a NotImplementedError if mean or cov is complex. * Expand and fix the unit test.
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/_generator.pyx5
-rw-r--r--numpy/random/tests/test_generator_mt19937.py6
2 files changed, 8 insertions, 3 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index c346c4943..0019c4bcd 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -3661,8 +3661,9 @@ cdef class Generator:
mean = np.array(mean)
cov = np.array(cov)
- if np.issubdtype(cov.dtype, np.complexfloating):
- raise NotImplementedError("Complex gaussians are not supported.")
+ if (np.issubdtype(mean.dtype, np.complexfloating) or
+ np.issubdtype(cov.dtype, np.complexfloating)):
+ raise TypeError("mean and cov must not be complex")
if size is None:
shape = []
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 925ac9e2b..fa55ac0ee 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -1453,7 +1453,11 @@ class TestRandomDist:
assert_raises(ValueError, random.multivariate_normal,
mu, np.eye(3))
- assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]])
+ @pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])])
+ def test_multivariate_normal_disallow_complex(self, mean, cov):
+ random = Generator(MT19937(self.seed))
+ with pytest.raises(TypeError, match="must not be complex"):
+ random.multivariate_normal(mean, cov)
@pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
def test_multivariate_normal_basic_stats(self, method):