summaryrefslogtreecommitdiff
path: root/numpy/core/numeric.py
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2017-10-16 09:36:32 -0600
committerGitHub <noreply@github.com>2017-10-16 09:36:32 -0600
commit36e716adf3647d2c6ec7703a52a39f38718cee78 (patch)
tree5cb5c2e44eb8a12cdca32ea319d3631831f5396b /numpy/core/numeric.py
parent075b162fbbcde754bef4ce711fb118789df6e026 (diff)
parentfb168b8a5ee222ff352d20bfc1efab9009d68347 (diff)
downloadnumpy-36e716adf3647d2c6ec7703a52a39f38718cee78.tar.gz
Merge pull request #9849 from eric-wieser/cleanup-count_nonzero
MAINT: Fix all special-casing of dtypes in `count_nonzero`
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r--numpy/core/numeric.py32
1 files changed, 6 insertions, 26 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 5b10361fe..6d29785da 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -411,33 +411,13 @@ def count_nonzero(a, axis=None):
a = asanyarray(a)
- if a.dtype == bool:
- return a.sum(axis=axis, dtype=np.intp)
-
- if issubdtype(a.dtype, np.number):
- return (a != 0).sum(axis=axis, dtype=np.intp)
-
- if issubdtype(a.dtype, np.character):
- nullstr = a.dtype.type('')
- return (a != nullstr).sum(axis=axis, dtype=np.intp)
-
- axis = asarray(normalize_axis_tuple(axis, a.ndim))
- counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
-
- if axis.size == 1:
- return counts.astype(np.intp, copy=False)
+ # TODO: this works around .astype(bool) not working properly (gh-9847)
+ if np.issubdtype(a.dtype, np.character):
+ a_bool = a != a.dtype.type()
else:
- # for subsequent axis numbers, that number decreases
- # by one in this new 'counts' array if it was larger
- # than the first axis upon which 'count_nonzero' was
- # applied but remains unchanged if that number was
- # smaller than that first axis
- #
- # this trick enables us to perform counts on object-like
- # elements across multiple axes very quickly because integer
- # addition is very well optimized
- return counts.sum(axis=tuple(axis[1:] - (
- axis[1:] > axis[0])), dtype=np.intp)
+ a_bool = a.astype(np.bool_, copy=False)
+
+ return a_bool.sum(axis=axis, dtype=np.intp)
def asarray(a, dtype=None, order=None):