summaryrefslogtreecommitdiff
path: root/numpy/random/tests
diff options
context:
space:
mode:
authorwarren <warren.weckesser@gmail.com>2022-06-16 12:19:49 -0400
committerwarren <warren.weckesser@gmail.com>2022-06-16 12:27:03 -0400
commit979d288077ba00b487213b683ec47d38f4a31648 (patch)
treee0df228e538021a8f5d4e1b67a7d18434bbb1c5f /numpy/random/tests
parentc8b5124fb614be196303ee622bc4936944652e34 (diff)
downloadnumpy-979d288077ba00b487213b683ec47d38f4a31648.tar.gz
MAINT: random: Update to disallowing complex inputs to multivariate_normal.
* Disallow both mean and cov from being complex. * Raise a TypeError instead of a NotImplementedError if mean or cov is complex. * Expand and fix the unit test.
Diffstat (limited to 'numpy/random/tests')
-rw-r--r--numpy/random/tests/test_generator_mt19937.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 925ac9e2b..fa55ac0ee 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -1453,7 +1453,11 @@ class TestRandomDist:
assert_raises(ValueError, random.multivariate_normal,
mu, np.eye(3))
- assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]])
+ @pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])])
+ def test_multivariate_normal_disallow_complex(self, mean, cov):
+ random = Generator(MT19937(self.seed))
+ with pytest.raises(TypeError, match="must not be complex"):
+ random.multivariate_normal(mean, cov)
@pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
def test_multivariate_normal_basic_stats(self, method):