summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2018-01-31 09:38:14 -0800
committerEric Wieser <wieser.eric@gmail.com>2018-02-03 11:36:30 -0800
commitef70f13177a53266fd8547da6e00bc252a057893 (patch)
treeb836e830ac216985e27c3f6e63e582df2bbab91c /numpy
parent2854d508d1c6d211f2ce99e8747eda1cb427a78a (diff)
downloadnumpy-ef70f13177a53266fd8547da6e00bc252a057893.tar.gz
MAINT: Use AxisError in swapaxes
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/common.h13
-rw-r--r--numpy/core/src/multiarray/shape.c14
-rw-r--r--numpy/core/tests/test_multiarray.py8
-rw-r--r--numpy/lib/arraysetops.py8
-rw-r--r--numpy/lib/tests/test_arraysetops.py4
5 files changed, 26 insertions, 21 deletions
diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h
index ae9b960c8..4b670b851 100644
--- a/numpy/core/src/multiarray/common.h
+++ b/numpy/core/src/multiarray/common.h
@@ -2,6 +2,7 @@
#define _NPY_PRIVATE_COMMON_H_
#include <numpy/npy_common.h>
#include <numpy/npy_cpu.h>
+#include <numpy/npy_3kcompat.h>
#include <numpy/ndarraytypes.h>
#include <limits.h>
@@ -181,6 +182,18 @@ check_and_adjust_axis(int *axis, int ndim)
{
return check_and_adjust_axis_msg(axis, ndim, Py_None);
}
+static NPY_INLINE int
+check_and_adjust_axis_cmsg(int *axis, int ndim, char const *cmsg)
+{
+ int ret;
+ PyObject *msg = PyUString_FromString(cmsg);
+ if (msg == NULL) {
+ return -1;
+ }
+ ret = check_and_adjust_axis_msg(axis, ndim, msg);
+ Py_DECREF(msg);
+ return ret;
+}
/*
diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c
index 61908e95e..21f901755 100644
--- a/numpy/core/src/multiarray/shape.c
+++ b/numpy/core/src/multiarray/shape.c
@@ -648,20 +648,10 @@ PyArray_SwapAxes(PyArrayObject *ap, int a1, int a2)
int n = PyArray_NDIM(ap);
int i;
- if (a1 < 0) {
- a1 += n;
- }
- if (a2 < 0) {
- a2 += n;
- }
- if ((a1 < 0) || (a1 >= n)) {
- PyErr_SetString(PyExc_ValueError,
- "bad axis1 argument to swapaxes");
+ if (check_and_adjust_axis_cmsg(&a1, n, "axis1") < 0) {
return NULL;
}
- if ((a2 < 0) || (a2 >= n)) {
- PyErr_SetString(PyExc_ValueError,
- "bad axis2 argument to swapaxes");
+ if (check_and_adjust_axis_cmsg(&a2, n, "axis2") < 0) {
return NULL;
}
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b768f7a65..3ab1b971e 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2864,10 +2864,10 @@ class TestMethods(object):
assert_(a.flags['OWNDATA'])
b = a.copy()
# check exceptions
- assert_raises(ValueError, a.swapaxes, -5, 0)
- assert_raises(ValueError, a.swapaxes, 4, 0)
- assert_raises(ValueError, a.swapaxes, 0, -5)
- assert_raises(ValueError, a.swapaxes, 0, 4)
+ assert_raises(np.AxisError, a.swapaxes, -5, 0)
+ assert_raises(np.AxisError, a.swapaxes, 4, 0)
+ assert_raises(np.AxisError, a.swapaxes, 0, -5)
+ assert_raises(np.AxisError, a.swapaxes, 0, 4)
for i in range(-4, 4):
for j in range(-4, 4):
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index e6ff5bf38..7b103ef3e 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -223,10 +223,12 @@ def unique(ar, return_index=False, return_inverse=False,
ret = _unique1d(ar, return_index, return_inverse, return_counts)
return _unpack_tuple(ret)
- if not (-ar.ndim <= axis < ar.ndim):
- raise ValueError('Invalid axis kwarg specified for unique')
+ try:
+ ar = np.swapaxes(ar, axis, 0)
+ except np.AxisError:
+ # this removes the "axis1" or "axis2" prefix from the error message
+ raise np.AxisError(axis, ar.ndim)
- ar = np.swapaxes(ar, axis, 0)
orig_shape, orig_dtype = ar.shape, ar.dtype
# Must reshape to a contiguous 2D array for this to work...
ar = ar.reshape(orig_shape[0], -1)
diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py
index c2ba7ac86..17415d8fe 100644
--- a/numpy/lib/tests/test_arraysetops.py
+++ b/numpy/lib/tests/test_arraysetops.py
@@ -409,8 +409,8 @@ class TestUnique(object):
assert_raises(TypeError, self._run_axis_tests,
[('a', int), ('b', object)])
- assert_raises(ValueError, unique, np.arange(10), axis=2)
- assert_raises(ValueError, unique, np.arange(10), axis=-2)
+ assert_raises(np.AxisError, unique, np.arange(10), axis=2)
+ assert_raises(np.AxisError, unique, np.arange(10), axis=-2)
def test_unique_axis_list(self):
msg = "Unique failed on list of lists"