summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index e97338106..2bb11bf1e 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -7,7 +7,7 @@ import numpy.core.numeric as _nx
from numpy.core.numeric import (
asarray, ScalarType, array, alltrue, cumprod, arange
)
-from numpy.core.numerictypes import find_common_type
+from numpy.core.numerictypes import find_common_type, issubdtype
from . import function_base
import numpy.matrixlib as matrix
@@ -71,17 +71,15 @@ def ix_(*args):
"""
out = []
nd = len(args)
- baseshape = [1]*nd
- for k in range(nd):
- new = _nx.asarray(args[k])
+ for k, new in enumerate(args):
+ # Explicitly type empty sequences to avoid float default
+ new = asarray(new, dtype=None if new else _nx.intp)
if (new.ndim != 1):
raise ValueError("Cross index must be 1 dimensional")
- if issubclass(new.dtype.type, _nx.bool_):
- new = new.nonzero()[0]
- baseshape[k] = len(new)
- new = new.reshape(tuple(baseshape))
+ if issubdtype(new.dtype, _nx.bool_):
+ new, = new.nonzero()
+ new.shape = (1,)*k + (new.size,) + (1,)*(nd-k-1)
out.append(new)
- baseshape[k] = 1
return tuple(out)
class nd_grid(object):