summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-05-05 21:11:39 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-05-05 21:46:58 +0100
commit36e7513edd1114c3f928be66953d4349273122c0 (patch)
treec0c6d7c854a9958512315ee30f03d5de14ea268b /numpy
parentb2006cb2a7fe508bca8aa7039352731634869334 (diff)
downloadnumpy-36e7513edd1114c3f928be66953d4349273122c0.tar.gz
BUG: np.ma.mr_['r',...] does not return masked arrays
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/index_tricks.py6
-rw-r--r--numpy/ma/extras.py4
-rw-r--r--numpy/ma/tests/test_extras.py10
3 files changed, 17 insertions, 3 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index cece2b868..dc8eb1c4a 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -15,8 +15,6 @@ from .function_base import diff
from numpy.core.multiarray import ravel_multi_index, unravel_index
from numpy.lib.stride_tricks import as_strided
-makemat = matrixlib.matrix
-
__all__ = [
'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
@@ -238,6 +236,8 @@ class AxisConcatenator(object):
"""
# allow ma.mr_ to override this
concatenate = staticmethod(_nx.concatenate)
+ makemat = staticmethod(matrixlib.matrix)
+
def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
self.axis = axis
self.matrix = matrix
@@ -341,7 +341,7 @@ class AxisConcatenator(object):
if matrix:
oldndim = res.ndim
- res = makemat(res)
+ res = self.makemat(res)
if oldndim == 1 and col:
res = res.T
return res
diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py
index 10b9634a3..d55e0d1ea 100644
--- a/numpy/ma/extras.py
+++ b/numpy/ma/extras.py
@@ -1463,6 +1463,10 @@ class MAxisConcatenator(AxisConcatenator):
"""
concatenate = staticmethod(concatenate)
+ @staticmethod
+ def makemat(arr):
+ return array(arr.data.view(np.matrix), mask=arr.mask)
+
def __getitem__(self, key):
# matrix builder syntax, like 'a, b; c, d'
if isinstance(key, str):
diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py
index 7de21ff59..4b7fe07b6 100644
--- a/numpy/ma/tests/test_extras.py
+++ b/numpy/ma/tests/test_extras.py
@@ -308,6 +308,16 @@ class TestConcatenator(TestCase):
def test_matrix_builder(self):
assert_raises(np.ma.MAError, lambda: mr_['1, 2; 3, 4'])
+ def test_matrix(self):
+ actual = mr_['r', 1, 2, 3]
+ expected = np.ma.array(np.r_['r', 1, 2, 3])
+ assert_array_equal(actual, expected)
+
+ # outer type is masked array, inner type is matrix
+ assert_equal(type(actual), type(expected))
+ assert_equal(type(actual.data), type(expected.data))
+
+
class TestNotMasked(TestCase):
# Tests notmasked_edges and notmasked_contiguous.