summaryrefslogtreecommitdiff
path: root/numpy/ma/core.py
diff options
context:
space:
mode:
authorahaldane <ealloc@gmail.com>2017-02-22 22:40:08 -0500
committerGitHub <noreply@github.com>2017-02-22 22:40:08 -0500
commit081a865d3d243c282dc452c28dbe89e80d8bccb0 (patch)
tree0240a1f56ef11a9a123ee1fbe6570d3bf1499cf1 /numpy/ma/core.py
parentf24da28e7bef4befa51354fff864bb7b5b3e07f9 (diff)
parent0bc42fc2194c4a7cbc8fa77165e25365e3e46012 (diff)
downloadnumpy-081a865d3d243c282dc452c28dbe89e80d8bccb0.tar.gz
Merge pull request #8665 from eric-wieser/better-ma-method-lookup
BUG: Look up methods on MaskedArray in _frommethod
Diffstat (limited to 'numpy/ma/core.py')
-rw-r--r--numpy/ma/core.py31
1 files changed, 14 insertions, 17 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 1b25725d1..30ef5dbfc 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -6372,21 +6372,16 @@ class _frommethod:
def __call__(self, a, *args, **params):
if self.reversed:
args = list(args)
- arr = args[0]
- args[0] = a
- a = arr
- # Get the method from the array (if possible)
+ a, args[0] = args[0], a
+
+ marr = asanyarray(a)
method_name = self.__name__
- method = getattr(a, method_name, None)
- if method is not None:
- return method(*args, **params)
- # Still here ? Then a is not a MaskedArray
- method = getattr(MaskedArray, method_name, None)
- if method is not None:
- return method(MaskedArray(a), *args, **params)
- # Still here ? OK, let's call the corresponding np function
- method = getattr(np, method_name)
- return method(a, *args, **params)
+ method = getattr(type(marr), method_name, None)
+ if method is None:
+ # use the corresponding np function
+ method = getattr(np, method_name)
+
+ return method(marr, *args, **params)
all = _frommethod('all')
@@ -6535,9 +6530,7 @@ def compressed(x):
Equivalent method.
"""
- if not isinstance(x, MaskedArray):
- x = asanyarray(x)
- return x.compressed()
+ return asanyarray(x).compressed()
def concatenate(arrays, axis=0):
@@ -7683,6 +7676,10 @@ def asanyarray(a, dtype=None):
<class 'numpy.ma.core.MaskedArray'>
"""
+ # workaround for #8666, to preserve identity. Ideally the bottom line
+ # would handle this for us.
+ if isinstance(a, MaskedArray) and (dtype is None or dtype == a.dtype):
+ return a
return masked_array(a, dtype=dtype, copy=False, keep_mask=True, subok=True)