summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorTravis E. Oliphant <teoliphant@gmail.com>2012-07-17 19:26:39 -0700
committerTravis E. Oliphant <teoliphant@gmail.com>2012-07-17 19:26:39 -0700
commit578a4199a81e7464011661fcf8d46a8af2235db2 (patch)
tree58e1650c914153acceb0229fea16a546db0e46e0 /numpy/lib
parent0b2bfa9c13070b08b3632f15a3aa327146994cc4 (diff)
parent2c04244da264cb1665d6162ae119d2f05ad65150 (diff)
downloadnumpy-578a4199a81e7464011661fcf8d46a8af2235db2.tar.gz
Merge pull request #352 from HackerSchool12/bugfix808
BF bug #808
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/function_base.py20
-rw-r--r--numpy/lib/tests/test_function_base.py3
2 files changed, 9 insertions, 14 deletions
diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py
index 1edfc4c6b..247f16560 100644
--- a/numpy/lib/function_base.py
+++ b/numpy/lib/function_base.py
@@ -3592,24 +3592,18 @@ def insert(arr, obj, values, axis=None):
N = arr.shape[axis]
newshape = list(arr.shape)
if isinstance(obj, (int, long, integer)):
+
if (obj < 0): obj += N
if obj < 0 or obj > N:
raise ValueError(
"index (%d) out of range (0<=index<=%d) "\
"in dimension %d" % (obj, N, axis))
- newshape[axis] += 1;
- new = empty(newshape, arr.dtype, arr.flags.fnc)
- slobj[axis] = slice(None, obj)
- new[slobj] = arr[slobj]
- slobj[axis] = obj
- new[slobj] = values
- slobj[axis] = slice(obj+1,None)
- slobj2 = [slice(None)]*ndim
- slobj2[axis] = slice(obj,None)
- new[slobj] = arr[slobj2]
- if wrap:
- return wrap(new)
- return new
+
+ if isinstance(values, (int, long, integer)):
+ obj = [obj]
+ else:
+ obj = [obj] * len(values)
+
elif isinstance(obj, slice):
# turn it into a range object
diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py
index 81892e634..db242153d 100644
--- a/numpy/lib/tests/test_function_base.py
+++ b/numpy/lib/tests/test_function_base.py
@@ -145,7 +145,8 @@ class TestInsert(TestCase):
assert_equal(insert(a, 0, 1), [1, 1, 2, 3])
assert_equal(insert(a, 3, 1), [1, 2, 3, 1])
assert_equal(insert(a, [1, 1, 1], [1, 2, 3]), [1, 1, 2, 3, 2, 3])
-
+ assert_equal(insert(a, 1,[1,2,3]), [1, 1, 2, 3, 2, 3])
+ assert_equal(insert(a,[1,2,3],9),[1,9,2,9,3,9])
class TestAmax(TestCase):
def test_basic(self):