diff options
author | Matti Picus <matti.picus@gmail.com> | 2019-01-13 10:12:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-01-13 10:12:28 +0200 |
commit | 1176ae80daf6afdd1703fdfed7754cf71f363ebb (patch) | |
tree | 40f49dbfab9dad4c90619f1b2d531bcb2c9ae8d2 /numpy/linalg/tests/test_linalg.py | |
parent | 4a16f56a7e517665c161c6f0a590149153008aab (diff) | |
parent | 9d8681b69020f1f8d63cdacc178bc858069abd4f (diff) | |
download | numpy-1176ae80daf6afdd1703fdfed7754cf71f363ebb.tar.gz |
Merge pull request #12693 from eric-wieser/gh-9436-hermitian
ENH: Add a hermitian argument to `pinv` and `svd`, matching `matrix_rank`
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 59 |
1 files changed, 46 insertions, 13 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index f95909cd1..831c059d0 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -633,6 +633,20 @@ class TestEig(EigCases): assert_(isinstance(a, np.ndarray)) +class SVDBaseTests(object): + hermitian = False + + @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) + def test_types(self, dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + u, s, vh = linalg.svd(x) + assert_equal(u.dtype, dtype) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(vh.dtype, dtype) + s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian) + assert_equal(s.dtype, get_real_dtype(dtype)) + + class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): def do(self, a, b, tags): @@ -644,32 +658,37 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): assert_(consistent_subclass(vt, a)) -class TestSVD(SVDCases): - @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble]) - def test_types(self, dtype): - x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) - u, s, vh = linalg.svd(x) - assert_equal(u.dtype, dtype) - assert_equal(s.dtype, get_real_dtype(dtype)) - assert_equal(vh.dtype, dtype) - s = linalg.svd(x, compute_uv=False) - assert_equal(s.dtype, get_real_dtype(dtype)) - +class TestSVD(SVDCases, SVDBaseTests): def test_empty_identity(self): """ Empty input should put an identity matrix in u or vh """ x = np.empty((4, 0)) - u, s, vh = linalg.svd(x, compute_uv=True) + u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian) assert_equal(u.shape, (4, 4)) assert_equal(vh.shape, (0, 0)) assert_equal(u, np.eye(4)) x = np.empty((0, 4)) - u, s, vh = linalg.svd(x, compute_uv=True) + u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian) assert_equal(u.shape, (0, 0)) assert_equal(vh.shape, (4, 4)) assert_equal(vh, np.eye(4)) +class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase): + + def do(self, a, b, tags): + u, s, vt = linalg.svd(a, 0, hermitian=True) + assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :], + np.asarray(vt)), + rtol=get_rtol(u.dtype)) + assert_(consistent_subclass(u, a)) + assert_(consistent_subclass(vt, a)) + + +class TestSVDHermitian(SVDHermitianCases, SVDBaseTests): + hermitian = True + + class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): # cond(x, p) for p in (None, 2, -2) @@ -797,6 +816,20 @@ class TestPinv(PinvCases): pass +class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase): + + def do(self, a, b, tags): + a_ginv = linalg.pinv(a, hermitian=True) + # `a @ a_ginv == I` does not hold if a is singular + dot = dot_generalized + assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11) + assert_(consistent_subclass(a_ginv, a)) + + +class TestPinvHermitian(PinvHermitianCases): + pass + + class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): def do(self, a, b, tags): |