diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2014-03-24 10:20:28 -0600 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2014-03-24 10:20:28 -0600 |
commit | 3a4030c650a0510b8e673f34464f4ef64212b022 (patch) | |
tree | 1a3d33ac59bec9acf2fcbb657c0a9b79f71e93ae /numpy/lib | |
parent | a0bbdcfda546a112b36353f8b5c6dd2c0e07f916 (diff) | |
parent | 123b319be37f01e3c4f2e42552d4ca121b27ca38 (diff) | |
download | numpy-3a4030c650a0510b8e673f34464f4ef64212b022.tar.gz |
Merge pull request #4358 from seberg/fast-select
ENH: Speed improvements and deprecations for np.select
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/function_base.py | 90 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 63 |
2 files changed, 121 insertions, 32 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index edce15776..df5876715 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -771,29 +771,68 @@ def select(condlist, choicelist, default=0): array([ 0, 1, 2, 0, 0, 0, 36, 49, 64, 81]) """ - n = len(condlist) - n2 = len(choicelist) - if n2 != n: + # Check the size of condlist and choicelist are the same, or abort. + if len(condlist) != len(choicelist): raise ValueError( - "list of cases must be same length as list of conditions") - choicelist = [default] + choicelist - S = 0 - pfac = 1 - for k in range(1, n+1): - S += k * pfac * asarray(condlist[k-1]) - if k < n: - pfac *= (1-asarray(condlist[k-1])) - # handle special case of a 1-element condition but - # a multi-element choice - if type(S) in ScalarType or max(asarray(S).shape) == 1: - pfac = asarray(1) - for k in range(n2+1): - pfac = pfac + asarray(choicelist[k]) - if type(S) in ScalarType: - S = S*ones(asarray(pfac).shape, type(S)) - else: - S = S*ones(asarray(pfac).shape, S.dtype) - return choose(S, tuple(choicelist)) + 'list of cases must be same length as list of conditions') + + # Now that the dtype is known, handle the deprecated select([], []) case + if len(condlist) == 0: + warnings.warn("select with an empty condition list is not possible" + "and will be deprecated", + DeprecationWarning) + return np.asarray(default)[()] + + choicelist = [np.asarray(choice) for choice in choicelist] + choicelist.append(np.asarray(default)) + + # need to get the result type before broadcasting for correct scalar + # behaviour + dtype = np.result_type(*choicelist) + + # Convert conditions to arrays and broadcast conditions and choices + # as the shape is needed for the result. Doing it seperatly optimizes + # for example when all choices are scalars. + condlist = np.broadcast_arrays(*condlist) + choicelist = np.broadcast_arrays(*choicelist) + + # If cond array is not an ndarray in boolean format or scalar bool, abort. + deprecated_ints = False + for i in range(len(condlist)): + cond = condlist[i] + if cond.dtype.type is not np.bool_: + if np.issubdtype(cond.dtype, np.integer): + # A previous implementation accepted int ndarrays accidentally. + # Supported here deliberately, but deprecated. + condlist[i] = condlist[i].astype(bool) + deprecated_ints = True + else: + raise ValueError( + 'invalid entry in choicelist: should be boolean ndarray') + + if deprecated_ints: + msg = "select condlists containing integer ndarrays is deprecated " \ + "and will be removed in the future. Use `.astype(bool)` to " \ + "convert to bools." + warnings.warn(msg, DeprecationWarning) + + if choicelist[0].ndim == 0: + # This may be common, so avoid the call. + result_shape = condlist[0].shape + else: + result_shape = np.broadcast_arrays(condlist[0], choicelist[0])[0].shape + + result = np.full(result_shape, choicelist[-1], dtype) + + # Use np.copyto to burn each choicelist array onto result, using the + # corresponding condlist as a boolean mask. This is done in reverse + # order since the first choice should take precedence. + choicelist = choicelist[-2::-1] + condlist = condlist[::-1] + for choice, cond in zip(choicelist, condlist): + np.copyto(result, choice, where=cond) + + return result def copy(a, order='K'): @@ -3240,7 +3279,7 @@ def meshgrid(*xi, **kwargs): Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector fields over N-D grids, given one-dimensional coordinate arrays x1, x2,..., xn. - + .. versionchanged:: 1.9 1-D and 0-D cases are allowed. @@ -3291,9 +3330,8 @@ def meshgrid(*xi, **kwargs): for i in range(nx): for j in range(ny): # treat xv[j,i], yv[j,i] - - In the 1-D and 0-D case, the indexing and sparse keywords have no - effect. + + In the 1-D and 0-D case, the indexing and sparse keywords have no effect. See Also -------- diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 9a26ce5a3..399a5a308 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -150,6 +150,13 @@ class TestAverage(TestCase): class TestSelect(TestCase): + choices = [np.array([1, 2, 3]), + np.array([4, 5, 6]), + np.array([7, 8, 9])] + conditions = [np.array([False, False, False]), + np.array([False, True, False]), + np.array([False, False, True])] + def _select(self, cond, values, default=0): output = [] for m in range(len(cond)): @@ -157,18 +164,62 @@ class TestSelect(TestCase): return output def test_basic(self): - choices = [np.array([1, 2, 3]), - np.array([4, 5, 6]), - np.array([7, 8, 9])] - conditions = [np.array([0, 0, 0]), - np.array([0, 1, 0]), - np.array([0, 0, 1])] + choices = self.choices + conditions = self.conditions assert_array_equal(select(conditions, choices, default=15), self._select(conditions, choices, default=15)) assert_equal(len(choices), 3) assert_equal(len(conditions), 3) + def test_broadcasting(self): + conditions = [np.array(True), np.array([False, True, False])] + choices = [1, np.arange(12).reshape(4, 3)] + assert_array_equal(select(conditions, choices), np.ones((4, 3))) + # default can broadcast too: + assert_equal(select([True], [0], default=[0]).shape, (1,)) + + def test_return_dtype(self): + assert_equal(select(self.conditions, self.choices, 1j).dtype, + np.complex_) + # But the conditions need to be stronger then the scalar default + # if it is scalar. + choices = [choice.astype(np.int8) for choice in self.choices] + assert_equal(select(self.conditions, choices).dtype, np.int8) + + d = np.array([1, 2, 3, np.nan, 5, 7]) + m = np.isnan(d) + assert_equal(select([m], [d]), [0, 0, 0, np.nan, 0, 0]) + + def test_deprecated_empty(self): + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + assert_equal(select([], [], 3j), 3j) + + with warnings.catch_warnings(): + warnings.simplefilter("always") + assert_warns(DeprecationWarning, select, [], []) + warnings.simplefilter("error") + assert_raises(DeprecationWarning, select, [], []) + + def test_non_bool_deprecation(self): + choices = self.choices + conditions = self.conditions[:] + with warnings.catch_warnings(): + warnings.filterwarnings("always") + conditions[0] = conditions[0].astype(np.int_) + assert_warns(DeprecationWarning, select, conditions, choices) + conditions[0] = conditions[0].astype(np.uint8) + assert_warns(DeprecationWarning, select, conditions, choices) + warnings.filterwarnings("error") + assert_raises(DeprecationWarning, select, conditions, choices) + + def test_many_arguments(self): + # This used to be limited by NPY_MAXARGS == 32 + conditions = [np.array([False])] * 100 + choices = [np.array([1])] * 100 + select(conditions, choices) + class TestInsert(TestCase): def test_basic(self): |