summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-12-23 23:43:43 +0000
committerpierregm <pierregm@localhost>2008-12-23 23:43:43 +0000
commit8e77ab879c8cb1f25f0245c74bc9bba402c40b51 (patch)
tree78a00afed9fd6742657bb6e1f444b21ec203d216 /numpy/ma/core.py
parented916ff07cf2f1fbb6056cdc01fe3ae82a027ed5 (diff)
downloadnumpy-8e77ab879c8cb1f25f0245c74bc9bba402c40b51.tar.gz
testutils:
* assert_equal : use assert_equal_array on records * assert_array_compare : prevent the common mask to be back-propagated to the initial input arrays. * assert_equal_array : use operator.__eq__ instead of ma.equal * assert_equal_less: use operator.__less__ instead of ma.less core: * Fixed _check_fill_value for nested flexible types * Add a ndtype option to _make_mask_descr * Fixed mask_or for nested flexible types * Fixed the printing of masked arrays w/ flexible types.
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py115
1 files changed, 78 insertions, 37 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 6b4dc98e6..8ee19778c 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -217,6 +217,28 @@ def maximum_fill_value(obj):
raise TypeError(errmsg)
+def _recursive_set_default_fill_value(dtypedescr):
+ deflist = []
+ for currentdescr in dtypedescr:
+ currenttype = currentdescr[1]
+ if isinstance(currenttype, list):
+ deflist.append(tuple(_recursive_set_default_fill_value(currenttype)))
+ else:
+ deflist.append(default_fill_value(np.dtype(currenttype)))
+ return tuple(deflist)
+
+def _recursive_set_fill_value(fillvalue, dtypedescr):
+ fillvalue = np.resize(fillvalue, len(dtypedescr))
+ output_value = []
+ for (fval, descr) in zip(fillvalue, dtypedescr):
+ cdtype = descr[1]
+ if isinstance(cdtype, list):
+ output_value.append(tuple(_recursive_set_fill_value(fval, cdtype)))
+ else:
+ output_value.append(np.array(fval, dtype=cdtype).item())
+ return tuple(output_value)
+
+
def _check_fill_value(fill_value, ndtype):
"""
Private function validating the given `fill_value` for the given dtype.
@@ -233,10 +255,9 @@ def _check_fill_value(fill_value, ndtype):
fields = ndtype.fields
if fill_value is None:
if fields:
- fdtype = [(_[0], _[1]) for _ in ndtype.descr]
- fill_value = np.array(tuple([default_fill_value(fields[n][0])
- for n in ndtype.names]),
- dtype=fdtype)
+ descr = ndtype.descr
+ fill_value = np.array(_recursive_set_default_fill_value(descr),
+ dtype=ndtype)
else:
fill_value = default_fill_value(ndtype)
elif fields:
@@ -248,10 +269,9 @@ def _check_fill_value(fill_value, ndtype):
err_msg = "Unable to transform %s to dtype %s"
raise ValueError(err_msg % (fill_value, fdtype))
else:
- fval = np.resize(fill_value, len(ndtype.descr))
- fill_value = [np.asarray(f).astype(desc[1]).item()
- for (f, desc) in zip(fval, ndtype.descr)]
- fill_value = np.array(tuple(fill_value), copy=False, dtype=fdtype)
+ descr = ndtype.descr
+ fill_value = np.array(_recursive_set_fill_value(fill_value, descr),
+ dtype=ndtype)
else:
if isinstance(fill_value, basestring) and (ndtype.char not in 'SV'):
fill_value = default_fill_value(ndtype)
@@ -831,35 +851,35 @@ mod = _DomainedBinaryOperation(umath.mod, _DomainSafeDivide(), 0, 1)
#####--------------------------------------------------------------------------
#---- --- Mask creation functions ---
#####--------------------------------------------------------------------------
+def _recursive_make_descr(datatype, newtype=bool_):
+ "Private function allowing recursion in make_descr."
+ # Do we have some name fields ?
+ if datatype.names:
+ descr = []
+ for name in datatype.names:
+ field = datatype.fields[name]
+ if len(field) == 3:
+ # Prepend the title to the name
+ name = (field[-1], name)
+ descr.append((name, _recursive_make_descr(field[0], newtype)))
+ return descr
+ # Is this some kind of composite a la (np.float,2)
+ elif datatype.subdtype:
+ mdescr = list(datatype.subdtype)
+ mdescr[0] = newtype
+ return tuple(mdescr)
+ else:
+ return newtype
def make_mask_descr(ndtype):
"""Constructs a dtype description list from a given dtype.
Each field is set to a bool.
"""
- def _make_descr(datatype):
- "Private function allowing recursion."
- # Do we have some name fields ?
- if datatype.names:
- descr = []
- for name in datatype.names:
- field = datatype.fields[name]
- if len(field) == 3:
- # Prepend the title to the name
- name = (field[-1], name)
- descr.append((name, _make_descr(field[0])))
- return descr
- # Is this some kind of composite a la (np.float,2)
- elif datatype.subdtype:
- mdescr = list(datatype.subdtype)
- mdescr[0] = np.dtype(bool)
- return tuple(mdescr)
- else:
- return np.bool
# Make sure we do have a dtype
if not isinstance(ndtype, np.dtype):
ndtype = np.dtype(ndtype)
- return np.dtype(_make_descr(ndtype))
+ return np.dtype(_recursive_make_descr(ndtype, np.bool))
def get_mask(a):
"""Return the mask of a, if any, or nomask.
@@ -988,7 +1008,17 @@ def mask_or (m1, m2, copy=False, shrink=True):
ValueError
If m1 and m2 have different flexible dtypes.
- """
+ """
+ def _recursive_mask_or(m1, m2, newmask):
+ names = m1.dtype.names
+ for name in names:
+ current1 = m1[name]
+ if current1.dtype.names:
+ _recursive_mask_or(current1, m2[name], newmask[name])
+ else:
+ umath.logical_or(current1, m2[name], newmask[name])
+ return
+ #
if (m1 is nomask) or (m1 is False):
dtype = getattr(m2, 'dtype', MaskType)
return make_mask(m2, copy=copy, shrink=shrink, dtype=dtype)
@@ -1002,8 +1032,7 @@ def mask_or (m1, m2, copy=False, shrink=True):
raise ValueError("Incompatible dtypes '%s'<>'%s'" % (dtype1, dtype2))
if dtype1.names:
newmask = np.empty_like(m1)
- for n in dtype1.names:
- newmask[n] = umath.logical_or(m1[n], m2[n])
+ _recursive_mask_or(m1, m2, newmask)
return newmask
return make_mask(umath.logical_or(m1, m2), copy=copy, shrink=shrink)
@@ -1291,6 +1320,22 @@ class _MaskedPrintOption:
#if you single index into a masked location you get this object.
masked_print_option = _MaskedPrintOption('--')
+
+def _recursive_printoption(result, mask, printopt):
+ """
+ Puts printoptions in result where mask is True.
+ Private function allowing for recursion
+ """
+ names = result.dtype.names
+ for name in names:
+ (curdata, curmask) = (result[name], mask[name])
+ if curdata.dtype.names:
+ _recursive_printoption(curdata, curmask, printopt)
+ else:
+ np.putmask(curdata, curmask, printopt)
+ return
+
+
#####--------------------------------------------------------------------------
#---- --- MaskedArray class ---
#####--------------------------------------------------------------------------
@@ -2184,13 +2229,9 @@ class MaskedArray(ndarray):
res = self._data.astype("|O8")
res[m] = f
else:
- rdtype = [list(_) for _ in self.dtype.descr]
- for r in rdtype:
- r[1] = '|O8'
- rdtype = [tuple(_) for _ in rdtype]
+ rdtype = _recursive_make_descr(self.dtype, "|O8")
res = self._data.astype(rdtype)
- for field in names:
- np.putmask(res[field], m[field], f)
+ _recursive_printoption(res, m, f)
else:
res = self.filled(self.fill_value)
return str(res)