diff options
Diffstat (limited to 'numpy/lib/index_tricks.py')
-rw-r--r-- | numpy/lib/index_tricks.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index 009e6d229..ff2e00d3e 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -1,5 +1,6 @@ from __future__ import division, absolute_import, print_function +import functools import sys import math @@ -9,13 +10,17 @@ from numpy.core.numeric import ( ) from numpy.core.numerictypes import find_common_type, issubdtype -from . import function_base import numpy.matrixlib as matrixlib from .function_base import diff from numpy.core.multiarray import ravel_multi_index, unravel_index +from numpy.core import overrides, linspace from numpy.lib.stride_tricks import as_strided +array_function_dispatch = functools.partial( + overrides.array_function_dispatch, module='numpy') + + __all__ = [ 'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_', 's_', 'index_exp', 'ix_', 'ndenumerate', 'ndindex', 'fill_diagonal', @@ -23,6 +28,11 @@ __all__ = [ ] +def _ix__dispatcher(*args): + return args + + +@array_function_dispatch(_ix__dispatcher) def ix_(*args): """ Construct an open mesh from multiple sequences. @@ -194,9 +204,6 @@ class nd_grid(object): else: return _nx.arange(start, stop, step) - def __len__(self): - return 0 - class MGridClass(nd_grid): """ @@ -338,7 +345,7 @@ class AxisConcatenator(object): step = 1 if isinstance(step, complex): size = int(abs(step)) - newobj = function_base.linspace(start, stop, num=size) + newobj = linspace(start, stop, num=size) else: newobj = _nx.arange(start, stop, step) if ndmin > 1: @@ -729,6 +736,12 @@ s_ = IndexExpression(maketuple=False) # The following functions complement those in twodim_base, but are # applicable to N-dimensions. + +def _fill_diagonal_dispatcher(a, val, wrap=None): + return (a,) + + +@array_function_dispatch(_fill_diagonal_dispatcher) def fill_diagonal(a, val, wrap=False): """Fill the main diagonal of the given array of any dimensionality. @@ -911,6 +924,11 @@ def diag_indices(n, ndim=2): return (idx,) * ndim +def _diag_indices_from(arr): + return (arr,) + + +@array_function_dispatch(_diag_indices_from) def diag_indices_from(arr): """ Return the indices to access the main diagonal of an n-dimensional array. |