diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/linalg/linalg.py | 8 | ||||
-rw-r--r-- | numpy/linalg/tests/test_linalg.py | 43 |
2 files changed, 29 insertions, 22 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 23b493447..30180f24a 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2146,10 +2146,14 @@ def norm(x, ord=None, axis=None, keepdims=False): return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord) elif len(axis) == 2: row_axis, col_axis = axis - if not (-nd <= row_axis < nd and -nd <= col_axis < nd): + if row_axis < 0: + row_axis += nd + if col_axis < 0: + col_axis += nd + if not (0 <= row_axis < nd and 0 <= col_axis < nd): raise ValueError('Invalid axis %r for an array with shape %r' % (axis, x.shape)) - if row_axis % nd == col_axis % nd: + if row_axis == col_axis: raise ValueError('Duplicate axes given.') if ord == 2: ret = _multi_svd_norm(x, row_axis, col_axis, amax) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 29e1f3480..ca59aa566 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -882,26 +882,29 @@ class _TestNorm(object): # Matrix norms. B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4) - - for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro', 'nuc']: - assert_almost_equal(norm(A, ord=order), norm(A, ord=order, - axis=(0, 1))) - - n = norm(B, ord=order, axis=(1, 2)) - expected = [norm(B[k], ord=order) for k in range(B.shape[0])] - assert_almost_equal(n, expected) - - n = norm(B, ord=order, axis=(2, 1)) - expected = [norm(B[k].T, ord=order) for k in range(B.shape[0])] - assert_almost_equal(n, expected) - - n = norm(B, ord=order, axis=(0, 2)) - expected = [norm(B[:, k,:], ord=order) for k in range(B.shape[1])] - assert_almost_equal(n, expected) - - n = norm(B, ord=order, axis=(0, 1)) - expected = [norm(B[:,:, k], ord=order) for k in range(B.shape[2])] - assert_almost_equal(n, expected) + nd = B.ndim + for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']: + for axis in itertools.combinations(range(-nd, nd), 2): + row_axis, col_axis = axis + if row_axis < 0: + row_axis += nd + if col_axis < 0: + col_axis += nd + if row_axis == col_axis: + assert_raises(ValueError, norm, B, ord=order, axis=axis) + else: + n = norm(B, ord=order, axis=axis) + + # The logic using k_index only works for nd = 3. + # This has to be changed if nd is increased. + k_index = nd - (row_axis + col_axis) + if row_axis < col_axis: + expected = [norm(B[:].take(k, axis=k_index), ord=order) + for k in range(B.shape[k_index])] + else: + expected = [norm(B[:].take(k, axis=k_index).T, ord=order) + for k in range(B.shape[k_index])] + assert_almost_equal(n, expected) def test_keepdims(self): A = np.arange(1,25, dtype=self.dt).reshape(2,3,4) |