diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-01-31 09:38:14 -0800 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-02-03 11:36:30 -0800 |
commit | ef70f13177a53266fd8547da6e00bc252a057893 (patch) | |
tree | b836e830ac216985e27c3f6e63e582df2bbab91c /numpy | |
parent | 2854d508d1c6d211f2ce99e8747eda1cb427a78a (diff) | |
download | numpy-ef70f13177a53266fd8547da6e00bc252a057893.tar.gz |
MAINT: Use AxisError in swapaxes
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/common.h | 13 | ||||
-rw-r--r-- | numpy/core/src/multiarray/shape.c | 14 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 8 | ||||
-rw-r--r-- | numpy/lib/arraysetops.py | 8 | ||||
-rw-r--r-- | numpy/lib/tests/test_arraysetops.py | 4 |
5 files changed, 26 insertions, 21 deletions
diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h index ae9b960c8..4b670b851 100644 --- a/numpy/core/src/multiarray/common.h +++ b/numpy/core/src/multiarray/common.h @@ -2,6 +2,7 @@ #define _NPY_PRIVATE_COMMON_H_ #include <numpy/npy_common.h> #include <numpy/npy_cpu.h> +#include <numpy/npy_3kcompat.h> #include <numpy/ndarraytypes.h> #include <limits.h> @@ -181,6 +182,18 @@ check_and_adjust_axis(int *axis, int ndim) { return check_and_adjust_axis_msg(axis, ndim, Py_None); } +static NPY_INLINE int +check_and_adjust_axis_cmsg(int *axis, int ndim, char const *cmsg) +{ + int ret; + PyObject *msg = PyUString_FromString(cmsg); + if (msg == NULL) { + return -1; + } + ret = check_and_adjust_axis_msg(axis, ndim, msg); + Py_DECREF(msg); + return ret; +} /* diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c index 61908e95e..21f901755 100644 --- a/numpy/core/src/multiarray/shape.c +++ b/numpy/core/src/multiarray/shape.c @@ -648,20 +648,10 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2) int n = PyArray_NDIM(ap); int i; - if (a1 < 0) { - a1 += n; - } - if (a2 < 0) { - a2 += n; - } - if ((a1 < 0) || (a1 >= n)) { - PyErr_SetString(PyExc_ValueError, - "bad axis1 argument to swapaxes"); + if (check_and_adjust_axis_cmsg(&a1, n, "axis1") < 0) { return NULL; } - if ((a2 < 0) || (a2 >= n)) { - PyErr_SetString(PyExc_ValueError, - "bad axis2 argument to swapaxes"); + if (check_and_adjust_axis_cmsg(&a2, n, "axis2") < 0) { return NULL; } diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b768f7a65..3ab1b971e 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2864,10 +2864,10 @@ class TestMethods(object): assert_(a.flags['OWNDATA']) b = a.copy() # check exceptions - assert_raises(ValueError, a.swapaxes, -5, 0) - assert_raises(ValueError, a.swapaxes, 4, 0) - assert_raises(ValueError, a.swapaxes, 0, -5) - assert_raises(ValueError, a.swapaxes, 0, 4) + assert_raises(np.AxisError, a.swapaxes, -5, 0) + assert_raises(np.AxisError, a.swapaxes, 4, 0) + assert_raises(np.AxisError, a.swapaxes, 0, -5) + assert_raises(np.AxisError, a.swapaxes, 0, 4) for i in range(-4, 4): for j in range(-4, 4): diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index e6ff5bf38..7b103ef3e 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -223,10 +223,12 @@ def unique(ar, return_index=False, return_inverse=False, ret = _unique1d(ar, return_index, return_inverse, return_counts) return _unpack_tuple(ret) - if not (-ar.ndim <= axis < ar.ndim): - raise ValueError('Invalid axis kwarg specified for unique') + try: + ar = np.swapaxes(ar, axis, 0) + except np.AxisError: + # this removes the "axis1" or "axis2" prefix from the error message + raise np.AxisError(axis, ar.ndim) - ar = np.swapaxes(ar, axis, 0) orig_shape, orig_dtype = ar.shape, ar.dtype # Must reshape to a contiguous 2D array for this to work... ar = ar.reshape(orig_shape[0], -1) diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index c2ba7ac86..17415d8fe 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -409,8 +409,8 @@ class TestUnique(object): assert_raises(TypeError, self._run_axis_tests, [('a', int), ('b', object)]) - assert_raises(ValueError, unique, np.arange(10), axis=2) - assert_raises(ValueError, unique, np.arange(10), axis=-2) + assert_raises(np.AxisError, unique, np.arange(10), axis=2) + assert_raises(np.AxisError, unique, np.arange(10), axis=-2) def test_unique_axis_list(self): msg = "Unique failed on list of lists" |