summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorFélix Hartmann <felix.hartmann@crans.org>2013-07-09 15:49:32 +0200
committerCharles Harris <charlesr.harris@gmail.com>2013-08-02 14:36:24 -0600
commitea768739dab69c0b67488179ffa67d57d63d59f8 (patch)
treea2a043f988f655b5fa3e8093c6420897c1c554d1 /numpy
parent496813f1a23363bbd50a62a60c37f6bd4e10649b (diff)
downloadnumpy-ea768739dab69c0b67488179ffa67d57d63d59f8.tar.gz
BUG: Make np.insert check for out of bounds axis arguments.
Also add test for IndexError exception when axis is out of bounds.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/function_base.py7
-rw-r--r--numpy/lib/tests/test_function_base.py4
2 files changed, 10 insertions, 1 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 9f1757d7c..17f99a065 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3696,6 +3696,11 @@ def insert(arr, obj, values, axis=None):
arr = arr.ravel()
ndim = arr.ndim
axis = ndim-1
+ else:
+ if ndim > 0 and (axis < -ndim or axis >= ndim):
+ raise IndexError("axis %i is out of bounds for an array "
+ "of dimension %i" % (axis, ndim))
+ if (axis < 0): axis += ndim
if (ndim == 0):
warnings.warn("in the future the special handling of scalars "
"will be removed from insert and raise an error",
@@ -3742,7 +3747,7 @@ def insert(arr, obj, values, axis=None):
# broadcasting is very different here, since a[:,0,:] = ... behaves
# very different from a[:,[0],:] = ...! This changes values so that
# it works likes the second case. (here a[:,0:1,:])
- values = np.rollaxis(values, 0, axis % ndim + 1)
+ values = np.rollaxis(values, 0, axis + 1)
numnew = values.shape[axis]
newshape[axis] += numnew
new = empty(newshape, arr.dtype, arr.flags.fnc)
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 13f907d5a..de561e55a 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -218,6 +218,10 @@ class TestInsert(TestCase):
assert_equal(insert(a, 1, a[:,2,:], axis=-2),
insert(a, 1, a[:,2,:], axis=1))
+ # invalid axis value
+ assert_raises(IndexError, insert, a, 1, a[:,2,:], axis=3)
+ assert_raises(IndexError, insert, a, 1, a[:,2,:], axis=-4)
+
def test_0d(self):
# This is an error in the future
a = np.array(1)