summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-03-24 10:20:28 -0600
committerCharles Harris <charlesr.harris@gmail.com>2014-03-24 10:20:28 -0600
commit3a4030c650a0510b8e673f34464f4ef64212b022 (patch)
tree1a3d33ac59bec9acf2fcbb657c0a9b79f71e93ae /numpy/lib
parenta0bbdcfda546a112b36353f8b5c6dd2c0e07f916 (diff)
parent123b319be37f01e3c4f2e42552d4ca121b27ca38 (diff)
downloadnumpy-3a4030c650a0510b8e673f34464f4ef64212b022.tar.gz
Merge pull request #4358 from seberg/fast-select
ENH: Speed improvements and deprecations for np.select
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py90
-rw-r--r--numpy/lib/tests/test_function_base.py63
2 files changed, 121 insertions, 32 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index edce15776..df5876715 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -771,29 +771,68 @@ def select(condlist, choicelist, default=0):
array([ 0, 1, 2, 0, 0, 0, 36, 49, 64, 81])
"""
- n = len(condlist)
- n2 = len(choicelist)
- if n2 != n:
+ # Check the size of condlist and choicelist are the same, or abort.
+ if len(condlist) != len(choicelist):
raise ValueError(
- "list of cases must be same length as list of conditions")
- choicelist = [default] + choicelist
- S = 0
- pfac = 1
- for k in range(1, n+1):
- S += k * pfac * asarray(condlist[k-1])
- if k < n:
- pfac *= (1-asarray(condlist[k-1]))
- # handle special case of a 1-element condition but
- # a multi-element choice
- if type(S) in ScalarType or max(asarray(S).shape) == 1:
- pfac = asarray(1)
- for k in range(n2+1):
- pfac = pfac + asarray(choicelist[k])
- if type(S) in ScalarType:
- S = S*ones(asarray(pfac).shape, type(S))
- else:
- S = S*ones(asarray(pfac).shape, S.dtype)
- return choose(S, tuple(choicelist))
+ 'list of cases must be same length as list of conditions')
+
+ # Now that the dtype is known, handle the deprecated select([], []) case
+ if len(condlist) == 0:
+ warnings.warn("select with an empty condition list is not possible"
+ "and will be deprecated",
+ DeprecationWarning)
+ return np.asarray(default)[()]
+
+ choicelist = [np.asarray(choice) for choice in choicelist]
+ choicelist.append(np.asarray(default))
+
+ # need to get the result type before broadcasting for correct scalar
+ # behaviour
+ dtype = np.result_type(*choicelist)
+
+ # Convert conditions to arrays and broadcast conditions and choices
+ # as the shape is needed for the result. Doing it seperatly optimizes
+ # for example when all choices are scalars.
+ condlist = np.broadcast_arrays(*condlist)
+ choicelist = np.broadcast_arrays(*choicelist)
+
+ # If cond array is not an ndarray in boolean format or scalar bool, abort.
+ deprecated_ints = False
+ for i in range(len(condlist)):
+ cond = condlist[i]
+ if cond.dtype.type is not np.bool_:
+ if np.issubdtype(cond.dtype, np.integer):
+ # A previous implementation accepted int ndarrays accidentally.
+ # Supported here deliberately, but deprecated.
+ condlist[i] = condlist[i].astype(bool)
+ deprecated_ints = True
+ else:
+ raise ValueError(
+ 'invalid entry in choicelist: should be boolean ndarray')
+
+ if deprecated_ints:
+ msg = "select condlists containing integer ndarrays is deprecated " \
+ "and will be removed in the future. Use `.astype(bool)` to " \
+ "convert to bools."
+ warnings.warn(msg, DeprecationWarning)
+
+ if choicelist[0].ndim == 0:
+ # This may be common, so avoid the call.
+ result_shape = condlist[0].shape
+ else:
+ result_shape = np.broadcast_arrays(condlist[0], choicelist[0])[0].shape
+
+ result = np.full(result_shape, choicelist[-1], dtype)
+
+ # Use np.copyto to burn each choicelist array onto result, using the
+ # corresponding condlist as a boolean mask. This is done in reverse
+ # order since the first choice should take precedence.
+ choicelist = choicelist[-2::-1]
+ condlist = condlist[::-1]
+ for choice, cond in zip(choicelist, condlist):
+ np.copyto(result, choice, where=cond)
+
+ return result
def copy(a, order='K'):
@@ -3240,7 +3279,7 @@ def meshgrid(*xi, **kwargs):
Make N-D coordinate arrays for vectorized evaluations of
N-D scalar/vector fields over N-D grids, given
one-dimensional coordinate arrays x1, x2,..., xn.
-
+
.. versionchanged:: 1.9
1-D and 0-D cases are allowed.
@@ -3291,9 +3330,8 @@ def meshgrid(*xi, **kwargs):
for i in range(nx):
for j in range(ny):
# treat xv[j,i], yv[j,i]
-
- In the 1-D and 0-D case, the indexing and sparse keywords have no
- effect.
+
+ In the 1-D and 0-D case, the indexing and sparse keywords have no effect.
See Also
--------
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 9a26ce5a3..399a5a308 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -150,6 +150,13 @@ class TestAverage(TestCase):
class TestSelect(TestCase):
+ choices = [np.array([1, 2, 3]),
+ np.array([4, 5, 6]),
+ np.array([7, 8, 9])]
+ conditions = [np.array([False, False, False]),
+ np.array([False, True, False]),
+ np.array([False, False, True])]
+
def _select(self, cond, values, default=0):
output = []
for m in range(len(cond)):
@@ -157,18 +164,62 @@ class TestSelect(TestCase):
return output
def test_basic(self):
- choices = [np.array([1, 2, 3]),
- np.array([4, 5, 6]),
- np.array([7, 8, 9])]
- conditions = [np.array([0, 0, 0]),
- np.array([0, 1, 0]),
- np.array([0, 0, 1])]
+ choices = self.choices
+ conditions = self.conditions
assert_array_equal(select(conditions, choices, default=15),
self._select(conditions, choices, default=15))
assert_equal(len(choices), 3)
assert_equal(len(conditions), 3)
+ def test_broadcasting(self):
+ conditions = [np.array(True), np.array([False, True, False])]
+ choices = [1, np.arange(12).reshape(4, 3)]
+ assert_array_equal(select(conditions, choices), np.ones((4, 3)))
+ # default can broadcast too:
+ assert_equal(select([True], [0], default=[0]).shape, (1,))
+
+ def test_return_dtype(self):
+ assert_equal(select(self.conditions, self.choices, 1j).dtype,
+ np.complex_)
+ # But the conditions need to be stronger then the scalar default
+ # if it is scalar.
+ choices = [choice.astype(np.int8) for choice in self.choices]
+ assert_equal(select(self.conditions, choices).dtype, np.int8)
+
+ d = np.array([1, 2, 3, np.nan, 5, 7])
+ m = np.isnan(d)
+ assert_equal(select([m], [d]), [0, 0, 0, np.nan, 0, 0])
+
+ def test_deprecated_empty(self):
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ assert_equal(select([], [], 3j), 3j)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("always")
+ assert_warns(DeprecationWarning, select, [], [])
+ warnings.simplefilter("error")
+ assert_raises(DeprecationWarning, select, [], [])
+
+ def test_non_bool_deprecation(self):
+ choices = self.choices
+ conditions = self.conditions[:]
+ with warnings.catch_warnings():
+ warnings.filterwarnings("always")
+ conditions[0] = conditions[0].astype(np.int_)
+ assert_warns(DeprecationWarning, select, conditions, choices)
+ conditions[0] = conditions[0].astype(np.uint8)
+ assert_warns(DeprecationWarning, select, conditions, choices)
+ warnings.filterwarnings("error")
+ assert_raises(DeprecationWarning, select, conditions, choices)
+
+ def test_many_arguments(self):
+ # This used to be limited by NPY_MAXARGS == 32
+ conditions = [np.array([False])] * 100
+ choices = [np.array([1])] * 100
+ select(conditions, choices)
+
class TestInsert(TestCase):
def test_basic(self):