diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2015-09-25 17:54:24 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2015-09-25 18:05:28 +0200 |
commit | ae56c58db4207bd11100a9d24c9edf7694e34d67 (patch) | |
tree | ba9886cc3ef9c3d48553d3dfa31ed134853dc047 | |
parent | 1765438b5f68eeb5c9b920e8df2760dc8e908cae (diff) | |
download | numpy-ae56c58db4207bd11100a9d24c9edf7694e34d67.tar.gz |
BUG,ENH: allow linalg.cond to work on a stack of matrices
This was buggy, because the underlying functions supported it
partially but cond was not aware of this.
Closes gh-6351
-rw-r--r-- | numpy/linalg/linalg.py | 14 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 14 |
2 files changed, 19 insertions, 9 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index a2405c180..f5cb3cb77 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1012,9 +1012,9 @@ def eig(a): w : (..., M) array The eigenvalues, each repeated according to its multiplicity. The eigenvalues are not necessarily ordered. The resulting - array will be of complex type, unless the imaginary part is - zero in which case it will be cast to a real type. When `a` - is real the resulting eigenvalues will be real (0 imaginary + array will be of complex type, unless the imaginary part is + zero in which case it will be cast to a real type. When `a` + is real the resulting eigenvalues will be real (0 imaginary part) or occur in conjugate pairs v : (..., M, M) array @@ -1382,7 +1382,7 @@ def cond(x, p=None): Parameters ---------- - x : (M, N) array_like + x : (..., M, N) array_like The matrix whose condition number is sought. p : {None, 1, -1, 2, -2, inf, -inf, 'fro'}, optional Order of the norm: @@ -1451,12 +1451,12 @@ def cond(x, p=None): 0.70710678118654746 """ - x = asarray(x) # in case we have a matrix + x = asarray(x) # in case we have a matrix if p is None: s = svd(x, compute_uv=False) - return s[0]/s[-1] + return s[..., 0]/s[..., -1] else: - return norm(x, p)*norm(inv(x), p) + return norm(x, p, axis=(-2, -1)) * norm(inv(x), p, axis=(-2, -1)) def matrix_rank(M, tol=None): diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index aedcc6a95..7c577d86f 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -556,7 +556,12 @@ class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase): def do(self, a, b): c = asarray(a) # a might be a matrix s = linalg.svd(c, compute_uv=False) - old_assert_almost_equal(s[0] / s[-1], linalg.cond(a), decimal=5) + old_assert_almost_equal( + s[..., 0] / s[..., -1], linalg.cond(a), decimal=5) + + def test_stacked_arrays_explicitly(self): + A = np.array([[1., 2., 1.], [0, -2., 0], [6., 2., 3.]]) + assert_equal(linalg.cond(A), linalg.cond(A[None, ...])[0]) class TestCond2(LinalgTestCase): @@ -564,7 +569,12 @@ class TestCond2(LinalgTestCase): def do(self, a, b): c = asarray(a) # a might be a matrix s = linalg.svd(c, compute_uv=False) - old_assert_almost_equal(s[0] / s[-1], linalg.cond(a, 2), decimal=5) + old_assert_almost_equal( + s[..., 0] / s[..., -1], linalg.cond(a, 2), decimal=5) + + def test_stacked_arrays_explicitly(self): + A = np.array([[1., 2., 1.], [0, -2., 0], [6., 2., 3.]]) + assert_equal(linalg.cond(A, 2), linalg.cond(A[None, ...], 2)[0]) class TestCondInf(object): |