diff options
author | warren <warren.weckesser@gmail.com> | 2022-06-16 12:19:49 -0400 |
---|---|---|
committer | warren <warren.weckesser@gmail.com> | 2022-06-16 12:27:03 -0400 |
commit | 979d288077ba00b487213b683ec47d38f4a31648 (patch) | |
tree | e0df228e538021a8f5d4e1b67a7d18434bbb1c5f /numpy/random | |
parent | c8b5124fb614be196303ee622bc4936944652e34 (diff) | |
download | numpy-979d288077ba00b487213b683ec47d38f4a31648.tar.gz |
MAINT: random: Update to disallowing complex inputs to multivariate_normal.
* Disallow both mean and cov from being complex.
* Raise a TypeError instead of a NotImplementedError if mean or cov is
complex.
* Expand and fix the unit test.
Diffstat (limited to 'numpy/random')
-rw-r--r-- | numpy/random/_generator.pyx | 5 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 6 |
2 files changed, 8 insertions, 3 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index c346c4943..0019c4bcd 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3661,8 +3661,9 @@ cdef class Generator: mean = np.array(mean) cov = np.array(cov) - if np.issubdtype(cov.dtype, np.complexfloating): - raise NotImplementedError("Complex gaussians are not supported.") + if (np.issubdtype(mean.dtype, np.complexfloating) or + np.issubdtype(cov.dtype, np.complexfloating)): + raise TypeError("mean and cov must not be complex") if size is None: shape = [] diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 925ac9e2b..fa55ac0ee 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1453,7 +1453,11 @@ class TestRandomDist: assert_raises(ValueError, random.multivariate_normal, mu, np.eye(3)) - assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]]) + @pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])]) + def test_multivariate_normal_disallow_complex(self, mean, cov): + random = Generator(MT19937(self.seed)) + with pytest.raises(TypeError, match="must not be complex"): + random.multivariate_normal(mean, cov) @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) def test_multivariate_normal_basic_stats(self, method): |