diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-02-08 22:05:11 +0000 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-02-20 22:03:05 +0000 |
commit | 370b6506f128460371484a50c813d66e64582f44 (patch) | |
tree | 31470f73a17e41608865d900228796800c4bb988 /numpy/core/numeric.py | |
parent | 763589d5adbda6230b30ba054cda7206dd14d379 (diff) | |
download | numpy-370b6506f128460371484a50c813d66e64582f44.tar.gz |
MAINT: Use normalize_axis_index in all python axis checking
As a result, some exceptions change from ValueError to IndexError
This also changes the exception types raised in places where
normalize_axis_index is not quite appropriate
Diffstat (limited to 'numpy/core/numeric.py')
-rw-r--r-- | numpy/core/numeric.py | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 97d19f008..a5cb5bb31 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -17,7 +17,7 @@ from .multiarray import ( inner, int_asbuffer, lexsort, matmul, may_share_memory, min_scalar_type, ndarray, nditer, nested_iters, promote_types, putmask, result_type, set_numeric_ops, shares_memory, vdot, where, - zeros) + zeros, normalize_axis_index) if sys.version_info[0] < 3: from .multiarray import newbuffer, getbuffer @@ -1527,15 +1527,12 @@ def rollaxis(a, axis, start=0): """ n = a.ndim - if axis < 0: - axis += n + axis = normalize_axis_index(axis, n) if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" - if not (0 <= axis < n): - raise ValueError(msg % ('axis', -n, 'axis', n, axis)) if not (0 <= start < n + 1): - raise ValueError(msg % ('start', -n, 'start', n + 1, start)) + raise IndexError(msg % ('start', -n, 'start', n + 1, start)) if axis < start: # it's been removed start -= 1 @@ -1554,7 +1551,7 @@ def _validate_axis(axis, ndim, argname): axis = list(axis) axis = [a + ndim if a < 0 else a for a in axis] if not builtins.all(0 <= a < ndim for a in axis): - raise ValueError('invalid axis for this array in `%s` argument' % + raise IndexError('invalid axis for this array in `%s` argument' % argname) if len(set(axis)) != len(axis): raise ValueError('repeated axis in `%s` argument' % argname) |