diff options
author | yuki <drsuaimqjgar@gmail.com> | 2023-03-25 00:49:20 +0000 |
---|---|---|
committer | yuki <drsuaimqjgar@gmail.com> | 2023-03-25 01:49:09 +0000 |
commit | 13086bdf5e3cafb64c265f8de475b618a6a0252f (patch) | |
tree | a69d2dddbe70e5eb4d8def09a28a345a918c84d3 /numpy/ma/extras.py | |
parent | 03a87196e9f51d87b7c28d726b795b8fd0547a98 (diff) | |
download | numpy-13086bdf5e3cafb64c265f8de475b618a6a0252f.tar.gz |
MAINT: move `mask_rowcols` to `ma/extras.py`
Diffstat (limited to 'numpy/ma/extras.py')
-rw-r--r-- | numpy/ma/extras.py | 92 |
1 files changed, 90 insertions, 2 deletions
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 4abe2107a..8a6246c36 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -27,8 +27,7 @@ from . import core as ma from .core import ( MaskedArray, MAError, add, array, asarray, concatenate, filled, count, getmask, getmaskarray, make_mask_descr, masked, masked_array, mask_or, - nomask, ones, sort, zeros, getdata, get_masked_subclass, dot, - mask_rowcols + nomask, ones, sort, zeros, getdata, get_masked_subclass, dot ) import numpy as np @@ -955,6 +954,95 @@ def compress_cols(a): return compress_rowcols(a, 1) +def mask_rowcols(a, axis=None): + """ + Mask rows and/or columns of a 2D array that contain masked values. + + Mask whole rows and/or columns of a 2D array that contain + masked values. The masking behavior is selected using the + `axis` parameter. + + - If `axis` is None, rows *and* columns are masked. + - If `axis` is 0, only rows are masked. + - If `axis` is 1 or -1, only columns are masked. + + Parameters + ---------- + a : array_like, MaskedArray + The array to mask. If not a MaskedArray instance (or if no array + elements are masked), the result is a MaskedArray with `mask` set + to `nomask` (False). Must be a 2D array. + axis : int, optional + Axis along which to perform the operation. If None, applies to a + flattened version of the array. + + Returns + ------- + a : MaskedArray + A modified version of the input array, masked depending on the value + of the `axis` parameter. + + Raises + ------ + NotImplementedError + If input array `a` is not 2D. + + See Also + -------- + mask_rows : Mask rows of a 2D array that contain masked values. + mask_cols : Mask cols of a 2D array that contain masked values. + masked_where : Mask where a condition is met. + + Notes + ----- + The input array's mask is modified by this function. + + Examples + -------- + >>> import numpy.ma as ma + >>> a = np.zeros((3, 3), dtype=int) + >>> a[1, 1] = 1 + >>> a + array([[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]) + >>> a = ma.masked_equal(a, 1) + >>> a + masked_array( + data=[[0, 0, 0], + [0, --, 0], + [0, 0, 0]], + mask=[[False, False, False], + [False, True, False], + [False, False, False]], + fill_value=1) + >>> ma.mask_rowcols(a) + masked_array( + data=[[0, --, 0], + [--, --, --], + [0, --, 0]], + mask=[[False, True, False], + [ True, True, True], + [False, True, False]], + fill_value=1) + + """ + a = array(a, subok=False) + if a.ndim != 2: + raise NotImplementedError("mask_rowcols works for 2D arrays only.") + m = getmask(a) + # Nothing is masked: return a + if m is nomask or not m.any(): + return a + maskedval = m.nonzero() + a._mask = a._mask.copy() + if not axis: + a[np.unique(maskedval[0])] = masked + if axis in [None, 1, -1]: + a[:, np.unique(maskedval[1])] = masked + return a + + def mask_rows(a, axis=np._NoValue): """ Mask rows of a 2D array that contain masked values. |