diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index e97338106..2bb11bf1e 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,7 @@ import numpy.core.numeric as _nx from numpy.core.numeric import ( asarray, ScalarType, array, alltrue, cumprod, arange ) -from numpy.core.numerictypes import find_common_type +from numpy.core.numerictypes import find_common_type, issubdtype from . import function_base import numpy.matrixlib as matrix @@ -71,17 +71,15 @@ def ix_(*args): """ out = [] nd = len(args) - baseshape = [1]*nd - for k in range(nd): - new = _nx.asarray(args[k]) + for k, new in enumerate(args): + # Explicitly type empty sequences to avoid float default + new = asarray(new, dtype=None if new else _nx.intp) if (new.ndim != 1): raise ValueError("Cross index must be 1 dimensional") - if issubclass(new.dtype.type, _nx.bool_): - new = new.nonzero()[0] - baseshape[k] = len(new) - new = new.reshape(tuple(baseshape)) + if issubdtype(new.dtype, _nx.bool_): + new, = new.nonzero() + new.shape = (1,)*k + (new.size,) + (1,)*(nd-k-1) out.append(new) - baseshape[k] = 1 return tuple(out) class nd_grid(object): |