summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorseberg <sebastian@sipsolutions.net>2013-01-23 02:09:35 -0800
committerseberg <sebastian@sipsolutions.net>2013-01-23 02:09:35 -0800
commitdce10183bc8f3d243bd5fc70140f5ad71179d05c (patch)
tree4c4d9ef759b460d571c8ac40648a9e5f1fae8cf7 /numpy
parent963c4e46dc56020ebea05bee10ceaa0feb61f022 (diff)
parentcde76b4d11a580e6d25eebdcce373bc5d8c850f5 (diff)
downloadnumpy-dce10183bc8f3d243bd5fc70140f5ad71179d05c.tar.gz
Merge pull request #2725 from seberg/take_0d
ENH: Allow 0-d indexes in np.take
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/fromnumeric.py9
-rw-r--r--numpy/core/src/multiarray/item_selection.c4
-rw-r--r--numpy/core/tests/test_item_selection.py43
3 files changed, 53 insertions, 3 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index d73f1313c..7596e3707 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -57,6 +57,10 @@ def take(a, indices, axis=None, out=None, mode='raise'):
The source array.
indices : array_like
The indices of the values to extract.
+
+ .. versionadded:: 1.8.0
+
+ Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
@@ -96,6 +100,11 @@ def take(a, indices, axis=None, out=None, mode='raise'):
>>> a[indices]
array([4, 3, 6])
+ If `indices` is not one dimensional, the output also has these dimensions.
+
+ >>> np.take(a, [[0, 1], [2, 3]])
+ array([[4, 3],
+ [5, 7]])
"""
try:
take = a.take
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index ab8480857..0da921b62 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -44,13 +44,11 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
}
indices = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
NPY_INTP,
- 1, 0);
+ 0, 0);
if (indices == NULL) {
goto fail;
}
-
-
n = m = chunk = 1;
nd = PyArray_NDIM(self) + PyArray_NDIM(indices) - 1;
for (i = 0; i < nd; i++) {
diff --git a/numpy/core/tests/test_item_selection.py b/numpy/core/tests/test_item_selection.py
new file mode 100644
index 000000000..f35e04c4f
--- /dev/null
+++ b/numpy/core/tests/test_item_selection.py
@@ -0,0 +1,43 @@
+import numpy as np
+from numpy.testing import *
+import sys, warnings
+
+def test_take():
+ a = [[1, 2], [3, 4]]
+ a_str = [['1','2'],['3','4']]
+ modes = ['raise', 'wrap', 'clip']
+ indices = [-1, 4]
+ index_arrays = [np.empty(0, dtype=np.intp),
+ np.empty(tuple(), dtype=np.intp),
+ np.empty((1,1), dtype=np.intp)]
+ real_indices = {}
+ real_indices['raise'] = {-1:1, 4:IndexError}
+ real_indices['wrap'] = {-1:1, 4:0}
+ real_indices['clip'] = {-1:0, 4:1}
+ # Currently all types but object, use the same function generation.
+ # So it should not be necessary to test all, but the code does support it.
+ types = np.int, np.object
+ for t in types:
+ ta = np.array(a if issubclass(t, np.number) else a_str, dtype=t)
+ tresult = list(ta.T.copy())
+ for index_array in index_arrays:
+ if index_array.size != 0:
+ tresult[0].shape = (2,) + index_array.shape
+ tresult[1].shape = (2,) + index_array.shape
+ for mode in modes:
+ for index in indices:
+ real_index = real_indices[mode][index]
+ if real_index is IndexError and index_array.size != 0:
+ index_array.put(0, index)
+ assert_raises(IndexError, ta.take, index_array,
+ mode=mode, axis=1)
+ elif index_array.size != 0:
+ index_array.put(0, index)
+ res = ta.take(index_array, mode=mode, axis=1)
+ assert_array_equal(res, tresult[real_index])
+ else:
+ res = ta.take(index_array, mode=mode, axis=1)
+ assert_(res.shape == (2,) + index_array.shape)
+
+if __name__ == "__main__":
+ run_module_suite()