summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_internal.py3
-rw-r--r--numpy/core/numeric.py4
-rw-r--r--numpy/core/src/multiarray/common.h16
3 files changed, 20 insertions, 3 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py
index 741c8bb5f..d73cdcc55 100644
--- a/numpy/core/_internal.py
+++ b/numpy/core/_internal.py
@@ -630,3 +630,6 @@ def _gcd(a, b):
# Exception used in shares_memory()
class TooHardError(RuntimeError):
pass
+
+class AxisError(ValueError, IndexError):
+ pass
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index a5cb5bb31..e7307a870 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -27,7 +27,7 @@ from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
-from ._internal import TooHardError
+from ._internal import TooHardError, AxisError
bitwise_not = invert
ufunc = type(sin)
@@ -65,7 +65,7 @@ __all__ = [
'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
- 'TooHardError',
+ 'TooHardError', 'AxisError'
]
diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h
index 02522138a..625ca9d76 100644
--- a/numpy/core/src/multiarray/common.h
+++ b/numpy/core/src/multiarray/common.h
@@ -144,7 +144,21 @@ check_and_adjust_axis(int *axis, int ndim)
{
/* Check that index is valid, taking into account negative indices */
if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) {
- PyErr_Format(PyExc_IndexError,
+ /*
+ * Load the exception type, if we don't already have it. Unfortunately
+ * we don't have access to npy_cache_import here
+ */
+ static PyObject *AxisError_cls = NULL;
+ if (AxisError_cls == NULL) {
+ PyObject *mod = PyImport_ImportModule("numpy.core._internal");
+
+ if (mod != NULL) {
+ AxisError_cls = PyObject_GetAttrString(mod, "AxisError");
+ Py_DECREF(mod);
+ }
+ }
+
+ PyErr_Format(AxisError_cls,
"axis %d is out of bounds for array of dimension %d",
*axis, ndim);
return -1;