diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2020-12-23 15:02:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-23 15:02:01 -0700 |
commit | cd50d88237a95eed03536b18eb1faf60885233ee (patch) | |
tree | 313a485a8cfb0e64fd343c2b163ff1eebc528b3a | |
parent | 073b9b9435b1bcabc0d7641fd7c9d9f5333a578e (diff) | |
parent | 50dce51e9fcbf5d80bfdc1bfe2b41fc9a9e9f9cc (diff) | |
download | numpy-cd50d88237a95eed03536b18eb1faf60885233ee.tar.gz |
Merge pull request #18052 from seberg/concat-with-string-dtype
BUG: Fix concatenation when the output is "S" or "U"
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 67 | ||||
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.h | 4 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 57 | ||||
-rw-r--r-- | numpy/core/tests/test_shape_base.py | 28 |
4 files changed, 116 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index f9dd35a73..5d5b69bd5 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) } +/* + * Helper to find the target descriptor for multiple arrays given an input + * one that may be a DType class (e.g. "U" or "S"). + * Works with arrays, since that is what `concatenate` works with. However, + * unlike `np.array(...)` or `arr.astype()` we will never inspect the array's + * content, which means that object arrays can only be cast to strings if a + * fixed width is provided (same for string -> generic datetime). + * + * As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should + * eventually be refactored to move the step to an earlier point. + */ +NPY_NO_EXPORT PyArray_Descr * +PyArray_FindConcatenationDescriptor( + npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype) +{ + if (requested_dtype == NULL) { + return PyArray_ResultType(n, arrays, 0, NULL); + } + + PyArray_DTypeMeta *common_dtype; + PyArray_Descr *result = NULL; + if (PyArray_ExtractDTypeAndDescriptor( + requested_dtype, &result, &common_dtype) < 0) { + return NULL; + } + if (result != NULL) { + if (result->subarray != NULL) { + PyErr_Format(PyExc_TypeError, + "The dtype `%R` is not a valid dtype for concatenation " + "since it is a subarray dtype (the subarray dimensions " + "would be added as array dimensions).", result); + Py_DECREF(result); + return NULL; + } + goto finish; + } + assert(n > 0); /* concatenate requires at least one array input. */ + PyArray_Descr *descr = PyArray_DESCR(arrays[0]); + result = PyArray_CastDescrToDType(descr, common_dtype); + if (result == NULL || n == 1) { + goto finish; + } + /* + * This could short-cut a bit, calling `common_instance` directly and/or + * returning the `default_descr()` directly. Avoiding that (for now) as + * it would duplicate code from `PyArray_PromoteTypes`. + */ + for (npy_intp i = 1; i < n; i++) { + descr = PyArray_DESCR(arrays[i]); + PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype); + if (curr == NULL) { + Py_SETREF(result, NULL); + goto finish; + } + Py_SETREF(result, PyArray_PromoteTypes(result, curr)); + Py_DECREF(curr); + if (result == NULL) { + goto finish; + } + } + + finish: + Py_DECREF(common_dtype); + return result; +} + + /** * This function defines the common DType operator. * diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index cc1930f77..97006b952 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -49,6 +49,10 @@ npy_set_invalid_cast_error( NPY_NO_EXPORT PyArray_Descr * PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType); +NPY_NO_EXPORT PyArray_Descr * +PyArray_FindConcatenationDescriptor( + npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype); + NPY_NO_EXPORT int PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth); diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 870b633ed..e10fe39bd 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, /* Get the priority subtype for the array */ PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays); - - if (dtype == NULL) { - /* Get the resulting dtype from combining all the arrays */ - dtype = (PyArray_Descr *)PyArray_ResultType( - narrays, arrays, 0, NULL); - if (dtype == NULL) { - return NULL; - } - } - else { - Py_INCREF(dtype); + PyArray_Descr *descr = PyArray_FindConcatenationDescriptor( + narrays, arrays, (PyObject *)dtype); + if (descr == NULL) { + return NULL; } /* @@ -467,7 +460,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, * resolution rules matching that of the NpyIter. */ PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm); - s = dtype->elsize; + s = descr->elsize; for (idim = ndim-1; idim >= 0; --idim) { int iperm = strideperm[idim]; strides[iperm] = s; @@ -475,17 +468,13 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, } /* Allocate the array for the result. This steals the 'dtype' reference. */ - ret = (PyArrayObject *)PyArray_NewFromDescr(subtype, - dtype, - ndim, - shape, - strides, - NULL, - 0, - NULL); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + subtype, descr, ndim, shape, strides, NULL, 0, NULL, + NULL, 0, 1); if (ret == NULL) { return NULL; } + assert(PyArray_DESCR(ret) == descr); } /* @@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays, /* Get the priority subtype for the array */ PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays); - if (dtype == NULL) { - /* Get the resulting dtype from combining all the arrays */ - dtype = (PyArray_Descr *)PyArray_ResultType( - narrays, arrays, 0, NULL); - if (dtype == NULL) { - return NULL; - } - } - else { - Py_INCREF(dtype); + PyArray_Descr *descr = PyArray_FindConcatenationDescriptor( + narrays, arrays, (PyObject *)dtype); + if (descr == NULL) { + return NULL; } - stride = dtype->elsize; + stride = descr->elsize; /* Allocate the array for the result. This steals the 'dtype' reference. */ - ret = (PyArrayObject *)PyArray_NewFromDescr(subtype, - dtype, - 1, - &shape, - &stride, - NULL, - 0, - NULL); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + subtype, descr, 1, &shape, &stride, NULL, 0, NULL, + NULL, 0, 1); if (ret == NULL) { return NULL; } + assert(PyArray_DESCR(ret) == descr); } /* diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index 4e56ace90..9922c9173 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -343,7 +343,7 @@ class TestConcatenate: concatenate((a, b), out=np.empty(4)) @pytest.mark.parametrize("axis", [None, 0]) - @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"]) + @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"]) @pytest.mark.parametrize("casting", ['no', 'equiv', 'safe', 'same_kind', 'unsafe']) def test_out_and_dtype(self, axis, out_dtype, casting): @@ -369,6 +369,32 @@ class TestConcatenate: with assert_raises(TypeError): concatenate(to_concat, out=out, dtype=out_dtype, axis=axis) + @pytest.mark.parametrize("axis", [None, 0]) + @pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"]) + @pytest.mark.parametrize("arrs", + [([0.],), ([0.], [1]), ([0], ["string"], [1.])]) + def test_dtype_with_promotion(self, arrs, string_dt, axis): + # Note that U0 and S0 should be deprecated eventually and changed to + # actually give the empty string result (together with `np.array`) + res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe") + assert res.dtype == np.promote_types("d", string_dt) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_string_dtype_does_not_inspect(self, axis): + # The error here currently depends on NPY_USE_NEW_CASTINGIMPL as + # the new version rejects using the "default string length" of 64. + # The new behaviour is better, `np.array()` and `arr.astype()` would + # have to be used instead. (currently only raises due to unsafe cast) + with pytest.raises((ValueError, TypeError)): + np.concatenate(([None], [1]), dtype="S", axis=axis) + with pytest.raises((ValueError, TypeError)): + np.concatenate(([None], [1]), dtype="U", axis=axis) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_subarray_error(self, axis): + with pytest.raises(TypeError, match=".*subarray dtype"): + np.concatenate(([1], [1]), dtype="(2,)i", axis=axis) + def test_stack(): # non-iterable input |