diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numeric.py | 43 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 15 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 8 | ||||
-rw-r--r-- | numpy/oldnumeric/misc.py | 5 |
4 files changed, 57 insertions, 14 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index b77fed3db..b89fc0586 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -7,7 +7,7 @@ __all__ = ['newaxis', 'ndarray', 'flatiter', 'ufunc', 'asarray', 'asanyarray', 'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like', 'zeros_like', 'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot', - 'alterdot', 'restoredot', 'cross', + 'alterdot', 'restoredot', 'cross', 'tensordot', 'array2string', 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str', 'set_string_function', 'little_endian', 'require', @@ -252,6 +252,47 @@ except ImportError: def restoredot(): pass +def tensordot(a, b, axes=(-1,0)) + """tensordot returns the product for any (ndim >= 1) arrays. + + r_{xxx, yyy} = \sum_k a_{xxx,k} b_{k,yyy} where + + the axes to be summed over are given by the axes argument. + the first element of the sequence determines the axis or axes + in arr1 to sum over and the second element in axes argument sequence + """ + axes_a, axes_b = axes + try: + na = len(axes_a) + except TypeError: + axes_a = [axes_a] + na = 1 + try: + nb = len(axes_b) + except TypeError: + axes_b = [axes_b] + nb = 1 + + a, b = asarray(a), asarray(b) + as = a.shape + bs = b.shape + equal = 1 + if (na != nb): equal = 0 + for k in xrange(na): + if as[axes_a[k]] != bs[axes_b[k]]: + equal = 0 + break + + if not equal: + raise ValueError, "shape-mismatch for sum" + + olda = [ for k in aa if k not in axes_a] + oldb = [k for k in bs if k not in axes_b] + + at = a.reshape(nd1, nd2) + res = dot(at, bt) + return res.reshape(olda + oldb) + def _move_axis_to_0(a, axis): if axis == 0: diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 3be679a03..9df31f463 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1,13 +1,13 @@ -__all__ = ['logspace', 'linspace', +u__all__ = ['logspace', 'linspace', 'select', 'piecewise', 'trim_zeros', 'copy', 'iterable', #'base_repr', 'binary_repr', 'diff', 'gradient', 'angle', 'unwrap', 'sort_complex', 'disp', - 'unique', 'extract', 'insert', 'nansum', 'nanmax', 'nanargmax', + 'unique', 'extract', 'place', 'nansum', 'nanmax', 'nanargmax', 'nanargmin', 'nanmin', 'vectorize', 'asarray_chkfinite', 'average', 'histogram', 'bincount', 'digitize', 'cov', 'corrcoef', 'msort', 'median', 'sinc', 'hamming', 'hanning', 'bartlett', 'blackman', 'kaiser', 'trapz', 'i0', 'add_newdoc', 'add_docstring', 'meshgrid', - 'deletefrom', 'insertinto', 'appendonto' + 'delete', 'insert', 'append' ] import types @@ -545,7 +545,7 @@ def extract(condition, arr): """ return _nx.take(ravel(arr), nonzero(ravel(condition))[0]) -def insert(arr, mask, vals): +def place(arr, mask, vals): """Similar to putmask arr[mask] = vals but the 1D array vals has the same number of elements as the non-zero values of mask. Inverse of extract. @@ -1011,7 +1011,7 @@ def meshgrid(x,y): Y = y.repeat(numCols, axis=1) return X, Y -def deletefrom(arr, obj, axis=None): +def delete(arr, obj, axis=None): """Return a new array with sub-arrays along an axis deleted. Return a new array with the sub-arrays (i.e. rows or columns) @@ -1117,7 +1117,7 @@ def deletefrom(arr, obj, axis=None): else: return new -def insertinto(arr, obj, values, axis=None): +def insert(arr, obj, values, axis=None): """Return a new array with values inserted along the given axis before the given indices @@ -1205,7 +1205,7 @@ def insertinto(arr, obj, values, axis=None): return wrap(new) return new -def appendonto(arr, values, axis=None): +def append(arr, values, axis=None): """Append to the end of an array along axis (ravel first if None) """ arr = asanyarray(arr) @@ -1215,3 +1215,4 @@ def appendonto(arr, values, axis=None): values = ravel(values) axis = arr.ndim-1 return concatenate((arr, values), axis=axis) + diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 9a1825e39..fdb2f270f 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -237,17 +237,17 @@ class test_extins(NumpyTestCase): a = array([1,3,2,1,2,3,3]) b = extract(a>1,a) assert_array_equal(b,[3,2,2,3,3]) - def check_insert(self): + def check_place(self): a = array([1,4,3,2,5,8,7]) - insert(a,[0,1,0,1,0,1,0],[2,4,6]) + place(a,[0,1,0,1,0,1,0],[2,4,6]) assert_array_equal(a,[1,2,3,4,5,6,7]) def check_both(self): a = rand(10) mask = a > 0.5 ac = a.copy() c = extract(mask, a) - insert(a,mask,0) - insert(a,mask,c) + place(a,mask,0) + place(a,mask,c) assert_array_equal(a,ac) class test_vectorize(NumpyTestCase): diff --git a/numpy/oldnumeric/misc.py b/numpy/oldnumeric/misc.py index a6c13d780..d7938fcac 100644 --- a/numpy/oldnumeric/misc.py +++ b/numpy/oldnumeric/misc.py @@ -9,7 +9,7 @@ __all__ = ['load', 'sort', 'copy_reg', 'clip', 'putmask', 'Unpickler', 'rank', 'searchsorted', 'put', 'fromfunction', 'copy', 'resize', 'array_repr', 'e', 'StringIO', 'pickle', 'argsort', 'convolve', 'loads', 'cross_correlate', - 'Pickler', 'dot', 'outerproduct', 'innerproduct'] + 'Pickler', 'dot', 'outerproduct', 'innerproduct', 'insert'] import types import StringIO @@ -23,7 +23,8 @@ from numpy import sort, clip, putmask, rank, sign, shape, allclose, size,\ choose, swapaxes, array_str, array_repr, e, pi, \ fromfunction, resize, around, concatenate, vdot, transpose, \ diagonal, searchsorted, put, argsort, convolve, dot, \ - outer as outerproduct, inner as innerproduct, correlate as cross_correlate + outer as outerproduct, inner as innerproduct, correlate as cross_correlate, \ + place as insert from array_printer import array2string |