summaryrefslogtreecommitdiff
path: root/numpy/linalg/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/linalg/tests')
-rw-r--r--numpy/linalg/tests/test_linalg.py59
1 files changed, 46 insertions, 13 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index f95909cd1..831c059d0 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -633,6 +633,20 @@ class TestEig(EigCases):
assert_(isinstance(a, np.ndarray))
+class SVDBaseTests(object):
+ hermitian = False
+
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
+ def test_types(self, dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ u, s, vh = linalg.svd(x)
+ assert_equal(u.dtype, dtype)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+ assert_equal(vh.dtype, dtype)
+ s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+
+
class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):
@@ -644,32 +658,37 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
assert_(consistent_subclass(vt, a))
-class TestSVD(SVDCases):
- @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
- def test_types(self, dtype):
- x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
- u, s, vh = linalg.svd(x)
- assert_equal(u.dtype, dtype)
- assert_equal(s.dtype, get_real_dtype(dtype))
- assert_equal(vh.dtype, dtype)
- s = linalg.svd(x, compute_uv=False)
- assert_equal(s.dtype, get_real_dtype(dtype))
-
+class TestSVD(SVDCases, SVDBaseTests):
def test_empty_identity(self):
""" Empty input should put an identity matrix in u or vh """
x = np.empty((4, 0))
- u, s, vh = linalg.svd(x, compute_uv=True)
+ u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
assert_equal(u.shape, (4, 4))
assert_equal(vh.shape, (0, 0))
assert_equal(u, np.eye(4))
x = np.empty((0, 4))
- u, s, vh = linalg.svd(x, compute_uv=True)
+ u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
assert_equal(u.shape, (0, 0))
assert_equal(vh.shape, (4, 4))
assert_equal(vh, np.eye(4))
+class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
+
+ def do(self, a, b, tags):
+ u, s, vt = linalg.svd(a, 0, hermitian=True)
+ assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
+ np.asarray(vt)),
+ rtol=get_rtol(u.dtype))
+ assert_(consistent_subclass(u, a))
+ assert_(consistent_subclass(vt, a))
+
+
+class TestSVDHermitian(SVDHermitianCases, SVDBaseTests):
+ hermitian = True
+
+
class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
# cond(x, p) for p in (None, 2, -2)
@@ -797,6 +816,20 @@ class TestPinv(PinvCases):
pass
+class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
+
+ def do(self, a, b, tags):
+ a_ginv = linalg.pinv(a, hermitian=True)
+ # `a @ a_ginv == I` does not hold if a is singular
+ dot = dot_generalized
+ assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
+ assert_(consistent_subclass(a_ginv, a))
+
+
+class TestPinvHermitian(PinvHermitianCases):
+ pass
+
+
class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
def do(self, a, b, tags):