summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2008-04-27 18:19:12 +0000
committerCharles Harris <charlesr.harris@gmail.com>2008-04-27 18:19:12 +0000
commit8d915a55c5ecbca15ebaf13584b0c255d22768a1 (patch)
treed52d327f96ba75883d48ec3e11470ebb4d68435c /numpy
parent82909417beb52eea0568fa04f54188ada227d66e (diff)
downloadnumpy-8d915a55c5ecbca15ebaf13584b0c255d22768a1.tar.gz
Add tests for matrix return types.
Fix cond computations for matrices. lstsq is currently broken for matrices, will fix shortly.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py7
-rw-r--r--numpy/linalg/tests/test_linalg.py44
2 files changed, 42 insertions, 9 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index d31ea527e..a557af875 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -27,7 +27,7 @@ from numpy.core import array, asarray, zeros, empty, transpose, \
isfinite, size
from numpy.lib import triu
from numpy.linalg import lapack_lite
-from numpy.core.defmatrix import matrix_power
+from numpy.core.defmatrix import matrix_power, matrix
fortran_int = intc
@@ -983,7 +983,7 @@ def svd(a, full_matrices=1, compute_uv=1):
else:
return wrap(s)
-def cond(x,p=None):
+def cond(x, p=None):
"""Compute the condition number of a matrix.
The condition number of x is the norm of x times the norm
@@ -1014,6 +1014,7 @@ def cond(x,p=None):
c : float
The condition number of the matrix. May be infinite.
"""
+ x = asarray(x) # in case we have a matrix
if p is None:
s = svd(x,compute_uv=False)
return s[0]/s[-1]
@@ -1146,7 +1147,7 @@ def lstsq(a, b, rcond=-1):
"""
import math
- a = asarray(a)
+ a = _makearray(a)
b, wrap = _makearray(b)
one_eq = len(b.shape) == 1
if one_eq:
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index b8a169a58..9a6e8bfc9 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -3,13 +3,19 @@
from numpy.testing import *
set_package_path()
-from numpy import array, single, double, csingle, cdouble, dot, identity, \
- multiply, atleast_2d, inf, asarray
+from numpy import array, single, double, csingle, cdouble, dot, identity
+from numpy import multiply, atleast_2d, inf, asarray, matrix
from numpy import linalg
from linalg import matrix_power
restore_path()
+def ifthen(a, b):
+ return not a or b
+
old_assert_almost_equal = assert_almost_equal
+def imply(a, b):
+ return not a or b
+
def assert_almost_equal(a, b, **kw):
if asarray(a).dtype.type in (single, csingle):
decimal = 6
@@ -52,41 +58,63 @@ class LinalgTestCase(NumpyTestCase):
b = [2, 1]
self.do(a,b)
+ def check_matrix_b_only(self):
+ """Check that matrix type is preserved."""
+ a = array([[1.,2.], [3.,4.]])
+ b = matrix([2., 1.]).T
+ self.do(a, b)
+
+ def check_matrix_a_and_b(self):
+ """Check that matrix type is preserved."""
+ a = matrix([[1.,2.], [3.,4.]])
+ b = matrix([2., 1.]).T
+ self.do(a, b)
+
class TestSolve(LinalgTestCase):
def do(self, a, b):
x = linalg.solve(a, b)
assert_almost_equal(b, dot(a, x))
+ assert imply(isinstance(b, matrix), isinstance(x, matrix))
class TestInv(LinalgTestCase):
def do(self, a, b):
a_inv = linalg.inv(a)
assert_almost_equal(dot(a, a_inv), identity(asarray(a).shape[0]))
+ assert imply(isinstance(a, matrix), isinstance(a_inv, matrix))
class TestEigvals(LinalgTestCase):
def do(self, a, b):
ev = linalg.eigvals(a)
evalues, evectors = linalg.eig(a)
assert_almost_equal(ev, evalues)
+ assert imply(isinstance(a, matrix), isinstance(ev, matrix))
class TestEig(LinalgTestCase):
def do(self, a, b):
evalues, evectors = linalg.eig(a)
- assert_almost_equal(dot(a, evectors), evectors*evalues)
+ assert_almost_equal(dot(a, evectors), multiply(evectors, evalues))
+ assert imply(isinstance(a, matrix), isinstance(evalues, matrix))
+ assert imply(isinstance(a, matrix), isinstance(evectors, matrix))
class TestSVD(LinalgTestCase):
def do(self, a, b):
u, s, vt = linalg.svd(a, 0)
- assert_almost_equal(a, dot(u*s, vt))
+ assert_almost_equal(a, dot(multiply(u, s), vt))
+ assert imply(isinstance(a, matrix), isinstance(u, matrix))
+ assert imply(isinstance(a, matrix), isinstance(s, matrix))
+ assert imply(isinstance(a, matrix), isinstance(vt, matrix))
class TestCondSVD(LinalgTestCase):
def do(self, a, b):
- s = linalg.svd(a, compute_uv=False)
+ c = asarray(a) # a might be a matrix
+ s = linalg.svd(c, compute_uv=False)
old_assert_almost_equal(s[0]/s[-1], linalg.cond(a), decimal=5)
class TestCond2(LinalgTestCase):
def do(self, a, b):
- s = linalg.svd(a, compute_uv=False)
+ c = asarray(a) # a might be a matrix
+ s = linalg.svd(c, compute_uv=False)
old_assert_almost_equal(s[0]/s[-1], linalg.cond(a,2), decimal=5)
class TestCondInf(NumpyTestCase):
@@ -98,6 +126,7 @@ class TestPinv(LinalgTestCase):
def do(self, a, b):
a_ginv = linalg.pinv(a)
assert_almost_equal(dot(a, a_ginv), identity(asarray(a).shape[0]))
+ assert imply(isinstance(a, matrix), isinstance(a_ginv, matrix))
class TestDet(LinalgTestCase):
def do(self, a, b):
@@ -116,6 +145,9 @@ class TestLstsq(LinalgTestCase):
assert_almost_equal(b, dot(a, x))
assert_equal(rank, asarray(a).shape[0])
assert_almost_equal(sv, s)
+ assert imply(isinstance(b, matrix), isinstance(x, matrix))
+ assert imply(isinstance(b, matrix), isinstance(residuals, matrix))
+ assert imply(isinstance(b, matrix), isinstance(sv, matrix))
class TestMatrixPower(ParametricTestCase):
R90 = array([[0,1],[-1,0]])