summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-04-12 18:51:04 +0300
committerPauli Virtanen <pav@iki.fi>2013-04-12 19:00:06 +0300
commitfb9b5bd2d71c92f64f151f812b4cee08a971eb2e (patch)
treec25258c33ff98a6183138628f5ef8ba5220ee241
parentaa8fde0f62a133319cfac8e8da208fcd8e224ef1 (diff)
downloadnumpy-fb9b5bd2d71c92f64f151f812b4cee08a971eb2e.tar.gz
TST: linalg: test return types of generalized linalg routines
-rw-r--r--numpy/linalg/tests/test_linalg.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index af9bf884d..84c95af10 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -27,6 +27,14 @@ def assert_almost_equal(a, b, **kw):
decimal = 12
old_assert_almost_equal(a, b, decimal=decimal, **kw)
+def get_real_dtype(dtype):
+ return {single: single, double: double,
+ csingle: single, cdouble: double}[dtype]
+
+def get_complex_dtype(dtype):
+ return {single: csingle, double: cdouble,
+ csingle: csingle, cdouble: cdouble}[dtype]
+
class LinalgTestCase(object):
def test_single(self):
@@ -85,6 +93,7 @@ class LinalgTestCase(object):
b = matrix([2., 1.]).T
self.do(a, b)
+
class LinalgNonsquareTestCase(object):
def test_single_nsq_1(self):
a = array([[1.,2.,3.], [3.,4.,6.]], dtype=single)
@@ -187,6 +196,13 @@ class TestSolve(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
assert_almost_equal(b, dot_generalized(a, x))
assert_(imply(isinstance(b, matrix), isinstance(x, matrix)))
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(linalg.solve(x, x).dtype, dtype)
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -195,6 +211,13 @@ class TestInv(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
identity_like_generalized(a))
assert_(imply(isinstance(a, matrix), isinstance(a_inv, matrix)))
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(linalg.inv(x).dtype, dtype)
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -202,6 +225,15 @@ class TestEigvals(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
evalues, evectors = linalg.eig(a)
assert_almost_equal(ev, evalues)
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(linalg.eigvals(x).dtype, dtype)
+ x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
+ assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class TestEig(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -212,6 +244,21 @@ class TestEig(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
assert_almost_equal(dot(a, evectors), multiply(evectors, evalues))
assert_(imply(isinstance(a, matrix), isinstance(evectors, matrix)))
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ w, v = np.linalg.eig(x)
+ assert_equal(w.dtype, dtype)
+ assert_equal(v.dtype, dtype)
+
+ x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
+ w, v = np.linalg.eig(x)
+ assert_equal(w.dtype, get_complex_dtype(dtype))
+ assert_equal(v.dtype, get_complex_dtype(dtype))
+
+ for dtype in [single, double, csingle, cdouble]:
+ yield dtype
+
class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -223,6 +270,19 @@ class TestSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
assert_(imply(isinstance(a, matrix), isinstance(u, matrix)))
assert_(imply(isinstance(a, matrix), isinstance(vt, matrix)))
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ u, s, vh = linalg.svd(x)
+ assert_equal(u.dtype, dtype)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+ assert_equal(vh.dtype, dtype)
+ s = linalg.svd(x, compute_uv=False)
+ assert_equal(s.dtype, get_real_dtype(dtype))
+
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -282,6 +342,13 @@ class TestDet(LinalgTestCase, LinalgGeneralizedTestCase, TestCase):
assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble)
assert_equal(type(linalg.slogdet([[0.0j]])[1]), double)
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(np.linalg.det(x), get_real_dtype(dtype))
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class TestLstsq(LinalgTestCase, LinalgNonsquareTestCase, TestCase):
def do(self, a, b):
@@ -418,6 +485,13 @@ class TestEigvalsh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase):
evalues.sort(axis=-1)
assert_almost_equal(ev, evalues)
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ assert_equal(np.linalg.eigvalsh(x), get_real_dtype(dtype))
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase):
def do(self, a, b):
@@ -429,6 +503,15 @@ class TestEigh(HermitianTestCase, HermitianGeneralizedTestCase, TestCase):
evalues.sort(axis=-1)
assert_almost_equal(ev, evalues)
+ def test_types(self):
+ def check(dtype):
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
+ w, v = np.linalg.eig(x)
+ assert_equal(w, get_real_dtype(dtype))
+ assert_equal(v, dtype)
+ for dtype in [single, double, csingle, cdouble]:
+ yield check, dtype
+
class _TestNorm(TestCase):
dt = None