diff options
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r-- | numpy/lib/shape_base.py | 47 |
1 files changed, 29 insertions, 18 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index a6d391728..70fa3ab03 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -1,16 +1,21 @@ from __future__ import division, absolute_import, print_function -__all__ = ['column_stack', 'row_stack', 'dstack', 'array_split', 'split', 'hsplit', - 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims', - 'apply_along_axis', 'kron', 'tile', 'get_array_wrap'] - import warnings import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, zeros, newaxis, outer, \ - concatenate, isscalar, array, asanyarray +from numpy.core.numeric import ( + asarray, zeros, outer, concatenate, isscalar, array, asanyarray + ) from numpy.core.fromnumeric import product, reshape -from numpy.core import hstack, vstack, atleast_3d +from numpy.core import vstack, atleast_3d + + +__all__ = [ + 'column_stack', 'row_stack', 'dstack', 'array_split', 'split', + 'hsplit', 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims', + 'apply_along_axis', 'kron', 'tile', 'get_array_wrap' + ] + def apply_along_axis(func1d, axis, arr, *args, **kwargs): """ @@ -196,7 +201,8 @@ def apply_over_axes(func, a, axes): if array(axes).ndim == 0: axes = (axes,) for axis in axes: - if axis < 0: axis = N + axis + if axis < 0: + axis = N + axis args = (val, axis) res = func(*args) if res.ndim == val.ndim: @@ -368,7 +374,7 @@ def _replace_zero_by_x_arrays(sub_arys): sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype) return sub_arys -def array_split(ary,indices_or_sections,axis = 0): +def array_split(ary, indices_or_sections, axis=0): """ Split an array into multiple sub-arrays. @@ -392,23 +398,26 @@ def array_split(ary,indices_or_sections,axis = 0): Ntotal = ary.shape[axis] except AttributeError: Ntotal = len(ary) - try: # handle scalar case. + try: + # handle scalar case. Nsections = len(indices_or_sections) + 1 div_points = [0] + list(indices_or_sections) + [Ntotal] - except TypeError: #indices_or_sections is a scalar, not an array. + except TypeError: + # indices_or_sections is a scalar, not an array. Nsections = int(indices_or_sections) if Nsections <= 0: raise ValueError('number sections must be larger than 0.') Neach_section, extras = divmod(Ntotal, Nsections) - section_sizes = [0] + \ - extras * [Neach_section+1] + \ - (Nsections-extras) * [Neach_section] + section_sizes = ([0] + + extras * [Neach_section+1] + + (Nsections-extras) * [Neach_section]) div_points = _nx.array(section_sizes).cumsum() sub_arys = [] sary = _nx.swapaxes(ary, axis, 0) for i in range(Nsections): - st = div_points[i]; end = div_points[i+1] + st = div_points[i] + end = div_points[i + 1] sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0)) # This "kludge" was introduced here to replace arrays shaped (0, 10) @@ -488,12 +497,14 @@ def split(ary,indices_or_sections,axis=0): array([], dtype=float64)] """ - try: len(indices_or_sections) + try: + len(indices_or_sections) except TypeError: sections = indices_or_sections N = ary.shape[axis] if N % sections: - raise ValueError('array split does not result in an equal division') + raise ValueError( + 'array split does not result in an equal division') res = array_split(ary, indices_or_sections, axis) return res @@ -845,7 +856,7 @@ def tile(A, reps): if (d < c.ndim): tup = (1,)*(c.ndim-d) + tup for i, nrep in enumerate(tup): - if nrep!=1: + if nrep != 1: c = c.reshape(-1, n).repeat(nrep, 0) dim_in = shape[i] dim_out = dim_in*nrep |