summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/linalg/tests/test_linalg.py43
1 files changed, 23 insertions, 20 deletions
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index 29e1f3480..ca59aa566 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -882,26 +882,29 @@ class _TestNorm(object):
# Matrix norms.
B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
-
- for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro', 'nuc']:
- assert_almost_equal(norm(A, ord=order), norm(A, ord=order,
- axis=(0, 1)))
-
- n = norm(B, ord=order, axis=(1, 2))
- expected = [norm(B[k], ord=order) for k in range(B.shape[0])]
- assert_almost_equal(n, expected)
-
- n = norm(B, ord=order, axis=(2, 1))
- expected = [norm(B[k].T, ord=order) for k in range(B.shape[0])]
- assert_almost_equal(n, expected)
-
- n = norm(B, ord=order, axis=(0, 2))
- expected = [norm(B[:, k,:], ord=order) for k in range(B.shape[1])]
- assert_almost_equal(n, expected)
-
- n = norm(B, ord=order, axis=(0, 1))
- expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])]
- assert_almost_equal(n, expected)
+ nd = B.ndim
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
+ for axis in itertools.combinations(range(-nd, nd), 2):
+ row_axis, col_axis = axis
+ if row_axis < 0:
+ row_axis += nd
+ if col_axis < 0:
+ col_axis += nd
+ if row_axis == col_axis:
+ assert_raises(ValueError, norm, B, ord=order, axis=axis)
+ else:
+ n = norm(B, ord=order, axis=axis)
+
+ # The logic using k_index only works for nd = 3.
+ # This has to be changed if nd is increased.
+ k_index = nd - (row_axis + col_axis)
+ if row_axis < col_axis:
+ expected = [norm(B[:].take(k, axis=k_index), ord=order)
+ for k in range(B.shape[k_index])]
+ else:
+ expected = [norm(B[:].take(k, axis=k_index).T, ord=order)
+ for k in range(B.shape[k_index])]
+ assert_almost_equal(n, expected)
def test_keepdims(self):
A = np.arange(1,25, dtype=self.dt).reshape(2,3,4)