diff options
author | Stefan van der Walt <stefan@sun.ac.za> | 2009-07-04 12:14:59 +0000 |
---|---|---|
committer | Stefan van der Walt <stefan@sun.ac.za> | 2009-07-04 12:14:59 +0000 |
commit | ae6da8374a055b5dfe738c03c41ff8e001f51180 (patch) | |
tree | 1c041e2f94fc9d4f23b23a88c448959da2778ac4 /numpy/lib/index_tricks.py | |
parent | 99df3daf134808115b458d90c4c6fa676a02e6f2 (diff) | |
download | numpy-ae6da8374a055b5dfe738c03c41ff8e001f51180.tar.gz |
Fix diag_indices_from and add test.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 737fc0a60..3388decf0 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -669,10 +669,11 @@ s_ = IndexExpression(maketuple=False) # End contribution from Konrad. -# I'm not sure this is the best place in numpy for these functions, but since -# they handle multidimensional arrays, it seemed better than twodim_base. -def fill_diagonal(a,val): +# The following functions complement those in twodim_base, but are +# applicable to N-dimensions. + +def fill_diagonal(a, val): """Fill the main diagonal of the given array of any dimensionality. For an array with ndim > 2, the diagonal is the list of locations with @@ -737,13 +738,13 @@ def fill_diagonal(a,val): # all dimensions equal, so we check first. if not alltrue(diff(a.shape)==0): raise ValueError("All dimensions of input must be of equal length") - step = cumprod((1,)+a.shape[:-1]).sum() + step = cumprod((1,) + a.shape[:-1]).sum() # Write the value out into the diagonal. a.flat[::step] = val -def diag_indices(n,ndim=2): +def diag_indices(n, ndim=2): """Return the indices to access the main diagonal of an array. This returns a tuple of indices that can be used to access the main @@ -758,7 +759,7 @@ def diag_indices(n,ndim=2): indices can be used. ndim : int, optional - The number of dimensions + The number of dimensions. Examples -------- @@ -794,10 +795,10 @@ def diag_indices(n,ndim=2): See also -------- - diag_indices_from: create the indices based on the shape of an existing - array. + array. """ idx = arange(n) - return (idx,)*ndim + return (idx,) * ndim def diag_indices_from(arr): @@ -814,7 +815,7 @@ def diag_indices_from(arr): raise ValueError("input array must be at least 2-d") # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. - if not alltrue(diff(a.shape)==0): + if not alltrue(diff(arr.shape) == 0): raise ValueError("All dimensions of input must be of equal length") - return diag_indices(a.shape[0],a.ndim) + return diag_indices(arr.shape[0], arr.ndim) |