summaryrefslogtreecommitdiff
path: root/numpy/ma
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2008-03-27 21:56:44 +0000
committerpierregm <pierregm@localhost>2008-03-27 21:56:44 +0000
commit3474b706e27054927ddafb39e88a0b0691c89227 (patch)
tree54fb94942081d6a7ae2762ff4c848b246678252a /numpy/ma
parent738f70ab57a85eae61b0da2045b9d16bd84667e5 (diff)
downloadnumpy-3474b706e27054927ddafb39e88a0b0691c89227.tar.gz
new methods : round
new functions : frombuffer, fromfunction, identity, indices, trace to fix : fromfile/tofile raise a NotImplementedError. For now.
Diffstat (limited to 'numpy/ma')
-rw-r--r--numpy/ma/core.py62
-rw-r--r--numpy/ma/tests/test_core.py15
2 files changed, 74 insertions, 3 deletions
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index c38e887ba..7a29c0a36 100644
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -31,9 +31,10 @@ __all__ = ['MAError', 'MaskType', 'MaskedArray',
'default_fill_value', 'diagonal', 'divide', 'dump', 'dumps',
'empty', 'empty_like', 'equal', 'exp',
'fabs', 'fmod', 'filled', 'floor', 'floor_divide','fix_invalid',
+ 'frombuffer', 'fromfunction',
'getdata','getmask', 'getmaskarray', 'greater', 'greater_equal',
'hypot',
- 'ids', 'inner', 'innerproduct',
+ 'identity', 'ids', 'indices', 'inner', 'innerproduct',
'isMA', 'isMaskedArray', 'is_mask', 'is_masked', 'isarray',
'left_shift', 'less', 'less_equal', 'load', 'loads', 'log', 'log10',
'logical_and', 'logical_not', 'logical_or', 'logical_xor',
@@ -2212,6 +2213,16 @@ masked_%(name)s(data = %(data)s,
if axis is not None or dvar is not masked:
dvar = sqrt(dvar)
return dvar
+
+ #............................................
+ def round(self, decimals=0, out=None):
+ result = self._data.round(decimals).view(type(self))
+ result._mask = self._mask
+ if out is None:
+ return result
+ out[:] = result
+ return
+ round.__doc__ = ndarray.round.__doc__
#............................................
def argsort(self, axis=None, fill_value=None, kind='quicksort',
@@ -2540,6 +2551,10 @@ masked_%(name)s(data = %(data)s,
"""
return self.filled(fill_value).tostring(order=order)
+ #........................
+ def tofile(self, fid, sep="", format="%s"):
+ raise NotImplementedError("Not implemented yet, sorry...")
+
#--------------------------------------------
# Pickling
def __getstate__(self):
@@ -2791,10 +2806,12 @@ product = _frommethod('prod')
ptp = _frommethod('ptp')
ravel = _frommethod('ravel')
repeat = _frommethod('repeat')
+round = _frommethod('round')
std = _frommethod('std')
sum = _frommethod('sum')
swapaxes = _frommethod('swapaxes')
take = _frommethod('take')
+trace = _frommethod('trace')
var = _frommethod('var')
compress = _frommethod('compress')
@@ -3147,7 +3164,7 @@ def round_(a, decimals=0, out=None):
"""
if out is None:
- result = fromnumeric.round_(getdata(a), decimals, out)
+ result = numpy.round_(getdata(a), decimals, out)
if isinstance(a,MaskedArray):
result = result.view(type(a))
result._mask = a._mask
@@ -3155,7 +3172,7 @@ def round_(a, decimals=0, out=None):
result = result.view(MaskedArray)
return result
else:
- fromnumeric.round_(getdata(a), decimals, out)
+ numpy.round_(getdata(a), decimals, out)
if hasattr(out, '_mask'):
out._mask = getmask(a)
return out
@@ -3312,4 +3329,43 @@ def loads(strg):
return cPickle.loads(strg)
################################################################################
+def fromfile(file, dtype=float, count=-1, sep=''):
+ raise NotImplementedError("Not yet implemented. Sorry")
+
+class _convert2ma:
+ """Convert functions from numpy to numpy.ma.
+
+ Parameters
+ ----------
+ _methodname : string
+ Name of the method to transform.
+
+ """
+ __doc__ = None
+ def __init__(self, funcname):
+ self._func = getattr(numpy, funcname)
+ self.__doc__ = self.getdoc()
+ def getdoc(self):
+ "Return the doc of the function (from the doc of the method)."
+ return self._func.__doc__
+ def __call__(self, a, *args, **params):
+ return self._func.__call__(a, *args, **params).view(MaskedArray)
+
+frombuffer = _convert2ma('frombuffer')
+fromfunction = _convert2ma('fromfunction')
+identity = _convert2ma('identity')
+indices = numpy.indices
+
+###############################################################################
+if 1:
+ if 1:
+ a = array([1.23456, 2.34567, 3.45678, 4.56789, 5.67890],
+ mask=[0,1,0,0,0])
+ assert(all(a.round() == array([1., 2., 3., 5., 6.])))
+ assert(all(a.round(1) == array([1.2, 2.3, 3.5, 4.6, 5.7])))
+ assert(all(a.round(3) == array([1.235, 2.346, 3.457, 4.568, 5.679])))
+ b = empty_like(a)
+ a.round(out=b)
+ assert(all(b == array([1., 2., 3., 5., 6.])))
+ print "OK"
diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py
index 5c47f5611..a8f371241 100644
--- a/numpy/ma/tests/test_core.py
+++ b/numpy/ma/tests/test_core.py
@@ -1461,6 +1461,21 @@ class TestMiscFunctions(NumpyTestCase):
y = masked_where(False,x)
assert_equal(y,[1,2])
assert_equal(y[1],2)
+ #
+ def test_round(self):
+ a = array([1.23456, 2.34567, 3.45678, 4.56789, 5.67890],
+ mask=[0,1,0,0,0])
+ assert_equal(a.round(), [1., 2., 3., 5., 6.])
+ assert_equal(a.round(1), [1.2, 2.3, 3.5, 4.6, 5.7])
+ assert_equal(a.round(3), [1.235, 2.346, 3.457, 4.568, 5.679])
+ b = empty_like(a)
+ a.round(out=b)
+ assert_equal(b, [1., 2., 3., 5., 6.])
+ #
+ def test_identity(self):
+ a = identity(5)
+ assert(isinstance(a, MaskedArray))
+ assert_equal(a, numpy.identity(5))
###############################################################################