summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-08-03 15:02:39 -0500
committerGitHub <noreply@github.com>2016-08-03 15:02:39 -0500
commit0e8d3bb76aa5d854942b584bc6b508c3be225e01 (patch)
treede5bc6867287a990e3fb627f6f8871ebbf670ff7
parent7606c7794be6d346fe56e0f06734cb2ca2039082 (diff)
parent8f847006642a02636c7a3c6c2f54d3446908606d (diff)
downloadnumpy-0e8d3bb76aa5d854942b584bc6b508c3be225e01.tar.gz
Merge pull request #7894 from charris/fixup-7790
fixup-7790, BUG: construct ma.array from np.array which contains padding
-rw-r--r--numpy/ma/core.py65
-rw-r--r--numpy/ma/tests/test_core.py6
2 files changed, 56 insertions, 15 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 29b818c06..37836a0a0 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -371,24 +371,61 @@ def maximum_fill_value(obj):
raise TypeError(errmsg)
-def _recursive_set_default_fill_value(dtypedescr):
+def _recursive_set_default_fill_value(dt):
+ """
+ Create the default fill value for a structured dtype.
+
+ Parameters
+ ----------
+ dt: dtype
+ The structured dtype for which to create the fill value.
+
+ Returns
+ -------
+ val: tuple
+ A tuple of values corresponding to the default structured fill value.
+
+ """
deflist = []
- for currentdescr in dtypedescr:
- currenttype = currentdescr[1]
- if isinstance(currenttype, list):
+ for name in dt.names:
+ currenttype = dt[name]
+ if currenttype.subdtype:
+ currenttype = currenttype.subdtype[0]
+
+ if currenttype.names:
deflist.append(
tuple(_recursive_set_default_fill_value(currenttype)))
else:
- deflist.append(default_fill_value(np.dtype(currenttype)))
+ deflist.append(default_fill_value(currenttype))
return tuple(deflist)
-def _recursive_set_fill_value(fillvalue, dtypedescr):
- fillvalue = np.resize(fillvalue, len(dtypedescr))
+def _recursive_set_fill_value(fillvalue, dt):
+ """
+ Create a fill value for a structured dtype.
+
+ Parameters
+ ----------
+ fillvalue: scalar or array_like
+ Scalar or array representing the fill value. If it is of shorter
+ length than the number of fields in dt, it will be resized.
+ dt: dtype
+ The structured dtype for which to create the fill value.
+
+ Returns
+ -------
+ val: tuple
+ A tuple of values corresponding to the structured fill value.
+
+ """
+ fillvalue = np.resize(fillvalue, len(dt.names))
output_value = []
- for (fval, descr) in zip(fillvalue, dtypedescr):
- cdtype = descr[1]
- if isinstance(cdtype, list):
+ for (fval, name) in zip(fillvalue, dt.names):
+ cdtype = dt[name]
+ if cdtype.subdtype:
+ cdtype = cdtype.subdtype[0]
+
+ if cdtype.names:
output_value.append(tuple(_recursive_set_fill_value(fval, cdtype)))
else:
output_value.append(np.array(fval, dtype=cdtype).item())
@@ -411,9 +448,8 @@ def _check_fill_value(fill_value, ndtype):
fields = ndtype.fields
if fill_value is None:
if fields:
- descr = ndtype.descr
- fill_value = np.array(_recursive_set_default_fill_value(descr),
- dtype=ndtype,)
+ fill_value = np.array(_recursive_set_default_fill_value(ndtype),
+ dtype=ndtype)
else:
fill_value = default_fill_value(ndtype)
elif fields:
@@ -425,9 +461,8 @@ def _check_fill_value(fill_value, ndtype):
err_msg = "Unable to transform %s to dtype %s"
raise ValueError(err_msg % (fill_value, fdtype))
else:
- descr = ndtype.descr
fill_value = np.asarray(fill_value, dtype=object)
- fill_value = np.array(_recursive_set_fill_value(fill_value, descr),
+ fill_value = np.array(_recursive_set_fill_value(fill_value, ndtype),
dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'OSVU'):
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 5c7ae4356..b3965000d 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -212,6 +212,12 @@ class TestMaskedArray(TestCase):
assert_equal(data, [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]])
self.assertTrue(data.mask is nomask)
+ def test_creation_from_ndarray_with_padding(self):
+ x = np.array([('A', 0)], dtype={'names':['f0','f1'],
+ 'formats':['S4','i8'],
+ 'offsets':[0,8]})
+ data = array(x) # used to fail due to 'V' padding field in x.dtype.descr
+
def test_asarray(self):
(x, y, a10, m1, m2, xm, ym, z, zm, xf) = self.d
xm.fill_value = -9999