summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2020-12-23 15:02:01 -0700
committerGitHub <noreply@github.com>2020-12-23 15:02:01 -0700
commitcd50d88237a95eed03536b18eb1faf60885233ee (patch)
tree313a485a8cfb0e64fd343c2b163ff1eebc528b3a
parent073b9b9435b1bcabc0d7641fd7c9d9f5333a578e (diff)
parent50dce51e9fcbf5d80bfdc1bfe2b41fc9a9e9f9cc (diff)
downloadnumpy-cd50d88237a95eed03536b18eb1faf60885233ee.tar.gz
Merge pull request #18052 from seberg/concat-with-string-dtype
BUG: Fix concatenation when the output is "S" or "U"
-rw-r--r--numpy/core/src/multiarray/convert_datatype.c67
-rw-r--r--numpy/core/src/multiarray/convert_datatype.h4
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c57
-rw-r--r--numpy/core/tests/test_shape_base.py28
4 files changed, 116 insertions, 40 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c
index f9dd35a73..5d5b69bd5 100644
--- a/numpy/core/src/multiarray/convert_datatype.c
+++ b/numpy/core/src/multiarray/convert_datatype.c
@@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
}
+/*
+ * Helper to find the target descriptor for multiple arrays given an input
+ * one that may be a DType class (e.g. "U" or "S").
+ * Works with arrays, since that is what `concatenate` works with. However,
+ * unlike `np.array(...)` or `arr.astype()` we will never inspect the array's
+ * content, which means that object arrays can only be cast to strings if a
+ * fixed width is provided (same for string -> generic datetime).
+ *
+ * As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should
+ * eventually be refactored to move the step to an earlier point.
+ */
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_FindConcatenationDescriptor(
+ npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype)
+{
+ if (requested_dtype == NULL) {
+ return PyArray_ResultType(n, arrays, 0, NULL);
+ }
+
+ PyArray_DTypeMeta *common_dtype;
+ PyArray_Descr *result = NULL;
+ if (PyArray_ExtractDTypeAndDescriptor(
+ requested_dtype, &result, &common_dtype) < 0) {
+ return NULL;
+ }
+ if (result != NULL) {
+ if (result->subarray != NULL) {
+ PyErr_Format(PyExc_TypeError,
+ "The dtype `%R` is not a valid dtype for concatenation "
+ "since it is a subarray dtype (the subarray dimensions "
+ "would be added as array dimensions).", result);
+ Py_DECREF(result);
+ return NULL;
+ }
+ goto finish;
+ }
+ assert(n > 0); /* concatenate requires at least one array input. */
+ PyArray_Descr *descr = PyArray_DESCR(arrays[0]);
+ result = PyArray_CastDescrToDType(descr, common_dtype);
+ if (result == NULL || n == 1) {
+ goto finish;
+ }
+ /*
+ * This could short-cut a bit, calling `common_instance` directly and/or
+ * returning the `default_descr()` directly. Avoiding that (for now) as
+ * it would duplicate code from `PyArray_PromoteTypes`.
+ */
+ for (npy_intp i = 1; i < n; i++) {
+ descr = PyArray_DESCR(arrays[i]);
+ PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype);
+ if (curr == NULL) {
+ Py_SETREF(result, NULL);
+ goto finish;
+ }
+ Py_SETREF(result, PyArray_PromoteTypes(result, curr));
+ Py_DECREF(curr);
+ if (result == NULL) {
+ goto finish;
+ }
+ }
+
+ finish:
+ Py_DECREF(common_dtype);
+ return result;
+}
+
+
/**
* This function defines the common DType operator.
*
diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h
index cc1930f77..97006b952 100644
--- a/numpy/core/src/multiarray/convert_datatype.h
+++ b/numpy/core/src/multiarray/convert_datatype.h
@@ -49,6 +49,10 @@ npy_set_invalid_cast_error(
NPY_NO_EXPORT PyArray_Descr *
PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType);
+NPY_NO_EXPORT PyArray_Descr *
+PyArray_FindConcatenationDescriptor(
+ npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype);
+
NPY_NO_EXPORT int
PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth);
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 870b633ed..e10fe39bd 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
/* Get the priority subtype for the array */
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
-
- if (dtype == NULL) {
- /* Get the resulting dtype from combining all the arrays */
- dtype = (PyArray_Descr *)PyArray_ResultType(
- narrays, arrays, 0, NULL);
- if (dtype == NULL) {
- return NULL;
- }
- }
- else {
- Py_INCREF(dtype);
+ PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
+ narrays, arrays, (PyObject *)dtype);
+ if (descr == NULL) {
+ return NULL;
}
/*
@@ -467,7 +460,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
* resolution rules matching that of the NpyIter.
*/
PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm);
- s = dtype->elsize;
+ s = descr->elsize;
for (idim = ndim-1; idim >= 0; --idim) {
int iperm = strideperm[idim];
strides[iperm] = s;
@@ -475,17 +468,13 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
}
/* Allocate the array for the result. This steals the 'dtype' reference. */
- ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
- dtype,
- ndim,
- shape,
- strides,
- NULL,
- 0,
- NULL);
+ ret = (PyArrayObject *)PyArray_NewFromDescr_int(
+ subtype, descr, ndim, shape, strides, NULL, 0, NULL,
+ NULL, 0, 1);
if (ret == NULL) {
return NULL;
}
+ assert(PyArray_DESCR(ret) == descr);
}
/*
@@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
/* Get the priority subtype for the array */
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
- if (dtype == NULL) {
- /* Get the resulting dtype from combining all the arrays */
- dtype = (PyArray_Descr *)PyArray_ResultType(
- narrays, arrays, 0, NULL);
- if (dtype == NULL) {
- return NULL;
- }
- }
- else {
- Py_INCREF(dtype);
+ PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
+ narrays, arrays, (PyObject *)dtype);
+ if (descr == NULL) {
+ return NULL;
}
- stride = dtype->elsize;
+ stride = descr->elsize;
/* Allocate the array for the result. This steals the 'dtype' reference. */
- ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
- dtype,
- 1,
- &shape,
- &stride,
- NULL,
- 0,
- NULL);
+ ret = (PyArrayObject *)PyArray_NewFromDescr_int(
+ subtype, descr, 1, &shape, &stride, NULL, 0, NULL,
+ NULL, 0, 1);
if (ret == NULL) {
return NULL;
}
+ assert(PyArray_DESCR(ret) == descr);
}
/*
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index 4e56ace90..9922c9173 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -343,7 +343,7 @@ class TestConcatenate:
concatenate((a, b), out=np.empty(4))
@pytest.mark.parametrize("axis", [None, 0])
- @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
+ @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"])
@pytest.mark.parametrize("casting",
['no', 'equiv', 'safe', 'same_kind', 'unsafe'])
def test_out_and_dtype(self, axis, out_dtype, casting):
@@ -369,6 +369,32 @@ class TestConcatenate:
with assert_raises(TypeError):
concatenate(to_concat, out=out, dtype=out_dtype, axis=axis)
+ @pytest.mark.parametrize("axis", [None, 0])
+ @pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"])
+ @pytest.mark.parametrize("arrs",
+ [([0.],), ([0.], [1]), ([0], ["string"], [1.])])
+ def test_dtype_with_promotion(self, arrs, string_dt, axis):
+ # Note that U0 and S0 should be deprecated eventually and changed to
+ # actually give the empty string result (together with `np.array`)
+ res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe")
+ assert res.dtype == np.promote_types("d", string_dt)
+
+ @pytest.mark.parametrize("axis", [None, 0])
+ def test_string_dtype_does_not_inspect(self, axis):
+ # The error here currently depends on NPY_USE_NEW_CASTINGIMPL as
+ # the new version rejects using the "default string length" of 64.
+ # The new behaviour is better, `np.array()` and `arr.astype()` would
+ # have to be used instead. (currently only raises due to unsafe cast)
+ with pytest.raises((ValueError, TypeError)):
+ np.concatenate(([None], [1]), dtype="S", axis=axis)
+ with pytest.raises((ValueError, TypeError)):
+ np.concatenate(([None], [1]), dtype="U", axis=axis)
+
+ @pytest.mark.parametrize("axis", [None, 0])
+ def test_subarray_error(self, axis):
+ with pytest.raises(TypeError, match=".*subarray dtype"):
+ np.concatenate(([1], [1]), dtype="(2,)i", axis=axis)
+
def test_stack():
# non-iterable input