diff options
Diffstat (limited to 'numpy/core/_internal.py')
-rw-r--r-- | numpy/core/_internal.py | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 8942955f6..9a1787dde 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -10,7 +10,7 @@ import sys import platform import warnings -from .multiarray import dtype, array, ndarray +from .multiarray import dtype, array, ndarray, promote_types try: import ctypes except ImportError: @@ -433,6 +433,61 @@ def _copy_fields(ary): 'formats': [dt.fields[name][0] for name in dt.names]} return array(ary, dtype=copy_dtype, copy=True) +def _promote_fields(dt1, dt2): + """ Perform type promotion for two structured dtypes. + + Parameters + ---------- + dt1 : structured dtype + First dtype. + dt2 : structured dtype + Second dtype. + + Returns + ------- + out : dtype + The promoted dtype + + Notes + ----- + If one of the inputs is aligned, the result will be. The titles of + both descriptors must match (point to the same field). + """ + # Both must be structured and have the same names in the same order + if (dt1.names is None or dt2.names is None) or dt1.names != dt2.names: + raise TypeError("invalid type promotion") + + # if both are identical, we can (maybe!) just return the same dtype. + identical = dt1 is dt2 + new_fields = [] + for name in dt1.names: + field1 = dt1.fields[name] + field2 = dt2.fields[name] + new_descr = promote_types(field1[0], field2[0]) + identical = identical and new_descr is field1[0] + + # Check that the titles match (if given): + if field1[2:] != field2[2:]: + raise TypeError("invalid type promotion") + if len(field1) == 2: + new_fields.append((name, new_descr)) + else: + new_fields.append(((field1[2], name), new_descr)) + + res = dtype(new_fields, align=dt1.isalignedstruct or dt2.isalignedstruct) + + # Might as well preserve identity (and metadata) if the dtype is identical + # and the itemsize, offsets are also unmodified. This could probably be + # sped up, but also probably just be removed entirely. + if identical and res.itemsize == dt1.itemsize: + for name in dt1.names: + if dt1.fields[name][1] != res.fields[name][1]: + return res # the dtype changed. + return dt1 + + return res + + def _getfield_is_safe(oldtype, newtype, offset): """ Checks safety of getfield for object arrays. |