diff options
author | Frederic <nouiz@nouiz.org> | 2012-06-11 16:23:17 -0400 |
---|---|---|
committer | Frederic <nouiz@nouiz.org> | 2012-06-11 16:23:17 -0400 |
commit | 69c33bf74bcdc1d9781bd5db27f942f6d676c032 (patch) | |
tree | 27cfdd7953dbe36b3ab9657d0078a06a97223475 /numpy/lib | |
parent | d0f520a30990c018114672f24197866452a2d088 (diff) | |
download | numpy-69c33bf74bcdc1d9781bd5db27f942f6d676c032.tar.gz |
fix the wrapping problem of fill_diagonal with tall matrix.
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/index_tricks.py | 5 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 22 |
2 files changed, 26 insertions, 1 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c29f3a6d3..e248bfaea 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -719,10 +719,13 @@ def fill_diagonal(a, val): """ if a.ndim < 2: raise ValueError("array must be at least 2-d") + end = None if a.ndim == 2: # Explicit, fast formula for the common case. For 2-d arrays, we # accept rectangular ones. step = a.shape[1] + 1 + #This is needed to don't have tall matrix have the diagonal wrap. + end = a.shape[1] * a.shape[1] else: # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. @@ -731,7 +734,7 @@ def fill_diagonal(a, val): step = 1 + (cumprod(a.shape[:-1])).sum() # Write the value out into the diagonal. - a.flat[::step] = val + a.flat[:end:step] = val def diag_indices(n, ndim=2): diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 2c6500a57..aaedd83ea 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -158,6 +158,28 @@ def test_fill_diagonal(): array([[5, 0, 0], [0, 5, 0], [0, 0, 5]])) + #Test tall matrix + a = zeros((10, 3),int) + fill_diagonal(a, 5) + yield (assert_array_equal, a, + array([[5, 0, 0], + [0, 5, 0], + [0, 0, 5], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]])) + + #Test wide matrix + a = zeros((3, 10),int) + fill_diagonal(a, 5) + yield (assert_array_equal, a, + array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 5, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 5, 0, 0, 0, 0, 0, 0, 0]])) # The same function can operate on a 4-d array: a = zeros((3, 3, 3, 3), int) |