summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatti Picus <matti.picus@gmail.com>2019-05-15 22:38:32 -0700
committerGitHub <noreply@github.com>2019-05-15 22:38:32 -0700
commit8bf8d28fefba9743f7c2c9a7b2fda827c17fefbc (patch)
tree5e7c1bc6fc0a327a51e686dfaf2f5823ebd0e79f
parentd7e1fcbc2cd9e7a1eea5e04d8fee5909b07b8076 (diff)
parent1be0e6862196ce92f4b8a2257bad2e890c398cc1 (diff)
downloadnumpy-8bf8d28fefba9743f7c2c9a7b2fda827c17fefbc.tar.gz
Merge pull request #13503 from fruchart/matmul-object
ENH: Support object arrays in matmul
-rw-r--r--doc/release/1.17.0-notes.rst9
-rw-r--r--numpy/core/code_generators/generate_umath.py1
-rw-r--r--numpy/core/src/umath/matmul.c.src79
-rw-r--r--numpy/core/src/umath/matmul.h.src2
-rw-r--r--numpy/core/tests/test_multiarray.py54
5 files changed, 134 insertions, 11 deletions
diff --git a/doc/release/1.17.0-notes.rst b/doc/release/1.17.0-notes.rst
index 3f8095440..1d9bb8f1a 100644
--- a/doc/release/1.17.0-notes.rst
+++ b/doc/release/1.17.0-notes.rst
@@ -289,6 +289,15 @@ methods when called on object arrays, making them compatible with
In general, this handles object arrays more gracefully, and avoids floating-
point operations if exact arithmetic types are used.
+Support of object arrays in ``np.matmul``
+-----------------------------------------
+It is now possible to use ``np.matmul`` (or the ``@`` operator) with object arrays.
+For instance, it is now possible to do::
+
+ from fractions import Fraction
+ a = np.array([[Fraction(1, 2), Fraction(1, 3)], [Fraction(1, 3), Fraction(1, 2)]])
+ b = a @ a
+
Changes
=======
diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py
index c58690069..bf1747272 100644
--- a/numpy/core/code_generators/generate_umath.py
+++ b/numpy/core/code_generators/generate_umath.py
@@ -931,6 +931,7 @@ defdict = {
docstrings.get('numpy.core.umath.matmul'),
"PyUFunc_SimpleUniformOperationTypeResolver",
TD(notimes_or_obj),
+ TD(O),
signature='(n?,k),(k,m?)->(n?,m?)',
),
}
diff --git a/numpy/core/src/umath/matmul.c.src b/numpy/core/src/umath/matmul.c.src
index 0cb3c82ad..480c0c72f 100644
--- a/numpy/core/src/umath/matmul.c.src
+++ b/numpy/core/src/umath/matmul.c.src
@@ -267,19 +267,88 @@ NPY_NO_EXPORT void
/**end repeat**/
+
+NPY_NO_EXPORT void
+OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
+ void *_ip2, npy_intp is2_n, npy_intp is2_p,
+ void *_op, npy_intp os_m, npy_intp os_p,
+ npy_intp dm, npy_intp dn, npy_intp dp)
+{
+ char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
+
+ npy_intp ib1_n = is1_n * dn;
+ npy_intp ib2_n = is2_n * dn;
+ npy_intp ib2_p = is2_p * dp;
+ npy_intp ob_p = os_p * dp;
+
+ PyObject *product, *sum_of_products = NULL;
+
+ for (npy_intp m = 0; m < dm; m++) {
+ for (npy_intp p = 0; p < dp; p++) {
+ if ( 0 == dn ) {
+ sum_of_products = PyLong_FromLong(0);
+ if (sum_of_products == NULL) {
+ return;
+ }
+ }
+
+ for (npy_intp n = 0; n < dn; n++) {
+ PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2;
+ if (obj1 == NULL) {
+ obj1 = Py_None;
+ }
+ if (obj2 == NULL) {
+ obj2 = Py_None;
+ }
+
+ product = PyNumber_Multiply(obj1, obj2);
+ if (product == NULL) {
+ Py_XDECREF(sum_of_products);
+ return;
+ }
+
+ if (n == 0) {
+ sum_of_products = product;
+ }
+ else {
+ Py_SETREF(sum_of_products, PyNumber_Add(sum_of_products, product));
+ Py_DECREF(product);
+ if (sum_of_products == NULL) {
+ return;
+ }
+ }
+
+ ip2 += is2_n;
+ ip1 += is1_n;
+ }
+
+ *((PyObject **)op) = sum_of_products;
+ ip1 -= ib1_n;
+ ip2 -= ib2_n;
+ op += os_p;
+ ip2 += is2_p;
+ }
+ op -= ob_p;
+ ip2 -= ib2_p;
+ ip1 += is1_m;
+ op += os_m;
+ }
+}
+
+
/**begin repeat
* #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
* CFLOAT, CDOUBLE, CLONGDOUBLE,
* UBYTE, USHORT, UINT, ULONG, ULONGLONG,
* BYTE, SHORT, INT, LONG, LONGLONG,
- * BOOL#
+ * BOOL, OBJECT#
* #typ = npy_float,npy_double,npy_longdouble, npy_half,
* npy_cfloat, npy_cdouble, npy_clongdouble,
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
- * npy_bool#
- * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*11#
- * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*12#
+ * npy_bool,npy_object#
+ * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*12#
+ * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13#
*/
@@ -398,5 +467,3 @@ NPY_NO_EXPORT void
}
/**end repeat**/
-
-
diff --git a/numpy/core/src/umath/matmul.h.src b/numpy/core/src/umath/matmul.h.src
index 16be7675b..a664b1b4e 100644
--- a/numpy/core/src/umath/matmul.h.src
+++ b/numpy/core/src/umath/matmul.h.src
@@ -3,7 +3,7 @@
* CFLOAT, CDOUBLE, CLONGDOUBLE,
* UBYTE, USHORT, UINT, ULONG, ULONGLONG,
* BYTE, SHORT, INT, LONG, LONGLONG,
- * BOOL#
+ * BOOL, OBJECT#
**/
NPY_NO_EXPORT void
@TYPE@_matmul(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 59c19aa1b..a2dd47c92 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -5756,7 +5756,7 @@ class MatmulCommon(object):
"""
# Should work with these types. Will want to add
# "O" at some point
- types = "?bhilqBHILQefdgFDG"
+ types = "?bhilqBHILQefdgFDGO"
def test_exceptions(self):
dims = [
@@ -5807,8 +5807,9 @@ class MatmulCommon(object):
assert_(res.dtype == dt)
# vector vector returns scalars
- res = self.matmul(v, v)
- assert_(type(res) is np.dtype(dt).type)
+ if dt != "O":
+ res = self.matmul(v, v)
+ assert_(type(res) is np.dtype(dt).type)
def test_scalar_output(self):
vec1 = np.array([2])
@@ -6059,7 +6060,52 @@ class TestMatmul(MatmulCommon):
r3 = np.matmul(args[0].copy(), args[1].copy())
assert_equal(r1, r3)
-
+
+ def test_matmul_object(self):
+ import fractions
+
+ f = np.vectorize(fractions.Fraction)
+ def random_ints():
+ return np.random.randint(1, 1000, size=(10, 3, 3))
+ M1 = f(random_ints(), random_ints())
+ M2 = f(random_ints(), random_ints())
+
+ M3 = self.matmul(M1, M2)
+
+ [N1, N2, N3] = [a.astype(float) for a in [M1, M2, M3]]
+
+ assert_allclose(N3, self.matmul(N1, N2))
+
+ def test_matmul_object_type_scalar(self):
+ from fractions import Fraction as F
+ v = np.array([F(2,3), F(5,7)])
+ res = self.matmul(v, v)
+ assert_(type(res) is F)
+
+ def test_matmul_empty(self):
+ a = np.empty((3, 0), dtype=object)
+ b = np.empty((0, 3), dtype=object)
+ c = np.zeros((3, 3))
+ assert_array_equal(np.matmul(a, b), c)
+
+ def test_matmul_exception_multiply(self):
+ # test that matmul fails if `__mul__` is missing
+ class add_not_multiply():
+ def __add__(self, other):
+ return self
+ a = np.full((3,3), add_not_multiply())
+ with assert_raises(TypeError):
+ b = np.matmul(a, a)
+
+ def test_matmul_exception_add(self):
+ # test that matmul fails if `__add__` is missing
+ class multiply_not_add():
+ def __mul__(self, other):
+ return self
+ a = np.full((3,3), multiply_not_add())
+ with assert_raises(TypeError):
+ b = np.matmul(a, a)
+
if sys.version_info[:2] >= (3, 5):