summaryrefslogtreecommitdiff
path: root/numpy/lib/index_tricks.py
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2009-07-04 12:14:59 +0000
committerStefan van der Walt <stefan@sun.ac.za>2009-07-04 12:14:59 +0000
commitae6da8374a055b5dfe738c03c41ff8e001f51180 (patch)
tree1c041e2f94fc9d4f23b23a88c448959da2778ac4 /numpy/lib/index_tricks.py
parent99df3daf134808115b458d90c4c6fa676a02e6f2 (diff)
downloadnumpy-ae6da8374a055b5dfe738c03c41ff8e001f51180.tar.gz
Fix diag_indices_from and add test.
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r--numpy/lib/index_tricks.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index 737fc0a60..3388decf0 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -669,10 +669,11 @@ s_ = IndexExpression(maketuple=False)
# End contribution from Konrad.
-# I'm not sure this is the best place in numpy for these functions, but since
-# they handle multidimensional arrays, it seemed better than twodim_base.
-def fill_diagonal(a,val):
+# The following functions complement those in twodim_base, but are
+# applicable to N-dimensions.
+
+def fill_diagonal(a, val):
"""Fill the main diagonal of the given array of any dimensionality.
For an array with ndim > 2, the diagonal is the list of locations with
@@ -737,13 +738,13 @@ def fill_diagonal(a,val):
# all dimensions equal, so we check first.
if not alltrue(diff(a.shape)==0):
raise ValueError("All dimensions of input must be of equal length")
- step = cumprod((1,)+a.shape[:-1]).sum()
+ step = cumprod((1,) + a.shape[:-1]).sum()
# Write the value out into the diagonal.
a.flat[::step] = val
-def diag_indices(n,ndim=2):
+def diag_indices(n, ndim=2):
"""Return the indices to access the main diagonal of an array.
This returns a tuple of indices that can be used to access the main
@@ -758,7 +759,7 @@ def diag_indices(n,ndim=2):
indices can be used.
ndim : int, optional
- The number of dimensions
+ The number of dimensions.
Examples
--------
@@ -794,10 +795,10 @@ def diag_indices(n,ndim=2):
See also
--------
- diag_indices_from: create the indices based on the shape of an existing
- array.
+ array.
"""
idx = arange(n)
- return (idx,)*ndim
+ return (idx,) * ndim
def diag_indices_from(arr):
@@ -814,7 +815,7 @@ def diag_indices_from(arr):
raise ValueError("input array must be at least 2-d")
# For more than d=2, the strided formula is only valid for arrays with
# all dimensions equal, so we check first.
- if not alltrue(diff(a.shape)==0):
+ if not alltrue(diff(arr.shape) == 0):
raise ValueError("All dimensions of input must be of equal length")
- return diag_indices(a.shape[0],a.ndim)
+ return diag_indices(arr.shape[0], arr.ndim)