diff options
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py index 39a76c2f6..b2bd7da3e 100644 --- a/numpy/lib/tests/test_stride_tricks.py +++ b/numpy/lib/tests/test_stride_tricks.py @@ -1,13 +1,14 @@ from __future__ import division, absolute_import, print_function import numpy as np +from numpy.core._rational_tests import rational from numpy.testing import ( - run_module_suite, assert_equal, assert_array_equal, - assert_raises, assert_ + assert_equal, assert_array_equal, assert_raises, assert_, + assert_raises_regex ) from numpy.lib.stride_tricks import ( as_strided, broadcast_arrays, _broadcast_shape, broadcast_to -) + ) def assert_shapes_correct(input_shapes, expected_shape): # Broadcast a list of arrays with the given input shapes and check the @@ -57,6 +58,17 @@ def test_same(): assert_array_equal(x, bx) assert_array_equal(y, by) +def test_broadcast_kwargs(): + # ensure that a TypeError is appropriately raised when + # np.broadcast_arrays() is called with any keyword + # argument other than 'subok' + x = np.arange(10) + y = np.arange(10) + + with assert_raises_regex(TypeError, + r'broadcast_arrays\(\) got an unexpected keyword*'): + broadcast_arrays(x, y, dtype='float64') + def test_one_off(): x = np.array([[1, 2, 3]]) @@ -317,6 +329,13 @@ def test_as_strided(): a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize)) assert_equal(a.dtype, a_view.dtype) + # Custom dtypes should not be lost (gh-9161) + r = [rational(i) for i in range(4)] + a = np.array(r, dtype=rational) + a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize)) + assert_equal(a.dtype, a_view.dtype) + assert_array_equal([r] * 3, a_view) + def as_strided_writeable(): arr = np.ones(10) view = as_strided(arr, writeable=False) @@ -407,7 +426,7 @@ def test_writeable(): _, result = broadcast_arrays(0, original) assert_equal(result.flags.writeable, False) - # regresssion test for GH6491 + # regression test for GH6491 shape = (2,) strides = [0] tricky_array = as_strided(np.array(0), shape, strides) @@ -424,7 +443,3 @@ def test_reference_types(): actual, _ = broadcast_arrays(input_array, np.ones(3)) assert_array_equal(expected, actual) - - -if __name__ == "__main__": - run_module_suite() |