diff options
author | Albert <albert.jornet@ic3.cat> | 2014-03-06 16:17:52 +0100 |
---|---|---|
committer | jurnix <albert.jornet@ic3.cat> | 2014-03-27 12:27:33 +0100 |
commit | ddc95d379f410c7ff787c157ded6c40bb873215c (patch) | |
tree | 89f67ab367c93beaa7887790a7c81b981da8e0ee /numpy | |
parent | a6f9b782cd9b60fc1464e6a4a7ef9a7762fcf2d5 (diff) | |
download | numpy-ddc95d379f410c7ff787c157ded6c40bb873215c.tar.gz |
ENH: apply_along_axis accepts named arguments
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/shape_base.py | 10 | ||||
-rw-r--r-- | numpy/ma/tests/test_extras.py | 10 |
2 files changed, 16 insertions, 4 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 38b928d57..31232c989 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -12,7 +12,7 @@ from numpy.core.numeric import asarray, zeros, newaxis, outer, \ from numpy.core.fromnumeric import product, reshape from numpy.core import hstack, vstack, atleast_3d -def apply_along_axis(func1d,axis,arr,*args): +def apply_along_axis(func1d,axis,arr,*args,**kwargs): """ Apply a function to 1-D slices along the given axis. @@ -30,6 +30,8 @@ def apply_along_axis(func1d,axis,arr,*args): Input array. args : any Additional arguments to `func1d`. + kwargs: any + Additional named arguments to `func1d`. Returns ------- @@ -78,7 +80,7 @@ def apply_along_axis(func1d,axis,arr,*args): i[axis] = slice(None, None) outshape = asarray(arr.shape).take(indlist) i.put(indlist, ind) - res = func1d(arr[tuple(i.tolist())],*args) + res = func1d(arr[tuple(i.tolist())],*args,**kwargs) # if res is a number, then we have a smaller output array if isscalar(res): outarr = zeros(outshape, asarray(res).dtype) @@ -94,7 +96,7 @@ def apply_along_axis(func1d,axis,arr,*args): ind[n] = 0 n -= 1 i.put(indlist, ind) - res = func1d(arr[tuple(i.tolist())],*args) + res = func1d(arr[tuple(i.tolist())],*args,**kwargs) outarr[tuple(ind)] = res k += 1 return outarr @@ -115,7 +117,7 @@ def apply_along_axis(func1d,axis,arr,*args): ind[n] = 0 n -= 1 i.put(indlist, ind) - res = func1d(arr[tuple(i.tolist())],*args) + res = func1d(arr[tuple(i.tolist())],*args,**kwargs) outarr[tuple(i.tolist())] = res k += 1 return outarr diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index dc0f87b92..98fb0597e 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -479,6 +479,16 @@ class TestApplyAlongAxis(TestCase): xa = apply_along_axis(myfunc, 2, a) assert_equal(xa, [[1, 4], [7, 10]]) + # Tests kwargs functions + def test_3d_kwargs(self): + a = arange(12).reshape(2, 2, 3) + + def myfunc(b, offset=0): + return b[1+offset] + + xa = apply_along_axis(myfunc, 2, a, offset=1) + assert_equal(xa, [[2, 5], [8, 11]]) + class TestApplyOverAxes(TestCase): # Tests apply_over_axes |