diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2016-08-03 15:02:39 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-03 15:02:39 -0500 |
commit | 0e8d3bb76aa5d854942b584bc6b508c3be225e01 (patch) | |
tree | de5bc6867287a990e3fb627f6f8871ebbf670ff7 /numpy | |
parent | 7606c7794be6d346fe56e0f06734cb2ca2039082 (diff) | |
parent | 8f847006642a02636c7a3c6c2f54d3446908606d (diff) | |
download | numpy-0e8d3bb76aa5d854942b584bc6b508c3be225e01.tar.gz |
Merge pull request #7894 from charris/fixup-7790
fixup-7790, BUG: construct ma.array from np.array which contains padding
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 65 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 6 |
2 files changed, 56 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 29b818c06..37836a0a0 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -371,24 +371,61 @@ def maximum_fill_value(obj): raise TypeError(errmsg) -def _recursive_set_default_fill_value(dtypedescr): +def _recursive_set_default_fill_value(dt): + """ + Create the default fill value for a structured dtype. + + Parameters + ---------- + dt: dtype + The structured dtype for which to create the fill value. + + Returns + ------- + val: tuple + A tuple of values corresponding to the default structured fill value. + + """ deflist = [] - for currentdescr in dtypedescr: - currenttype = currentdescr[1] - if isinstance(currenttype, list): + for name in dt.names: + currenttype = dt[name] + if currenttype.subdtype: + currenttype = currenttype.subdtype[0] + + if currenttype.names: deflist.append( tuple(_recursive_set_default_fill_value(currenttype))) else: - deflist.append(default_fill_value(np.dtype(currenttype))) + deflist.append(default_fill_value(currenttype)) return tuple(deflist) -def _recursive_set_fill_value(fillvalue, dtypedescr): - fillvalue = np.resize(fillvalue, len(dtypedescr)) +def _recursive_set_fill_value(fillvalue, dt): + """ + Create a fill value for a structured dtype. + + Parameters + ---------- + fillvalue: scalar or array_like + Scalar or array representing the fill value. If it is of shorter + length than the number of fields in dt, it will be resized. + dt: dtype + The structured dtype for which to create the fill value. + + Returns + ------- + val: tuple + A tuple of values corresponding to the structured fill value. + + """ + fillvalue = np.resize(fillvalue, len(dt.names)) output_value = [] - for (fval, descr) in zip(fillvalue, dtypedescr): - cdtype = descr[1] - if isinstance(cdtype, list): + for (fval, name) in zip(fillvalue, dt.names): + cdtype = dt[name] + if cdtype.subdtype: + cdtype = cdtype.subdtype[0] + + if cdtype.names: output_value.append(tuple(_recursive_set_fill_value(fval, cdtype))) else: output_value.append(np.array(fval, dtype=cdtype).item()) @@ -411,9 +448,8 @@ def _check_fill_value(fill_value, ndtype): fields = ndtype.fields if fill_value is None: if fields: - descr = ndtype.descr - fill_value = np.array(_recursive_set_default_fill_value(descr), - dtype=ndtype,) + fill_value = np.array(_recursive_set_default_fill_value(ndtype), + dtype=ndtype) else: fill_value = default_fill_value(ndtype) elif fields: @@ -425,9 +461,8 @@ def _check_fill_value(fill_value, ndtype): err_msg = "Unable to transform %s to dtype %s" raise ValueError(err_msg % (fill_value, fdtype)) else: - descr = ndtype.descr fill_value = np.asarray(fill_value, dtype=object) - fill_value = np.array(_recursive_set_fill_value(fill_value, descr), + fill_value = np.array(_recursive_set_fill_value(fill_value, ndtype), dtype=ndtype) else: if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'): diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 5c7ae4356..b3965000d 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -212,6 +212,12 @@ class TestMaskedArray(TestCase): assert_equal(data, [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]) self.assertTrue(data.mask is nomask) + def test_creation_from_ndarray_with_padding(self): + x = np.array([('A', 0)], dtype={'names':['f0','f1'], + 'formats':['S4','i8'], + 'offsets':[0,8]}) + data = array(x) # used to fail due to 'V' padding field in x.dtype.descr + def test_asarray(self): (x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d xm.fill_value = -9999 |