summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index f6e445c2f..9cc1a1272 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -2758,13 +2758,19 @@ class MaskedArray(ndarray):
_data._sharedmask = True
else:
# Case 2. : With a mask in input.
- # Read the mask with the current mdtype
- try:
- mask = np.array(mask, copy=copy, dtype=mdtype)
- # Or assume it's a sequence of bool/int
- except TypeError:
- mask = np.array([tuple([m] * len(mdtype)) for m in mask],
- dtype=mdtype)
+ # If mask is boolean, create an array of True or False
+ if mask is True and mdtype == MaskType:
+ mask = np.ones(_data.shape, dtype=mdtype)
+ elif mask is False and mdtype == MaskType:
+ mask = np.zeros(_data.shape, dtype=mdtype)
+ else:
+ # Read the mask with the current mdtype
+ try:
+ mask = np.array(mask, copy=copy, dtype=mdtype)
+ # Or assume it's a sequence of bool/int
+ except TypeError:
+ mask = np.array([tuple([m] * len(mdtype)) for m in mask],
+ dtype=mdtype)
# Make sure the mask and the data have the same shape
if mask.shape != _data.shape:
(nd, nm) = (_data.size, mask.size)