summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authoraarchiba <peridot.faceted@gmail.com>2008-04-07 18:12:09 +0000
committeraarchiba <peridot.faceted@gmail.com>2008-04-07 18:12:09 +0000
commita5574c3dfef52b36e50dc060c46fbb491f2b2aa2 (patch)
treea1cc000f3ca6412f36577b5d61fca3891e322f26 /numpy/ma
parent99cbf12fd2b9c2c91b1e64bcc1cef19fcf120f6f (diff)
downloadnumpy-a5574c3dfef52b36e50dc060c46fbb491f2b2aa2.tar.gz
Fix maskedarray's std and var of complex arrays, with test. Add test for ddof.
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py7
-rw-r--r--numpy/ma/tests/test_core.py54
2 files changed, 59 insertions, 2 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index e2352b944..50fc87f2a 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -68,7 +68,7 @@ import numpy.core.umath as umath
import numpy.core.fromnumeric as fromnumeric
import numpy.core.numeric as numeric
import numpy.core.numerictypes as ntypes
-from numpy import bool_, dtype, typecodes, amax, amin, ndarray
+from numpy import bool_, dtype, typecodes, amax, amin, ndarray, iscomplexobj
from numpy import expand_dims as n_expand_dims
from numpy import array as narray
import warnings
@@ -2180,7 +2180,10 @@ masked_%(name)s(data = %(data)s,
else:
cnt = self.count(axis=axis)-ddof
danom = self.anom(axis=axis, dtype=dtype)
- danom *= danom
+ if iscomplexobj(self):
+ danom = umath.absolute(danom)**2
+ else:
+ danom *= danom
dvar = narray(danom.sum(axis) / cnt).view(type(self))
if axis is not None:
dvar._mask = mask_or(self._mask.all(axis), (cnt==1))
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 5bed298be..e1d4048a8 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1006,6 +1006,10 @@ class TestArrayMethods(NumpyTestCase):
(x,X,XX,m,mx,mX,mXX,m2x,m2X,m2XX) = self.d
assert_almost_equal(mX.var(axis=None),mX.compressed().var())
assert_almost_equal(mX.std(axis=None),mX.compressed().std())
+ assert_almost_equal(mX.std(axis=None,ddof=1),
+ mX.compressed().std(ddof=1))
+ assert_almost_equal(mX.var(axis=None,ddof=1),
+ mX.compressed().var(ddof=1))
assert_equal(mXX.var(axis=3).shape,XX.var(axis=3).shape)
assert_equal(mX.var().shape,X.var().shape)
(mXvar0,mXvar1) = (mX.var(axis=0), mX.var(axis=1))
@@ -1453,6 +1457,56 @@ class TestArrayMethods(NumpyTestCase):
assert_equal(b.shape, a.shape)
assert_equal(b.fill_value, a.fill_value)
+class TestArrayMethodsComplex(NumpyTestCase):
+ "Test class for miscellaneous MaskedArrays methods."
+ def setUp(self):
+ "Base data definition."
+ x = numpy.array([ 8.375j, 7.545j, 8.828j, 8.5j , 1.757j, 5.928,
+ 8.43 , 7.78 , 9.865, 5.878, 8.979, 4.732,
+ 3.012, 6.022, 5.095, 3.116, 5.238, 3.957,
+ 6.04 , 9.63 , 7.712, 3.382, 4.489, 6.479j,
+ 7.189j, 9.645, 5.395, 4.961, 9.894, 2.893,
+ 7.357, 9.828, 6.272, 3.758, 6.693, 0.993j])
+ X = x.reshape(6,6)
+ XX = x.reshape(3,2,2,3)
+
+ m = numpy.array([0, 1, 0, 1, 0, 0,
+ 1, 0, 1, 1, 0, 1,
+ 0, 0, 0, 1, 0, 1,
+ 0, 0, 0, 1, 1, 1,
+ 1, 0, 0, 1, 0, 0,
+ 0, 0, 1, 0, 1, 0])
+ mx = array(data=x,mask=m)
+ mX = array(data=X,mask=m.reshape(X.shape))
+ mXX = array(data=XX,mask=m.reshape(XX.shape))
+
+ m2 = numpy.array([1, 1, 0, 1, 0, 0,
+ 1, 1, 1, 1, 0, 1,
+ 0, 0, 1, 1, 0, 1,
+ 0, 0, 0, 1, 1, 1,
+ 1, 0, 0, 1, 1, 0,
+ 0, 0, 1, 0, 1, 1])
+ m2x = array(data=x,mask=m2)
+ m2X = array(data=X,mask=m2.reshape(X.shape))
+ m2XX = array(data=XX,mask=m2.reshape(XX.shape))
+ self.d = (x,X,XX,m,mx,mX,mXX,m2x,m2X,m2XX)
+
+ #------------------------------------------------------
+ def test_varstd(self):
+ "Tests var & std on MaskedArrays."
+ (x,X,XX,m,mx,mX,mXX,m2x,m2X,m2XX) = self.d
+ assert_almost_equal(mX.var(axis=None),mX.compressed().var())
+ assert_almost_equal(mX.std(axis=None),mX.compressed().std())
+ assert_equal(mXX.var(axis=3).shape,XX.var(axis=3).shape)
+ assert_equal(mX.var().shape,X.var().shape)
+ (mXvar0,mXvar1) = (mX.var(axis=0), mX.var(axis=1))
+ assert_almost_equal(mX.var(axis=None,ddof=2),mX.compressed().var(ddof=2))
+ assert_almost_equal(mX.std(axis=None,ddof=2),mX.compressed().std(ddof=2))
+ for k in range(6):
+ assert_almost_equal(mXvar1[k],mX[k].compressed().var())
+ assert_almost_equal(mXvar0[k],mX[:,k].compressed().var())
+ assert_almost_equal(numpy.sqrt(mXvar0[k]), mX[:,k].compressed().std())
+
#..............................................................................
class TestMiscFunctions(NumpyTestCase):