diff options
author | Félix Hartmann <felix.hartmann@crans.org> | 2013-07-09 15:49:32 +0200 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2013-08-02 14:36:24 -0600 |
commit | ea768739dab69c0b67488179ffa67d57d63d59f8 (patch) | |
tree | a2a043f988f655b5fa3e8093c6420897c1c554d1 /numpy | |
parent | 496813f1a23363bbd50a62a60c37f6bd4e10649b (diff) | |
download | numpy-ea768739dab69c0b67488179ffa67d57d63d59f8.tar.gz |
BUG: Make np.insert check for out of bounds axis arguments.
Also add test for IndexError exception when axis is out of bounds.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/lib/function_base.py | 7 | ||||
-rw-r--r-- | numpy/lib/tests/test_function_base.py | 4 |
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index 9f1757d7c..17f99a065 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -3696,6 +3696,11 @@ def insert(arr, obj, values, axis=None): arr = arr.ravel() ndim = arr.ndim axis = ndim-1 + else: + if ndim > 0 and (axis < -ndim or axis >= ndim): + raise IndexError("axis %i is out of bounds for an array " + "of dimension %i" % (axis, ndim)) + if (axis < 0): axis += ndim if (ndim == 0): warnings.warn("in the future the special handling of scalars " "will be removed from insert and raise an error", @@ -3742,7 +3747,7 @@ def insert(arr, obj, values, axis=None): # broadcasting is very different here, since a[:,0,:] = ... behaves # very different from a[:,[0],:] = ...! This changes values so that # it works likes the second case. (here a[:,0:1,:]) - values = np.rollaxis(values, 0, axis % ndim + 1) + values = np.rollaxis(values, 0, axis + 1) numnew = values.shape[axis] newshape[axis] += numnew new = empty(newshape, arr.dtype, arr.flags.fnc) diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index 13f907d5a..de561e55a 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -218,6 +218,10 @@ class TestInsert(TestCase): assert_equal(insert(a, 1, a[:,2,:], axis=-2), insert(a, 1, a[:,2,:], axis=1)) + # invalid axis value + assert_raises(IndexError, insert, a, 1, a[:,2,:], axis=3) + assert_raises(IndexError, insert, a, 1, a[:,2,:], axis=-4) + def test_0d(self): # This is an error in the future a = np.array(1) |