diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2017-10-16 09:36:32 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-10-16 09:36:32 -0600 |
commit | 36e716adf3647d2c6ec7703a52a39f38718cee78 (patch) | |
tree | 5cb5c2e44eb8a12cdca32ea319d3631831f5396b /numpy/core/numeric.py | |
parent | 075b162fbbcde754bef4ce711fb118789df6e026 (diff) | |
parent | fb168b8a5ee222ff352d20bfc1efab9009d68347 (diff) | |
download | numpy-36e716adf3647d2c6ec7703a52a39f38718cee78.tar.gz |
Merge pull request #9849 from eric-wieser/cleanup-count_nonzero
MAINT: Fix all special-casing of dtypes in `count_nonzero`
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 32 |
1 files changed, 6 insertions, 26 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 5b10361fe..6d29785da 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -411,33 +411,13 @@ def count_nonzero(a, axis=None): a = asanyarray(a) - if a.dtype == bool: - return a.sum(axis=axis, dtype=np.intp) - - if issubdtype(a.dtype, np.number): - return (a != 0).sum(axis=axis, dtype=np.intp) - - if issubdtype(a.dtype, np.character): - nullstr = a.dtype.type('') - return (a != nullstr).sum(axis=axis, dtype=np.intp) - - axis = asarray(normalize_axis_tuple(axis, a.ndim)) - counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) - - if axis.size == 1: - return counts.astype(np.intp, copy=False) + # TODO: this works around .astype(bool) not working properly (gh-9847) + if np.issubdtype(a.dtype, np.character): + a_bool = a != a.dtype.type() else: - # for subsequent axis numbers, that number decreases - # by one in this new 'counts' array if it was larger - # than the first axis upon which 'count_nonzero' was - # applied but remains unchanged if that number was - # smaller than that first axis - # - # this trick enables us to perform counts on object-like - # elements across multiple axes very quickly because integer - # addition is very well optimized - return counts.sum(axis=tuple(axis[1:] - ( - axis[1:] > axis[0])), dtype=np.intp) + a_bool = a.astype(np.bool_, copy=False) + + return a_bool.sum(axis=axis, dtype=np.intp) def asarray(a, dtype=None, order=None): |