summaryrefslogtreecommitdiff
path: root/numpy/ma/tests/test_subclassing.py
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2014-04-04 11:47:40 -0400
committerMarten van Kerkwijk <mhvk@astro.utoronto.ca>2015-04-22 19:06:52 -0400
commit7a84c5660539bb210746ba6b9b8e38d82d9fd330 (patch)
treea7d02e08b4a9ae144addeb051845782ff60e6fa8 /numpy/ma/tests/test_subclassing.py
parent02b858326dac217607a83ed0bf4d7d51d5bfbfbe (diff)
downloadnumpy-7a84c5660539bb210746ba6b9b8e38d82d9fd330.tar.gz
ENH: Let MaskedArray getter, setter respect baseclass overrides
Diffstat (limited to 'numpy/ma/tests/test_subclassing.py')
-rw-r--r--numpy/ma/tests/test_subclassing.py93
1 files changed, 88 insertions, 5 deletions
diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py
index ade5c59da..07fc8fdd6 100644
--- a/numpy/ma/tests/test_subclassing.py
+++ b/numpy/ma/tests/test_subclassing.py
@@ -84,20 +84,71 @@ mmatrix = MMatrix
# also a subclass that overrides __str__, __repr__ and __setitem__, disallowing
# setting to non-class values (and thus np.ma.core.masked_print_option)
+class CSAIterator(object):
+ """
+ Flat iterator object that uses its own setter/getter
+ (works around ndarray.flat not propagating subclass setters/getters
+ see https://github.com/numpy/numpy/issues/4564)
+ roughly following MaskedIterator
+ """
+ def __init__(self, a):
+ self._original = a
+ self._dataiter = a.view(np.ndarray).flat
+
+ def __iter__(self):
+ return self
+
+ def __getitem__(self, indx):
+ out = self._dataiter.__getitem__(indx)
+ if not isinstance(out, np.ndarray):
+ out = out.__array__()
+ out = out.view(type(self._original))
+ return out
+
+ def __setitem__(self, index, value):
+ self._dataiter[index] = self._original._validate_input(value)
+
+ def __next__(self):
+ return next(self._dataiter).__array__().view(type(self._original))
+
+ next = __next__
+
+
class ComplicatedSubArray(SubArray):
+
def __str__(self):
- return 'myprefix {0} mypostfix'.format(
- super(ComplicatedSubArray, self).__str__())
+ return 'myprefix {0} mypostfix'.format(self.view(SubArray))
def __repr__(self):
# Return a repr that does not start with 'name('
return '<{0} {1}>'.format(self.__class__.__name__, self)
- def __setitem__(self, item, value):
- # this ensures direct assignment to masked_print_option will fail
+ def _validate_input(self, value):
if not isinstance(value, ComplicatedSubArray):
raise ValueError("Can only set to MySubArray values")
- super(ComplicatedSubArray, self).__setitem__(item, value)
+ return value
+
+ def __setitem__(self, item, value):
+ # validation ensures direct assignment with ndarray or
+ # masked_print_option will fail
+ super(ComplicatedSubArray, self).__setitem__(
+ item, self._validate_input(value))
+
+ def __getitem__(self, item):
+ # ensure getter returns our own class also for scalars
+ value = super(ComplicatedSubArray, self).__getitem__(item)
+ if not isinstance(value, np.ndarray): # scalar
+ value = value.__array__().view(ComplicatedSubArray)
+ return value
+
+ @property
+ def flat(self):
+ return CSAIterator(self)
+
+ @flat.setter
+ def flat(self, value):
+ y = self.ravel()
+ y[:] = value
class TestSubclassing(TestCase):
@@ -205,6 +256,38 @@ class TestSubclassing(TestCase):
assert_equal(mxsub.info, xsub.info)
assert_equal(mxsub._mask, m)
+ def test_subclass_items(self):
+ """test that getter and setter go via baseclass"""
+ x = np.arange(5)
+ xcsub = ComplicatedSubArray(x)
+ mxcsub = masked_array(xcsub, mask=[True, False, True, False, False])
+ # getter should return a ComplicatedSubArray, even for single item
+ # first check we wrote ComplicatedSubArray correctly
+ self.assertTrue(isinstance(xcsub[1], ComplicatedSubArray))
+ self.assertTrue(isinstance(xcsub[1:4], ComplicatedSubArray))
+ # now that it propagates inside the MaskedArray
+ self.assertTrue(isinstance(mxcsub[1], ComplicatedSubArray))
+ self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub[1:4].data, ComplicatedSubArray))
+ # also for flattened version (which goes via MaskedIterator)
+ self.assertTrue(isinstance(mxcsub.flat[1].data, ComplicatedSubArray))
+ self.assertTrue(mxcsub[0] is masked)
+ self.assertTrue(isinstance(mxcsub.flat[1:4].base, ComplicatedSubArray))
+
+ # setter should only work with ComplicatedSubArray input
+ # first check we wrote ComplicatedSubArray correctly
+ assert_raises(ValueError, xcsub.__setitem__, 1, x[4])
+ # now that it propagates inside the MaskedArray
+ assert_raises(ValueError, mxcsub.__setitem__, 1, x[4])
+ assert_raises(ValueError, mxcsub.__setitem__, slice(1, 4), x[1:4])
+ mxcsub[1] = xcsub[4]
+ mxcsub[1:4] = xcsub[1:4]
+ # also for flattened version (which goes via MaskedIterator)
+ assert_raises(ValueError, mxcsub.flat.__setitem__, 1, x[4])
+ assert_raises(ValueError, mxcsub.flat.__setitem__, slice(1, 4), x[1:4])
+ mxcsub.flat[1] = xcsub[4]
+ mxcsub.flat[1:4] = xcsub[1:4]
+
def test_subclass_repr(self):
"""test that repr uses the name of the subclass
and 'array' for np.ndarray"""