diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-02-28 14:03:27 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-05-25 22:55:58 -0700 |
commit | 905e906d55fdcb8cc215de8aa287ea9654d1c95c (patch) | |
tree | ba7d36006b940a87d1dff93a1a8965f7f986c2ab /numpy/lib/tests/test_shape_base.py | |
parent | 84f582f25e168dbfd59b3be470bc8ebc46ee2d92 (diff) | |
download | numpy-905e906d55fdcb8cc215de8aa287ea9654d1c95c.tar.gz |
ENH: Add (put|take)_along_axis as described in #8708
This is the reduced version that does not allow any insertion of extra dimensions
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 92 |
1 files changed, 91 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index a35d90b70..c95894f94 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -2,16 +2,106 @@ from __future__ import division, absolute_import, print_function import numpy as np import warnings +import functools from numpy.lib.shape_base import ( apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit, - vsplit, dstack, column_stack, kron, tile, expand_dims, + vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis, + put_along_axis ) from numpy.testing import ( assert_, assert_equal, assert_array_equal, assert_raises, assert_warns ) +def _add_keepdims(func): + """ hack in keepdims behavior into a function taking an axis """ + @functools.wraps(func) + def wrapped(a, axis, **kwargs): + res = func(a, axis=axis, **kwargs) + if axis is None: + axis = 0 # res is now a scalar, so we can insert this anywhere + return np.expand_dims(res, axis=axis) + return wrapped + + +class TestTakeAlongAxis(object): + def test_argequivalent(self): + """ Test it translates from arg<func> to <func> """ + from numpy.random import rand + a = rand(3, 4, 5) + + funcs = [ + (np.sort, np.argsort, dict()), + (_add_keepdims(np.min), _add_keepdims(np.argmin), dict()), + (_add_keepdims(np.max), _add_keepdims(np.argmax), dict()), + (np.partition, np.argpartition, dict(kth=2)), + ] + + for func, argfunc, kwargs in funcs: + for axis in list(range(a.ndim)) + [None]: + a_func = func(a, axis=axis, **kwargs) + ai_func = argfunc(a, axis=axis, **kwargs) + assert_equal(a_func, take_along_axis(a, ai_func, axis=axis)) + + def test_invalid(self): + """ Test it errors when indices has too few dimensions """ + a = np.ones((10, 10)) + ai = np.ones((10, 2), dtype=np.intp) + + # sanity check + take_along_axis(a, ai, axis=1) + + # not enough indices + assert_raises(ValueError, take_along_axis, a, np.array(1), axis=1) + # bool arrays not allowed + assert_raises(IndexError, take_along_axis, a, ai.astype(bool), axis=1) + # float arrays not allowed + assert_raises(IndexError, take_along_axis, a, ai.astype(float), axis=1) + # invalid axis + assert_raises(np.AxisError, take_along_axis, a, ai, axis=10) + + def test_empty(self): + """ Test everything is ok with empty results, even with inserted dims """ + a = np.ones((3, 4, 5)) + ai = np.ones((3, 0, 5), dtype=np.intp) + + actual = take_along_axis(a, ai, axis=1) + assert_equal(actual.shape, ai.shape) + + def test_broadcast(self): + """ Test that non-indexing dimensions are broadcast in both directions """ + a = np.ones((3, 4, 1)) + ai = np.ones((1, 2, 5), dtype=np.intp) + actual = take_along_axis(a, ai, axis=1) + assert_equal(actual.shape, (3, 2, 5)) + + +class TestPutAlongAxis(object): + def test_replace_max(self): + a_base = np.array([[10, 30, 20], [60, 40, 50]]) + + for axis in list(range(a_base.ndim)) + [None]: + # we mutate this in the loop + a = a_base.copy() + + # replace the max with a small value + i_max = _add_keepdims(np.argmax)(a, axis=axis) + put_along_axis(a, i_max, -99, axis=axis) + + # find the new minimum, which should max + i_min = _add_keepdims(np.argmin)(a, axis=axis) + + assert_equal(i_min, i_max) + + def test_broadcast(self): + """ Test that non-indexing dimensions are broadcast in both directions """ + a = np.ones((3, 4, 1)) + ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4 + put_along_axis(a, ai, 20, axis=1) + assert_equal(take_along_axis(a, ai, axis=1), 20) + + class TestApplyAlongAxis(object): def test_simple(self): a = np.ones((20, 10), 'd') |