summaryrefslogtreecommitdiff
path: root/numpy/core/oldnumeric.py
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-01-04 21:05:36 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-01-04 21:05:36 +0000
commit490712cd35dcecfc9423de4bde0b29cb012dda25 (patch)
tree56b6ccaac48afc370a189c596d5e9e90ac0254d4 /numpy/core/oldnumeric.py
parent7ff852162596a8eaa02ef87730474285b080d594 (diff)
downloadnumpy-490712cd35dcecfc9423de4bde0b29cb012dda25.tar.gz
More numpy fixes...
Diffstat (limited to 'numpy/core/oldnumeric.py')
-rw-r--r--numpy/core/oldnumeric.py32
1 files changed, 30 insertions, 2 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py
index 9cf87218e..833a2176e 100644
--- a/numpy/core/oldnumeric.py
+++ b/numpy/core/oldnumeric.py
@@ -23,7 +23,8 @@ __all__ = ['asarray', 'array', 'concatenate',
'resize', 'diagonal', 'trace', 'ravel', 'nonzero', 'shape',
'compress', 'clip', 'sum', 'product', 'prod', 'sometrue', 'alltrue',
'any', 'all', 'cumsum', 'cumproduct', 'cumprod', 'ptp', 'ndim',
- 'rank', 'size', 'around', 'mean', 'std', 'var', 'squeeze', 'amax', 'amin'
+ 'rank', 'size', 'around', 'round_', 'mean', 'std', 'var', 'squeeze',
+ 'amax', 'amin'
]
import multiarray as mu
@@ -420,7 +421,34 @@ def size (a, axis=None):
except AttributeError:
return asarray(a).shape[axis]
-from function_base import round_ as around
+def round_(a, decimals=0):
+ """Round 'a' to the given number of decimal places. Rounding
+ behaviour is equivalent to Python.
+
+ Return 'a' if the array is not floating point. Round both the real
+ and imaginary parts separately if the array is complex.
+ """
+ a = asarray(a)
+ if not issubclass(a.dtype, _nx.inexact):
+ return a
+ if issubclass(a.dtype, _nx.complexfloating):
+ return round_(a.real, decimals) + 1j*round_(a.imag, decimals)
+ if decimals is not 0:
+ decimals = asarray(decimals)
+ s = sign(a)
+ if decimals is not 0:
+ a = absolute(multiply(a, 10.**decimals))
+ else:
+ a = absolute(a)
+ rem = a-asarray(a).astype(_nx.intp)
+ a = _nx.where(_nx.less(rem, 0.5), _nx.floor(a), _nx.ceil(a))
+ # convert back
+ if decimals is not 0:
+ return multiply(a, s/(10.**decimals))
+ else:
+ return multiply(a, s)
+
+around = round_
def mean(a, axis=0, dtype=None):
return asarray(a).mean(axis, dtype)