summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2008-04-27 15:27:30 +0000
committerCharles Harris <charlesr.harris@gmail.com>2008-04-27 15:27:30 +0000
commit82909417beb52eea0568fa04f54188ada227d66e (patch)
treeda3776d2b98db789c5a082c217d194f13d9f1c91
parenta5b626c2915f3af4897a5a5bd5e29f815810b5c8 (diff)
downloadnumpy-82909417beb52eea0568fa04f54188ada227d66e.tar.gz
Make functions in linalg.py accept nestes lists.
Use wrap to keep matrix environment intact. Base patch from nmb.
-rw-r--r--numpy/linalg/linalg.py91
-rw-r--r--numpy/linalg/tests/test_linalg.py22
2 files changed, 69 insertions, 44 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index f56005292..d31ea527e 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -148,10 +148,10 @@ def tensorsolve(a, b, axes=None):
Parameters
----------
- a : array, shape b.shape+Q
+ a : array-like, shape b.shape+Q
Coefficient tensor. Shape Q of the rightmost indices of a must
be such that a is 'square', ie., prod(Q) == prod(b.shape).
- b : array, any shape
+ b : array-like, any shape
Right-hand tensor.
axes : tuple of integers
Axes in a to reorder to the right, before inversion.
@@ -192,7 +192,7 @@ def tensorsolve(a, b, axes=None):
a = a.reshape(-1, prod)
b = b.ravel()
- res = solve(a, b)
+ res = wrap(solve(a, b))
res.shape = oldshape
return res
@@ -201,8 +201,8 @@ def solve(a, b):
Parameters
----------
- a : array, shape (M, M)
- b : array, shape (M,)
+ a : array-like, shape (M, M)
+ b : array-like, shape (M,)
Returns
-------
@@ -211,6 +211,8 @@ def solve(a, b):
Raises LinAlgError if a is singular or not square
"""
+ a, _ = _makearray(a)
+ b, wrap = _makearray(b)
one_eq = len(b.shape) == 1
if one_eq:
b = b[:, newaxis]
@@ -232,9 +234,9 @@ def solve(a, b):
if results['info'] > 0:
raise LinAlgError, 'Singular matrix'
if one_eq:
- return b.ravel().astype(result_t)
+ return wrap(b.ravel().astype(result_t))
else:
- return b.transpose().astype(result_t)
+ return wrap(b.transpose().astype(result_t))
def tensorinv(a, ind=2):
@@ -250,7 +252,7 @@ def tensorinv(a, ind=2):
Parameters
----------
- a : array
+ a : array-like
Tensor to 'invert'. Its shape must 'square', ie.,
prod(a.shape[:ind]) == prod(a.shape[ind:])
ind : integer > 0
@@ -340,12 +342,12 @@ def cholesky(a):
Parameters
----------
- a : array, shape (M, M)
+ a : array-like, shape (M, M)
Matrix to be decomposed
Returns
-------
- L : array, shape (M, M)
+ L : array-like, shape (M, M)
Lower-triangular Cholesky factor of A
Raises LinAlgError if decomposition fails
@@ -363,6 +365,7 @@ def cholesky(a):
[ 0.+2.j, 5.+0.j]])
"""
+ a, wrap = _makearray(a)
_assertRank2(a)
_assertSquareness(a)
t, result_t = _commonType(a)
@@ -379,8 +382,8 @@ def cholesky(a):
Cholesky decomposition cannot be computed'
s = triu(a, k=0).transpose()
if (s.dtype != result_t):
- return s.astype(result_t)
- return s
+ s = s.astype(result_t)
+ return wrap(s)
# QR decompostion
@@ -392,7 +395,7 @@ def qr(a, mode='full'):
Parameters
----------
- a : array, shape (M, N)
+ a : array-like, shape (M, N)
Matrix to be decomposed
mode : {'full', 'r', 'economic'}
Determines what information is to be returned. 'full' is the default.
@@ -413,6 +416,8 @@ def qr(a, mode='full'):
The diagonal and the upper triangle of A2 contains R,
while the rest of the matrix is undefined.
+ If a is a matrix, so are all the return values.
+
Raises LinAlgError if decomposition fails
Notes
@@ -435,6 +440,7 @@ def qr(a, mode='full'):
True
"""
+ a, wrap = _makearray(a)
_assertRank2(a)
m, n = a.shape
t, result_t = _commonType(a)
@@ -503,7 +509,7 @@ def qr(a, mode='full'):
q = _fastCopyAndTranspose(result_t, a[:mn,:])
- return q, r
+ return wrap(q), wrap(r)
# Eigenvalues
@@ -514,7 +520,7 @@ def eigvals(a):
Parameters
----------
- a : array, shape (M, M)
+ a : array-like, shape (M, M)
A complex or real matrix whose eigenvalues and eigenvectors
will be computed.
@@ -525,6 +531,8 @@ def eigvals(a):
They are not necessarily ordered, nor are they necessarily
real for real matrices.
+ If a is a matrix, so is w.
+
Raises LinAlgError if eigenvalue computation does not converge
See Also
@@ -545,6 +553,7 @@ def eigvals(a):
determinant and I is the identity matrix.
"""
+ a, wrap = _makearray(a)
_assertRank2(a)
_assertSquareness(a)
_assertFinite(a)
@@ -585,7 +594,7 @@ def eigvals(a):
result_t = _complexType(result_t)
if results['info'] > 0:
raise LinAlgError, 'Eigenvalues did not converge'
- return w.astype(result_t)
+ return wrap(w.astype(result_t))
def eigvalsh(a, UPLO='L'):
@@ -593,7 +602,7 @@ def eigvalsh(a, UPLO='L'):
Parameters
----------
- a : array, shape (M, M)
+ a : array-like, shape (M, M)
A complex or real matrix whose eigenvalues and eigenvectors
will be computed.
UPLO : {'L', 'U'}
@@ -627,6 +636,7 @@ def eigvalsh(a, UPLO='L'):
determinant and I is the identity matrix.
"""
+ a, wrap = _makearray(a)
_assertRank2(a)
_assertSquareness(a)
t, result_t = _commonType(a)
@@ -663,7 +673,7 @@ def eigvalsh(a, UPLO='L'):
iwork, liwork, 0)
if results['info'] > 0:
raise LinAlgError, 'Eigenvalues did not converge'
- return w.astype(result_t)
+ return wrap(w.astype(result_t))
def _convertarray(a):
t, result_t = _commonType(a)
@@ -679,7 +689,7 @@ def eig(a):
Parameters
----------
- a : array, shape (M, M)
+ a : array-like, shape (M, M)
A complex or real 2-d array whose eigenvalues and eigenvectors
will be computed.
@@ -693,6 +703,8 @@ def eig(a):
The normalized eigenvector corresponding to the eigenvalue w[i] is
the column v[:,i].
+ If a is a matrix, so are all the return values.
+
Raises LinAlgError if eigenvalue computation does not converge
See Also
@@ -774,7 +786,7 @@ def eig(a):
if results['info'] > 0:
raise LinAlgError, 'Eigenvalues did not converge'
vt = v.transpose().astype(result_t)
- return w.astype(result_t), wrap(vt)
+ return wrap(w.astype(result_t)), wrap(vt)
def eigh(a, UPLO='L'):
@@ -782,7 +794,7 @@ def eigh(a, UPLO='L'):
Parameters
----------
- a : array, shape (M, M)
+ a : array-like, shape (M, M)
A complex Hermitian or symmetric real matrix whose eigenvalues
and eigenvectors will be computed.
UPLO : {'L', 'U'}
@@ -798,6 +810,8 @@ def eigh(a, UPLO='L'):
The normalized eigenvector corresponding to the eigenvalue w[i] is
the column v[:,i].
+ If a is a matrix, then so are the return values.
+
Raises LinAlgError if eigenvalue computation does not converge
See Also
@@ -858,7 +872,7 @@ def eigh(a, UPLO='L'):
if results['info'] > 0:
raise LinAlgError, 'Eigenvalues did not converge'
at = a.transpose().astype(result_t)
- return w.astype(_realType(result_t)), wrap(at)
+ return wrap(w.astype(_realType(result_t))), wrap(at)
# Singular value decomposition
@@ -873,13 +887,13 @@ def svd(a, full_matrices=1, compute_uv=1):
Parameters
----------
- a : array, shape (M, N)
+ a : array-like, shape (M, N)
Matrix to decompose
full_matrices : boolean
If true, U, Vh are shaped (M,M), (N,N)
If false, the shapes are (M,K), (K,N) where K = min(M,N)
compute_uv : boolean
- Whether to compute also U, Vh in addition to s
+ Whether to compute U and Vh in addition to s
Returns
-------
@@ -889,7 +903,7 @@ def svd(a, full_matrices=1, compute_uv=1):
K = min(M, N)
Vh: array, shape (N,N) or (K,N) depending on full_matrices
- For compute_uv = False, only s is returned.
+ If a is a matrix, so are all the return values.
Raises LinAlgError if SVD computation does not converge
@@ -965,9 +979,9 @@ def svd(a, full_matrices=1, compute_uv=1):
if compute_uv:
u = u.transpose().astype(result_t)
vt = vt.transpose().astype(result_t)
- return wrap(u), s, wrap(vt)
+ return wrap(u), wrap(s), wrap(vt)
else:
- return s
+ return wrap(s)
def cond(x,p=None):
"""Compute the condition number of a matrix.
@@ -978,7 +992,7 @@ def cond(x,p=None):
Parameters
----------
- x : array, shape (M, N)
+ x : array-like, shape (M, N)
The matrix whose condition number is sought.
p : {None, 1, -1, 2, -2, inf, -inf, 'fro'}
Order of the norm:
@@ -1017,7 +1031,7 @@ def pinv(a, rcond=1e-15 ):
Parameters
----------
- a : array, shape (M, N)
+ a : array-like, shape (M, N)
Matrix to be pseudo-inverted
rcond : float
Cutoff for 'small' singular values.
@@ -1027,6 +1041,7 @@ def pinv(a, rcond=1e-15 ):
Returns
-------
B : array, shape (N, M)
+ If a is a matrix, then so is B.
Raises LinAlgError if SVD computation does not converge
@@ -1053,8 +1068,8 @@ def pinv(a, rcond=1e-15 ):
s[i] = 1./s[i]
else:
s[i] = 0.;
- return wrap(dot(transpose(vt),
- multiply(s[:, newaxis],transpose(u))))
+ res = dot(transpose(vt), multiply(s[:, newaxis],transpose(u)))
+ return wrap(res)
# Determinant
@@ -1063,7 +1078,7 @@ def det(a):
Parameters
----------
- a : array, shape (M, M)
+ a : array-like, shape (M, M)
Returns
-------
@@ -1104,8 +1119,8 @@ def lstsq(a, b, rcond=-1):
Parameters
----------
- a : array, shape (M, N)
- b : array, shape (M,) or (M, K)
+ a : array-like, shape (M, N)
+ b : array-like, shape (M,) or (M, K)
rcond : float
Cutoff for 'small' singular values.
Singular values smaller than rcond*largest_singular_value are
@@ -1125,6 +1140,10 @@ def lstsq(a, b, rcond=-1):
Rank of matrix a
s : array, shape (min(M,N),)
Singular values of a
+
+ If b is a matrix, then all results except the rank are also returned as
+ matrices.
+
"""
import math
a = asarray(a)
@@ -1188,14 +1207,14 @@ def lstsq(a, b, rcond=-1):
if results['rank'] == n and m > n:
resids = sum((transpose(bstar)[n:,:])**2, axis=0).astype(result_t)
st = s[:min(n, m)].copy().astype(_realType(result_t))
- return wrap(x), resids, results['rank'], st
+ return wrap(x), wrap(resids), results['rank'], wrap(st)
def norm(x, ord=None):
"""Matrix or vector norm.
Parameters
----------
- x : array, shape (M,) or (M, N)
+ x : array-like, shape (M,) or (M, N)
ord : number, or {None, 1, -1, 2, -2, inf, -inf, 'fro'}
Order of the norm:
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 7d2390980..b8a169a58 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -4,14 +4,14 @@
from numpy.testing import *
set_package_path()
from numpy import array, single, double, csingle, cdouble, dot, identity, \
- multiply, atleast_2d, inf
+ multiply, atleast_2d, inf, asarray
from numpy import linalg
from linalg import matrix_power
restore_path()
old_assert_almost_equal = assert_almost_equal
def assert_almost_equal(a, b, **kw):
- if a.dtype.type in (single, csingle):
+ if asarray(a).dtype.type in (single, csingle):
decimal = 6
else:
decimal = 12
@@ -47,6 +47,12 @@ class LinalgTestCase(NumpyTestCase):
except linalg.LinAlgError, e:
pass
+ def check_nonarray(self):
+ a = [[1,2], [3,4]]
+ b = [2, 1]
+ self.do(a,b)
+
+
class TestSolve(LinalgTestCase):
def do(self, a, b):
x = linalg.solve(a, b)
@@ -55,7 +61,7 @@ class TestSolve(LinalgTestCase):
class TestInv(LinalgTestCase):
def do(self, a, b):
a_inv = linalg.inv(a)
- assert_almost_equal(dot(a, a_inv), identity(a.shape[0]))
+ assert_almost_equal(dot(a, a_inv), identity(asarray(a).shape[0]))
class TestEigvals(LinalgTestCase):
def do(self, a, b):
@@ -91,15 +97,15 @@ class TestCondInf(NumpyTestCase):
class TestPinv(LinalgTestCase):
def do(self, a, b):
a_ginv = linalg.pinv(a)
- assert_almost_equal(dot(a, a_ginv), identity(a.shape[0]))
+ assert_almost_equal(dot(a, a_ginv), identity(asarray(a).shape[0]))
class TestDet(LinalgTestCase):
def do(self, a, b):
d = linalg.det(a)
- if a.dtype.type in (single, double):
- ad = a.astype(double)
+ if asarray(a).dtype.type in (single, double):
+ ad = asarray(a).astype(double)
else:
- ad = a.astype(cdouble)
+ ad = asarray(a).astype(cdouble)
ev = linalg.eigvals(ad)
assert_almost_equal(d, multiply.reduce(ev))
@@ -108,7 +114,7 @@ class TestLstsq(LinalgTestCase):
u, s, vt = linalg.svd(a, 0)
x, residuals, rank, sv = linalg.lstsq(a, b)
assert_almost_equal(b, dot(a, x))
- assert_equal(rank, a.shape[0])
+ assert_equal(rank, asarray(a).shape[0])
assert_almost_equal(sv, s)
class TestMatrixPower(ParametricTestCase):