summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorHameer Abbasi <einstein.edison@gmail.com>2019-10-17 12:27:13 +0200
committerwarren <warren.weckesser@gmail.com>2022-06-16 12:26:52 -0400
commitc8b5124fb614be196303ee622bc4936944652e34 (patch)
treeef41ed41f5205f8139f75405e4471d7f17e42de7 /numpy/random
parent88d373cc5734e8cf47253f98c94403f124f1c017 (diff)
downloadnumpy-c8b5124fb614be196303ee622bc4936944652e34.tar.gz
MAINT: random: Disallow complex covariances in multivariate_normal
This commit disallows complex covariances in multivariate_normal as passing them can silently lead to incorrect results.
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/_generator.pyx4
-rw-r--r--numpy/random/tests/test_generator_mt19937.py2
2 files changed, 6 insertions, 0 deletions
diff --git a/numpy/random/_generator.pyx b/numpy/random/_generator.pyx
index b54fe3610..c346c4943 100644
--- a/numpy/random/_generator.pyx
+++ b/numpy/random/_generator.pyx
@@ -3660,6 +3660,10 @@ cdef class Generator:
# Check preconditions on arguments
mean = np.array(mean)
cov = np.array(cov)
+
+ if np.issubdtype(cov.dtype, np.complexfloating):
+ raise NotImplementedError("Complex gaussians are not supported.")
+
if size is None:
shape = []
elif isinstance(size, (int, long, np.integer)):
diff --git a/numpy/random/tests/test_generator_mt19937.py b/numpy/random/tests/test_generator_mt19937.py
index 3ccb9103c..925ac9e2b 100644
--- a/numpy/random/tests/test_generator_mt19937.py
+++ b/numpy/random/tests/test_generator_mt19937.py
@@ -1452,6 +1452,8 @@ class TestRandomDist:
mu, np.empty((3, 2)))
assert_raises(ValueError, random.multivariate_normal,
mu, np.eye(3))
+
+ assert_raises(NotImplementedError, np.random.multivariate_normal, [0], [[1+1j]])
@pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
def test_multivariate_normal_basic_stats(self, method):