diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2014-01-10 11:36:50 -0800 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2014-01-10 11:36:50 -0800 |
commit | d400b0422a016cf9ca75cc8e9375ea3c5324e88a (patch) | |
tree | 87fb69446f24699037400005baccce9b9b986806 | |
parent | d1dbf8e796ab6bcdc4f3b71252f3921ab2a62269 (diff) | |
parent | b1b0ea8030cad32d6fce2e6e6b5068e54bd6b7a7 (diff) | |
download | numpy-d400b0422a016cf9ca75cc8e9375ea3c5324e88a.tar.gz |
Merge pull request #4183 from charris/gh-4099
ENH: Remove unnecessary broadcasting notation restrictions in einsum.
-rw-r--r-- | doc/release/1.9.0-notes.rst | 7 | ||||
-rw-r--r-- | numpy/add_newdocs.py | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/einsum.c.src | 193 | ||||
-rw-r--r-- | numpy/core/tests/test_einsum.py | 39 |
4 files changed, 91 insertions, 162 deletions
diff --git a/doc/release/1.9.0-notes.rst b/doc/release/1.9.0-notes.rst index aaa2415f2..45d50fc05 100644 --- a/doc/release/1.9.0-notes.rst +++ b/doc/release/1.9.0-notes.rst @@ -79,7 +79,12 @@ The `out` argument to `np.argmin` and `np.argmax` and their equivalent C-API functions is now checked to match the desired output shape exactly. If the check fails a `ValueError` instead of `TypeError` is raised. - +Einsum +~~~~~~ +Remove unnecessary broadcasting notation restrictions. +np.einsum('ijk,j->ijk', A, B) can also be written as +np.einsum('ij...,j->ij...', A, B) (ellipsis is no longer required on 'j') + C-API ~~~~~ diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 62ca6dca3..62e8898c9 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -2078,6 +2078,8 @@ add_newdoc('numpy.core', 'einsum', array([ 30, 80, 130, 180, 230]) >>> np.dot(a, b) array([ 30, 80, 130, 180, 230]) + >>> np.einsum('...j,j', a, b) + array([ 30, 80, 130, 180, 230]) >>> np.einsum('ji', c) array([[0, 3], @@ -2147,6 +2149,18 @@ add_newdoc('numpy.core', 'einsum', [ 4796., 5162.], [ 4928., 5306.]]) + >>> a = np.arange(6).reshape((3,2)) + >>> b = np.arange(12).reshape((4,3)) + >>> np.einsum('ki,jk->ij', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + >>> np.einsum('ki,...k->i...', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + >>> np.einsum('k...,jk', a, b) + array([[10, 28, 46, 64], + [13, 40, 67, 94]]) + """) add_newdoc('numpy.core', 'alterdot', diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 7a94c9305..d143bd626 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -64,13 +64,6 @@ #endif /**********************************************/ -typedef enum { - BROADCAST_NONE, - BROADCAST_LEFT, - BROADCAST_RIGHT, - BROADCAST_MIDDLE -} EINSUM_BROADCAST; - /**begin repeat * #name = byte, short, int, long, longlong, * ubyte, ushort, uint, ulong, ulonglong, @@ -1802,8 +1795,7 @@ parse_operand_subscripts(char *subscripts, int length, char *out_label_counts, int *out_min_label, int *out_max_label, - int *out_num_labels, - EINSUM_BROADCAST *out_broadcast) + int *out_num_labels) { int i, idim, ndim_left, label; int left_labels = 0, right_labels = 0, ellipsis = 0; @@ -1947,19 +1939,6 @@ parse_operand_subscripts(char *subscripts, int length, } } - if (!ellipsis) { - *out_broadcast = BROADCAST_NONE; - } - else if (left_labels && right_labels) { - *out_broadcast = BROADCAST_MIDDLE; - } - else if (!left_labels) { - *out_broadcast = BROADCAST_RIGHT; - } - else { - *out_broadcast = BROADCAST_LEFT; - } - return 1; } @@ -1972,8 +1951,7 @@ static int parse_output_subscripts(char *subscripts, int length, int ndim_broadcast, const char *label_counts, - char *out_labels, - EINSUM_BROADCAST *out_broadcast) + char *out_labels) { int i, nlabels, label, idim, ndim, ndim_left; int left_labels = 0, right_labels = 0, ellipsis = 0; @@ -2097,19 +2075,6 @@ parse_output_subscripts(char *subscripts, int length, out_labels[idim++] = 0; } - if (!ellipsis) { - *out_broadcast = BROADCAST_NONE; - } - else if (left_labels && right_labels) { - *out_broadcast = BROADCAST_MIDDLE; - } - else if (!left_labels) { - *out_broadcast = BROADCAST_RIGHT; - } - else { - *out_broadcast = BROADCAST_LEFT; - } - return ndim; } @@ -2328,136 +2293,44 @@ get_combined_dims_view(PyArrayObject *op, int iop, char *labels) static int prepare_op_axes(int ndim, int iop, char *labels, int *axes, - int ndim_iter, char *iter_labels, EINSUM_BROADCAST broadcast) + int ndim_iter, char *iter_labels) { int i, label, ibroadcast; - /* Regular broadcasting */ - if (broadcast == BROADCAST_RIGHT) { - /* broadcast dimensions get placed in rightmost position */ - ibroadcast = ndim-1; - for (i = ndim_iter-1; i >= 0; --i) { - label = iter_labels[i]; - /* - * If it's an unlabeled broadcast dimension, choose - * the next broadcast dimension from the operand. - */ - if (label == 0) { - while (ibroadcast >= 0 && labels[ibroadcast] != 0) { - --ibroadcast; - } - /* - * If we used up all the operand broadcast dimensions, - * extend it with a "newaxis" - */ - if (ibroadcast < 0) { - axes[i] = -1; - } - /* Otherwise map to the broadcast axis */ - else { - axes[i] = ibroadcast; - --ibroadcast; - } - } - /* It's a labeled dimension, find the matching one */ - else { - char *match = memchr(labels, label, ndim); - /* If the op doesn't have the label, broadcast it */ - if (match == NULL) { - axes[i] = -1; - } - /* Otherwise use it */ - else { - axes[i] = match - labels; - } + ibroadcast = ndim-1; + for (i = ndim_iter-1; i >= 0; --i) { + label = iter_labels[i]; + /* + * If it's an unlabeled broadcast dimension, choose + * the next broadcast dimension from the operand. + */ + if (label == 0) { + while (ibroadcast >= 0 && labels[ibroadcast] != 0) { + --ibroadcast; } - } - } - /* Reverse broadcasting */ - else if (broadcast == BROADCAST_LEFT) { - /* broadcast dimensions get placed in leftmost position */ - ibroadcast = 0; - for (i = 0; i < ndim_iter; ++i) { - label = iter_labels[i]; /* - * If it's an unlabeled broadcast dimension, choose - * the next broadcast dimension from the operand. + * If we used up all the operand broadcast dimensions, + * extend it with a "newaxis" */ - if (label == 0) { - while (ibroadcast < ndim && labels[ibroadcast] != 0) { - ++ibroadcast; - } - /* - * If we used up all the operand broadcast dimensions, - * extend it with a "newaxis" - */ - if (ibroadcast >= ndim) { - axes[i] = -1; - } - /* Otherwise map to the broadcast axis */ - else { - axes[i] = ibroadcast; - ++ibroadcast; - } + if (ibroadcast < 0) { + axes[i] = -1; } - /* It's a labeled dimension, find the matching one */ + /* Otherwise map to the broadcast axis */ else { - char *match = memchr(labels, label, ndim); - /* If the op doesn't have the label, broadcast it */ - if (match == NULL) { - axes[i] = -1; - } - /* Otherwise use it */ - else { - axes[i] = match - labels; - } + axes[i] = ibroadcast; + --ibroadcast; } } - } - /* Middle or None broadcasting */ - else { - /* broadcast dimensions get placed in leftmost position */ - ibroadcast = 0; - for (i = 0; i < ndim_iter; ++i) { - label = iter_labels[i]; - /* - * If it's an unlabeled broadcast dimension, choose - * the next broadcast dimension from the operand. - */ - if (label == 0) { - while (ibroadcast < ndim && labels[ibroadcast] != 0) { - ++ibroadcast; - } - /* - * If we used up all the operand broadcast dimensions, - * it's an error - */ - if (ibroadcast >= ndim) { - PyErr_Format(PyExc_ValueError, - "operand %d did not have enough dimensions " - "to match the broadcasting, and couldn't be " - "extended because einstein sum subscripts " - "were specified at both the start and end", - iop); - return 0; - } - /* Otherwise map to the broadcast axis */ - else { - axes[i] = ibroadcast; - ++ibroadcast; - } + /* It's a labeled dimension, find the matching one */ + else { + char *match = memchr(labels, label, ndim); + /* If the op doesn't have the label, broadcast it */ + if (match == NULL) { + axes[i] = -1; } - /* It's a labeled dimension, find the matching one */ + /* Otherwise use it */ else { - char *match = memchr(labels, label, ndim); - /* If the op doesn't have the label, broadcast it */ - if (match == NULL) { - axes[i] = -1; - } - /* Otherwise use it */ - else { - axes[i] = match - labels; - } + axes[i] = match - labels; } } } @@ -2737,7 +2610,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, char output_labels[NPY_MAXDIMS], *iter_labels; int idim, ndim_output, ndim_broadcast, ndim_iter; - EINSUM_BROADCAST broadcast[NPY_MAXARGS]; PyArrayObject *op[NPY_MAXARGS], *ret = NULL; PyArray_Descr *op_dtypes_array[NPY_MAXARGS], **op_dtypes; @@ -2783,8 +2655,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, if (!parse_operand_subscripts(subscripts, length, PyArray_NDIM(op_in[iop]), iop, op_labels[iop], label_counts, - &min_label, &max_label, &num_labels, - &broadcast[iop])) { + &min_label, &max_label, &num_labels)) { return NULL; } @@ -2845,7 +2716,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, /* Parse the output subscript string */ ndim_output = parse_output_subscripts(outsubscripts, length, ndim_broadcast, label_counts, - output_labels, &broadcast[nop]); + output_labels); } else { if (subscripts[0] != '-' || subscripts[1] != '>') { @@ -2859,7 +2730,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, /* Parse the output subscript string */ ndim_output = parse_output_subscripts(subscripts, strlen(subscripts), ndim_broadcast, label_counts, - output_labels, &broadcast[nop]); + output_labels); } if (ndim_output < 0) { return NULL; @@ -2961,7 +2832,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop, op_axes[iop] = op_axes_arrays[iop]; if (!prepare_op_axes(PyArray_NDIM(op[iop]), iop, op_labels[iop], - op_axes[iop], ndim_iter, iter_labels, broadcast[iop])) { + op_axes[iop], ndim_iter, iter_labels)) { goto fail; } } diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py index 31f94bf07..7d6a9b30c 100644 --- a/numpy/core/tests/test_einsum.py +++ b/numpy/core/tests/test_einsum.py @@ -498,5 +498,44 @@ class TestEinSum(TestCase): [[[1, 3], [3, 9], [5, 15], [7, 21]], [[8, 16], [16, 32], [24, 48], [32, 64]]]) + def test_einsum_broadcast(self): + # Issue #2455 change in handling ellipsis + # remove the 'middle broadcast' error + # only use the 'RIGHT' iteration in prepare_op_axes + # adds auto broadcast on left where it belongs + # broadcast on right has to be explicit + + A = np.arange(2*3*4).reshape(2,3,4) + B = np.arange(3) + ref = np.einsum('ijk,j->ijk',A, B) + assert_equal(np.einsum('ij...,j...->ij...',A, B), ref) + assert_equal(np.einsum('ij...,...j->ij...',A, B), ref) + assert_equal(np.einsum('ij...,j->ij...',A, B), ref) # used to raise error + + A = np.arange(12).reshape((4,3)) + B = np.arange(6).reshape((3,2)) + ref = np.einsum('ik,kj->ij', A, B) + assert_equal(np.einsum('ik...,k...->i...', A, B), ref) + assert_equal(np.einsum('ik...,...kj->i...j', A, B), ref) + assert_equal(np.einsum('...k,kj', A, B), ref) # used to raise error + assert_equal(np.einsum('ik,k...->i...', A, B), ref) # used to raise error + + dims=[2,3,4,5]; + a = np.arange(np.prod(dims)).reshape(dims) + v = np.arange(dims[2]) + ref = np.einsum('ijkl,k->ijl', a, v) + assert_equal(np.einsum('ijkl,k', a, v), ref) + assert_equal(np.einsum('...kl,k', a, v), ref) # used to raise error + assert_equal(np.einsum('...kl,k...', a, v), ref) + # no real diff from 1st + + J,K,M=160,160,120; + A=np.arange(J*K*M).reshape(1,1,1,J,K,M) + B=np.arange(J*K*M*3).reshape(J,K,M,3) + ref = np.einsum('...lmn,...lmno->...o', A, B) + assert_equal(np.einsum('...lmn,lmno->...o', A, B), ref) # used to raise error + + + if __name__ == "__main__": run_module_suite() |