diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 16 | ||||
-rw-r--r-- | numpy/core/src/ufuncobject.c | 22 |
2 files changed, 30 insertions, 8 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index f3236cdb1..af4f3e6f8 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -395,6 +395,14 @@ def allclose (a, b, rtol=1.e-5, atol=1.e-8): return d.ravel().all() +class ufunc_values_obj(object): + def __init__(self, obj): + self._val_obj = obj + def __del__(self): + umath.seterrobj(self._val_obj) + del self._val_obj + + _errdict = {"ignore":ERR_IGNORE, "warn":ERR_WARN, "raise":ERR_RAISE, @@ -415,9 +423,9 @@ def seterr(divide="ignore", over="ignore", under="ignore", pyvals = umath.geterrobj() old = pyvals[:] - pyvals[1] = maskvalue + pyvals[1] = maskvalue umath.seterrobj(pyvals) - return old + return ufunc_values_obj(old) def geterr(): maskvalue = umath.geterrobj()[1] @@ -441,7 +449,7 @@ def setbufsize(size): old = pyvals[:] pyvals[0] = size umath.seterrobj(pyvals) - return old + return ufunc_values_obj(old) def getbufsize(): return umath.geterrobj()[0] @@ -453,7 +461,7 @@ def seterrcall(func): old = pyvals[:] pyvals[2] = func umath.seterrobj(pyvals) - return old + return ufunc_values_obj(old) def geterrcall(): return umath.geterrobj()[2] diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index ee2314ca1..721ee2f42 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -2768,6 +2768,7 @@ ufunc_geterr(PyObject *dummy, PyObject *args) return res; } /* Construct list of defaults */ + fprintf(stderr, "Nothing found... return defaults.\n"); res = PyList_New(3); if (res == NULL) return NULL; PyList_SET_ITEM(res, 0, PyInt_FromLong(PyArray_BUFSIZE)); @@ -2799,12 +2800,24 @@ ufunc_seterr(PyObject *dummy, PyObject *args) PyObject *thedict; int res; PyObject *val; + static char *msg = "Error object must be a list of length 3"; - if (!PyArg_ParseTuple(args, "O!", &PyList_Type, &val)) return NULL; + if (!PyArg_ParseTuple(args, "O", &val)) return NULL; + + if (!PyList_CheckExact(val)) { + PyObject *new; + new = PyObject_GetAttrString(val, "_val_obj"); + if (new == NULL) { + PyErr_SetString(PyExc_ValueError, msg); + return NULL; + } + val = new; + } + else Py_INCREF(val); - if (PyList_GET_SIZE(val) < 3) { - PyErr_SetString(PyExc_ValueError, - "Error object Must be a list of length 3"); + if (!PyList_CheckExact(val) || PyList_GET_SIZE(val) != 3) { + PyErr_SetString(PyExc_ValueError, msg); + Py_DECREF(val); return NULL; } if (PyUFunc_PYVALS_NAME == NULL) { @@ -2815,6 +2828,7 @@ ufunc_seterr(PyObject *dummy, PyObject *args) thedict = PyEval_GetBuiltins(); } res = PyDict_SetItem(thedict, PyUFunc_PYVALS_NAME, val); + Py_DECREF(val); if (res < 0) return NULL; if (ufunc_update_use_defaults() < 0) return NULL; Py_INCREF(Py_None); |