diff options
author | warren <warren.weckesser@gmail.com> | 2022-06-16 12:19:49 -0400 |
---|---|---|
committer | warren <warren.weckesser@gmail.com> | 2022-06-16 12:27:03 -0400 |
commit | 979d288077ba00b487213b683ec47d38f4a31648 (patch) | |
tree | e0df228e538021a8f5d4e1b67a7d18434bbb1c5f /numpy/random/tests | |
parent | c8b5124fb614be196303ee622bc4936944652e34 (diff) | |
download | numpy-979d288077ba00b487213b683ec47d38f4a31648.tar.gz |
MAINT: random: Update to disallowing complex inputs to multivariate_normal.
* Disallow both mean and cov from being complex.
* Raise a TypeError instead of a NotImplementedError if mean or cov is
complex.
* Expand and fix the unit test.
Diffstat (limited to 'numpy/random/tests')
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 925ac9e2b..fa55ac0ee 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1453,7 +1453,11 @@ class TestRandomDist: assert_raises(ValueError, random.multivariate_normal, mu, np.eye(3)) - assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]]) + @pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])]) + def test_multivariate_normal_disallow_complex(self, mean, cov): + random = Generator(MT19937(self.seed)) + with pytest.raises(TypeError, match="must not be complex"): + random.multivariate_normal(mean, cov) @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) def test_multivariate_normal_basic_stats(self, method): |