summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-04-10 20:55:40 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-04-10 20:56:44 -0400
commita7eef2de430edd3ee066da6194e6a38d58f9eec2 (patch)
tree3db2cdd47dabf579d456b3d38630980b67dc6e0c /numpy
parentdfebb5a9a6664148b32444497bb792ecbb2f56f5 (diff)
downloadnumpy-a7eef2de430edd3ee066da6194e6a38d58f9eec2.tar.gz
Ensure single record items also work correctly with MaskedIterator; tests
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py14
-rw-r--r--numpy/ma/tests/test_core.py30
2 files changed, 37 insertions, 7 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index e969abd5e..3e12d22e4 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -2470,7 +2470,6 @@ class _arraymethod(object):
return result
-
class MaskedIterator(object):
"""
Flat iterator object to iterate over masked arrays.
@@ -2535,9 +2534,10 @@ class MaskedIterator(object):
if self.maskiter is not None:
_mask = self.maskiter.__getitem__(indx)
if isinstance(_mask, ndarray):
- _mask.shape = result.shape
result._mask = _mask
- elif _mask:
+ elif isinstance(_mask, np.void):
+ return mvoid(result, mask=_mask, hardmask=self.ma._hardmask)
+ elif _mask: # Just a scalar, masked
return masked
return result
@@ -2570,8 +2570,12 @@ class MaskedIterator(object):
"""
d = next(self.dataiter)
- if self.maskiter is not None and next(self.maskiter):
- d = masked
+ if self.maskiter is not None:
+ m = next(self.maskiter)
+ if isinstance(m, np.void):
+ return mvoid(d, mask=m, hardmask=self.ma._hardmask)
+ elif m: # Just a scalar, masked
+ return masked
return d
next = __next__
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 177e5670e..65311313b 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1315,14 +1315,40 @@ class TestMaskedArrayAttributes(TestCase):
test.flat = masked_array([3, 2, 1], mask=[1, 0, 0])
control = masked_array(np.matrix([[3, 2, 1]]), mask=[1, 0, 0])
assert_equal(test, control)
- #
+ # Test setting
test = masked_array(np.matrix([[1, 2, 3]]), mask=[0, 0, 1])
testflat = test.flat
testflat[:] = testflat[[2, 1, 0]]
assert_equal(test, control)
testflat[0] = 9
assert_equal(test[0, 0], 9)
-
+ # test 2-D record array
+ # ... on structured array w/ masked records
+ x = array([[(1, 1.1, 'one'), (2, 2.2, 'two'), (3, 3.3, 'thr')],
+ [(4, 4.4, 'fou'), (5, 5.5, 'fiv'), (6, 6.6, 'six')]],
+ dtype=[('a', int), ('b', float), ('c', '|S8')])
+ x['a'][0, 1] = masked
+ x['b'][1, 0] = masked
+ x['c'][0, 2] = masked
+ x[-1, -1] = masked
+ xflat = x.flat
+ assert_equal(xflat[0], x[0, 0])
+ assert_equal(xflat[1], x[0, 1])
+ assert_equal(xflat[2], x[0, 2])
+ assert_equal(xflat[:3], x[0])
+ assert_equal(xflat[3], x[1, 0])
+ assert_equal(xflat[4], x[1, 1])
+ assert_equal(xflat[5], x[1, 2])
+ assert_equal(xflat[3:], x[1])
+ assert_equal(xflat[-1], x[-1, -1])
+ i = 0
+ j = 0
+ for xf in xflat:
+ assert_equal(xf, x[j, i])
+ i += 1
+ if i >= x.shape[-1]:
+ i = 0
+ j += 1
#------------------------------------------------------------------------------
class TestFillingValues(TestCase):