diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2008-04-27 18:19:12 +0000 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2008-04-27 18:19:12 +0000 |
commit | 8d915a55c5ecbca15ebaf13584b0c255d22768a1 (patch) | |
tree | d52d327f96ba75883d48ec3e11470ebb4d68435c /numpy | |
parent | 82909417beb52eea0568fa04f54188ada227d66e (diff) | |
download | numpy-8d915a55c5ecbca15ebaf13584b0c255d22768a1.tar.gz |
Add tests for matrix return types.
Fix cond computations for matrices.
lstsq is currently broken for matrices, will fix shortly.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 7 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 44 |
2 files changed, 42 insertions, 9 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index d31ea527e..a557af875 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -27,7 +27,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \ isfinite, size from numpy.lib import triu from numpy.linalg import lapack_lite -from numpy.core.defmatrix import matrix_power +from numpy.core.defmatrix import matrix_power, matrix fortran_int = intc @@ -983,7 +983,7 @@ def svd(a, full_matrices=1, compute_uv=1): else: return wrap(s) -def cond(x,p=None): +def cond(x, p=None): """Compute the condition number of a matrix. The condition number of x is the norm of x times the norm @@ -1014,6 +1014,7 @@ def cond(x,p=None): c : float The condition number of the matrix. May be infinite. """ + x = asarray(x) # in case we have a matrix if p is None: s = svd(x,compute_uv=False) return s[0]/s[-1] @@ -1146,7 +1147,7 @@ def lstsq(a, b, rcond=-1): """ import math - a = asarray(a) + a = _makearray(a) b, wrap = _makearray(b) one_eq = len(b.shape) == 1 if one_eq: diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index b8a169a58..9a6e8bfc9 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -3,13 +3,19 @@ from numpy.testing import * set_package_path() -from numpy import array, single, double, csingle, cdouble, dot, identity, \ - multiply, atleast_2d, inf, asarray +from numpy import array, single, double, csingle, cdouble, dot, identity +from numpy import multiply, atleast_2d, inf, asarray, matrix from numpy import linalg from linalg import matrix_power restore_path() +def ifthen(a, b): + return not a or b + old_assert_almost_equal = assert_almost_equal +def imply(a, b): + return not a or b + def assert_almost_equal(a, b, **kw): if asarray(a).dtype.type in (single, csingle): decimal = 6 @@ -52,41 +58,63 @@ class LinalgTestCase(NumpyTestCase): b = [2, 1] self.do(a,b) + def check_matrix_b_only(self): + """Check that matrix type is preserved.""" + a = array([[1.,2.], [3.,4.]]) + b = matrix([2., 1.]).T + self.do(a, b) + + def check_matrix_a_and_b(self): + """Check that matrix type is preserved.""" + a = matrix([[1.,2.], [3.,4.]]) + b = matrix([2., 1.]).T + self.do(a, b) + class TestSolve(LinalgTestCase): def do(self, a, b): x = linalg.solve(a, b) assert_almost_equal(b, dot(a, x)) + assert imply(isinstance(b, matrix), isinstance(x, matrix)) class TestInv(LinalgTestCase): def do(self, a, b): a_inv = linalg.inv(a) assert_almost_equal(dot(a, a_inv), identity(asarray(a).shape[0])) + assert imply(isinstance(a, matrix), isinstance(a_inv, matrix)) class TestEigvals(LinalgTestCase): def do(self, a, b): ev = linalg.eigvals(a) evalues, evectors = linalg.eig(a) assert_almost_equal(ev, evalues) + assert imply(isinstance(a, matrix), isinstance(ev, matrix)) class TestEig(LinalgTestCase): def do(self, a, b): evalues, evectors = linalg.eig(a) - assert_almost_equal(dot(a, evectors), evectors*evalues) + assert_almost_equal(dot(a, evectors), multiply(evectors, evalues)) + assert imply(isinstance(a, matrix), isinstance(evalues, matrix)) + assert imply(isinstance(a, matrix), isinstance(evectors, matrix)) class TestSVD(LinalgTestCase): def do(self, a, b): u, s, vt = linalg.svd(a, 0) - assert_almost_equal(a, dot(u*s, vt)) + assert_almost_equal(a, dot(multiply(u, s), vt)) + assert imply(isinstance(a, matrix), isinstance(u, matrix)) + assert imply(isinstance(a, matrix), isinstance(s, matrix)) + assert imply(isinstance(a, matrix), isinstance(vt, matrix)) class TestCondSVD(LinalgTestCase): def do(self, a, b): - s = linalg.svd(a, compute_uv=False) + c = asarray(a) # a might be a matrix + s = linalg.svd(c, compute_uv=False) old_assert_almost_equal(s[0]/s[-1], linalg.cond(a), decimal=5) class TestCond2(LinalgTestCase): def do(self, a, b): - s = linalg.svd(a, compute_uv=False) + c = asarray(a) # a might be a matrix + s = linalg.svd(c, compute_uv=False) old_assert_almost_equal(s[0]/s[-1], linalg.cond(a,2), decimal=5) class TestCondInf(NumpyTestCase): @@ -98,6 +126,7 @@ class TestPinv(LinalgTestCase): def do(self, a, b): a_ginv = linalg.pinv(a) assert_almost_equal(dot(a, a_ginv), identity(asarray(a).shape[0])) + assert imply(isinstance(a, matrix), isinstance(a_ginv, matrix)) class TestDet(LinalgTestCase): def do(self, a, b): @@ -116,6 +145,9 @@ class TestLstsq(LinalgTestCase): assert_almost_equal(b, dot(a, x)) assert_equal(rank, asarray(a).shape[0]) assert_almost_equal(sv, s) + assert imply(isinstance(b, matrix), isinstance(x, matrix)) + assert imply(isinstance(b, matrix), isinstance(residuals, matrix)) + assert imply(isinstance(b, matrix), isinstance(sv, matrix)) class TestMatrixPower(ParametricTestCase): R90 = array([[0,1],[-1,0]]) |