summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorRittaNarita <narittan@gmail.com>2015-03-29 01:31:49 +0900
committerRittaNarita <narittan@gmail.com>2015-03-30 01:05:27 +0900
commit61caa22481e43da734cd964d19e36de471c2ca60 (patch)
treee57c8357e2e256706b66bf60a72b158e2f030a90 /numpy
parentfa372af5e82637cccd494f044517b6e48b9ca456 (diff)
downloadnumpy-61caa22481e43da734cd964d19e36de471c2ca60.tar.gz
TST: Make the test for linalg matrix norms coverage complete
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/tests/test_linalg.py43
1 files changed, 23 insertions, 20 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 29e1f3480..ca59aa566 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -882,26 +882,29 @@ class _TestNorm(object):
# Matrix norms.
B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
-
- for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro', 'nuc']:
- assert_almost_equal(norm(A, ord=order), norm(A, ord=order,
- axis=(0, 1)))
-
- n = norm(B, ord=order, axis=(1, 2))
- expected = [norm(B[k], ord=order) for k in range(B.shape[0])]
- assert_almost_equal(n, expected)
-
- n = norm(B, ord=order, axis=(2, 1))
- expected = [norm(B[k].T, ord=order) for k in range(B.shape[0])]
- assert_almost_equal(n, expected)
-
- n = norm(B, ord=order, axis=(0, 2))
- expected = [norm(B[:, k,:], ord=order) for k in range(B.shape[1])]
- assert_almost_equal(n, expected)
-
- n = norm(B, ord=order, axis=(0, 1))
- expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])]
- assert_almost_equal(n, expected)
+ nd = B.ndim
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
+ for axis in itertools.combinations(range(-nd, nd), 2):
+ row_axis, col_axis = axis
+ if row_axis < 0:
+ row_axis += nd
+ if col_axis < 0:
+ col_axis += nd
+ if row_axis == col_axis:
+ assert_raises(ValueError, norm, B, ord=order, axis=axis)
+ else:
+ n = norm(B, ord=order, axis=axis)
+
+ # The logic using k_index only works for nd = 3.
+ # This has to be changed if nd is increased.
+ k_index = nd - (row_axis + col_axis)
+ if row_axis < col_axis:
+ expected = [norm(B[:].take(k, axis=k_index), ord=order)
+ for k in range(B.shape[k_index])]
+ else:
+ expected = [norm(B[:].take(k, axis=k_index).T, ord=order)
+ for k in range(B.shape[k_index])]
+ assert_almost_equal(n, expected)
def test_keepdims(self):
A = np.arange(1,25, dtype=self.dt).reshape(2,3,4)