summaryrefslogtreecommitdiff
path: root/numpy/random
diff options
context:
space:
mode:
authorOscar Villellas <oscar.villellas@continuum.io>2017-01-03 22:05:14 +0100
committerOscar Villellas <oscar.villellas@continuum.io>2017-01-03 22:05:14 +0100
commit6d7f14f60e12d200b02fd1f41d2315a5167cc859 (patch)
tree084ec5135971fc794147cf14a92ec600e9369354 /numpy/random
parentfde261788008fd830999a16dceb534a5168baa72 (diff)
downloadnumpy-6d7f14f60e12d200b02fd1f41d2315a5167cc859.tar.gz
Documentation fix and proper handling of tolerance
Diffstat (limited to 'numpy/random')
-rw-r--r--numpy/random/mtrand/mtrand.pyx25
1 files changed, 15 insertions, 10 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx
index a6698e1ee..982efe741 100644
--- a/numpy/random/mtrand/mtrand.pyx
+++ b/numpy/random/mtrand/mtrand.pyx
@@ -4356,9 +4356,9 @@ cdef class RandomState:
# Multivariate distributions:
def multivariate_normal(self, mean, cov, size=None, check_valid='warn',
- tol=1e-8):
+ rtol=1e-05, atol=1e-8):
"""
- multivariate_normal(mean, cov[, size])
+ multivariate_normal(mean, cov[, size, check_valid, rtol, atol])
Draw random samples from a multivariate normal distribution.
@@ -4381,10 +4381,14 @@ cdef class RandomState:
generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because
each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``.
If no shape is specified, a single (`N`-D) sample is returned.
- check_valid : 'warn', 'raise', 'ignore'
- Behavior when the covariance matrix is not Positive Semi-definite.
- tol : float
- Tolerance of the singular values in covariance matrix.
+ check_valid : { 'warn', 'raise', 'ignore' }, optional
+ Behavior when the covariance matrix is not positive semidefinite.
+ rtol : float, optional
+ Relative tolerance to use when checking the singular values in
+ covariance matrix.
+ atol : float, optional
+ Absolute tolerance to use when checking the singular values in
+ covariance matrix
Returns
-------
@@ -4500,15 +4504,16 @@ cdef class RandomState:
(u, s, v) = svd(cov)
if check_valid != 'ignore':
- psd = np.allclose(np.dot(v.T * s, v), cov)
+ if check_valid != 'warn' and check_valid != 'raise':
+ raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'")
+
+ psd = np.allclose(np.dot(v.T * s, v), cov, rtol=rtol, atol=atol)
if not psd:
if check_valid == 'warn':
warnings.warn("covariance is not positive-semidefinite.",
RuntimeWarning)
- elif check_valid == 'raise':
- raise ValueError("covariance is not positive-semidefinite.")
else:
- raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'")
+ raise ValueError("covariance is not positive-semidefinite.")
x = np.dot(x, np.sqrt(s)[:, None] * v)
x += mean