summaryrefslogtreecommitdiff
path: root/numpy/lib/tests
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests')
-rw-r--r--numpy/lib/tests/test_shape_base.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 4d06001f4..14406fe21 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function
import numpy as np
+import warnings
+
from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
- vsplit, dstack, column_stack, kron, tile
+ vsplit, dstack, column_stack, kron, tile, expand_dims,
)
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
@@ -182,6 +184,25 @@ class TestApplyOverAxes(TestCase):
assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
+class TestExpandDims(TestCase):
+ def test_functionality(self):
+ s = (2, 3, 4, 5)
+ a = np.empty(s)
+ for axis in range(-5, 4):
+ b = expand_dims(a, axis)
+ assert_(b.shape[axis] == 1)
+ assert_(np.squeeze(b).shape == s)
+
+ def test_deprecations(self):
+ # 2017-05-17, 1.13.0
+ s = (2, 3, 4, 5)
+ a = np.empty(s)
+ with warnings.catch_warnings():
+ warnings.simplefilter("always")
+ assert_warns(DeprecationWarning, expand_dims, a, -6)
+ assert_warns(DeprecationWarning, expand_dims, a, 5)
+
+
class TestArraySplit(TestCase):
def test_integer_0_split(self):
a = np.arange(10)