summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_shape_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-02-28 14:03:27 +0000
committerEric Wieser <wieser.eric@gmail.com>2018-05-25 22:55:58 -0700
commit905e906d55fdcb8cc215de8aa287ea9654d1c95c (patch)
treeba7d36006b940a87d1dff93a1a8965f7f986c2ab /numpy/lib/tests/test_shape_base.py
parent84f582f25e168dbfd59b3be470bc8ebc46ee2d92 (diff)
downloadnumpy-905e906d55fdcb8cc215de8aa287ea9654d1c95c.tar.gz
ENH: Add (put|take)_along_axis as described in #8708
This is the reduced version that does not allow any insertion of extra dimensions
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r--numpy/lib/tests/test_shape_base.py92
1 files changed, 91 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index a35d90b70..c95894f94 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -2,16 +2,106 @@ from __future__ import division, absolute_import, print_function
import numpy as np
import warnings
+import functools
from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
- vsplit, dstack, column_stack, kron, tile, expand_dims,
+ vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis,
+ put_along_axis
)
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_raises, assert_warns
)
+def _add_keepdims(func):
+ """ hack in keepdims behavior into a function taking an axis """
+ @functools.wraps(func)
+ def wrapped(a, axis, **kwargs):
+ res = func(a, axis=axis, **kwargs)
+ if axis is None:
+ axis = 0 # res is now a scalar, so we can insert this anywhere
+ return np.expand_dims(res, axis=axis)
+ return wrapped
+
+
+class TestTakeAlongAxis(object):
+ def test_argequivalent(self):
+ """ Test it translates from arg<func> to <func> """
+ from numpy.random import rand
+ a = rand(3, 4, 5)
+
+ funcs = [
+ (np.sort, np.argsort, dict()),
+ (_add_keepdims(np.min), _add_keepdims(np.argmin), dict()),
+ (_add_keepdims(np.max), _add_keepdims(np.argmax), dict()),
+ (np.partition, np.argpartition, dict(kth=2)),
+ ]
+
+ for func, argfunc, kwargs in funcs:
+ for axis in list(range(a.ndim)) + [None]:
+ a_func = func(a, axis=axis, **kwargs)
+ ai_func = argfunc(a, axis=axis, **kwargs)
+ assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))
+
+ def test_invalid(self):
+ """ Test it errors when indices has too few dimensions """
+ a = np.ones((10, 10))
+ ai = np.ones((10, 2), dtype=np.intp)
+
+ # sanity check
+ take_along_axis(a, ai, axis=1)
+
+ # not enough indices
+ assert_raises(ValueError, take_along_axis, a, np.array(1), axis=1)
+ # bool arrays not allowed
+ assert_raises(IndexError, take_along_axis, a, ai.astype(bool), axis=1)
+ # float arrays not allowed
+ assert_raises(IndexError, take_along_axis, a, ai.astype(float), axis=1)
+ # invalid axis
+ assert_raises(np.AxisError, take_along_axis, a, ai, axis=10)
+
+ def test_empty(self):
+ """ Test everything is ok with empty results, even with inserted dims """
+ a = np.ones((3, 4, 5))
+ ai = np.ones((3, 0, 5), dtype=np.intp)
+
+ actual = take_along_axis(a, ai, axis=1)
+ assert_equal(actual.shape, ai.shape)
+
+ def test_broadcast(self):
+ """ Test that non-indexing dimensions are broadcast in both directions """
+ a = np.ones((3, 4, 1))
+ ai = np.ones((1, 2, 5), dtype=np.intp)
+ actual = take_along_axis(a, ai, axis=1)
+ assert_equal(actual.shape, (3, 2, 5))
+
+
+class TestPutAlongAxis(object):
+ def test_replace_max(self):
+ a_base = np.array([[10, 30, 20], [60, 40, 50]])
+
+ for axis in list(range(a_base.ndim)) + [None]:
+ # we mutate this in the loop
+ a = a_base.copy()
+
+ # replace the max with a small value
+ i_max = _add_keepdims(np.argmax)(a, axis=axis)
+ put_along_axis(a, i_max, -99, axis=axis)
+
+ # find the new minimum, which should max
+ i_min = _add_keepdims(np.argmin)(a, axis=axis)
+
+ assert_equal(i_min, i_max)
+
+ def test_broadcast(self):
+ """ Test that non-indexing dimensions are broadcast in both directions """
+ a = np.ones((3, 4, 1))
+ ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4
+ put_along_axis(a, ai, 20, axis=1)
+ assert_equal(take_along_axis(a, ai, axis=1), 20)
+
+
class TestApplyAlongAxis(object):
def test_simple(self):
a = np.ones((20, 10), 'd')