diff options
-rw-r--r-- | doc/release/1.17.0-notes.rst | 9 | ||||
-rw-r--r-- | numpy/core/code_generators/generate_umath.py | 1 | ||||
-rw-r--r-- | numpy/core/src/umath/matmul.c.src | 79 | ||||
-rw-r--r-- | numpy/core/src/umath/matmul.h.src | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 54 |
5 files changed, 134 insertions, 11 deletions
diff --git a/doc/release/1.17.0-notes.rst b/doc/release/1.17.0-notes.rst index 3f8095440..1d9bb8f1a 100644 --- a/doc/release/1.17.0-notes.rst +++ b/doc/release/1.17.0-notes.rst @@ -289,6 +289,15 @@ methods when called on object arrays, making them compatible with In general, this handles object arrays more gracefully, and avoids floating- point operations if exact arithmetic types are used. +Support of object arrays in ``np.matmul`` +----------------------------------------- +It is now possible to use ``np.matmul`` (or the ``@`` operator) with object arrays. +For instance, it is now possible to do:: + + from fractions import Fraction + a = np.array([[Fraction(1, 2), Fraction(1, 3)], [Fraction(1, 3), Fraction(1, 2)]]) + b = a @ a + Changes ======= diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index c58690069..bf1747272 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -931,6 +931,7 @@ defdict = { docstrings.get('numpy.core.umath.matmul'), "PyUFunc_SimpleUniformOperationTypeResolver", TD(notimes_or_obj), + TD(O), signature='(n?,k),(k,m?)->(n?,m?)', ), } diff --git a/numpy/core/src/umath/matmul.c.src b/numpy/core/src/umath/matmul.c.src index 0cb3c82ad..480c0c72f 100644 --- a/numpy/core/src/umath/matmul.c.src +++ b/numpy/core/src/umath/matmul.c.src @@ -267,19 +267,88 @@ NPY_NO_EXPORT void /**end repeat**/ + +NPY_NO_EXPORT void +OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n, + void *_ip2, npy_intp is2_n, npy_intp is2_p, + void *_op, npy_intp os_m, npy_intp os_p, + npy_intp dm, npy_intp dn, npy_intp dp) +{ + char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op; + + npy_intp ib1_n = is1_n * dn; + npy_intp ib2_n = is2_n * dn; + npy_intp ib2_p = is2_p * dp; + npy_intp ob_p = os_p * dp; + + PyObject *product, *sum_of_products = NULL; + + for (npy_intp m = 0; m < dm; m++) { + for (npy_intp p = 0; p < dp; p++) { + if ( 0 == dn ) { + sum_of_products = PyLong_FromLong(0); + if (sum_of_products == NULL) { + return; + } + } + + for (npy_intp n = 0; n < dn; n++) { + PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2; + if (obj1 == NULL) { + obj1 = Py_None; + } + if (obj2 == NULL) { + obj2 = Py_None; + } + + product = PyNumber_Multiply(obj1, obj2); + if (product == NULL) { + Py_XDECREF(sum_of_products); + return; + } + + if (n == 0) { + sum_of_products = product; + } + else { + Py_SETREF(sum_of_products, PyNumber_Add(sum_of_products, product)); + Py_DECREF(product); + if (sum_of_products == NULL) { + return; + } + } + + ip2 += is2_n; + ip1 += is1_n; + } + + *((PyObject **)op) = sum_of_products; + ip1 -= ib1_n; + ip2 -= ib2_n; + op += os_p; + ip2 += is2_p; + } + op -= ob_p; + ip2 -= ib2_p; + ip1 += is1_m; + op += os_m; + } +} + + /**begin repeat * #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF, * CFLOAT, CDOUBLE, CLONGDOUBLE, * UBYTE, USHORT, UINT, ULONG, ULONGLONG, * BYTE, SHORT, INT, LONG, LONGLONG, - * BOOL# + * BOOL, OBJECT# * #typ = npy_float,npy_double,npy_longdouble, npy_half, * npy_cfloat, npy_cdouble, npy_clongdouble, * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, * npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_bool# - * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*11# - * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*12# + * npy_bool,npy_object# + * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*12# + * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13# */ @@ -398,5 +467,3 @@ NPY_NO_EXPORT void } /**end repeat**/ - - diff --git a/numpy/core/src/umath/matmul.h.src b/numpy/core/src/umath/matmul.h.src index 16be7675b..a664b1b4e 100644 --- a/numpy/core/src/umath/matmul.h.src +++ b/numpy/core/src/umath/matmul.h.src @@ -3,7 +3,7 @@ * CFLOAT, CDOUBLE, CLONGDOUBLE, * UBYTE, USHORT, UINT, ULONG, ULONGLONG, * BYTE, SHORT, INT, LONG, LONGLONG, - * BOOL# + * BOOL, OBJECT# **/ NPY_NO_EXPORT void @TYPE@_matmul(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 59c19aa1b..a2dd47c92 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -5756,7 +5756,7 @@ class MatmulCommon(object): """ # Should work with these types. Will want to add # "O" at some point - types = "?bhilqBHILQefdgFDG" + types = "?bhilqBHILQefdgFDGO" def test_exceptions(self): dims = [ @@ -5807,8 +5807,9 @@ class MatmulCommon(object): assert_(res.dtype == dt) # vector vector returns scalars - res = self.matmul(v, v) - assert_(type(res) is np.dtype(dt).type) + if dt != "O": + res = self.matmul(v, v) + assert_(type(res) is np.dtype(dt).type) def test_scalar_output(self): vec1 = np.array([2]) @@ -6059,7 +6060,52 @@ class TestMatmul(MatmulCommon): r3 = np.matmul(args[0].copy(), args[1].copy()) assert_equal(r1, r3) - + + def test_matmul_object(self): + import fractions + + f = np.vectorize(fractions.Fraction) + def random_ints(): + return np.random.randint(1, 1000, size=(10, 3, 3)) + M1 = f(random_ints(), random_ints()) + M2 = f(random_ints(), random_ints()) + + M3 = self.matmul(M1, M2) + + [N1, N2, N3] = [a.astype(float) for a in [M1, M2, M3]] + + assert_allclose(N3, self.matmul(N1, N2)) + + def test_matmul_object_type_scalar(self): + from fractions import Fraction as F + v = np.array([F(2,3), F(5,7)]) + res = self.matmul(v, v) + assert_(type(res) is F) + + def test_matmul_empty(self): + a = np.empty((3, 0), dtype=object) + b = np.empty((0, 3), dtype=object) + c = np.zeros((3, 3)) + assert_array_equal(np.matmul(a, b), c) + + def test_matmul_exception_multiply(self): + # test that matmul fails if `__mul__` is missing + class add_not_multiply(): + def __add__(self, other): + return self + a = np.full((3,3), add_not_multiply()) + with assert_raises(TypeError): + b = np.matmul(a, a) + + def test_matmul_exception_add(self): + # test that matmul fails if `__add__` is missing + class multiply_not_add(): + def __mul__(self, other): + return self + a = np.full((3,3), multiply_not_add()) + with assert_raises(TypeError): + b = np.matmul(a, a) + if sys.version_info[:2] >= (3, 5): |