diff options
author | Oscar Villellas <oscar.villellas@continuum.io> | 2017-01-03 22:05:14 +0100 |
---|---|---|
committer | Oscar Villellas <oscar.villellas@continuum.io> | 2017-01-03 22:05:14 +0100 |
commit | 6d7f14f60e12d200b02fd1f41d2315a5167cc859 (patch) | |
tree | 084ec5135971fc794147cf14a92ec600e9369354 /numpy/random | |
parent | fde261788008fd830999a16dceb534a5168baa72 (diff) | |
download | numpy-6d7f14f60e12d200b02fd1f41d2315a5167cc859.tar.gz |
Documentation fix and proper handling of tolerance
Diffstat (limited to 'numpy/random')
-rw-r--r-- | numpy/random/mtrand/mtrand.pyx | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/numpy/random/mtrand/mtrand.pyx b/numpy/random/mtrand/mtrand.pyx index a6698e1ee..982efe741 100644 --- a/numpy/random/mtrand/mtrand.pyx +++ b/numpy/random/mtrand/mtrand.pyx @@ -4356,9 +4356,9 @@ cdef class RandomState: # Multivariate distributions: def multivariate_normal(self, mean, cov, size=None, check_valid='warn', - tol=1e-8): + rtol=1e-05, atol=1e-8): """ - multivariate_normal(mean, cov[, size]) + multivariate_normal(mean, cov[, size, check_valid, rtol, atol]) Draw random samples from a multivariate normal distribution. @@ -4381,10 +4381,14 @@ cdef class RandomState: generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``. If no shape is specified, a single (`N`-D) sample is returned. - check_valid : 'warn', 'raise', 'ignore' - Behavior when the covariance matrix is not Positive Semi-definite. - tol : float - Tolerance of the singular values in covariance matrix. + check_valid : { 'warn', 'raise', 'ignore' }, optional + Behavior when the covariance matrix is not positive semidefinite. + rtol : float, optional + Relative tolerance to use when checking the singular values in + covariance matrix. + atol : float, optional + Absolute tolerance to use when checking the singular values in + covariance matrix Returns ------- @@ -4500,15 +4504,16 @@ cdef class RandomState: (u, s, v) = svd(cov) if check_valid != 'ignore': - psd = np.allclose(np.dot(v.T * s, v), cov) + if check_valid != 'warn' and check_valid != 'raise': + raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'") + + psd = np.allclose(np.dot(v.T * s, v), cov, rtol=rtol, atol=atol) if not psd: if check_valid == 'warn': warnings.warn("covariance is not positive-semidefinite.", RuntimeWarning) - elif check_valid == 'raise': - raise ValueError("covariance is not positive-semidefinite.") else: - raise ValueError("check_valid must equal 'warn', 'raise', or 'ignore'") + raise ValueError("covariance is not positive-semidefinite.") x = np.dot(x, np.sqrt(s)[:, None] * v) x += mean |