diff options
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 4c4e51748..760ca31b0 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2029,16 +2029,15 @@ def norm(x, ord=None, axis=None): # Check the default case first and handle it immediately. if ord is None and axis is None: - s = (x.conj() * x).real - return sqrt(add.reduce((x.conj() * x).ravel().real)) + return sqrt(add.reduce((x.conj() * x).real, axis=None)) # Normalize the `axis` argument to a tuple. + nd = x.ndim if axis is None: - axis = tuple(range(x.ndim)) + axis = tuple(range(nd)) elif not isinstance(axis, tuple): axis = (axis,) - nd = x.ndim if len(axis) == 1: if ord == Inf: return abs(x).max(axis=axis) @@ -2067,11 +2066,10 @@ def norm(x, ord=None, axis=None): return add.reduce(absx**ord, axis=axis)**(1.0/ord) elif len(axis) == 2: row_axis, col_axis = axis - if not (-x.ndim <= row_axis < x.ndim and - -x.ndim <= col_axis < x.ndim): + if not (-nd <= row_axis < nd and -nd <= col_axis < nd): raise ValueError('Invalid axis %r for an array with shape %r' % (axis, x.shape)) - if row_axis % x.ndim == col_axis % x.ndim: + if row_axis % nd == col_axis % nd: raise ValueError('Duplicate axes given.') if ord == 2: return _multi_svd_norm(x, row_axis, col_axis, amax) |