diff options
author | Pauli Virtanen <pav@iki.fi> | 2013-04-12 18:51:04 +0300 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2013-04-12 19:00:06 +0300 |
commit | fb9b5bd2d71c92f64f151f812b4cee08a971eb2e (patch) | |
tree | c25258c33ff98a6183138628f5ef8ba5220ee241 | |
parent | aa8fde0f62a133319cfac8e8da208fcd8e224ef1 (diff) | |
download | numpy-fb9b5bd2d71c92f64f151f812b4cee08a971eb2e.tar.gz |
TST: linalg: test return types of generalized linalg routines
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index af9bf884d..84c95af10 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -27,6 +27,14 @@ def assert_almost_equal(a, b, **kw): decimal = 12 old_assert_almost_equal(a, b, decimal=decimal, **kw) +def get_real_dtype(dtype): + return {single: single, double: double, + csingle: single, cdouble: double}[dtype] + +def get_complex_dtype(dtype): + return {single: csingle, double: cdouble, + csingle: csingle, cdouble: cdouble}[dtype] + class LinalgTestCase(object): def test_single(self): @@ -85,6 +93,7 @@ class LinalgTestCase(object): b = matrix([2., 1.]).T self.do(a, b) + class LinalgNonsquareTestCase(object): def test_single_nsq_1(self): a = array([[1.,2.,3.], [3.,4.,6.]], dtype=single) @@ -187,6 +196,13 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): assert_almost_equal(b, dot_generalized(a, x)) assert_(imply(isinstance(b, matrix), isinstance(x, matrix))) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.solve(x, x).dtype, dtype) + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -195,6 +211,13 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): identity_like_generalized(a)) assert_(imply(isinstance(a, matrix), isinstance(a_inv, matrix))) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.inv(x).dtype, dtype) + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -202,6 +225,15 @@ class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(linalg.eigvals(x).dtype, dtype) + x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) + assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype)) + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class TestEig(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -212,6 +244,21 @@ class TestEig(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): assert_almost_equal(dot(a, evectors), multiply(evectors, evalues)) assert_(imply(isinstance(a, matrix), isinstance(evectors, matrix))) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w.dtype, dtype) + assert_equal(v.dtype, dtype) + + x = np.array([[1, 0.5], [-1, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w.dtype, get_complex_dtype(dtype)) + assert_equal(v.dtype, get_complex_dtype(dtype)) + + for dtype in [single, double, csingle, cdouble]: + yield dtype + class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -223,6 +270,19 @@ class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): assert_(imply(isinstance(a, matrix), isinstance(u, matrix))) assert_(imply(isinstance(a, matrix), isinstance(vt, matrix))) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + u, s, vh = linalg.svd(x) + assert_equal(u.dtype, dtype) + assert_equal(s.dtype, get_real_dtype(dtype)) + assert_equal(vh.dtype, dtype) + s = linalg.svd(x, compute_uv=False) + assert_equal(s.dtype, get_real_dtype(dtype)) + + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): def do(self, a, b): @@ -282,6 +342,13 @@ class TestDet(LinalgTestCase, LinalgGeneralizedTestCase, TestCase): assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble) assert_equal(type(linalg.slogdet([[0.0j]])[1]), double) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(np.linalg.det(x), get_real_dtype(dtype)) + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class TestLstsq(LinalgTestCase, LinalgNonsquareTestCase, TestCase): def do(self, a, b): @@ -418,6 +485,13 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase): evalues.sort(axis=-1) assert_almost_equal(ev, evalues) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + assert_equal(np.linalg.eigvalsh(x), get_real_dtype(dtype)) + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase): def do(self, a, b): @@ -429,6 +503,15 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase): evalues.sort(axis=-1) assert_almost_equal(ev, evalues) + def test_types(self): + def check(dtype): + x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype) + w, v = np.linalg.eig(x) + assert_equal(w, get_real_dtype(dtype)) + assert_equal(v, dtype) + for dtype in [single, double, csingle, cdouble]: + yield check, dtype + class _TestNorm(TestCase): dt = None |