summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_shape_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-01-02 23:04:34 +0000
committerEric Wieser <wieser.eric@gmail.com>2017-02-11 21:08:58 +0000
commit52988ea98120c10a927325401257eb3715a51b15 (patch)
treef78899a7943bf381c34e78561a3826b396ced71d /numpy/lib/tests/test_shape_base.py
parent1b483c2398ce2de49f72b9005dfb92859e389914 (diff)
downloadnumpy-52988ea98120c10a927325401257eb3715a51b15.tar.gz
BUG: Fix crash on 0d return value in apply_along_axis
Also: ENH: Support arbitrary dimensionality of return value MAINT: remove special casing
Diffstat (limited to 'numpy/lib/tests/test_shape_base.py')
-rw-r--r--numpy/lib/tests/test_shape_base.py68
1 files changed, 62 insertions, 6 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index a716d3b38..7bf2b4a81 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -32,7 +32,7 @@ class TestApplyAlongAxis(TestCase):
return row * 2
m = np.matrix([[0, 1], [2, 3]])
result = apply_along_axis(double, 0, m)
- assert isinstance(result, np.matrix)
+ assert_(isinstance(result, np.matrix))
assert_array_equal(
result, np.matrix([[0, 2], [4, 6]])
)
@@ -50,13 +50,69 @@ class TestApplyAlongAxis(TestCase):
apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1])
)
- def test_scalar_array(self):
+ def test_scalar_array(self, cls=np.ndarray):
+ a = np.ones((6, 3)).view(cls)
+ res = apply_along_axis(np.sum, 0, a)
+ assert_(isinstance(res, cls))
+ assert_array_equal(res, np.array([6, 6, 6]).view(cls))
+
+ def test_0d_array(self, cls=np.ndarray):
+ def sum_to_0d(x):
+ """ Sum x, returning a 0d array of the same class """
+ assert_equal(x.ndim, 1)
+ return np.squeeze(np.sum(x, keepdims=True))
+ a = np.ones((6, 3)).view(cls)
+ res = apply_along_axis(sum_to_0d, 0, a)
+ assert_(isinstance(res, cls))
+ assert_array_equal(res, np.array([6, 6, 6]).view(cls))
+
+ res = apply_along_axis(sum_to_0d, 1, a)
+ assert_(isinstance(res, cls))
+ assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls))
+
+ def test_axis_insertion(self, cls=np.ndarray):
+ def f1to2(x):
+ """produces an assymmetric non-square matrix from x"""
+ assert_equal(x.ndim, 1)
+ return (x[::-1] * x[1:,None]).view(cls)
+
+ a2d = np.arange(6*3).reshape((6, 3))
+
+ # 2d insertion along first axis
+ actual = apply_along_axis(f1to2, 0, a2d)
+ expected = np.stack([
+ f1to2(a2d[:,i]) for i in range(a2d.shape[1])
+ ], axis=-1).view(cls)
+ assert_equal(type(actual), type(expected))
+ assert_equal(actual, expected)
+
+ # 2d insertion along last axis
+ actual = apply_along_axis(f1to2, 1, a2d)
+ expected = np.stack([
+ f1to2(a2d[i,:]) for i in range(a2d.shape[0])
+ ], axis=0).view(cls)
+ assert_equal(type(actual), type(expected))
+ assert_equal(actual, expected)
+
+ # 3d insertion along middle axis
+ a3d = np.arange(6*5*3).reshape((6, 5, 3))
+
+ actual = apply_along_axis(f1to2, 1, a3d)
+ expected = np.stack([
+ np.stack([
+ f1to2(a3d[i,:,j]) for i in range(a3d.shape[0])
+ ], axis=0)
+ for j in range(a3d.shape[2])
+ ], axis=-1).view(cls)
+ assert_equal(type(actual), type(expected))
+ assert_equal(actual, expected)
+
+ def test_subclass_preservation(self):
class MinimalSubclass(np.ndarray):
pass
- a = np.ones((6, 3)).view(MinimalSubclass)
- res = apply_along_axis(np.sum, 0, a)
- assert isinstance(res, MinimalSubclass)
- assert_array_equal(res, np.array([6, 6, 6]).view(MinimalSubclass))
+ self.test_scalar_array(MinimalSubclass)
+ self.test_0d_array(MinimalSubclass)
+ self.test_axis_insertion(MinimalSubclass)
def test_tuple_func1d(self):
def sample_1d(x):