diff options
Diffstat (limited to 'numpy/linalg/tests/test_linalg.py')
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index aedcc6a95..7c577d86f 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -556,7 +556,12 @@ class TestCondSVD(LinalgTestCase, LinalgGeneralizedTestCase): def do(self, a, b): c = asarray(a) # a might be a matrix s = linalg.svd(c, compute_uv=False) - old_assert_almost_equal(s[0] / s[-1], linalg.cond(a), decimal=5) + old_assert_almost_equal( + s[..., 0] / s[..., -1], linalg.cond(a), decimal=5) + + def test_stacked_arrays_explicitly(self): + A = np.array([[1., 2., 1.], [0, -2., 0], [6., 2., 3.]]) + assert_equal(linalg.cond(A), linalg.cond(A[None, ...])[0]) class TestCond2(LinalgTestCase): @@ -564,7 +569,12 @@ class TestCond2(LinalgTestCase): def do(self, a, b): c = asarray(a) # a might be a matrix s = linalg.svd(c, compute_uv=False) - old_assert_almost_equal(s[0] / s[-1], linalg.cond(a, 2), decimal=5) + old_assert_almost_equal( + s[..., 0] / s[..., -1], linalg.cond(a, 2), decimal=5) + + def test_stacked_arrays_explicitly(self): + A = np.array([[1., 2., 1.], [0, -2., 0], [6., 2., 3.]]) + assert_equal(linalg.cond(A, 2), linalg.cond(A[None, ...], 2)[0]) class TestCondInf(object): |