summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorJaime Fernandez <jaime.frio@gmail.com>2015-04-28 22:42:50 -0700
committerJaime Fernandez <jaime.frio@gmail.com>2015-05-04 20:43:01 -0700
commitc01165f43068fea96722c172eb23efed4ca99763 (patch)
tree72096386212fa2b1c8d101e848ad26fa702433ce /numpy/lib
parentf06b1210d7171b4a452d0c9c67cde7b1a130303e (diff)
downloadnumpy-c01165f43068fea96722c172eb23efed4ca99763.tar.gz
BUG: Fix handling of non-empty ndarrays
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/index_tricks.py8
-rw-r--r--numpy/lib/tests/test_index_tricks.py20
2 files changed, 17 insertions, 11 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 2bb11bf1e..752407f18 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -72,10 +72,12 @@ def ix_(*args):
out = []
nd = len(args)
for k, new in enumerate(args):
- # Explicitly type empty sequences to avoid float default
- new = asarray(new, dtype=None if new else _nx.intp)
- if (new.ndim != 1):
+ new = asarray(new)
+ if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
+ if new.size == 0:
+ # Explicitly type empty arrays to avoid float default
+ new = new.astype(_nx.intp)
if issubdtype(new.dtype, _nx.bool_):
new, = new.nonzero()
new.shape = (1,)*k + (new.size,) + (1,)*(nd-k-1)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index fc3b90900..0e3c98ee1 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -171,17 +171,21 @@ class TestIndexExpression(TestCase):
class TestIx_(TestCase):
def test_regression_1(self):
- # Empty inputs create ouputs of indexing type, gh-5804
- a, = np.ix_(range(0, 0))
- assert_equal(a.dtype, np.intp)
+ # Test empty inputs create ouputs of indexing type, gh-5804
+ # Test both lists and arrays
+ for func in (range, np.arange):
+ a, = np.ix_(func(0))
+ assert_equal(a.dtype, np.intp)
def test_shape_and_dtype(self):
sizes = (4, 5, 3, 2)
- arrays = np.ix_(*[range(sz) for sz in sizes])
- for k, (a, sz) in enumerate(zip(arrays, sizes)):
- assert_equal(a.shape[k], sz)
- assert_(all(sh == 1 for j, sh in enumerate(a.shape) if j != k))
- assert_(np.issubdtype(a.dtype, int))
+ # Test both lists and arrays
+ for func in (range, np.arange):
+ arrays = np.ix_(*[func(sz) for sz in sizes])
+ for k, (a, sz) in enumerate(zip(arrays, sizes)):
+ assert_equal(a.shape[k], sz)
+ assert_(all(sh == 1 for j, sh in enumerate(a.shape) if j != k))
+ assert_(np.issubdtype(a.dtype, int))
def test_bool(self):
bool_a = [True, False, True, True]