diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-05-05 21:11:39 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-05-05 21:46:58 +0100 |
commit | 36e7513edd1114c3f928be66953d4349273122c0 (patch) | |
tree | c0c6d7c854a9958512315ee30f03d5de14ea268b /numpy | |
parent | b2006cb2a7fe508bca8aa7039352731634869334 (diff) | |
download | numpy-36e7513edd1114c3f928be66953d4349273122c0.tar.gz |
BUG: np.ma.mr_['r',...] does not return masked arrays
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/index_tricks.py | 6 | ||||
-rw-r--r-- | numpy/ma/extras.py | 4 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 10 |
3 files changed, 17 insertions, 3 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index cece2b868..dc8eb1c4a 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -15,8 +15,6 @@ from .function_base import diff from numpy.core.multiarray import ravel_multi_index, unravel_index from numpy.lib.stride_tricks import as_strided -makemat = matrixlib.matrix - __all__ = [ 'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_', @@ -238,6 +236,8 @@ class AxisConcatenator(object): """ # allow ma.mr_ to override this concatenate = staticmethod(_nx.concatenate) + makemat = staticmethod(matrixlib.matrix) + def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1): self.axis = axis self.matrix = matrix @@ -341,7 +341,7 @@ class AxisConcatenator(object): if matrix: oldndim = res.ndim - res = makemat(res) + res = self.makemat(res) if oldndim == 1 and col: res = res.T return res diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 10b9634a3..d55e0d1ea 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -1463,6 +1463,10 @@ class MAxisConcatenator(AxisConcatenator): """ concatenate = staticmethod(concatenate) + @staticmethod + def makemat(arr): + return array(arr.data.view(np.matrix), mask=arr.mask) + def __getitem__(self, key): # matrix builder syntax, like 'a, b; c, d' if isinstance(key, str): diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index 7de21ff59..4b7fe07b6 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -308,6 +308,16 @@ class TestConcatenator(TestCase): def test_matrix_builder(self): assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4']) + def test_matrix(self): + actual = mr_['r', 1, 2, 3] + expected = np.ma.array(np.r_['r', 1, 2, 3]) + assert_array_equal(actual, expected) + + # outer type is masked array, inner type is matrix + assert_equal(type(actual), type(expected)) + assert_equal(type(actual.data), type(expected.data)) + + class TestNotMasked(TestCase): # Tests notmasked_edges and notmasked_contiguous. |