diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 47 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 59 |
2 files changed, 83 insertions, 23 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 92fa6cb73..17e84be2d 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -26,7 +26,7 @@ from numpy.core import ( add, multiply, sqrt, fastCopyAndTranspose, sum, isfinite, finfo, errstate, geterrobj, moveaxis, amin, amax, product, abs, atleast_2d, intp, asanyarray, object_, matmul, - swapaxes, divide, count_nonzero, isnan + swapaxes, divide, count_nonzero, isnan, sign ) from numpy.core.multiarray import normalize_axis_index from numpy.core.overrides import set_module @@ -1461,12 +1461,12 @@ def eigh(a, UPLO='L'): # Singular value decomposition -def _svd_dispatcher(a, full_matrices=None, compute_uv=None): +def _svd_dispatcher(a, full_matrices=None, compute_uv=None, hermitian=None): return (a,) @array_function_dispatch(_svd_dispatcher) -def svd(a, full_matrices=True, compute_uv=True): +def svd(a, full_matrices=True, compute_uv=True, hermitian=False): """ Singular Value Decomposition. @@ -1504,6 +1504,12 @@ def svd(a, full_matrices=True, compute_uv=True): size as those of the input `a`. The size of the last two dimensions depends on the value of `full_matrices`. Only returned when `compute_uv` is True. + hermitian : bool, optional + If True, `a` is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Defaults to False. + + ..versionadded:: 1.17.0 Raises ------ @@ -1590,6 +1596,24 @@ def svd(a, full_matrices=True, compute_uv=True): """ a, wrap = _makearray(a) + + if hermitian: + # note: lapack returns eigenvalues in reverse order to our contract. + # reversing is cheap by design in numpy, so we do so to be consistent + if compute_uv: + s, u = eigh(a) + s = s[..., ::-1] + u = u[..., ::-1] + # singular values are unsigned, move the sign into v + vt = transpose(u * sign(s)[..., None, :]).conjugate() + s = abs(s) + return wrap(u), s, wrap(vt) + else: + s = eigvalsh(a) + s = s[..., ::-1] + s = abs(s) + return s + _assertRankAtLeast2(a) t, result_t = _commonType(a) @@ -1844,10 +1868,7 @@ def matrix_rank(M, tol=None, hermitian=False): M = asarray(M) if M.ndim < 2: return int(not all(M==0)) - if hermitian: - S = abs(eigvalsh(M)) - else: - S = svd(M, compute_uv=False) + S = svd(M, compute_uv=False, hermitian=hermitian) if tol is None: tol = S.max(axis=-1, keepdims=True) * max(M.shape[-2:]) * finfo(S.dtype).eps else: @@ -1857,12 +1878,12 @@ def matrix_rank(M, tol=None, hermitian=False): # Generalized inverse -def _pinv_dispatcher(a, rcond=None): +def _pinv_dispatcher(a, rcond=None, hermitian=None): return (a,) @array_function_dispatch(_pinv_dispatcher) -def pinv(a, rcond=1e-15): +def pinv(a, rcond=1e-15, hermitian=False): """ Compute the (Moore-Penrose) pseudo-inverse of a matrix. @@ -1882,6 +1903,12 @@ def pinv(a, rcond=1e-15): Singular values smaller (in modulus) than `rcond` * largest_singular_value (again, in modulus) are set to zero. Broadcasts against the stack of matrices + hermitian : bool, optional + If True, `a` is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Defaults to False. + + ..versionadded:: 1.17.0 Returns ------- @@ -1935,7 +1962,7 @@ def pinv(a, rcond=1e-15): res = empty(a.shape[:-2] + (n, m), dtype=a.dtype) return wrap(res) a = a.conjugate() - u, s, vt = svd(a, full_matrices=False) + u, s, vt = svd(a, full_matrices=False, hermitian=hermitian) # discard small singular values cutoff = rcond[..., newaxis] * amax(s, axis=-1, keepdims=True) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index f95909cd1..831c059d0 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -633,6 +633,20 @@ class TestEig(EigCases): assert_(isinstance(a, np.ndarray)) +class SVDBaseTests(object): + hermitian = False + + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + u, s, vh = linalg.svd(x) + assert_equal(u.dtype, dtype) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(vh.dtype, dtype) + s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian) + assert_equal(s.dtype, get_real_dtype(dtype)) + + class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): def do(self, a, b, tags): @@ -644,32 +658,37 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): assert_(consistent_subclass(vt, a)) -class TestSVD(SVDCases): - @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) - def test_types(self, dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - u, s, vh = linalg.svd(x) - assert_equal(u.dtype, dtype) - assert_equal(s.dtype, get_real_dtype(dtype)) - assert_equal(vh.dtype, dtype) - s = linalg.svd(x, compute_uv=False) - assert_equal(s.dtype, get_real_dtype(dtype)) - +class TestSVD(SVDCases, SVDBaseTests): def test_empty_identity(self): """ Empty input should put an identity matrix in u or vh """ x = np.empty((4, 0)) - u, s, vh = linalg.svd(x, compute_uv=True) + u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian) assert_equal(u.shape, (4, 4)) assert_equal(vh.shape, (0, 0)) assert_equal(u, np.eye(4)) x = np.empty((0, 4)) - u, s, vh = linalg.svd(x, compute_uv=True) + u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian) assert_equal(u.shape, (0, 0)) assert_equal(vh.shape, (4, 4)) assert_equal(vh, np.eye(4)) +class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase): + + def do(self, a, b, tags): + u, s, vt = linalg.svd(a, 0, hermitian=True) + assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :], + np.asarray(vt)), + rtol=get_rtol(u.dtype)) + assert_(consistent_subclass(u, a)) + assert_(consistent_subclass(vt, a)) + + +class TestSVDHermitian(SVDHermitianCases, SVDBaseTests): + hermitian = True + + class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): # cond(x, p) for p in (None, 2, -2) @@ -797,6 +816,20 @@ class TestPinv(PinvCases): pass +class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase): + + def do(self, a, b, tags): + a_ginv = linalg.pinv(a, hermitian=True) + # `a @ a_ginv == I` does not hold if a is singular + dot = dot_generalized + assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11) + assert_(consistent_subclass(a_ginv, a)) + + +class TestPinvHermitian(PinvHermitianCases): + pass + + class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): def do(self, a, b, tags): |