summaryrefslogtreecommitdiff
path: root/numpy/linalg/linalg.py
diff options
context:
space:
mode:
authorEric Moore <ewm@redtetrahedron.org>2014-10-16 11:59:03 -0400
committerEric Moore <ewm@redtetrahedron.org>2014-10-17 10:50:43 -0400
commit9b152aaaf3abd0f98b7a88ed66df7518a6a6c85b (patch)
tree7d6cec3fd0573068497a478ec0522dc4ccb602b4 /numpy/linalg/linalg.py
parent51f0976c1ca101a01d09e26ee5dfea5360f73c63 (diff)
downloadnumpy-9b152aaaf3abd0f98b7a88ed66df7518a6a6c85b.tar.gz
ENH: Add keepdims to linalg.norm
Diffstat (limited to 'numpy/linalg/linalg.py')
-rw-r--r--numpy/linalg/linalg.py45
1 files changed, 30 insertions, 15 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 6b2299fe7..e8f7a8ab1 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1921,7 +1921,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op):
return result
-def norm(x, ord=None, axis=None):
+def norm(x, ord=None, axis=None, keepdims=False):
"""
Matrix or vector norm.
@@ -1942,6 +1942,11 @@ def norm(x, ord=None, axis=None):
axes that hold 2-D matrices, and the matrix norms of these matrices
are computed. If `axis` is None then either a vector norm (when `x`
is 1-D) or a matrix norm (when `x` is 2-D) is returned.
+ keepdims : bool, optional
+ .. versionadded:: 1.10.0
+ If this is set to True, the axes which are normed over are left in the
+ result as dimensions with size one. With this option the result will
+ broadcast correctly against the original `x`.
Returns
-------
@@ -2053,12 +2058,16 @@ def norm(x, ord=None, axis=None):
# Check the default case first and handle it immediately.
if ord is None and axis is None:
+ ndim = x.ndim
x = x.ravel(order='K')
if isComplexType(x.dtype.type):
sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag)
else:
sqnorm = dot(x, x)
- return sqrt(sqnorm)
+ ret = sqrt(sqnorm)
+ if keepdims:
+ ret = ret.reshape(ndim*[1])
+ return ret
# Normalize the `axis` argument to a tuple.
nd = x.ndim
@@ -2069,19 +2078,19 @@ def norm(x, ord=None, axis=None):
if len(axis) == 1:
if ord == Inf:
- return abs(x).max(axis=axis)
+ return abs(x).max(axis=axis, keepdims=keepdims)
elif ord == -Inf:
- return abs(x).min(axis=axis)
+ return abs(x).min(axis=axis, keepdims=keepdims)
elif ord == 0:
# Zero norm
- return (x != 0).sum(axis=axis)
+ return (x != 0).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
# special case for speedup
- return add.reduce(abs(x), axis=axis)
+ return add.reduce(abs(x), axis=axis, keepdims=keepdims)
elif ord is None or ord == 2:
# special case for speedup
s = (x.conj() * x).real
- return sqrt(add.reduce(s, axis=axis))
+ return sqrt(add.reduce(s, axis=axis, keepdims=keepdims))
else:
try:
ord + 1
@@ -2100,7 +2109,7 @@ def norm(x, ord=None, axis=None):
# if the type changed, we can safely overwrite absx
abs(absx, out=absx)
absx **= ord
- return add.reduce(absx, axis=axis) ** (1.0 / ord)
+ return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord)
elif len(axis) == 2:
row_axis, col_axis = axis
if not (-nd <= row_axis < nd and -nd <= col_axis < nd):
@@ -2109,28 +2118,34 @@ def norm(x, ord=None, axis=None):
if row_axis % nd == col_axis % nd:
raise ValueError('Duplicate axes given.')
if ord == 2:
- return _multi_svd_norm(x, row_axis, col_axis, amax)
+ ret = _multi_svd_norm(x, row_axis, col_axis, amax)
elif ord == -2:
- return _multi_svd_norm(x, row_axis, col_axis, amin)
+ ret = _multi_svd_norm(x, row_axis, col_axis, amin)
elif ord == 1:
if col_axis > row_axis:
col_axis -= 1
- return add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
+ ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
elif ord == Inf:
if row_axis > col_axis:
row_axis -= 1
- return add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
+ ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
elif ord == -1:
if col_axis > row_axis:
col_axis -= 1
- return add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
+ ret = add.reduce(abs(x), axis=row_axis).min(axis=col_axis)
elif ord == -Inf:
if row_axis > col_axis:
row_axis -= 1
- return add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
+ ret = add.reduce(abs(x), axis=col_axis).min(axis=row_axis)
elif ord in [None, 'fro', 'f']:
- return sqrt(add.reduce((x.conj() * x).real, axis=axis))
+ ret = sqrt(add.reduce((x.conj() * x).real, axis=axis))
else:
raise ValueError("Invalid norm order for matrices.")
+ if keepdims:
+ ret_shape = list(x.shape)
+ ret_shape[axis[0]] = 1
+ ret_shape[axis[1]] = 1
+ ret = ret.reshape(ret_shape)
+ return ret
else:
raise ValueError("Improper number of dimensions to norm.")