diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/_internal.py | 3 | ||||
-rw-r--r-- | numpy/core/numeric.py | 4 | ||||
-rw-r--r-- | numpy/core/src/multiarray/common.h | 16 |
3 files changed, 20 insertions, 3 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 741c8bb5f..d73cdcc55 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -630,3 +630,6 @@ def _gcd(a, b): # Exception used in shares_memory() class TooHardError(RuntimeError): pass + +class AxisError(ValueError, IndexError): + pass diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index a5cb5bb31..e7307a870 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -27,7 +27,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, ERR_DEFAULT, PINF, NAN) from . import numerictypes from .numerictypes import longlong, intc, int_, float_, complex_, bool_ -from ._internal import TooHardError +from ._internal import TooHardError, AxisError bitwise_not = invert ufunc = type(sin) @@ -65,7 +65,7 @@ __all__ = [ 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', - 'TooHardError', + 'TooHardError', 'AxisError' ] diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h index 02522138a..625ca9d76 100644 --- a/numpy/core/src/multiarray/common.h +++ b/numpy/core/src/multiarray/common.h @@ -144,7 +144,21 @@ check_and_adjust_axis(int *axis, int ndim) { /* Check that index is valid, taking into account negative indices */ if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) { - PyErr_Format(PyExc_IndexError, + /* + * Load the exception type, if we don't already have it. Unfortunately + * we don't have access to npy_cache_import here + */ + static PyObject *AxisError_cls = NULL; + if (AxisError_cls == NULL) { + PyObject *mod = PyImport_ImportModule("numpy.core._internal"); + + if (mod != NULL) { + AxisError_cls = PyObject_GetAttrString(mod, "AxisError"); + Py_DECREF(mod); + } + } + + PyErr_Format(AxisError_cls, "axis %d is out of bounds for array of dimension %d", *axis, ndim); return -1; |