summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2014-01-10 11:36:50 -0800
committerCharles Harris <charlesr.harris@gmail.com>2014-01-10 11:36:50 -0800
commitd400b0422a016cf9ca75cc8e9375ea3c5324e88a (patch)
tree87fb69446f24699037400005baccce9b9b986806
parentd1dbf8e796ab6bcdc4f3b71252f3921ab2a62269 (diff)
parentb1b0ea8030cad32d6fce2e6e6b5068e54bd6b7a7 (diff)
downloadnumpy-d400b0422a016cf9ca75cc8e9375ea3c5324e88a.tar.gz
Merge pull request #4183 from charris/gh-4099
ENH: Remove unnecessary broadcasting notation restrictions in einsum.
-rw-r--r--doc/release/1.9.0-notes.rst7
-rw-r--r--numpy/add_newdocs.py14
-rw-r--r--numpy/core/src/multiarray/einsum.c.src193
-rw-r--r--numpy/core/tests/test_einsum.py39
4 files changed, 91 insertions, 162 deletions
diff --git a/doc/release/1.9.0-notes.rst b/doc/release/1.9.0-notes.rst
index aaa2415f2..45d50fc05 100644
--- a/doc/release/1.9.0-notes.rst
+++ b/doc/release/1.9.0-notes.rst
@@ -79,7 +79,12 @@ The `out` argument to `np.argmin` and `np.argmax` and their equivalent
C-API functions is now checked to match the desired output shape exactly.
If the check fails a `ValueError` instead of `TypeError` is raised.
-
+Einsum
+~~~~~~
+Remove unnecessary broadcasting notation restrictions.
+np.einsum('ijk,j->ijk', A, B) can also be written as
+np.einsum('ij...,j->ij...', A, B) (ellipsis is no longer required on 'j')
+
C-API
~~~~~
diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py
index 62ca6dca3..62e8898c9 100644
--- a/numpy/add_newdocs.py
+++ b/numpy/add_newdocs.py
@@ -2078,6 +2078,8 @@ add_newdoc('numpy.core', 'einsum',
array([ 30, 80, 130, 180, 230])
>>> np.dot(a, b)
array([ 30, 80, 130, 180, 230])
+ >>> np.einsum('...j,j', a, b)
+ array([ 30, 80, 130, 180, 230])
>>> np.einsum('ji', c)
array([[0, 3],
@@ -2147,6 +2149,18 @@ add_newdoc('numpy.core', 'einsum',
[ 4796., 5162.],
[ 4928., 5306.]])
+ >>> a = np.arange(6).reshape((3,2))
+ >>> b = np.arange(12).reshape((4,3))
+ >>> np.einsum('ki,jk->ij', a, b)
+ array([[10, 28, 46, 64],
+ [13, 40, 67, 94]])
+ >>> np.einsum('ki,...k->i...', a, b)
+ array([[10, 28, 46, 64],
+ [13, 40, 67, 94]])
+ >>> np.einsum('k...,jk', a, b)
+ array([[10, 28, 46, 64],
+ [13, 40, 67, 94]])
+
""")
add_newdoc('numpy.core', 'alterdot',
diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src
index 7a94c9305..d143bd626 100644
--- a/numpy/core/src/multiarray/einsum.c.src
+++ b/numpy/core/src/multiarray/einsum.c.src
@@ -64,13 +64,6 @@
#endif
/**********************************************/
-typedef enum {
- BROADCAST_NONE,
- BROADCAST_LEFT,
- BROADCAST_RIGHT,
- BROADCAST_MIDDLE
-} EINSUM_BROADCAST;
-
/**begin repeat
* #name = byte, short, int, long, longlong,
* ubyte, ushort, uint, ulong, ulonglong,
@@ -1802,8 +1795,7 @@ parse_operand_subscripts(char *subscripts, int length,
char *out_label_counts,
int *out_min_label,
int *out_max_label,
- int *out_num_labels,
- EINSUM_BROADCAST *out_broadcast)
+ int *out_num_labels)
{
int i, idim, ndim_left, label;
int left_labels = 0, right_labels = 0, ellipsis = 0;
@@ -1947,19 +1939,6 @@ parse_operand_subscripts(char *subscripts, int length,
}
}
- if (!ellipsis) {
- *out_broadcast = BROADCAST_NONE;
- }
- else if (left_labels && right_labels) {
- *out_broadcast = BROADCAST_MIDDLE;
- }
- else if (!left_labels) {
- *out_broadcast = BROADCAST_RIGHT;
- }
- else {
- *out_broadcast = BROADCAST_LEFT;
- }
-
return 1;
}
@@ -1972,8 +1951,7 @@ static int
parse_output_subscripts(char *subscripts, int length,
int ndim_broadcast,
const char *label_counts,
- char *out_labels,
- EINSUM_BROADCAST *out_broadcast)
+ char *out_labels)
{
int i, nlabels, label, idim, ndim, ndim_left;
int left_labels = 0, right_labels = 0, ellipsis = 0;
@@ -2097,19 +2075,6 @@ parse_output_subscripts(char *subscripts, int length,
out_labels[idim++] = 0;
}
- if (!ellipsis) {
- *out_broadcast = BROADCAST_NONE;
- }
- else if (left_labels && right_labels) {
- *out_broadcast = BROADCAST_MIDDLE;
- }
- else if (!left_labels) {
- *out_broadcast = BROADCAST_RIGHT;
- }
- else {
- *out_broadcast = BROADCAST_LEFT;
- }
-
return ndim;
}
@@ -2328,136 +2293,44 @@ get_combined_dims_view(PyArrayObject *op, int iop, char *labels)
static int
prepare_op_axes(int ndim, int iop, char *labels, int *axes,
- int ndim_iter, char *iter_labels, EINSUM_BROADCAST broadcast)
+ int ndim_iter, char *iter_labels)
{
int i, label, ibroadcast;
- /* Regular broadcasting */
- if (broadcast == BROADCAST_RIGHT) {
- /* broadcast dimensions get placed in rightmost position */
- ibroadcast = ndim-1;
- for (i = ndim_iter-1; i >= 0; --i) {
- label = iter_labels[i];
- /*
- * If it's an unlabeled broadcast dimension, choose
- * the next broadcast dimension from the operand.
- */
- if (label == 0) {
- while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
- --ibroadcast;
- }
- /*
- * If we used up all the operand broadcast dimensions,
- * extend it with a "newaxis"
- */
- if (ibroadcast < 0) {
- axes[i] = -1;
- }
- /* Otherwise map to the broadcast axis */
- else {
- axes[i] = ibroadcast;
- --ibroadcast;
- }
- }
- /* It's a labeled dimension, find the matching one */
- else {
- char *match = memchr(labels, label, ndim);
- /* If the op doesn't have the label, broadcast it */
- if (match == NULL) {
- axes[i] = -1;
- }
- /* Otherwise use it */
- else {
- axes[i] = match - labels;
- }
+ ibroadcast = ndim-1;
+ for (i = ndim_iter-1; i >= 0; --i) {
+ label = iter_labels[i];
+ /*
+ * If it's an unlabeled broadcast dimension, choose
+ * the next broadcast dimension from the operand.
+ */
+ if (label == 0) {
+ while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
+ --ibroadcast;
}
- }
- }
- /* Reverse broadcasting */
- else if (broadcast == BROADCAST_LEFT) {
- /* broadcast dimensions get placed in leftmost position */
- ibroadcast = 0;
- for (i = 0; i < ndim_iter; ++i) {
- label = iter_labels[i];
/*
- * If it's an unlabeled broadcast dimension, choose
- * the next broadcast dimension from the operand.
+ * If we used up all the operand broadcast dimensions,
+ * extend it with a "newaxis"
*/
- if (label == 0) {
- while (ibroadcast < ndim && labels[ibroadcast] != 0) {
- ++ibroadcast;
- }
- /*
- * If we used up all the operand broadcast dimensions,
- * extend it with a "newaxis"
- */
- if (ibroadcast >= ndim) {
- axes[i] = -1;
- }
- /* Otherwise map to the broadcast axis */
- else {
- axes[i] = ibroadcast;
- ++ibroadcast;
- }
+ if (ibroadcast < 0) {
+ axes[i] = -1;
}
- /* It's a labeled dimension, find the matching one */
+ /* Otherwise map to the broadcast axis */
else {
- char *match = memchr(labels, label, ndim);
- /* If the op doesn't have the label, broadcast it */
- if (match == NULL) {
- axes[i] = -1;
- }
- /* Otherwise use it */
- else {
- axes[i] = match - labels;
- }
+ axes[i] = ibroadcast;
+ --ibroadcast;
}
}
- }
- /* Middle or None broadcasting */
- else {
- /* broadcast dimensions get placed in leftmost position */
- ibroadcast = 0;
- for (i = 0; i < ndim_iter; ++i) {
- label = iter_labels[i];
- /*
- * If it's an unlabeled broadcast dimension, choose
- * the next broadcast dimension from the operand.
- */
- if (label == 0) {
- while (ibroadcast < ndim && labels[ibroadcast] != 0) {
- ++ibroadcast;
- }
- /*
- * If we used up all the operand broadcast dimensions,
- * it's an error
- */
- if (ibroadcast >= ndim) {
- PyErr_Format(PyExc_ValueError,
- "operand %d did not have enough dimensions "
- "to match the broadcasting, and couldn't be "
- "extended because einstein sum subscripts "
- "were specified at both the start and end",
- iop);
- return 0;
- }
- /* Otherwise map to the broadcast axis */
- else {
- axes[i] = ibroadcast;
- ++ibroadcast;
- }
+ /* It's a labeled dimension, find the matching one */
+ else {
+ char *match = memchr(labels, label, ndim);
+ /* If the op doesn't have the label, broadcast it */
+ if (match == NULL) {
+ axes[i] = -1;
}
- /* It's a labeled dimension, find the matching one */
+ /* Otherwise use it */
else {
- char *match = memchr(labels, label, ndim);
- /* If the op doesn't have the label, broadcast it */
- if (match == NULL) {
- axes[i] = -1;
- }
- /* Otherwise use it */
- else {
- axes[i] = match - labels;
- }
+ axes[i] = match - labels;
}
}
}
@@ -2737,7 +2610,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
char output_labels[NPY_MAXDIMS], *iter_labels;
int idim, ndim_output, ndim_broadcast, ndim_iter;
- EINSUM_BROADCAST broadcast[NPY_MAXARGS];
PyArrayObject *op[NPY_MAXARGS], *ret = NULL;
PyArray_Descr *op_dtypes_array[NPY_MAXARGS], **op_dtypes;
@@ -2783,8 +2655,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
if (!parse_operand_subscripts(subscripts, length,
PyArray_NDIM(op_in[iop]),
iop, op_labels[iop], label_counts,
- &min_label, &max_label, &num_labels,
- &broadcast[iop])) {
+ &min_label, &max_label, &num_labels)) {
return NULL;
}
@@ -2845,7 +2716,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
/* Parse the output subscript string */
ndim_output = parse_output_subscripts(outsubscripts, length,
ndim_broadcast, label_counts,
- output_labels, &broadcast[nop]);
+ output_labels);
}
else {
if (subscripts[0] != '-' || subscripts[1] != '>') {
@@ -2859,7 +2730,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
/* Parse the output subscript string */
ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
ndim_broadcast, label_counts,
- output_labels, &broadcast[nop]);
+ output_labels);
}
if (ndim_output < 0) {
return NULL;
@@ -2961,7 +2832,7 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
op_axes[iop] = op_axes_arrays[iop];
if (!prepare_op_axes(PyArray_NDIM(op[iop]), iop, op_labels[iop],
- op_axes[iop], ndim_iter, iter_labels, broadcast[iop])) {
+ op_axes[iop], ndim_iter, iter_labels)) {
goto fail;
}
}
diff --git a/numpy/core/tests/test_einsum.py b/numpy/core/tests/test_einsum.py
index 31f94bf07..7d6a9b30c 100644
--- a/numpy/core/tests/test_einsum.py
+++ b/numpy/core/tests/test_einsum.py
@@ -498,5 +498,44 @@ class TestEinSum(TestCase):
[[[1, 3], [3, 9], [5, 15], [7, 21]],
[[8, 16], [16, 32], [24, 48], [32, 64]]])
+ def test_einsum_broadcast(self):
+ # Issue #2455 change in handling ellipsis
+ # remove the 'middle broadcast' error
+ # only use the 'RIGHT' iteration in prepare_op_axes
+ # adds auto broadcast on left where it belongs
+ # broadcast on right has to be explicit
+
+ A = np.arange(2*3*4).reshape(2,3,4)
+ B = np.arange(3)
+ ref = np.einsum('ijk,j->ijk',A, B)
+ assert_equal(np.einsum('ij...,j...->ij...',A, B), ref)
+ assert_equal(np.einsum('ij...,...j->ij...',A, B), ref)
+ assert_equal(np.einsum('ij...,j->ij...',A, B), ref) # used to raise error
+
+ A = np.arange(12).reshape((4,3))
+ B = np.arange(6).reshape((3,2))
+ ref = np.einsum('ik,kj->ij', A, B)
+ assert_equal(np.einsum('ik...,k...->i...', A, B), ref)
+ assert_equal(np.einsum('ik...,...kj->i...j', A, B), ref)
+ assert_equal(np.einsum('...k,kj', A, B), ref) # used to raise error
+ assert_equal(np.einsum('ik,k...->i...', A, B), ref) # used to raise error
+
+ dims=[2,3,4,5];
+ a = np.arange(np.prod(dims)).reshape(dims)
+ v = np.arange(dims[2])
+ ref = np.einsum('ijkl,k->ijl', a, v)
+ assert_equal(np.einsum('ijkl,k', a, v), ref)
+ assert_equal(np.einsum('...kl,k', a, v), ref) # used to raise error
+ assert_equal(np.einsum('...kl,k...', a, v), ref)
+ # no real diff from 1st
+
+ J,K,M=160,160,120;
+ A=np.arange(J*K*M).reshape(1,1,1,J,K,M)
+ B=np.arange(J*K*M*3).reshape(J,K,M,3)
+ ref = np.einsum('...lmn,...lmno->...o', A, B)
+ assert_equal(np.einsum('...lmn,lmno->...o', A, B), ref) # used to raise error
+
+
+
if __name__ == "__main__":
run_module_suite()