diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/ma/core.py | 20 | ||||
-rw-r--r-- | numpy/ma/tests/test_core.py | 9 |
2 files changed, 22 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py index f6e445c2f..9cc1a1272 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -2758,13 +2758,19 @@ class MaskedArray(ndarray): _data._sharedmask = True else: # Case 2. : With a mask in input. - # Read the mask with the current mdtype - try: - mask = np.array(mask, copy=copy, dtype=mdtype) - # Or assume it's a sequence of bool/int - except TypeError: - mask = np.array([tuple([m] * len(mdtype)) for m in mask], - dtype=mdtype) + # If mask is boolean, create an array of True or False + if mask is True and mdtype == MaskType: + mask = np.ones(_data.shape, dtype=mdtype) + elif mask is False and mdtype == MaskType: + mask = np.zeros(_data.shape, dtype=mdtype) + else: + # Read the mask with the current mdtype + try: + mask = np.array(mask, copy=copy, dtype=mdtype) + # Or assume it's a sequence of bool/int + except TypeError: + mask = np.array([tuple([m] * len(mdtype)) for m in mask], + dtype=mdtype) # Make sure the mask and the data have the same shape if mask.shape != _data.shape: (nd, nm) = (_data.size, mask.size) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 4e6a20ad9..2fdd00484 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -191,6 +191,15 @@ class TestMaskedArray(TestCase): dma_3 = MaskedArray(dma_1, mask=[1, 0, 0, 0] * 6) fail_if_equal(dma_3.mask, dma_1.mask) + x = array([1, 2, 3], mask=True) + assert_equal(x._mask, [True, True, True]) + x = array([1, 2, 3], mask=False) + assert_equal(x._mask, [False, False, False]) + y = array([1, 2, 3], mask=x._mask, copy=False) + assert_(np.may_share_memory(x.mask, y.mask)) + y = array([1, 2, 3], mask=x._mask, copy=True) + assert_(not np.may_share_memory(x.mask, y.mask)) + def test_creation_with_list_of_maskedarrays(self): # Tests creaating a masked array from alist of masked arrays. x = array(np.arange(5), mask=[1, 0, 0, 0, 0]) |