summaryrefslogtreecommitdiff
path: root/numpy/lib/function_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/function_base.py')
-rw-r--r--numpy/lib/function_base.py32
1 files changed, 12 insertions, 20 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 8953eebd5..00bfab6ba 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -651,7 +651,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
The output is the same shape and type as x and is found by
calling the functions in `funclist` on the appropriate portions of `x`,
as defined by the boolean arrays in `condlist`. Portions not covered
- by any condition have undefined values.
+ by any condition have a default value of 0.
See Also
@@ -693,32 +693,24 @@ def piecewise(x, condlist, funclist, *args, **kw):
if (isscalar(condlist) or not (isinstance(condlist[0], list) or
isinstance(condlist[0], ndarray))):
condlist = [condlist]
- condlist = [asarray(c, dtype=bool) for c in condlist]
+ condlist = array(condlist, dtype=bool)
n = len(condlist)
- if n == n2 - 1: # compute the "otherwise" condition.
- totlist = condlist[0]
- for k in range(1, n):
- totlist |= condlist[k]
- condlist.append(~totlist)
- n += 1
- if (n != n2):
- raise ValueError(
- "function list and condition list must be the same")
- zerod = False
# This is a hack to work around problems with NumPy's
# handling of 0-d arrays and boolean indexing with
# numpy.bool_ scalars
+ zerod = False
if x.ndim == 0:
x = x[None]
zerod = True
- newcondlist = []
- for k in range(n):
- if condlist[k].ndim == 0:
- condition = condlist[k][None]
- else:
- condition = condlist[k]
- newcondlist.append(condition)
- condlist = newcondlist
+ if condlist.shape[-1] != 1:
+ condlist = condlist.T
+ if n == n2 - 1: # compute the "otherwise" condition.
+ totlist = np.logical_or.reduce(condlist, axis=0)
+ condlist = np.vstack([condlist, ~totlist])
+ n += 1
+ if (n != n2):
+ raise ValueError(
+ "function list and condition list must be the same")
y = zeros(x.shape, x.dtype)
for k in range(n):