diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/stride_tricks.py | 6 | ||||
-rw-r--r-- | numpy/lib/tests/test_stride_tricks.py | 10 |
2 files changed, 13 insertions, 3 deletions
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 0f46ed335..a5f247abf 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -60,9 +60,9 @@ def _broadcast_to(array, shape, subok, readonly): if any(size < 0 for size in shape): raise ValueError('all elements of broadcast shape must be non-' 'negative') - broadcast = np.nditer((array,), flags=['multi_index', 'zerosize_ok'], - op_flags=['readonly'], itershape=shape, order='C' - ).itviews[0] + broadcast = np.nditer( + (array,), flags=['multi_index', 'refs_ok', 'zerosize_ok'], + op_flags=['readonly'], itershape=shape, order='C').itviews[0] result = _maybe_view_as_subclass(array, broadcast) if not readonly and array.flags.writeable: result.flags.writeable = True diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py index 0b73109bc..ef483921c 100644 --- a/numpy/lib/tests/test_stride_tricks.py +++ b/numpy/lib/tests/test_stride_tricks.py @@ -364,5 +364,15 @@ def test_writeable(): assert_equal(result.flags.writeable, False) +def test_reference_types(): + input_array = np.array('a', dtype=object) + expected = np.array(['a'] * 3, dtype=object) + actual = broadcast_to(input_array, (3,)) + assert_array_equal(expected, actual) + + actual, _ = broadcast_arrays(input_array, np.ones(3)) + assert_array_equal(expected, actual) + + if __name__ == "__main__": run_module_suite() |