diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-01-02 23:04:34 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-02-11 21:08:58 +0000 |
commit | 52988ea98120c10a927325401257eb3715a51b15 (patch) | |
tree | f78899a7943bf381c34e78561a3826b396ced71d /numpy/lib/tests/test_shape_base.py | |
parent | 1b483c2398ce2de49f72b9005dfb92859e389914 (diff) | |
download | numpy-52988ea98120c10a927325401257eb3715a51b15.tar.gz |
BUG: Fix crash on 0d return value in apply_along_axis
Also:
ENH: Support arbitrary dimensionality of return value
MAINT: remove special casing
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r-- | numpy/lib/tests/test_shape_base.py | 68 |
1 files changed, 62 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py index a716d3b38..7bf2b4a81 100644 --- a/numpy/lib/tests/test_shape_base.py +++ b/numpy/lib/tests/test_shape_base.py @@ -32,7 +32,7 @@ class TestApplyAlongAxis(TestCase): return row * 2 m = np.matrix([[0, 1], [2, 3]]) result = apply_along_axis(double, 0, m) - assert isinstance(result, np.matrix) + assert_(isinstance(result, np.matrix)) assert_array_equal( result, np.matrix([[0, 2], [4, 6]]) ) @@ -50,13 +50,69 @@ class TestApplyAlongAxis(TestCase): apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1]) ) - def test_scalar_array(self): + def test_scalar_array(self, cls=np.ndarray): + a = np.ones((6, 3)).view(cls) + res = apply_along_axis(np.sum, 0, a) + assert_(isinstance(res, cls)) + assert_array_equal(res, np.array([6, 6, 6]).view(cls)) + + def test_0d_array(self, cls=np.ndarray): + def sum_to_0d(x): + """ Sum x, returning a 0d array of the same class """ + assert_equal(x.ndim, 1) + return np.squeeze(np.sum(x, keepdims=True)) + a = np.ones((6, 3)).view(cls) + res = apply_along_axis(sum_to_0d, 0, a) + assert_(isinstance(res, cls)) + assert_array_equal(res, np.array([6, 6, 6]).view(cls)) + + res = apply_along_axis(sum_to_0d, 1, a) + assert_(isinstance(res, cls)) + assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls)) + + def test_axis_insertion(self, cls=np.ndarray): + def f1to2(x): + """produces an assymmetric non-square matrix from x""" + assert_equal(x.ndim, 1) + return (x[::-1] * x[1:,None]).view(cls) + + a2d = np.arange(6*3).reshape((6, 3)) + + # 2d insertion along first axis + actual = apply_along_axis(f1to2, 0, a2d) + expected = np.stack([ + f1to2(a2d[:,i]) for i in range(a2d.shape[1]) + ], axis=-1).view(cls) + assert_equal(type(actual), type(expected)) + assert_equal(actual, expected) + + # 2d insertion along last axis + actual = apply_along_axis(f1to2, 1, a2d) + expected = np.stack([ + f1to2(a2d[i,:]) for i in range(a2d.shape[0]) + ], axis=0).view(cls) + assert_equal(type(actual), type(expected)) + assert_equal(actual, expected) + + # 3d insertion along middle axis + a3d = np.arange(6*5*3).reshape((6, 5, 3)) + + actual = apply_along_axis(f1to2, 1, a3d) + expected = np.stack([ + np.stack([ + f1to2(a3d[i,:,j]) for i in range(a3d.shape[0]) + ], axis=0) + for j in range(a3d.shape[2]) + ], axis=-1).view(cls) + assert_equal(type(actual), type(expected)) + assert_equal(actual, expected) + + def test_subclass_preservation(self): class MinimalSubclass(np.ndarray): pass - a = np.ones((6, 3)).view(MinimalSubclass) - res = apply_along_axis(np.sum, 0, a) - assert isinstance(res, MinimalSubclass) - assert_array_equal(res, np.array([6, 6, 6]).view(MinimalSubclass)) + self.test_scalar_array(MinimalSubclass) + self.test_0d_array(MinimalSubclass) + self.test_axis_insertion(MinimalSubclass) def test_tuple_func1d(self): def sample_1d(x): |