diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-01-04 21:05:36 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-01-04 21:05:36 +0000 |
commit | 490712cd35dcecfc9423de4bde0b29cb012dda25 (patch) | |
tree | 56b6ccaac48afc370a189c596d5e9e90ac0254d4 /numpy/core/oldnumeric.py | |
parent | 7ff852162596a8eaa02ef87730474285b080d594 (diff) | |
download | numpy-490712cd35dcecfc9423de4bde0b29cb012dda25.tar.gz |
More numpy fixes...
Diffstat (limited to 'numpy/core/oldnumeric.py')
-rw-r--r-- | numpy/core/oldnumeric.py | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py index 9cf87218e..833a2176e 100644 --- a/numpy/core/oldnumeric.py +++ b/numpy/core/oldnumeric.py @@ -23,7 +23,8 @@ __all__ = ['asarray', 'array', 'concatenate', 'resize', 'diagonal', 'trace', 'ravel', 'nonzero', 'shape', 'compress', 'clip', 'sum', 'product', 'prod', 'sometrue', 'alltrue', 'any', 'all', 'cumsum', 'cumproduct', 'cumprod', 'ptp', 'ndim', - 'rank', 'size', 'around', 'mean', 'std', 'var', 'squeeze', 'amax', 'amin' + 'rank', 'size', 'around', 'round_', 'mean', 'std', 'var', 'squeeze', + 'amax', 'amin' ] import multiarray as mu @@ -420,7 +421,34 @@ def size (a, axis=None): except AttributeError: return asarray(a).shape[axis] -from function_base import round_ as around +def round_(a, decimals=0): + """Round 'a' to the given number of decimal places. Rounding + behaviour is equivalent to Python. + + Return 'a' if the array is not floating point. Round both the real + and imaginary parts separately if the array is complex. + """ + a = asarray(a) + if not issubclass(a.dtype, _nx.inexact): + return a + if issubclass(a.dtype, _nx.complexfloating): + return round_(a.real, decimals) + 1j*round_(a.imag, decimals) + if decimals is not 0: + decimals = asarray(decimals) + s = sign(a) + if decimals is not 0: + a = absolute(multiply(a, 10.**decimals)) + else: + a = absolute(a) + rem = a-asarray(a).astype(_nx.intp) + a = _nx.where(_nx.less(rem, 0.5), _nx.floor(a), _nx.ceil(a)) + # convert back + if decimals is not 0: + return multiply(a, s/(10.**decimals)) + else: + return multiply(a, s) + +around = round_ def mean(a, axis=0, dtype=None): return asarray(a).mean(axis, dtype) |