diff options
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 6 | ||||
-rw-r--r-- | numpy/core/code_generators/multiarray_api_order.txt | 1 | ||||
-rw-r--r-- | numpy/core/include/numpy/arrayobject.h | 2 | ||||
-rw-r--r-- | numpy/core/ma.py | 1 | ||||
-rw-r--r-- | numpy/core/src/arraymethods.c | 16 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 10 | ||||
-rw-r--r-- | numpy/core/src/multiarraymodule.c | 70 | ||||
-rw-r--r-- | numpy/core/src/umathmodule.c.src | 8 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 9 |
9 files changed, 113 insertions, 10 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 727e0e6c4..2332cc914 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -414,6 +414,12 @@ defdict = { TD(flts, f='floor'), TD(M, f='floor'), ), +'rint' : + Ufunc(1, 1, None, + 'round x elementwise to the nearest integer, round halfway cases away from zero', + TD(flts, f='rint'), + TD(M, f='rint'), + ), 'arctan2' : Ufunc(2, 1, None, 'a safe and correct arctan(x1/x2)', diff --git a/numpy/core/code_generators/multiarray_api_order.txt b/numpy/core/code_generators/multiarray_api_order.txt index d328ad7fc..c558b636b 100644 --- a/numpy/core/code_generators/multiarray_api_order.txt +++ b/numpy/core/code_generators/multiarray_api_order.txt @@ -65,4 +65,5 @@ PyArray_ArangeObj PyArray_SortkindConverter PyArray_LexSort PyArray_GetNDArrayCVersion +PyArray_Round diff --git a/numpy/core/include/numpy/arrayobject.h b/numpy/core/include/numpy/arrayobject.h index be0640c5f..e02fd3df0 100644 --- a/numpy/core/include/numpy/arrayobject.h +++ b/numpy/core/include/numpy/arrayobject.h @@ -79,7 +79,7 @@ extern "C" CONFUSE_EMACS #define PY_SUCCEED 1 /* Helpful to distinguish what is installed */ -#define NDARRAY_VERSION 0x00090501 +#define NDARRAY_VERSION 0x00090502 /* Some platforms don't define bool, long long, or long double. Handle that here. diff --git a/numpy/core/ma.py b/numpy/core/ma.py index ae10ca681..b0a9bd5da 100644 --- a/numpy/core/ma.py +++ b/numpy/core/ma.py @@ -2180,6 +2180,7 @@ array.trace = _m(not_implemented) array.transpose = _m(transpose) array.var = _m(not_implemented) array.view = _m(not_implemented) +array.round = _m(around) del _m, MethodType, not_implemented diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c index fe87009e0..559617016 100644 --- a/numpy/core/src/arraymethods.c +++ b/numpy/core/src/arraymethods.c @@ -1451,6 +1451,20 @@ array_ravel(PyArrayObject *self, PyObject *args) return PyArray_Ravel(self, fortran); } +static char doc_round[] = "a.round(decimals=0)"; + +static PyObject * +array_round(PyArrayObject *self, PyObject *args, PyObject *kwds) +{ + int decimals = 0; + static char *kwlist[] = {"decimals", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|i", kwlist, + &decimals)) + return NULL; + + return _ARET(PyArray_Round(self, decimals)); +} static char doc_setflags[] = "a.setflags(write=None, align=None, uic=None)"; @@ -1643,6 +1657,8 @@ static PyMethodDef array_methods[] = { METH_VARARGS, doc_flatten}, {"ravel", (PyCFunction)array_ravel, METH_VARARGS, doc_ravel}, + {"round", (PyCFunction)array_round, + METH_VARARGS|METH_KEYWORDS, doc_round}, {"setflags", (PyCFunction)array_setflags, METH_VARARGS|METH_KEYWORDS, doc_setflags}, {"newbyteorder", (PyCFunction)array_newbyteorder, diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 3d11d42ac..afca2b618 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2334,14 +2334,12 @@ typedef struct { *floor, *ceil, *maximum, - *minimum; + *minimum, + *rint; } NumericOps; -static NumericOps n_ops = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL}; +static NumericOps n_ops; /* NB: static objects inlitialized to zero */ /* Dictionary can contain any of the numeric operations, by name. Those not present will not be changed @@ -2391,6 +2389,7 @@ PyArray_SetNumericOps(PyObject *dict) SET(ceil); SET(maximum); SET(minimum); + SET(rint); return 0; } @@ -2436,6 +2435,7 @@ PyArray_GetNumericOps(void) GET(ceil); GET(maximum); GET(minimum); + GET(rint); return dict; fail: diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c index 265786a71..b5ae008eb 100644 --- a/numpy/core/src/multiarraymodule.c +++ b/numpy/core/src/multiarraymodule.c @@ -195,6 +195,76 @@ PyArray_Ravel(PyArrayObject *a, int fortran) return PyArray_Flatten(a, fortran); } +double +power_of_ten(int n) +{ + static const double p10[] = {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8}; + double ret; + if (n < 9) + ret = p10[n]; + else { + ret = 1e9; + while (n-- > 9) + ret *= 10.; + } + return ret; +} + +/*MULTIARRAY_API + Round +*/ +static PyObject * +PyArray_Round(PyArrayObject *a, int decimals) +{ + /* do the most common case first */ + if (decimals == 0) { + if (PyTypeNum_ISINTEGER(PyArray_TYPE(a))) { + Py_INCREF(a); + return (PyObject *)a; + } + return PyArray_GenericUnaryFunction((PyAO *)a, n_ops.rint); + } + if (decimals > 0) { + PyObject *f, *ret; + if (PyTypeNum_ISINTEGER(PyArray_TYPE(a))) { + Py_INCREF(a); + return (PyObject *)a; + } + f = PyFloat_FromDouble(power_of_ten(decimals)); + ret = PyNumber_Multiply((PyObject *)a, f); + if (PyArray_IsScalar(ret, Generic)) { + /* array scalars cannot be modified inplace */ + PyObject *tmp; + tmp = PyObject_CallFunction(n_ops.rint, "O", ret); + Py_DECREF(ret); + ret = PyObject_CallFunction(n_ops.divide, "OO", tmp, f); + } else { + PyObject_CallFunction(n_ops.rint, "OO", ret, ret); + PyObject_CallFunction(n_ops.divide, "OOO", ret, f, ret); + } + Py_DECREF(f); + return ret; + } else { + /* remaining case: decimals < 0 */ + PyObject *f, *ret; + f = PyFloat_FromDouble(power_of_ten(-decimals)); + ret = PyNumber_Divide((PyObject *)a, f); + if (PyArray_IsScalar(ret, Generic)) { + /* array scalars cannot be modified inplace */ + PyObject *tmp; + tmp = PyObject_CallFunction(n_ops.rint, "O", ret); + Py_DECREF(ret); + ret = PyObject_CallFunction(n_ops.multiply, "OO", tmp, f); + } else { + PyObject_CallFunction(n_ops.rint, "OO", ret, ret); + PyObject_CallFunction(n_ops.multiply, "OOO", ret, f, ret); + } + Py_DECREF(f); + return ret; + } +} + + /*MULTIARRAY_API Flatten */ diff --git a/numpy/core/src/umathmodule.c.src b/numpy/core/src/umathmodule.c.src index 68a848dac..9bff659bf 100644 --- a/numpy/core/src/umathmodule.c.src +++ b/numpy/core/src/umathmodule.c.src @@ -357,10 +357,10 @@ double hypot(double x, double y) /**begin repeat -#kind=(sin,cos,tan,sinh,cosh,tanh,fabs,floor,ceil,sqrt,log10,log,exp,asin,acos,atan)*2# -#typ=longdouble*16, float*16# -#c=l*16,f*16# -#TYPE=LONGDOUBLE*16, FLOAT*16# +#kind=(sin,cos,tan,sinh,cosh,tanh,fabs,floor,ceil,sqrt,log10,log,exp,asin,acos,atan,rint)*2# +#typ=longdouble*17, float*17# +#c=l*17,f*17# +#TYPE=LONGDOUBLE*17, FLOAT*17# */ #ifndef HAVE_@TYPE@_FUNCS #ifdef @kind@@c@ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 72672bf0d..e026ee5a0 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -191,6 +191,15 @@ class test_bool(ScipyTestCase): self.failUnless(array([True])[0] is a1) self.failUnless(array(True)[...] is a1) + +class test_methods(ScipyTestCase): + def check_test_round(self): + assert_equal(array([1.2,1.5]).round(), [1,2]) + assert_equal(array(1.5).round(), 2) + assert_equal(array([12.2,15.5]).round(-1), [10,20]) + assert_equal(array([12.15,15.51]).round(1), [12.2,15.5]) + + # Import tests from unicode set_local_path() from test_unicode import * |