diff options
Diffstat (limited to 'numpy/lib/tests/test_nanfunctions.py')
-rw-r--r-- | numpy/lib/tests/test_nanfunctions.py | 105 |
1 files changed, 54 insertions, 51 deletions
diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index e69d9dd7d..55877e602 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -113,42 +113,46 @@ class TestNanFunctions_MinMax(object): for f in self.nanfuncs: assert_(f(0.) == 0.) - def test_matrices(self): + def test_subclass(self): + class MyNDArray(np.ndarray): + pass + # Check that it works and that type and # shape are preserved - mat = np.matrix(np.eye(3)) + mine = np.eye(3).view(MyNDArray) for f in self.nanfuncs: - res = f(mat, axis=0) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (1, 3)) - res = f(mat, axis=1) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (3, 1)) - res = f(mat) - assert_(np.isscalar(res)) + res = f(mine, axis=0) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == (3,)) + res = f(mine, axis=1) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == (3,)) + res = f(mine) + assert_(res.shape == ()) + # check that rows of nan are dealt with for subclasses (#4628) - mat[1] = np.nan + mine[1] = np.nan for f in self.nanfuncs: with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - res = f(mat, axis=0) - assert_(isinstance(res, np.matrix)) + res = f(mine, axis=0) + assert_(isinstance(res, MyNDArray)) assert_(not np.any(np.isnan(res))) assert_(len(w) == 0) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - res = f(mat, axis=1) - assert_(isinstance(res, np.matrix)) - assert_(np.isnan(res[1, 0]) and not np.isnan(res[0, 0]) - and not np.isnan(res[2, 0])) + res = f(mine, axis=1) + assert_(isinstance(res, MyNDArray)) + assert_(np.isnan(res[1]) and not np.isnan(res[0]) + and not np.isnan(res[2])) assert_(len(w) == 1, 'no warning raised') assert_(issubclass(w[0].category, RuntimeWarning)) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') - res = f(mat) - assert_(np.isscalar(res)) + res = f(mine) + assert_(res.shape == ()) assert_(res != np.nan) assert_(len(w) == 0) @@ -209,19 +213,22 @@ class TestNanFunctions_ArgminArgmax(object): for f in self.nanfuncs: assert_(f(0.) == 0.) - def test_matrices(self): + def test_subclass(self): + class MyNDArray(np.ndarray): + pass + # Check that it works and that type and # shape are preserved - mat = np.matrix(np.eye(3)) + mine = np.eye(3).view(MyNDArray) for f in self.nanfuncs: - res = f(mat, axis=0) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (1, 3)) - res = f(mat, axis=1) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (3, 1)) - res = f(mat) - assert_(np.isscalar(res)) + res = f(mine, axis=0) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == (3,)) + res = f(mine, axis=1) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == (3,)) + res = f(mine) + assert_(res.shape == ()) class TestNanFunctions_IntTypes(object): @@ -381,19 +388,27 @@ class SharedNanFunctionsTestsMixin(object): for f in self.nanfuncs: assert_(f(0.) == 0.) - def test_matrices(self): + def test_subclass(self): + class MyNDArray(np.ndarray): + pass + # Check that it works and that type and # shape are preserved - mat = np.matrix(np.eye(3)) + array = np.eye(3) + mine = array.view(MyNDArray) for f in self.nanfuncs: - res = f(mat, axis=0) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (1, 3)) - res = f(mat, axis=1) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (3, 1)) - res = f(mat) - assert_(np.isscalar(res)) + expected_shape = f(array, axis=0).shape + res = f(mine, axis=0) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == expected_shape) + expected_shape = f(array, axis=1).shape + res = f(mine, axis=1) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == expected_shape) + expected_shape = f(array).shape + res = f(mine) + assert_(isinstance(res, MyNDArray)) + assert_(res.shape == expected_shape) class TestNanFunctions_SumProd(SharedNanFunctionsTestsMixin): @@ -481,18 +496,6 @@ class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin): res = f(d, axis=axis) assert_equal(res.shape, (3, 5, 7, 11)) - def test_matrices(self): - # Check that it works and that type and - # shape are preserved - mat = np.matrix(np.eye(3)) - for f in self.nanfuncs: - for axis in np.arange(2): - res = f(mat, axis=axis) - assert_(isinstance(res, np.matrix)) - assert_(res.shape == (3, 3)) - res = f(mat) - assert_(res.shape == (1, 3*3)) - def test_result_values(self): for axis in (-2, -1, 0, 1, None): tgt = np.cumprod(_ndat_ones, axis=axis) |