summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2022-05-27 09:05:54 -0700
committerSebastian Berg <sebastian@sipsolutions.net>2022-05-27 09:17:58 -0700
commit8fced79a8c60d86aaaaf997aa861589336f7899c (patch)
treee843b375e29e9d9b06a9af48c4824970d6f45ad5 /numpy
parent7e15fd77ca6c09989f219acf432c98a1036d14c5 (diff)
downloadnumpy-8fced79a8c60d86aaaaf997aa861589336f7899c.tar.gz
MAINT: Fortify masked in-place ops against promotion warnings
These warnings are probably optional in the future. They should not matter much (since the following is an in-place op), but the `np.where` could upcast currently!
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ma/core.py42
1 files changed, 24 insertions, 18 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index ed17b1b22..78333ed02 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -4293,8 +4293,9 @@ class MaskedArray(ndarray):
else:
if m is not nomask:
self._mask += m
- self._data.__iadd__(np.where(self._mask, self.dtype.type(0),
- getdata(other)))
+ other_data = getdata(other)
+ other_data = np.where(self._mask, other_data.dtype.type(0), other_data)
+ self._data.__iadd__(other_data)
return self
def __isub__(self, other):
@@ -4309,8 +4310,9 @@ class MaskedArray(ndarray):
self._mask += m
elif m is not nomask:
self._mask += m
- self._data.__isub__(np.where(self._mask, self.dtype.type(0),
- getdata(other)))
+ other_data = getdata(other)
+ other_data = np.where(self._mask, other_data.dtype.type(0), other_data)
+ self._data.__isub__(other_data)
return self
def __imul__(self, other):
@@ -4325,8 +4327,9 @@ class MaskedArray(ndarray):
self._mask += m
elif m is not nomask:
self._mask += m
- self._data.__imul__(np.where(self._mask, self.dtype.type(1),
- getdata(other)))
+ other_data = getdata(other)
+ other_data = np.where(self._mask, other_data.dtype.type(1), other_data)
+ self._data.__imul__(other_data)
return self
def __idiv__(self, other):
@@ -4338,13 +4341,14 @@ class MaskedArray(ndarray):
dom_mask = _DomainSafeDivide().__call__(self._data, other_data)
other_mask = getmask(other)
new_mask = mask_or(other_mask, dom_mask)
- # The following 3 lines control the domain filling
+ # The following 4 lines control the domain filling
if dom_mask.any():
(_, fval) = ufunc_fills[np.divide]
- other_data = np.where(dom_mask, fval, other_data)
+ other_data = np.where(
+ dom_mask, other_data.dtype.type(fval), other_data)
self._mask |= new_mask
- self._data.__idiv__(np.where(self._mask, self.dtype.type(1),
- other_data))
+ other_data = np.where(self._mask, other_data.dtype.type(1), other_data)
+ self._data.__idiv__(other_data)
return self
def __ifloordiv__(self, other):
@@ -4359,10 +4363,11 @@ class MaskedArray(ndarray):
# The following 3 lines control the domain filling
if dom_mask.any():
(_, fval) = ufunc_fills[np.floor_divide]
- other_data = np.where(dom_mask, fval, other_data)
+ other_data = np.where(
+ dom_mask, other_data.dtype.type(fval), other_data)
self._mask |= new_mask
- self._data.__ifloordiv__(np.where(self._mask, self.dtype.type(1),
- other_data))
+ other_data = np.where(self._mask, other_data.dtype.type(1), other_data)
+ self._data.__ifloordiv__(other_data)
return self
def __itruediv__(self, other):
@@ -4377,10 +4382,11 @@ class MaskedArray(ndarray):
# The following 3 lines control the domain filling
if dom_mask.any():
(_, fval) = ufunc_fills[np.true_divide]
- other_data = np.where(dom_mask, fval, other_data)
+ other_data = np.where(
+ dom_mask, other_data.dtype.type(fval), other_data)
self._mask |= new_mask
- self._data.__itruediv__(np.where(self._mask, self.dtype.type(1),
- other_data))
+ other_data = np.where(self._mask, other_data.dtype.type(1), other_data)
+ self._data.__itruediv__(other_data)
return self
def __ipow__(self, other):
@@ -4389,10 +4395,10 @@ class MaskedArray(ndarray):
"""
other_data = getdata(other)
+ other_data = np.where(self._mask, other_data.dtype.type(1), other_data)
other_mask = getmask(other)
with np.errstate(divide='ignore', invalid='ignore'):
- self._data.__ipow__(np.where(self._mask, self.dtype.type(1),
- other_data))
+ self._data.__ipow__(other_data)
invalid = np.logical_not(np.isfinite(self._data))
if invalid.any():
if self._mask is not nomask: