diff options
-rw-r--r-- | numpy/lib/function_base.py | 15 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 6 |
2 files changed, 5 insertions, 16 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index cc0e8feeb..2a8a13caa 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -1679,21 +1679,10 @@ def gradient(f, *varargs, **kwargs): axes = kwargs.pop('axis', None) if axes is None: axes = tuple(range(N)) - # check axes to have correct type and no duplicate entries - if isinstance(axes, int): - axes = (axes,) - if not isinstance(axes, tuple): - raise TypeError("A tuple of integers or a single integer is required") - - # normalize axis values: - axes = tuple(x + N if x < 0 else x for x in axes) - if max(axes) >= N or min(axes) < 0: - raise ValueError("'axis' entry is out of bounds") + else: + axes = _nx._validate_axis(axes, N) len_axes = len(axes) - if len(set(axes)) != len_axes: - raise ValueError("duplicate value in 'axis'") - n = len(varargs) if n == 0: dx = [1.0] * len_axes diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 708b20482..6c6ed5941 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -902,9 +902,9 @@ class TestGradient(TestCase): # test maximal number of varargs assert_raises(TypeError, gradient, x, 1, 2, axis=1) - assert_raises(ValueError, gradient, x, axis=3) - assert_raises(ValueError, gradient, x, axis=-3) - assert_raises(TypeError, gradient, x, axis=[1,]) + assert_raises(np.AxisError, gradient, x, axis=3) + assert_raises(np.AxisError, gradient, x, axis=-3) + # assert_raises(TypeError, gradient, x, axis=[1,]) def test_timedelta64(self): # Make sure gradient() can handle special types like timedelta64 |