summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/code_generators/generate_umath.py6
-rw-r--r--numpy/core/code_generators/multiarray_api_order.txt1
-rw-r--r--numpy/core/include/numpy/arrayobject.h2
-rw-r--r--numpy/core/ma.py1
-rw-r--r--numpy/core/src/arraymethods.c16
-rw-r--r--numpy/core/src/arrayobject.c10
-rw-r--r--numpy/core/src/multiarraymodule.c70
-rw-r--r--numpy/core/src/umathmodule.c.src8
-rw-r--r--numpy/core/tests/test_multiarray.py9
9 files changed, 113 insertions, 10 deletions
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index 727e0e6c4..2332cc914 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -414,6 +414,12 @@ defdict = {
TD(flts, f='floor'),
TD(M, f='floor'),
),
+'rint' :
+ Ufunc(1, 1, None,
+ 'round x elementwise to the nearest integer, round halfway cases away from zero',
+ TD(flts, f='rint'),
+ TD(M, f='rint'),
+ ),
'arctan2' :
Ufunc(2, 1, None,
'a safe and correct arctan(x1/x2)',
diff --git a/numpy/core/code_generators/multiarray_api_order.txt b/numpy/core/code_generators/multiarray_api_order.txt
index d328ad7fc..c558b636b 100644
--- a/numpy/core/code_generators/multiarray_api_order.txt
+++ b/numpy/core/code_generators/multiarray_api_order.txt
@@ -65,4 +65,5 @@ PyArray_ArangeObj
PyArray_SortkindConverter
PyArray_LexSort
PyArray_GetNDArrayCVersion
+PyArray_Round
diff --git a/numpy/core/include/numpy/arrayobject.h b/numpy/core/include/numpy/arrayobject.h
index be0640c5f..e02fd3df0 100644
--- a/numpy/core/include/numpy/arrayobject.h
+++ b/numpy/core/include/numpy/arrayobject.h
@@ -79,7 +79,7 @@ extern "C" CONFUSE_EMACS
#define PY_SUCCEED 1
/* Helpful to distinguish what is installed */
-#define NDARRAY_VERSION 0x00090501
+#define NDARRAY_VERSION 0x00090502
/* Some platforms don't define bool, long long, or long double.
Handle that here.
diff --git a/numpy/core/ma.py b/numpy/core/ma.py
index ae10ca681..b0a9bd5da 100644
--- a/numpy/core/ma.py
+++ b/numpy/core/ma.py
@@ -2180,6 +2180,7 @@ array.trace = _m(not_implemented)
array.transpose = _m(transpose)
array.var = _m(not_implemented)
array.view = _m(not_implemented)
+array.round = _m(around)
del _m, MethodType, not_implemented
diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c
index fe87009e0..559617016 100644
--- a/numpy/core/src/arraymethods.c
+++ b/numpy/core/src/arraymethods.c
@@ -1451,6 +1451,20 @@ array_ravel(PyArrayObject *self, PyObject *args)
return PyArray_Ravel(self, fortran);
}
+static char doc_round[] = "a.round(decimals=0)";
+
+static PyObject *
+array_round(PyArrayObject *self, PyObject *args, PyObject *kwds)
+{
+ int decimals = 0;
+ static char *kwlist[] = {"decimals", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|i", kwlist,
+ &decimals))
+ return NULL;
+
+ return _ARET(PyArray_Round(self, decimals));
+}
static char doc_setflags[] = "a.setflags(write=None, align=None, uic=None)";
@@ -1643,6 +1657,8 @@ static PyMethodDef array_methods[] = {
METH_VARARGS, doc_flatten},
{"ravel", (PyCFunction)array_ravel,
METH_VARARGS, doc_ravel},
+ {"round", (PyCFunction)array_round,
+ METH_VARARGS|METH_KEYWORDS, doc_round},
{"setflags", (PyCFunction)array_setflags,
METH_VARARGS|METH_KEYWORDS, doc_setflags},
{"newbyteorder", (PyCFunction)array_newbyteorder,
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 3d11d42ac..afca2b618 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2334,14 +2334,12 @@ typedef struct {
*floor,
*ceil,
*maximum,
- *minimum;
+ *minimum,
+ *rint;
} NumericOps;
-static NumericOps n_ops = {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
- NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
- NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
- NULL, NULL, NULL, NULL, NULL};
+static NumericOps n_ops; /* NB: static objects inlitialized to zero */
/* Dictionary can contain any of the numeric operations, by name.
Those not present will not be changed
@@ -2391,6 +2389,7 @@ PyArray_SetNumericOps(PyObject *dict)
SET(ceil);
SET(maximum);
SET(minimum);
+ SET(rint);
return 0;
}
@@ -2436,6 +2435,7 @@ PyArray_GetNumericOps(void)
GET(ceil);
GET(maximum);
GET(minimum);
+ GET(rint);
return dict;
fail:
diff --git a/numpy/core/src/multiarraymodule.c b/numpy/core/src/multiarraymodule.c
index 265786a71..b5ae008eb 100644
--- a/numpy/core/src/multiarraymodule.c
+++ b/numpy/core/src/multiarraymodule.c
@@ -195,6 +195,76 @@ PyArray_Ravel(PyArrayObject *a, int fortran)
return PyArray_Flatten(a, fortran);
}
+double
+power_of_ten(int n)
+{
+ static const double p10[] = {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8};
+ double ret;
+ if (n < 9)
+ ret = p10[n];
+ else {
+ ret = 1e9;
+ while (n-- > 9)
+ ret *= 10.;
+ }
+ return ret;
+}
+
+/*MULTIARRAY_API
+ Round
+*/
+static PyObject *
+PyArray_Round(PyArrayObject *a, int decimals)
+{
+ /* do the most common case first */
+ if (decimals == 0) {
+ if (PyTypeNum_ISINTEGER(PyArray_TYPE(a))) {
+ Py_INCREF(a);
+ return (PyObject *)a;
+ }
+ return PyArray_GenericUnaryFunction((PyAO *)a, n_ops.rint);
+ }
+ if (decimals > 0) {
+ PyObject *f, *ret;
+ if (PyTypeNum_ISINTEGER(PyArray_TYPE(a))) {
+ Py_INCREF(a);
+ return (PyObject *)a;
+ }
+ f = PyFloat_FromDouble(power_of_ten(decimals));
+ ret = PyNumber_Multiply((PyObject *)a, f);
+ if (PyArray_IsScalar(ret, Generic)) {
+ /* array scalars cannot be modified inplace */
+ PyObject *tmp;
+ tmp = PyObject_CallFunction(n_ops.rint, "O", ret);
+ Py_DECREF(ret);
+ ret = PyObject_CallFunction(n_ops.divide, "OO", tmp, f);
+ } else {
+ PyObject_CallFunction(n_ops.rint, "OO", ret, ret);
+ PyObject_CallFunction(n_ops.divide, "OOO", ret, f, ret);
+ }
+ Py_DECREF(f);
+ return ret;
+ } else {
+ /* remaining case: decimals < 0 */
+ PyObject *f, *ret;
+ f = PyFloat_FromDouble(power_of_ten(-decimals));
+ ret = PyNumber_Divide((PyObject *)a, f);
+ if (PyArray_IsScalar(ret, Generic)) {
+ /* array scalars cannot be modified inplace */
+ PyObject *tmp;
+ tmp = PyObject_CallFunction(n_ops.rint, "O", ret);
+ Py_DECREF(ret);
+ ret = PyObject_CallFunction(n_ops.multiply, "OO", tmp, f);
+ } else {
+ PyObject_CallFunction(n_ops.rint, "OO", ret, ret);
+ PyObject_CallFunction(n_ops.multiply, "OOO", ret, f, ret);
+ }
+ Py_DECREF(f);
+ return ret;
+ }
+}
+
+
/*MULTIARRAY_API
Flatten
*/
diff --git a/numpy/core/src/umathmodule.c.src b/numpy/core/src/umathmodule.c.src
index 68a848dac..9bff659bf 100644
--- a/numpy/core/src/umathmodule.c.src
+++ b/numpy/core/src/umathmodule.c.src
@@ -357,10 +357,10 @@ double hypot(double x, double y)
/**begin repeat
-#kind=(sin,cos,tan,sinh,cosh,tanh,fabs,floor,ceil,sqrt,log10,log,exp,asin,acos,atan)*2#
-#typ=longdouble*16, float*16#
-#c=l*16,f*16#
-#TYPE=LONGDOUBLE*16, FLOAT*16#
+#kind=(sin,cos,tan,sinh,cosh,tanh,fabs,floor,ceil,sqrt,log10,log,exp,asin,acos,atan,rint)*2#
+#typ=longdouble*17, float*17#
+#c=l*17,f*17#
+#TYPE=LONGDOUBLE*17, FLOAT*17#
*/
#ifndef HAVE_@TYPE@_FUNCS
#ifdef @kind@@c@
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 72672bf0d..e026ee5a0 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -191,6 +191,15 @@ class test_bool(ScipyTestCase):
self.failUnless(array([True])[0] is a1)
self.failUnless(array(True)[...] is a1)
+
+class test_methods(ScipyTestCase):
+ def check_test_round(self):
+ assert_equal(array([1.2,1.5]).round(), [1,2])
+ assert_equal(array(1.5).round(), 2)
+ assert_equal(array([12.2,15.5]).round(-1), [10,20])
+ assert_equal(array([12.15,15.51]).round(1), [12.2,15.5])
+
+
# Import tests from unicode
set_local_path()
from test_unicode import *