diff options
author | pierregm <pierregm@localhost> | 2008-12-23 23:43:43 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2008-12-23 23:43:43 +0000 |
commit | 8e77ab879c8cb1f25f0245c74bc9bba402c40b51 (patch) | |
tree | 78a00afed9fd6742657bb6e1f444b21ec203d216 /numpy/ma/core.py | |
parent | ed916ff07cf2f1fbb6056cdc01fe3ae82a027ed5 (diff) | |
download | numpy-8e77ab879c8cb1f25f0245c74bc9bba402c40b51.tar.gz |
testutils:
* assert_equal : use assert_equal_array on records
* assert_array_compare : prevent the common mask to be back-propagated to the initial input arrays.
* assert_equal_array : use operator.__eq__ instead of ma.equal
* assert_equal_less: use operator.__less__ instead of ma.less
core:
* Fixed _check_fill_value for nested flexible types
* Add a ndtype option to _make_mask_descr
* Fixed mask_or for nested flexible types
* Fixed the printing of masked arrays w/ flexible types.
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r-- | numpy/ma/core.py | 115 |
1 files changed, 78 insertions, 37 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 6b4dc98e6..8ee19778c 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -217,6 +217,28 @@ def maximum_fill_value(obj): raise TypeError(errmsg) +def _recursive_set_default_fill_value(dtypedescr): + deflist = [] + for currentdescr in dtypedescr: + currenttype = currentdescr[1] + if isinstance(currenttype, list): + deflist.append(tuple(_recursive_set_default_fill_value(currenttype))) + else: + deflist.append(default_fill_value(np.dtype(currenttype))) + return tuple(deflist) + +def _recursive_set_fill_value(fillvalue, dtypedescr): + fillvalue = np.resize(fillvalue, len(dtypedescr)) + output_value = [] + for (fval, descr) in zip(fillvalue, dtypedescr): + cdtype = descr[1] + if isinstance(cdtype, list): + output_value.append(tuple(_recursive_set_fill_value(fval, cdtype))) + else: + output_value.append(np.array(fval, dtype=cdtype).item()) + return tuple(output_value) + + def _check_fill_value(fill_value, ndtype): """ Private function validating the given `fill_value` for the given dtype. @@ -233,10 +255,9 @@ def _check_fill_value(fill_value, ndtype): fields = ndtype.fields if fill_value is None: if fields: - fdtype = [(_[0], _[1]) for _ in ndtype.descr] - fill_value = np.array(tuple([default_fill_value(fields[n][0]) - for n in ndtype.names]), - dtype=fdtype) + descr = ndtype.descr + fill_value = np.array(_recursive_set_default_fill_value(descr), + dtype=ndtype) else: fill_value = default_fill_value(ndtype) elif fields: @@ -248,10 +269,9 @@ def _check_fill_value(fill_value, ndtype): err_msg = "Unable to transform %s to dtype %s" raise ValueError(err_msg % (fill_value, fdtype)) else: - fval = np.resize(fill_value, len(ndtype.descr)) - fill_value = [np.asarray(f).astype(desc[1]).item() - for (f, desc) in zip(fval, ndtype.descr)] - fill_value = np.array(tuple(fill_value), copy=False, dtype=fdtype) + descr = ndtype.descr + fill_value = np.array(_recursive_set_fill_value(fill_value, descr), + dtype=ndtype) else: if isinstance(fill_value, basestring) and (ndtype.char not in 'SV'): fill_value = default_fill_value(ndtype) @@ -831,35 +851,35 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1) #####-------------------------------------------------------------------------- #---- --- Mask creation functions --- #####-------------------------------------------------------------------------- +def _recursive_make_descr(datatype, newtype=bool_): + "Private function allowing recursion in make_descr." + # Do we have some name fields ? + if datatype.names: + descr = [] + for name in datatype.names: + field = datatype.fields[name] + if len(field) == 3: + # Prepend the title to the name + name = (field[-1], name) + descr.append((name, _recursive_make_descr(field[0], newtype))) + return descr + # Is this some kind of composite a la (np.float,2) + elif datatype.subdtype: + mdescr = list(datatype.subdtype) + mdescr[0] = newtype + return tuple(mdescr) + else: + return newtype def make_mask_descr(ndtype): """Constructs a dtype description list from a given dtype. Each field is set to a bool. """ - def _make_descr(datatype): - "Private function allowing recursion." - # Do we have some name fields ? - if datatype.names: - descr = [] - for name in datatype.names: - field = datatype.fields[name] - if len(field) == 3: - # Prepend the title to the name - name = (field[-1], name) - descr.append((name, _make_descr(field[0]))) - return descr - # Is this some kind of composite a la (np.float,2) - elif datatype.subdtype: - mdescr = list(datatype.subdtype) - mdescr[0] = np.dtype(bool) - return tuple(mdescr) - else: - return np.bool # Make sure we do have a dtype if not isinstance(ndtype, np.dtype): ndtype = np.dtype(ndtype) - return np.dtype(_make_descr(ndtype)) + return np.dtype(_recursive_make_descr(ndtype, np.bool)) def get_mask(a): """Return the mask of a, if any, or nomask. @@ -988,7 +1008,17 @@ def mask_or (m1, m2, copy=False, shrink=True): ValueError If m1 and m2 have different flexible dtypes. - """ + """ + def _recursive_mask_or(m1, m2, newmask): + names = m1.dtype.names + for name in names: + current1 = m1[name] + if current1.dtype.names: + _recursive_mask_or(current1, m2[name], newmask[name]) + else: + umath.logical_or(current1, m2[name], newmask[name]) + return + # if (m1 is nomask) or (m1 is False): dtype = getattr(m2, 'dtype', MaskType) return make_mask(m2, copy=copy, shrink=shrink, dtype=dtype) @@ -1002,8 +1032,7 @@ def mask_or (m1, m2, copy=False, shrink=True): raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2)) if dtype1.names: newmask = np.empty_like(m1) - for n in dtype1.names: - newmask[n] = umath.logical_or(m1[n], m2[n]) + _recursive_mask_or(m1, m2, newmask) return newmask return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink) @@ -1291,6 +1320,22 @@ class _MaskedPrintOption: #if you single index into a masked location you get this object. masked_print_option = _MaskedPrintOption('--') + +def _recursive_printoption(result, mask, printopt): + """ + Puts printoptions in result where mask is True. + Private function allowing for recursion + """ + names = result.dtype.names + for name in names: + (curdata, curmask) = (result[name], mask[name]) + if curdata.dtype.names: + _recursive_printoption(curdata, curmask, printopt) + else: + np.putmask(curdata, curmask, printopt) + return + + #####-------------------------------------------------------------------------- #---- --- MaskedArray class --- #####-------------------------------------------------------------------------- @@ -2184,13 +2229,9 @@ class MaskedArray(ndarray): res = self._data.astype("|O8") res[m] = f else: - rdtype = [list(_) for _ in self.dtype.descr] - for r in rdtype: - r[1] = '|O8' - rdtype = [tuple(_) for _ in rdtype] + rdtype = _recursive_make_descr(self.dtype, "|O8") res = self._data.astype(rdtype) - for field in names: - np.putmask(res[field], m[field], f) + _recursive_printoption(res, m, f) else: res = self.filled(self.fill_value) return str(res) |