diff options
author | Hameer Abbasi <einstein.edison@gmail.com> | 2019-10-17 12:27:13 +0200 |
---|---|---|
committer | warren <warren.weckesser@gmail.com> | 2022-06-16 12:26:52 -0400 |
commit | c8b5124fb614be196303ee622bc4936944652e34 (patch) | |
tree | ef41ed41f5205f8139f75405e4471d7f17e42de7 /numpy/random | |
parent | 88d373cc5734e8cf47253f98c94403f124f1c017 (diff) | |
download | numpy-c8b5124fb614be196303ee622bc4936944652e34.tar.gz |
MAINT: random: Disallow complex covariances in multivariate_normal
This commit disallows complex covariances in multivariate_normal
as passing them can silently lead to incorrect results.
Diffstat (limited to 'numpy/random')
-rw-r--r-- | numpy/random/_generator.pyx | 4 | ||||
-rw-r--r-- | numpy/random/tests/test_generator_mt19937.py | 2 |
2 files changed, 6 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx index b54fe3610..c346c4943 100644 --- a/numpy/random/_generator.pyx +++ b/numpy/random/_generator.pyx @@ -3660,6 +3660,10 @@ cdef class Generator: # Check preconditions on arguments mean = np.array(mean) cov = np.array(cov) + + if np.issubdtype(cov.dtype, np.complexfloating): + raise NotImplementedError("Complex gaussians are not supported.") + if size is None: shape = [] elif isinstance(size, (int, long, np.integer)): diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py index 3ccb9103c..925ac9e2b 100644 --- a/numpy/random/tests/test_generator_mt19937.py +++ b/numpy/random/tests/test_generator_mt19937.py @@ -1452,6 +1452,8 @@ class TestRandomDist: mu, np.empty((3, 2))) assert_raises(ValueError, random.multivariate_normal, mu, np.eye(3)) + + assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]]) @pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"]) def test_multivariate_normal_basic_stats(self, method): |