summaryrefslogtreecommitdiff
path: root/numpy/lib/test_shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/test_shape_base.py')
-rw-r--r--numpy/lib/test_shape_base.py364
1 files changed, 364 insertions, 0 deletions
diff --git a/numpy/lib/test_shape_base.py b/numpy/lib/test_shape_base.py
new file mode 100644
index 000000000..005868e96
--- /dev/null
+++ b/numpy/lib/test_shape_base.py
@@ -0,0 +1,364 @@
+
+from scipy.testing import *
+set_package_path()
+import scipy.base;
+from scipy.base import *
+restore_path()
+
+class test_apply_along_axis(ScipyTestCase):
+ def check_simple(self):
+ a = ones((20,10),'d')
+ assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
+ def check_simple101(self,level=11):
+ # This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
+ # when enabled and shape(a)[1]>100. See Issue 202.
+ a = ones((10,101),'d')
+ assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
+
+class test_array_split(ScipyTestCase):
+ def check_integer_0_split(self):
+ a = arange(10)
+ try:
+ res = array_split(a,0)
+ assert(0) # it should have thrown a value error
+ except ValueError:
+ pass
+ def check_integer_split(self):
+ a = arange(10)
+ res = array_split(a,1)
+ desired = [arange(10)]
+ compare_results(res,desired)
+
+ res = array_split(a,2)
+ desired = [arange(5),arange(5,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,3)
+ desired = [arange(4),arange(4,7),arange(7,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,4)
+ desired = [arange(3),arange(3,6),arange(6,8),arange(8,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,5)
+ desired = [arange(2),arange(2,4),arange(4,6),arange(6,8),arange(8,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,6)
+ desired = [arange(2),arange(2,4),arange(4,6),arange(6,8),arange(8,9),
+ arange(9,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,7)
+ desired = [arange(2),arange(2,4),arange(4,6),arange(6,7),arange(7,8),
+ arange(8,9), arange(9,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,8)
+ desired = [arange(2),arange(2,4),arange(4,5),arange(5,6),arange(6,7),
+ arange(7,8), arange(8,9), arange(9,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,9)
+ desired = [arange(2),arange(2,3),arange(3,4),arange(4,5),arange(5,6),
+ arange(6,7), arange(7,8), arange(8,9), arange(9,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,10)
+ desired = [arange(1),arange(1,2),arange(2,3),arange(3,4),
+ arange(4,5),arange(5,6), arange(6,7), arange(7,8),
+ arange(8,9), arange(9,10)]
+ compare_results(res,desired)
+
+ res = array_split(a,11)
+ desired = [arange(1),arange(1,2),arange(2,3),arange(3,4),
+ arange(4,5),arange(5,6), arange(6,7), arange(7,8),
+ arange(8,9), arange(9,10),array([])]
+ compare_results(res,desired)
+ def check_integer_split_2D_rows(self):
+ a = array([arange(10),arange(10)])
+ res = array_split(a,3,axis=0)
+ desired = [array([arange(10)]),array([arange(10)]),array([])]
+ compare_results(res,desired)
+ def check_integer_split_2D_cols(self):
+ a = array([arange(10),arange(10)])
+ res = array_split(a,3,axis=-1)
+ desired = [array([arange(4),arange(4)]),
+ array([arange(4,7),arange(4,7)]),
+ array([arange(7,10),arange(7,10)])]
+ compare_results(res,desired)
+ def check_integer_split_2D_default(self):
+ """ This will fail if we change default axis
+ """
+ a = array([arange(10),arange(10)])
+ res = array_split(a,3)
+ desired = [array([arange(10)]),array([arange(10)]),array([])]
+ compare_results(res,desired)
+ #perhaps should check higher dimensions
+
+ def check_index_split_simple(self):
+ a = arange(10)
+ indices = [1,5,7]
+ res = array_split(a,indices,axis=-1)
+ desired = [arange(0,1),arange(1,5),arange(5,7),arange(7,10)]
+ compare_results(res,desired)
+
+ def check_index_split_low_bound(self):
+ a = arange(10)
+ indices = [0,5,7]
+ res = array_split(a,indices,axis=-1)
+ desired = [array([]),arange(0,5),arange(5,7),arange(7,10)]
+ compare_results(res,desired)
+ def check_index_split_high_bound(self):
+ a = arange(10)
+ indices = [0,5,7,10,12]
+ res = array_split(a,indices,axis=-1)
+ desired = [array([]),arange(0,5),arange(5,7),arange(7,10),
+ array([]),array([])]
+ compare_results(res,desired)
+
+class test_split(ScipyTestCase):
+ """* This function is essentially the same as array_split,
+ except that it test if splitting will result in an
+ equal split. Only test for this case.
+ *"""
+ def check_equal_split(self):
+ a = arange(10)
+ res = split(a,2)
+ desired = [arange(5),arange(5,10)]
+ compare_results(res,desired)
+
+ def check_unequal_split(self):
+ a = arange(10)
+ try:
+ res = split(a,3)
+ assert(0) # should raise an error
+ except ValueError:
+ pass
+
+class test_atleast_1d(ScipyTestCase):
+ def check_0D_array(self):
+ a = array(1); b = array(2);
+ res=map(atleast_1d,[a,b])
+ desired = [array([1]),array([2])]
+ assert_array_equal(res,desired)
+ def check_1D_array(self):
+ a = array([1,2]); b = array([2,3]);
+ res=map(atleast_1d,[a,b])
+ desired = [array([1,2]),array([2,3])]
+ assert_array_equal(res,desired)
+ def check_2D_array(self):
+ a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
+ res=map(atleast_1d,[a,b])
+ desired = [a,b]
+ assert_array_equal(res,desired)
+ def check_3D_array(self):
+ a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
+ a = array([a,a]);b = array([b,b]);
+ res=map(atleast_1d,[a,b])
+ desired = [a,b]
+ assert_array_equal(res,desired)
+ def check_r1array(self):
+ """ Test to make sure equivalent Travis O's r1array function
+ """
+ assert(atleast_1d(3).shape == (1,))
+ assert(atleast_1d(3j).shape == (1,))
+ assert(atleast_1d(3L).shape == (1,))
+ assert(atleast_1d(3.0).shape == (1,))
+ assert(atleast_1d([[2,3],[4,5]]).shape == (2,2))
+
+class test_atleast_2d(ScipyTestCase):
+ def check_0D_array(self):
+ a = array(1); b = array(2);
+ res=map(atleast_2d,[a,b])
+ desired = [array([[1]]),array([[2]])]
+ assert_array_equal(res,desired)
+ def check_1D_array(self):
+ a = array([1,2]); b = array([2,3]);
+ res=map(atleast_2d,[a,b])
+ desired = [array([[1,2]]),array([[2,3]])]
+ assert_array_equal(res,desired)
+ def check_2D_array(self):
+ a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
+ res=map(atleast_2d,[a,b])
+ desired = [a,b]
+ assert_array_equal(res,desired)
+ def check_3D_array(self):
+ a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
+ a = array([a,a]);b = array([b,b]);
+ res=map(atleast_2d,[a,b])
+ desired = [a,b]
+ assert_array_equal(res,desired)
+ def check_r2array(self):
+ """ Test to make sure equivalent Travis O's r2array function
+ """
+ assert(atleast_2d(3).shape == (1,1))
+ assert(atleast_2d([3j,1]).shape == (1,2))
+ assert(atleast_2d([[[3,1],[4,5]],[[3,5],[1,2]]]).shape == (2,2,2))
+
+class test_atleast_3d(ScipyTestCase):
+ def check_0D_array(self):
+ a = array(1); b = array(2);
+ res=map(atleast_3d,[a,b])
+ desired = [array([[[1]]]),array([[[2]]])]
+ assert_array_equal(res,desired)
+ def check_1D_array(self):
+ a = array([1,2]); b = array([2,3]);
+ res=map(atleast_3d,[a,b])
+ desired = [array([[[1],[2]]]),array([[[2],[3]]])]
+ assert_array_equal(res,desired)
+ def check_2D_array(self):
+ a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
+ res=map(atleast_3d,[a,b])
+ desired = [a[:,:,NewAxis],b[:,:,NewAxis]]
+ assert_array_equal(res,desired)
+ def check_3D_array(self):
+ a = array([[1,2],[1,2]]); b = array([[2,3],[2,3]]);
+ a = array([a,a]);b = array([b,b]);
+ res=map(atleast_3d,[a,b])
+ desired = [a,b]
+ assert_array_equal(res,desired)
+
+class test_hstack(ScipyTestCase):
+ def check_0D_array(self):
+ a = array(1); b = array(2);
+ res=hstack([a,b])
+ desired = array([1,2])
+ assert_array_equal(res,desired)
+ def check_1D_array(self):
+ a = array([1]); b = array([2]);
+ res=hstack([a,b])
+ desired = array([1,2])
+ assert_array_equal(res,desired)
+ def check_2D_array(self):
+ a = array([[1],[2]]); b = array([[1],[2]]);
+ res=hstack([a,b])
+ desired = array([[1,1],[2,2]])
+ assert_array_equal(res,desired)
+
+class test_vstack(ScipyTestCase):
+ def check_0D_array(self):
+ a = array(1); b = array(2);
+ res=vstack([a,b])
+ desired = array([[1],[2]])
+ assert_array_equal(res,desired)
+ def check_1D_array(self):
+ a = array([1]); b = array([2]);
+ res=vstack([a,b])
+ desired = array([[1],[2]])
+ assert_array_equal(res,desired)
+ def check_2D_array(self):
+ a = array([[1],[2]]); b = array([[1],[2]]);
+ res=vstack([a,b])
+ desired = array([[1],[2],[1],[2]])
+ assert_array_equal(res,desired)
+ def check_2D_array2(self):
+ a = array([1,2]); b = array([1,2]);
+ res=vstack([a,b])
+ desired = array([[1,2],[1,2]])
+ assert_array_equal(res,desired)
+
+class test_dstack(ScipyTestCase):
+ def check_0D_array(self):
+ a = array(1); b = array(2);
+ res=dstack([a,b])
+ desired = array([[[1,2]]])
+ assert_array_equal(res,desired)
+ def check_1D_array(self):
+ a = array([1]); b = array([2]);
+ res=dstack([a,b])
+ desired = array([[[1,2]]])
+ assert_array_equal(res,desired)
+ def check_2D_array(self):
+ a = array([[1],[2]]); b = array([[1],[2]]);
+ res=dstack([a,b])
+ desired = array([[[1,1]],[[2,2,]]])
+ assert_array_equal(res,desired)
+ def check_2D_array2(self):
+ a = array([1,2]); b = array([1,2]);
+ res=dstack([a,b])
+ desired = array([[[1,1],[2,2]]])
+ assert_array_equal(res,desired)
+
+""" array_split has more comprehensive test of splitting.
+ only do simple test on hsplit, vsplit, and dsplit
+"""
+class test_hsplit(ScipyTestCase):
+ """ only testing for integer splits.
+ """
+ def check_0D_array(self):
+ a= array(1)
+ try:
+ hsplit(a,2)
+ assert(0)
+ except ValueError:
+ pass
+ def check_1D_array(self):
+ a= array([1,2,3,4])
+ res = hsplit(a,2)
+ desired = [array([1,2]),array([3,4])]
+ compare_results(res,desired)
+ def check_2D_array(self):
+ a= array([[1,2,3,4],
+ [1,2,3,4]])
+ res = hsplit(a,2)
+ desired = [array([[1,2],[1,2]]),array([[3,4],[3,4]])]
+ compare_results(res,desired)
+
+class test_vsplit(ScipyTestCase):
+ """ only testing for integer splits.
+ """
+ def check_1D_array(self):
+ a= array([1,2,3,4])
+ try:
+ vsplit(a,2)
+ assert(0)
+ except ValueError:
+ pass
+ def check_2D_array(self):
+ a= array([[1,2,3,4],
+ [1,2,3,4]])
+ res = vsplit(a,2)
+ desired = [array([[1,2,3,4]]),array([[1,2,3,4]])]
+ compare_results(res,desired)
+
+class test_dsplit(ScipyTestCase):
+ """ only testing for integer splits.
+ """
+ def check_2D_array(self):
+ a= array([[1,2,3,4],
+ [1,2,3,4]])
+ try:
+ dsplit(a,2)
+ assert(0)
+ except ValueError:
+ pass
+ def check_3D_array(self):
+ a= array([[[1,2,3,4],
+ [1,2,3,4]],
+ [[1,2,3,4],
+ [1,2,3,4]]])
+ res = dsplit(a,2)
+ desired = [array([[[1,2],[1,2]],[[1,2],[1,2]]]),
+ array([[[3,4],[3,4]],[[3,4],[3,4]]])]
+ compare_results(res,desired)
+
+class test_squeeze(ScipyTestCase):
+ def check_basic(self):
+ a = rand(20,10,10,1,1)
+ b = rand(20,1,10,1,20)
+ c = rand(1,1,20,10)
+ assert_array_equal(squeeze(a),reshape(a,(20,10,10)))
+ assert_array_equal(squeeze(b),reshape(b,(20,10,20)))
+ assert_array_equal(squeeze(c),reshape(c,(20,10)))
+
+# Utility
+
+def compare_results(res,desired):
+ for i in range(len(desired)):
+ assert_array_equal(res[i],desired[i])
+
+
+if __name__ == "__main__":
+ ScipyTest().run()