summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/function_base.py7
-rw-r--r--numpy/lib/tests/test_function_base.py11
2 files changed, 17 insertions, 1 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 5e433d3ab..17f99a065 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3696,6 +3696,11 @@ def insert(arr, obj, values, axis=None):
arr = arr.ravel()
ndim = arr.ndim
axis = ndim-1
+ else:
+ if ndim > 0 and (axis < -ndim or axis >= ndim):
+ raise IndexError("axis %i is out of bounds for an array "
+ "of dimension %i" % (axis, ndim))
+ if (axis < 0): axis += ndim
if (ndim == 0):
warnings.warn("in the future the special handling of scalars "
"will be removed from insert and raise an error",
@@ -3742,7 +3747,7 @@ def insert(arr, obj, values, axis=None):
# broadcasting is very different here, since a[:,0,:] = ... behaves
# very different from a[:,[0],:] = ...! This changes values so that
# it works likes the second case. (here a[:,0:1,:])
- values = np.rollaxis(values, 0, axis+1)
+ values = np.rollaxis(values, 0, axis + 1)
numnew = values.shape[axis]
newshape[axis] += numnew
new = empty(newshape, arr.dtype, arr.flags.fnc)
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index a23e406e3..de561e55a 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -211,6 +211,17 @@ class TestInsert(TestCase):
assert_equal(insert(a[:,:1], 1, a[:,1], axis=1), a)
assert_equal(insert(a[:1,:], 1, a[1,:], axis=0), a)
+ # negative axis value
+ a = np.arange(24).reshape((2,3,4))
+ assert_equal(insert(a, 1, a[:,:,3], axis=-1),
+ insert(a, 1, a[:,:,3], axis=2))
+ assert_equal(insert(a, 1, a[:,2,:], axis=-2),
+ insert(a, 1, a[:,2,:], axis=1))
+
+ # invalid axis value
+ assert_raises(IndexError, insert, a, 1, a[:,2,:], axis=3)
+ assert_raises(IndexError, insert, a, 1, a[:,2,:], axis=-4)
+
def test_0d(self):
# This is an error in the future
a = np.array(1)