diff options
author | Matti Picus <matti.picus@gmail.com> | 2021-08-29 09:57:22 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-08-29 09:57:22 +0300 |
commit | 5ae53e93b2be9ccf47bf72a85d71ea15d15a2eed (patch) | |
tree | b374818c65a8a4ca02df2a1b5528e30e0d09b269 /numpy | |
parent | a90677a0b36d63912597c96b887807bbc74d1483 (diff) | |
parent | 580d83ff127f330470867c009c7c6b17847f4287 (diff) | |
download | numpy-5ae53e93b2be9ccf47bf72a85d71ea15d15a2eed.tar.gz |
Merge pull request #19715 from yashasvimisra2798/casting_patch1
BUG: Casting bool_ to float16
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/lowlevel_strided_loops.c.src | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_casting_unittests.py | 9 |
2 files changed, 12 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src index e533e4932..e38873746 100644 --- a/numpy/core/src/multiarray/lowlevel_strided_loops.c.src +++ b/numpy/core/src/multiarray/lowlevel_strided_loops.c.src @@ -819,6 +819,10 @@ NPY_NO_EXPORT PyArrayMethod_StridedLoop * # define _CONVERT_FN(x) npy_floatbits_to_halfbits(x) # elif @is_double1@ # define _CONVERT_FN(x) npy_doublebits_to_halfbits(x) +# elif @is_half1@ +# define _CONVERT_FN(x) (x) +# elif @is_bool1@ +# define _CONVERT_FN(x) npy_float_to_half((float)(x!=0)) # else # define _CONVERT_FN(x) npy_float_to_half((float)x) # endif diff --git a/numpy/core/tests/test_casting_unittests.py b/numpy/core/tests/test_casting_unittests.py index 3f67f1832..a13e807e2 100644 --- a/numpy/core/tests/test_casting_unittests.py +++ b/numpy/core/tests/test_casting_unittests.py @@ -695,6 +695,13 @@ class TestCasting: expected = arr_normal.astype(dtype) except TypeError: with pytest.raises(TypeError): - arr_NULLs.astype(dtype) + arr_NULLs.astype(dtype), else: assert_array_equal(expected, arr_NULLs.astype(dtype)) + + def test_float_to_bool(self): + # test case corresponding to gh-19514 + # simple test for casting bool_ to float16 + res = np.array([0, 3, -7], dtype=np.int8).view(bool) + expected = [0, 1, 1] + assert_array_equal(res, expected) |