summaryrefslogtreecommitdiff
path: root/numpy/lib
diff options
context:
space:
mode:
authorFrederic <nouiz@nouiz.org>2012-06-11 16:23:17 -0400
committerFrederic <nouiz@nouiz.org>2012-06-11 16:23:17 -0400
commit69c33bf74bcdc1d9781bd5db27f942f6d676c032 (patch)
tree27cfdd7953dbe36b3ab9657d0078a06a97223475 /numpy/lib
parentd0f520a30990c018114672f24197866452a2d088 (diff)
downloadnumpy-69c33bf74bcdc1d9781bd5db27f942f6d676c032.tar.gz
fix the wrapping problem of fill_diagonal with tall matrix.
Diffstat (limited to 'numpy/lib')
-rw-r--r--numpy/lib/index_tricks.py5
-rw-r--r--numpy/lib/tests/test_index_tricks.py22
2 files changed, 26 insertions, 1 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index c29f3a6d3..e248bfaea 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -719,10 +719,13 @@ def fill_diagonal(a, val):
"""
if a.ndim < 2:
raise ValueError("array must be at least 2-d")
+ end = None
if a.ndim == 2:
# Explicit, fast formula for the common case. For 2-d arrays, we
# accept rectangular ones.
step = a.shape[1] + 1
+ #This is needed to don't have tall matrix have the diagonal wrap.
+ end = a.shape[1] * a.shape[1]
else:
# For more than d=2, the strided formula is only valid for arrays with
# all dimensions equal, so we check first.
@@ -731,7 +734,7 @@ def fill_diagonal(a, val):
step = 1 + (cumprod(a.shape[:-1])).sum()
# Write the value out into the diagonal.
- a.flat[::step] = val
+ a.flat[:end:step] = val
def diag_indices(n, ndim=2):
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 2c6500a57..aaedd83ea 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -158,6 +158,28 @@ def test_fill_diagonal():
array([[5, 0, 0],
[0, 5, 0],
[0, 0, 5]]))
+ #Test tall matrix
+ a = zeros((10, 3),int)
+ fill_diagonal(a, 5)
+ yield (assert_array_equal, a,
+ array([[5, 0, 0],
+ [0, 5, 0],
+ [0, 0, 5],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]]))
+
+ #Test wide matrix
+ a = zeros((3, 10),int)
+ fill_diagonal(a, 5)
+ yield (assert_array_equal, a,
+ array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 5, 0, 0, 0, 0, 0, 0, 0]]))
# The same function can operate on a 4-d array:
a = zeros((3, 3, 3, 3), int)