summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test_stride_tricks.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/tests/test_stride_tricks.py')
-rw-r--r--numpy/lib/tests/test_stride_tricks.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
index 39a76c2f6..3c2ca8b87 100644
--- a/numpy/lib/tests/test_stride_tricks.py
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -1,13 +1,13 @@
from __future__ import division, absolute_import, print_function
import numpy as np
+from numpy.core._rational_tests import rational
from numpy.testing import (
- run_module_suite, assert_equal, assert_array_equal,
- assert_raises, assert_
+ assert_equal, assert_array_equal, assert_raises, assert_
)
from numpy.lib.stride_tricks import (
as_strided, broadcast_arrays, _broadcast_shape, broadcast_to
-)
+ )
def assert_shapes_correct(input_shapes, expected_shape):
# Broadcast a list of arrays with the given input shapes and check the
@@ -317,6 +317,13 @@ def test_as_strided():
a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
assert_equal(a.dtype, a_view.dtype)
+ # Custom dtypes should not be lost (gh-9161)
+ r = [rational(i) for i in range(4)]
+ a = np.array(r, dtype=rational)
+ a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
+ assert_equal(a.dtype, a_view.dtype)
+ assert_array_equal([r] * 3, a_view)
+
def as_strided_writeable():
arr = np.ones(10)
view = as_strided(arr, writeable=False)
@@ -407,7 +414,7 @@ def test_writeable():
_, result = broadcast_arrays(0, original)
assert_equal(result.flags.writeable, False)
- # regresssion test for GH6491
+ # regression test for GH6491
shape = (2,)
strides = [0]
tricky_array = as_strided(np.array(0), shape, strides)
@@ -424,7 +431,3 @@ def test_reference_types():
actual, _ = broadcast_arrays(input_array, np.ones(3))
assert_array_equal(expected, actual)
-
-
-if __name__ == "__main__":
- run_module_suite()