diff options
author | Eric Moore <ewm@redtetrahedron.org> | 2014-10-16 11:59:03 -0400 |
---|---|---|
committer | Eric Moore <ewm@redtetrahedron.org> | 2014-10-17 10:50:43 -0400 |
commit | 9b152aaaf3abd0f98b7a88ed66df7518a6a6c85b (patch) | |
tree | 7d6cec3fd0573068497a478ec0522dc4ccb602b4 /numpy/linalg/linalg.py | |
parent | 51f0976c1ca101a01d09e26ee5dfea5360f73c63 (diff) | |
download | numpy-9b152aaaf3abd0f98b7a88ed66df7518a6a6c85b.tar.gz |
ENH: Add keepdims to linalg.norm
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r-- | numpy/linalg/linalg.py | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 6b2299fe7..e8f7a8ab1 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -1921,7 +1921,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op): return result -def norm(x, ord=None, axis=None): +def norm(x, ord=None, axis=None, keepdims=False): """ Matrix or vector norm. @@ -1942,6 +1942,11 @@ def norm(x, ord=None, axis=None): axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If `axis` is None then either a vector norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned. + keepdims : bool, optional + .. versionadded:: 1.10.0 + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `x`. Returns ------- @@ -2053,12 +2058,16 @@ def norm(x, ord=None, axis=None): # Check the default case first and handle it immediately. if ord is None and axis is None: + ndim = x.ndim x = x.ravel(order='K') if isComplexType(x.dtype.type): sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag) else: sqnorm = dot(x, x) - return sqrt(sqnorm) + ret = sqrt(sqnorm) + if keepdims: + ret = ret.reshape(ndim*[1]) + return ret # Normalize the `axis` argument to a tuple. nd = x.ndim @@ -2069,19 +2078,19 @@ def norm(x, ord=None, axis=None): if len(axis) == 1: if ord == Inf: - return abs(x).max(axis=axis) + return abs(x).max(axis=axis, keepdims=keepdims) elif ord == -Inf: - return abs(x).min(axis=axis) + return abs(x).min(axis=axis, keepdims=keepdims) elif ord == 0: # Zero norm - return (x != 0).sum(axis=axis) + return (x != 0).sum(axis=axis, keepdims=keepdims) elif ord == 1: # special case for speedup - return add.reduce(abs(x), axis=axis) + return add.reduce(abs(x), axis=axis, keepdims=keepdims) elif ord is None or ord == 2: # special case for speedup s = (x.conj() * x).real - return sqrt(add.reduce(s, axis=axis)) + return sqrt(add.reduce(s, axis=axis, keepdims=keepdims)) else: try: ord + 1 @@ -2100,7 +2109,7 @@ def norm(x, ord=None, axis=None): # if the type changed, we can safely overwrite absx abs(absx, out=absx) absx **= ord - return add.reduce(absx, axis=axis) ** (1.0 / ord) + return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord) elif len(axis) == 2: row_axis, col_axis = axis if not (-nd <= row_axis < nd and -nd <= col_axis < nd): @@ -2109,28 +2118,34 @@ def norm(x, ord=None, axis=None): if row_axis % nd == col_axis % nd: raise ValueError('Duplicate axes given.') if ord == 2: - return _multi_svd_norm(x, row_axis, col_axis, amax) + ret = _multi_svd_norm(x, row_axis, col_axis, amax) elif ord == -2: - return _multi_svd_norm(x, row_axis, col_axis, amin) + ret = _multi_svd_norm(x, row_axis, col_axis, amin) elif ord == 1: if col_axis > row_axis: col_axis -= 1 - return add.reduce(abs(x), axis=row_axis).max(axis=col_axis) + ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis) elif ord == Inf: if row_axis > col_axis: row_axis -= 1 - return add.reduce(abs(x), axis=col_axis).max(axis=row_axis) + ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis) elif ord == -1: if col_axis > row_axis: col_axis -= 1 - return add.reduce(abs(x), axis=row_axis).min(axis=col_axis) + ret = add.reduce(abs(x), axis=row_axis).min(axis=col_axis) elif ord == -Inf: if row_axis > col_axis: row_axis -= 1 - return add.reduce(abs(x), axis=col_axis).min(axis=row_axis) + ret = add.reduce(abs(x), axis=col_axis).min(axis=row_axis) elif ord in [None, 'fro', 'f']: - return sqrt(add.reduce((x.conj() * x).real, axis=axis)) + ret = sqrt(add.reduce((x.conj() * x).real, axis=axis)) else: raise ValueError("Invalid norm order for matrices.") + if keepdims: + ret_shape = list(x.shape) + ret_shape[axis[0]] = 1 + ret_shape[axis[1]] = 1 + ret = ret.reshape(ret_shape) + return ret else: raise ValueError("Improper number of dimensions to norm.") |