summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py47
-rw-r--r--numpy/linalg/tests/test_linalg.py59
2 files changed, 83 insertions, 23 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 92fa6cb73..17e84be2d 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -26,7 +26,7 @@ from numpy.core import (
add, multiply, sqrt, fastCopyAndTranspose, sum, isfinite,
finfo, errstate, geterrobj, moveaxis, amin, amax, product, abs,
atleast_2d, intp, asanyarray, object_, matmul,
- swapaxes, divide, count_nonzero, isnan
+ swapaxes, divide, count_nonzero, isnan, sign
)
from numpy.core.multiarray import normalize_axis_index
from numpy.core.overrides import set_module
@@ -1461,12 +1461,12 @@ def eigh(a, UPLO='L'):
# Singular value decomposition
-def _svd_dispatcher(a, full_matrices=None, compute_uv=None):
+def _svd_dispatcher(a, full_matrices=None, compute_uv=None, hermitian=None):
return (a,)
@array_function_dispatch(_svd_dispatcher)
-def svd(a, full_matrices=True, compute_uv=True):
+def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
"""
Singular Value Decomposition.
@@ -1504,6 +1504,12 @@ def svd(a, full_matrices=True, compute_uv=True):
size as those of the input `a`. The size of the last two dimensions
depends on the value of `full_matrices`. Only returned when
`compute_uv` is True.
+ hermitian : bool, optional
+ If True, `a` is assumed to be Hermitian (symmetric if real-valued),
+ enabling a more efficient method for finding singular values.
+ Defaults to False.
+
+ ..versionadded:: 1.17.0
Raises
------
@@ -1590,6 +1596,24 @@ def svd(a, full_matrices=True, compute_uv=True):
"""
a, wrap = _makearray(a)
+
+ if hermitian:
+ # note: lapack returns eigenvalues in reverse order to our contract.
+ # reversing is cheap by design in numpy, so we do so to be consistent
+ if compute_uv:
+ s, u = eigh(a)
+ s = s[..., ::-1]
+ u = u[..., ::-1]
+ # singular values are unsigned, move the sign into v
+ vt = transpose(u * sign(s)[..., None, :]).conjugate()
+ s = abs(s)
+ return wrap(u), s, wrap(vt)
+ else:
+ s = eigvalsh(a)
+ s = s[..., ::-1]
+ s = abs(s)
+ return s
+
_assertRankAtLeast2(a)
t, result_t = _commonType(a)
@@ -1844,10 +1868,7 @@ def matrix_rank(M, tol=None, hermitian=False):
M = asarray(M)
if M.ndim < 2:
return int(not all(M==0))
- if hermitian:
- S = abs(eigvalsh(M))
- else:
- S = svd(M, compute_uv=False)
+ S = svd(M, compute_uv=False, hermitian=hermitian)
if tol is None:
tol = S.max(axis=-1, keepdims=True) * max(M.shape[-2:]) * finfo(S.dtype).eps
else:
@@ -1857,12 +1878,12 @@ def matrix_rank(M, tol=None, hermitian=False):
# Generalized inverse
-def _pinv_dispatcher(a, rcond=None):
+def _pinv_dispatcher(a, rcond=None, hermitian=None):
return (a,)
@array_function_dispatch(_pinv_dispatcher)
-def pinv(a, rcond=1e-15):
+def pinv(a, rcond=1e-15, hermitian=False):
"""
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
@@ -1882,6 +1903,12 @@ def pinv(a, rcond=1e-15):
Singular values smaller (in modulus) than
`rcond` * largest_singular_value (again, in modulus)
are set to zero. Broadcasts against the stack of matrices
+ hermitian : bool, optional
+ If True, `a` is assumed to be Hermitian (symmetric if real-valued),
+ enabling a more efficient method for finding singular values.
+ Defaults to False.
+
+ ..versionadded:: 1.17.0
Returns
-------
@@ -1935,7 +1962,7 @@ def pinv(a, rcond=1e-15):
res = empty(a.shape[:-2] + (n, m), dtype=a.dtype)
return wrap(res)
a = a.conjugate()
- u, s, vt = svd(a, full_matrices=False)
+ u, s, vt = svd(a, full_matrices=False, hermitian=hermitian)
# discard small singular values
cutoff = rcond[..., newaxis] * amax(s, axis=-1, keepdims=True)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index f95909cd1..831c059d0 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -633,6 +633,20 @@ class TestEig(EigCases):
assert_(isinstance(a, np.ndarray))
+class SVDBaseTests(object):
+ hermitian = False
+
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ u, s, vh = linalg.svd(x)
+ assert_equal(u.dtype, dtype)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+ assert_equal(vh.dtype, dtype)
+ s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+
+
class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):
@@ -644,32 +658,37 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
assert_(consistent_subclass(vt, a))
-class TestSVD(SVDCases):
- @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
- def test_types(self, dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- u, s, vh = linalg.svd(x)
- assert_equal(u.dtype, dtype)
- assert_equal(s.dtype, get_real_dtype(dtype))
- assert_equal(vh.dtype, dtype)
- s = linalg.svd(x, compute_uv=False)
- assert_equal(s.dtype, get_real_dtype(dtype))
-
+class TestSVD(SVDCases, SVDBaseTests):
def test_empty_identity(self):
""" Empty input should put an identity matrix in u or vh """
x = np.empty((4, 0))
- u, s, vh = linalg.svd(x, compute_uv=True)
+ u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
assert_equal(u.shape, (4, 4))
assert_equal(vh.shape, (0, 0))
assert_equal(u, np.eye(4))
x = np.empty((0, 4))
- u, s, vh = linalg.svd(x, compute_uv=True)
+ u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
assert_equal(u.shape, (0, 0))
assert_equal(vh.shape, (4, 4))
assert_equal(vh, np.eye(4))
+class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
+
+ def do(self, a, b, tags):
+ u, s, vt = linalg.svd(a, 0, hermitian=True)
+ assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
+ np.asarray(vt)),
+ rtol=get_rtol(u.dtype))
+ assert_(consistent_subclass(u, a))
+ assert_(consistent_subclass(vt, a))
+
+
+class TestSVDHermitian(SVDHermitianCases, SVDBaseTests):
+ hermitian = True
+
+
class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
# cond(x, p) for p in (None, 2, -2)
@@ -797,6 +816,20 @@ class TestPinv(PinvCases):
pass
+class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
+
+ def do(self, a, b, tags):
+ a_ginv = linalg.pinv(a, hermitian=True)
+ # `a @ a_ginv == I` does not hold if a is singular
+ dot = dot_generalized
+ assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
+ assert_(consistent_subclass(a_ginv, a))
+
+
+class TestPinvHermitian(PinvHermitianCases):
+ pass
+
+
class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):