summaryrefslogtreecommitdiff
path: root/numpy/core/_internal.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r--numpy/core/_internal.py57
1 files changed, 56 insertions, 1 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 8942955f6..9a1787dde 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -10,7 +10,7 @@ import sys
import platform
import warnings
-from .multiarray import dtype, array, ndarray
+from .multiarray import dtype, array, ndarray, promote_types
try:
import ctypes
except ImportError:
@@ -433,6 +433,61 @@ def _copy_fields(ary):
'formats': [dt.fields[name][0] for name in dt.names]}
return array(ary, dtype=copy_dtype, copy=True)
+def _promote_fields(dt1, dt2):
+ """ Perform type promotion for two structured dtypes.
+
+ Parameters
+ ----------
+ dt1 : structured dtype
+ First dtype.
+ dt2 : structured dtype
+ Second dtype.
+
+ Returns
+ -------
+ out : dtype
+ The promoted dtype
+
+ Notes
+ -----
+ If one of the inputs is aligned, the result will be. The titles of
+ both descriptors must match (point to the same field).
+ """
+ # Both must be structured and have the same names in the same order
+ if (dt1.names is None or dt2.names is None) or dt1.names != dt2.names:
+ raise TypeError("invalid type promotion")
+
+ # if both are identical, we can (maybe!) just return the same dtype.
+ identical = dt1 is dt2
+ new_fields = []
+ for name in dt1.names:
+ field1 = dt1.fields[name]
+ field2 = dt2.fields[name]
+ new_descr = promote_types(field1[0], field2[0])
+ identical = identical and new_descr is field1[0]
+
+ # Check that the titles match (if given):
+ if field1[2:] != field2[2:]:
+ raise TypeError("invalid type promotion")
+ if len(field1) == 2:
+ new_fields.append((name, new_descr))
+ else:
+ new_fields.append(((field1[2], name), new_descr))
+
+ res = dtype(new_fields, align=dt1.isalignedstruct or dt2.isalignedstruct)
+
+ # Might as well preserve identity (and metadata) if the dtype is identical
+ # and the itemsize, offsets are also unmodified. This could probably be
+ # sped up, but also probably just be removed entirely.
+ if identical and res.itemsize == dt1.itemsize:
+ for name in dt1.names:
+ if dt1.fields[name][1] != res.fields[name][1]:
+ return res # the dtype changed.
+ return dt1
+
+ return res
+
+
def _getfield_is_safe(oldtype, newtype, offset):
""" Checks safety of getfield for object arrays.