diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2014-04-04 11:47:40 -0400 |
---|---|---|
committer | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2015-04-22 19:06:52 -0400 |
commit | 7a84c5660539bb210746ba6b9b8e38d82d9fd330 (patch) | |
tree | a7d02e08b4a9ae144addeb051845782ff60e6fa8 /numpy/ma/tests/test_subclassing.py | |
parent | 02b858326dac217607a83ed0bf4d7d51d5bfbfbe (diff) | |
download | numpy-7a84c5660539bb210746ba6b9b8e38d82d9fd330.tar.gz |
ENH: Let MaskedArray getter, setter respect baseclass overrides
Diffstat (limited to 'numpy/ma/tests/test_subclassing.py')
-rw-r--r-- | numpy/ma/tests/test_subclassing.py | 93 |
1 files changed, 88 insertions, 5 deletions
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py index ade5c59da..07fc8fdd6 100644 --- a/numpy/ma/tests/test_subclassing.py +++ b/numpy/ma/tests/test_subclassing.py @@ -84,20 +84,71 @@ mmatrix = MMatrix # also a subclass that overrides __str__, __repr__ and __setitem__, disallowing # setting to non-class values (and thus np.ma.core.masked_print_option) +class CSAIterator(object): + """ + Flat iterator object that uses its own setter/getter + (works around ndarray.flat not propagating subclass setters/getters + see https://github.com/numpy/numpy/issues/4564) + roughly following MaskedIterator + """ + def __init__(self, a): + self._original = a + self._dataiter = a.view(np.ndarray).flat + + def __iter__(self): + return self + + def __getitem__(self, indx): + out = self._dataiter.__getitem__(indx) + if not isinstance(out, np.ndarray): + out = out.__array__() + out = out.view(type(self._original)) + return out + + def __setitem__(self, index, value): + self._dataiter[index] = self._original._validate_input(value) + + def __next__(self): + return next(self._dataiter).__array__().view(type(self._original)) + + next = __next__ + + class ComplicatedSubArray(SubArray): + def __str__(self): - return 'myprefix {0} mypostfix'.format( - super(ComplicatedSubArray, self).__str__()) + return 'myprefix {0} mypostfix'.format(self.view(SubArray)) def __repr__(self): # Return a repr that does not start with 'name(' return '<{0} {1}>'.format(self.__class__.__name__, self) - def __setitem__(self, item, value): - # this ensures direct assignment to masked_print_option will fail + def _validate_input(self, value): if not isinstance(value, ComplicatedSubArray): raise ValueError("Can only set to MySubArray values") - super(ComplicatedSubArray, self).__setitem__(item, value) + return value + + def __setitem__(self, item, value): + # validation ensures direct assignment with ndarray or + # masked_print_option will fail + super(ComplicatedSubArray, self).__setitem__( + item, self._validate_input(value)) + + def __getitem__(self, item): + # ensure getter returns our own class also for scalars + value = super(ComplicatedSubArray, self).__getitem__(item) + if not isinstance(value, np.ndarray): # scalar + value = value.__array__().view(ComplicatedSubArray) + return value + + @property + def flat(self): + return CSAIterator(self) + + @flat.setter + def flat(self, value): + y = self.ravel() + y[:] = value class TestSubclassing(TestCase): @@ -205,6 +256,38 @@ class TestSubclassing(TestCase): assert_equal(mxsub.info, xsub.info) assert_equal(mxsub._mask, m) + def test_subclass_items(self): + """test that getter and setter go via baseclass""" + x = np.arange(5) + xcsub = ComplicatedSubArray(x) + mxcsub = masked_array(xcsub, mask=[True, False, True, False, False]) + # getter should return a ComplicatedSubArray, even for single item + # first check we wrote ComplicatedSubArray correctly + self.assertTrue(isinstance(xcsub[1], ComplicatedSubArray)) + self.assertTrue(isinstance(xcsub[1:4], ComplicatedSubArray)) + # now that it propagates inside the MaskedArray + self.assertTrue(isinstance(mxcsub[1], ComplicatedSubArray)) + self.assertTrue(mxcsub[0] is masked) + self.assertTrue(isinstance(mxcsub[1:4].data, ComplicatedSubArray)) + # also for flattened version (which goes via MaskedIterator) + self.assertTrue(isinstance(mxcsub.flat[1].data, ComplicatedSubArray)) + self.assertTrue(mxcsub[0] is masked) + self.assertTrue(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray)) + + # setter should only work with ComplicatedSubArray input + # first check we wrote ComplicatedSubArray correctly + assert_raises(ValueError, xcsub.__setitem__, 1, x[4]) + # now that it propagates inside the MaskedArray + assert_raises(ValueError, mxcsub.__setitem__, 1, x[4]) + assert_raises(ValueError, mxcsub.__setitem__, slice(1, 4), x[1:4]) + mxcsub[1] = xcsub[4] + mxcsub[1:4] = xcsub[1:4] + # also for flattened version (which goes via MaskedIterator) + assert_raises(ValueError, mxcsub.flat.__setitem__, 1, x[4]) + assert_raises(ValueError, mxcsub.flat.__setitem__, slice(1, 4), x[1:4]) + mxcsub.flat[1] = xcsub[4] + mxcsub.flat[1:4] = xcsub[1:4] + def test_subclass_repr(self): """test that repr uses the name of the subclass and 'array' for np.ndarray""" |