summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/private/ufunc_override.h221
-rw-r--r--numpy/core/src/umath/ufunc_object.c38
-rw-r--r--numpy/core/tests/test_umath.py129
3 files changed, 320 insertions, 68 deletions
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index ce2836834..6b0f73fcf 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -3,8 +3,154 @@
#include <npy_config.h>
#include "numpy/arrayobject.h"
#include "common.h"
+#include <string.h>
#include "numpy/ufuncobject.h"
+static void
+normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
+ PyObject **normal_args, PyObject **normal_kwds,
+ int nin)
+{
+ /* ufunc.__call__(*args, **kwds) */
+ int nargs = PyTuple_GET_SIZE(args);
+ PyObject *obj;
+
+ *normal_args = PyTuple_GetSlice(args, 0, nin);
+
+ /* If we have more args than nin, they must be the output variables.*/
+ if (nargs > nin) {
+ if ((nargs - nin) == 1) {
+ obj = PyTuple_GET_ITEM(args, nargs - 1);
+ PyDict_SetItemString(*normal_kwds, "out", obj);
+ }
+ else {
+ obj = PyTuple_GetSlice(args, nin, nargs);
+ PyDict_SetItemString(*normal_kwds, "out", obj);
+ }
+ }
+}
+
+static void
+normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args,
+ PyObject **normal_args, PyObject **normal_kwds)
+{
+ /* ufunc.reduce(a[, axis, dtype, out, keepdims]) */
+ int nargs = PyTuple_GET_SIZE(args);
+ int i;
+ PyObject *obj;
+
+ for (i = 0; i < nargs; i++) {
+ obj = PyTuple_GET_ITEM(args, i);
+ if (i == 0) {
+ *normal_args = PyTuple_GetSlice(args, 0, 1);
+ }
+ else if (i == 1) {
+ /* axis */
+ PyDict_SetItemString(*normal_kwds, "axis", obj);
+ }
+ else if (i == 2) {
+ /* dtype */
+ PyDict_SetItemString(*normal_kwds, "dtype", obj);
+ }
+ else if (i == 3) {
+ /* out */
+ PyDict_SetItemString(*normal_kwds, "out", obj);
+ }
+ else {
+ /* keepdims */
+ PyDict_SetItemString(*normal_kwds, "keepdims", obj);
+ }
+ }
+ return;
+}
+
+static void
+normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args,
+ PyObject **normal_args, PyObject **normal_kwds)
+{
+ /* ufunc.accumulate(a[, axis, dtype, out]) */
+ int nargs = PyTuple_GET_SIZE(args);
+ int i;
+ PyObject *obj;
+
+ for (i = 0; i < nargs; i++) {
+ obj = PyTuple_GET_ITEM(args, i);
+ if (i == 0) {
+ *normal_args = PyTuple_GetSlice(args, 0, 1);
+ }
+ else if (i == 1) {
+ /* axis */
+ PyDict_SetItemString(*normal_kwds, "axis", obj);
+ }
+ else if (i == 2) {
+ /* dtype */
+ PyDict_SetItemString(*normal_kwds, "dtype", obj);
+ }
+ else {
+ /* out */
+ PyDict_SetItemString(*normal_kwds, "out", obj);
+ }
+ }
+ return;
+}
+
+static void
+normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args,
+ PyObject **normal_args, PyObject **normal_kwds)
+{
+ /* ufunc.reduceat(a, indicies[, axis, dtype, out]) */
+ int i;
+ int nargs = PyTuple_GET_SIZE(args);
+ PyObject *obj;
+
+ for (i = 0; i < nargs; i++) {
+ obj = PyTuple_GET_ITEM(args, i);
+ if (i == 0) {
+ /* a and indicies */
+ *normal_args = PyTuple_GetSlice(args, 0, 2);
+ }
+ else if (i == 1) {
+ /* Handled above, when i == 0. */
+ continue;
+ }
+ else if (i == 2) {
+ /* axis */
+ PyDict_SetItemString(*normal_kwds, "axis", obj);
+ }
+ else if (i == 3) {
+ /* dtype */
+ PyDict_SetItemString(*normal_kwds, "dtype", obj);
+ }
+ else {
+ /* out */
+ PyDict_SetItemString(*normal_kwds, "out", obj);
+ }
+ }
+ return;
+}
+
+static void
+normalize_outer_args(PyUFuncObject *ufunc, PyObject *args,
+ PyObject **normal_args, PyObject **normal_kwds)
+{
+ /* ufunc.outer(A, B)
+ * This has no kwds so we don't need to do any kwd stuff.
+ */
+ *normal_args = PyTuple_GetSlice(args, 0, 2);
+ return;
+}
+
+static void
+normalize_at_args(PyUFuncObject *ufunc, PyObject *args,
+ PyObject **normal_args, PyObject **normal_kwds)
+{
+ /* ufunc.at(a, indices[, b]) */
+ int nargs = PyTuple_GET_SIZE(args);
+
+ *normal_args = PyTuple_GetSlice(args, 0, nargs);
+ return;
+}
+
/*
* Check a set of args for the `__numpy_ufunc__` method. If more than one of
* the input arguments implements `__numpy_ufunc__`, they are tried in the
@@ -18,7 +164,7 @@
*/
static int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *args, PyObject *kwds,
+ PyObject *args, PyObject *kwds,
PyObject **result,
int nin)
{
@@ -36,23 +182,23 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */
PyObject *normal_kwds = NULL;
- PyObject *with_override[NPY_MAXARGS];
+ PyObject *with_override[NPY_MAXARGS];
/* Pos of each override in args */
int with_override_pos[NPY_MAXARGS];
- /*
+ /*
* Check inputs
*/
if (!PyTuple_Check(args)) {
- PyErr_SetString(PyExc_ValueError,
+ PyErr_SetString(PyExc_ValueError,
"Internal Numpy error: call to PyUFunc_CheckOverride "
"with non-tuple");
goto fail;
}
if (PyTuple_GET_SIZE(args) > NPY_MAXARGS) {
- PyErr_SetString(PyExc_ValueError,
+ PyErr_SetString(PyExc_ValueError,
"Internal Numpy error: too many arguments in call "
"to PyUFunc_CheckOverride");
goto fail;
@@ -81,14 +227,15 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
return 0;
}
- /*
- * Normalize ufunc arguments.
- */
- normal_args = PyTuple_GetSlice(args, 0, nin);
- if (normal_args == NULL) {
+ method_name = PyUString_FromString(method);
+ if (method_name == NULL) {
goto fail;
}
+ /*
+ * Normalize ufunc arguments.
+ */
+
/* Build new kwds */
if (kwds && PyDict_CheckExact(kwds)) {
normal_kwds = PyDict_Copy(kwds);
@@ -100,21 +247,38 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
goto fail;
}
- /* If we have more args than nin, they must be the output variables.*/
- if (nargs > nin) {
- if ((nargs - nin) == 1) {
- obj = PyTuple_GET_ITEM(args, nargs - 1);
- PyDict_SetItemString(normal_kwds, "out", obj);
- }
- else {
- obj = PyTuple_GetSlice(args, nin, nargs);
- PyDict_SetItemString(normal_kwds, "out", obj);
- Py_DECREF(obj);
- }
+ /* decide what to do based on the method. */
+ /* ufunc.__call__ */
+ if (strcmp(method, "__call__") == 0) {
+ normalize___call___args(ufunc, args, &normal_args, &normal_kwds, nin);
}
- method_name = PyUString_FromString(method);
- if (method_name == NULL) {
+ /* ufunc.reduce */
+ else if (strcmp(method, "reduce") == 0) {
+ normalize_reduce_args(ufunc, args, &normal_args, &normal_kwds);
+ }
+
+ /* ufunc.accumulate */
+ else if (strcmp(method, "accumulate") == 0) {
+ normalize_accumulate_args(ufunc, args, &normal_args, &normal_kwds);
+ }
+
+ /* ufunc.reduceat */
+ else if (strcmp(method, "reduceat") == 0) {
+ normalize_reduceat_args(ufunc, args, &normal_args, &normal_kwds);
+ }
+
+ /* ufunc.outer */
+ else if (strcmp(method, "outer") == 0) {
+ normalize_outer_args(ufunc, args, &normal_args, &normal_kwds);
+ }
+
+ /* ufunc.at */
+ else if (strcmp(method, "at") == 0) {
+ normalize_at_args(ufunc, args, &normal_args, &normal_kwds);
+ }
+
+ if (normal_args == NULL) {
goto fail;
}
@@ -144,7 +308,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
for (j = i + 1; j < noa; j++) {
other_obj = with_override[j];
if (PyObject_Type(other_obj) != PyObject_Type(obj) &&
- PyObject_IsInstance(other_obj,
+ PyObject_IsInstance(other_obj,
PyObject_Type(override_obj))) {
override_obj = NULL;
break;
@@ -161,19 +325,19 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
/* Check if there is a method left to call */
if (!override_obj) {
/* No acceptable override found. */
- PyErr_SetString(PyExc_TypeError,
+ PyErr_SetString(PyExc_TypeError,
"__numpy_ufunc__ not implemented for this type.");
goto fail;
}
/* Call the override */
- numpy_ufunc = PyObject_GetAttrString(override_obj,
+ numpy_ufunc = PyObject_GetAttrString(override_obj,
"__numpy_ufunc__");
if (numpy_ufunc == NULL) {
goto fail;
}
- override_args = Py_BuildValue("OOiO", ufunc, method_name,
+ override_args = Py_BuildValue("OOiO", ufunc, method_name,
override_pos, normal_args);
if (override_args == NULL) {
Py_DECREF(numpy_ufunc);
@@ -181,7 +345,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
}
*result = PyObject_Call(numpy_ufunc, override_args, normal_kwds);
-
+
Py_DECREF(numpy_ufunc);
Py_DECREF(override_args);
@@ -212,5 +376,4 @@ fail:
Py_XDECREF(normal_kwds);
return 1;
}
-
#endif
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 767c68932..705359cb8 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4066,7 +4066,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
mps[i] = NULL;
}
- errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override,
+ errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override,
ufunc->nin);
if (errval) {
return NULL;
@@ -4751,6 +4751,8 @@ static PyObject *
ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
{
int i;
+ int errval;
+ PyObject *override = NULL;
PyObject *ret;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ap_new = NULL;
PyObject *new_args, *tmp;
@@ -4775,6 +4777,15 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
return NULL;
}
+ /* `nin`, the last arg, is unused. So we put 0. */
+ errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override, 0);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
+
tmp = PySequence_GetItem(args, 0);
if (tmp == NULL) {
return NULL;
@@ -4844,8 +4855,8 @@ ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
int errval;
PyObject *override = NULL;
- errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override,
- ufunc->nin);
+ /* `nin`, the last arg, is unused. So we put 0. */
+ errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override, 0);
if (errval) {
return NULL;
}
@@ -4861,8 +4872,8 @@ ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
int errval;
PyObject *override = NULL;
- errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override,
- ufunc->nin);
+ /* `nin`, the last arg, is unused. So we put 0. */
+ errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override, 0);
if (errval) {
return NULL;
}
@@ -4878,8 +4889,8 @@ ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
int errval;
PyObject *override = NULL;
- errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override,
- ufunc->nin);
+ /* `nin`, the last arg, is unused. So we put 0. */
+ errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override, 0);
if (errval) {
return NULL;
}
@@ -4931,6 +4942,10 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
int i;
int nop;
+ /* override vars */
+ int errval;
+ PyObject *override = NULL;
+
NpyIter *iter_buffer;
NpyIter_IterNextFunc *iternext;
npy_uint32 op_flags[NPY_MAXARGS];
@@ -4939,6 +4954,15 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
char * err_msg = NULL;
NPY_BEGIN_THREADS_DEF;
+ /* `nin`, the last arg, is unused. So we put 0. */
+ errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override, 0);
+ if (errval) {
+ return NULL;
+ }
+ else if (override) {
+ return override;
+ }
+
if (ufunc->nin > 2) {
PyErr_SetString(PyExc_ValueError,
"Only unary and binary ufuncs supported at this time");
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 3646fd2a9..0f232e2ff 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1026,39 +1026,105 @@ class TestSpecialMethods(TestCase):
def test_ufunc_override_methods(self):
class A(object):
def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
- if method == "__call__":
- return method
- if method == "reduce":
- return method
- if method == "accumulate":
- return method
- if method == "reduceat":
- return method
+ return self, ufunc, method, pos, inputs, kwargs
+ # __call__
a = A()
- res = np.multiply(1, a)
- assert_equal(res, "__call__")
-
- res = np.multiply.reduce(1, a)
- assert_equal(res, "reduce")
-
- res = np.multiply.accumulate(1, a)
- assert_equal(res, "accumulate")
-
- res = np.multiply.reduceat(1, a)
- assert_equal(res, "reduceat")
-
- res = np.multiply(a, 1)
- assert_equal(res, "__call__")
-
- res = np.multiply.reduce(a, 1)
- assert_equal(res, "reduce")
-
- res = np.multiply.accumulate(a, 1)
- assert_equal(res, "accumulate")
-
- res = np.multiply.reduceat(a, 1)
- assert_equal(res, "reduceat")
+ res = np.multiply.__call__(1, a, foo='bar', answer=42)
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], '__call__')
+ assert_equal(res[3], 1)
+ assert_equal(res[4], (1, a))
+ assert_equal(res[5], {'foo': 'bar', 'answer': 42})
+
+ # reduce, positional args
+ res = np.multiply.reduce(a, 'axis0', 'dtype0', 'out0', 'keep0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'reduce')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a,))
+ assert_equal(res[5], {'dtype':'dtype0',
+ 'out': 'out0',
+ 'keepdims': 'keep0',
+ 'axis': 'axis0'})
+
+ # reduce, kwargs
+ res = np.multiply.reduce(a, axis='axis0', dtype='dtype0', out='out0',
+ keepdims='keep0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'reduce')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a,))
+ assert_equal(res[5], {'dtype':'dtype0',
+ 'out': 'out0',
+ 'keepdims': 'keep0',
+ 'axis': 'axis0'})
+
+ # accumulate, pos args
+ res = np.multiply.accumulate(a, 'axis0', 'dtype0', 'out0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'accumulate')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a,))
+ assert_equal(res[5], {'dtype':'dtype0',
+ 'out': 'out0',
+ 'axis': 'axis0'})
+
+ # accumulate, kwargs
+ res = np.multiply.accumulate(a, axis='axis0', dtype='dtype0',
+ out='out0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'accumulate')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a,))
+ assert_equal(res[5], {'dtype':'dtype0',
+ 'out': 'out0',
+ 'axis': 'axis0'})
+
+ # reduceat, pos args
+ res = np.multiply.reduceat(a, [4, 2], 'axis0', 'dtype0', 'out0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'reduceat')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a, [4, 2]))
+ assert_equal(res[5], {'dtype':'dtype0',
+ 'out': 'out0',
+ 'axis': 'axis0'})
+
+ # reduceat, kwargs
+ res = np.multiply.reduceat(a, [4, 2], axis='axis0', dtype='dtype0',
+ out='out0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'reduceat')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a, [4, 2]))
+ assert_equal(res[5], {'dtype':'dtype0',
+ 'out': 'out0',
+ 'axis': 'axis0'})
+
+ # outer
+ res = np.multiply.outer(a, 42)
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'outer')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a, 42))
+ assert_equal(res[5], {})
+
+ # at
+ res = np.multiply.at(a, [4, 2], 'b0')
+ assert_equal(res[0], a)
+ assert_equal(res[1], np.multiply)
+ assert_equal(res[2], 'at')
+ assert_equal(res[3], 0)
+ assert_equal(res[4], (a, [4, 2], 'b0'))
def test_ufunc_override_out(self):
class A(object):
@@ -1094,7 +1160,6 @@ class TestSpecialMethods(TestCase):
assert_equal(res7['out'][0], 'out0')
assert_equal(res7['out'][1], 'out1')
-
def test_ufunc_override_exception(self):
class A(object):
def __numpy_ufunc__(self, *a, **kwargs):