summaryrefslogtreecommitdiff
path: root/numpy/lib/shape_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/shape_base.py')
-rw-r--r--numpy/lib/shape_base.py47
1 files changed, 29 insertions, 18 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index a6d391728..70fa3ab03 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -1,16 +1,21 @@
from __future__ import division, absolute_import, print_function
-__all__ = ['column_stack', 'row_stack', 'dstack', 'array_split', 'split', 'hsplit',
- 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims',
- 'apply_along_axis', 'kron', 'tile', 'get_array_wrap']
-
import warnings
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, zeros, newaxis, outer, \
- concatenate, isscalar, array, asanyarray
+from numpy.core.numeric import (
+ asarray, zeros, outer, concatenate, isscalar, array, asanyarray
+ )
from numpy.core.fromnumeric import product, reshape
-from numpy.core import hstack, vstack, atleast_3d
+from numpy.core import vstack, atleast_3d
+
+
+__all__ = [
+ 'column_stack', 'row_stack', 'dstack', 'array_split', 'split',
+ 'hsplit', 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims',
+ 'apply_along_axis', 'kron', 'tile', 'get_array_wrap'
+ ]
+
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
@@ -196,7 +201,8 @@ def apply_over_axes(func, a, axes):
if array(axes).ndim == 0:
axes = (axes,)
for axis in axes:
- if axis < 0: axis = N + axis
+ if axis < 0:
+ axis = N + axis
args = (val, axis)
res = func(*args)
if res.ndim == val.ndim:
@@ -368,7 +374,7 @@ def _replace_zero_by_x_arrays(sub_arys):
sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype)
return sub_arys
-def array_split(ary,indices_or_sections,axis = 0):
+def array_split(ary, indices_or_sections, axis=0):
"""
Split an array into multiple sub-arrays.
@@ -392,23 +398,26 @@ def array_split(ary,indices_or_sections,axis = 0):
Ntotal = ary.shape[axis]
except AttributeError:
Ntotal = len(ary)
- try: # handle scalar case.
+ try:
+ # handle scalar case.
Nsections = len(indices_or_sections) + 1
div_points = [0] + list(indices_or_sections) + [Ntotal]
- except TypeError: #indices_or_sections is a scalar, not an array.
+ except TypeError:
+ # indices_or_sections is a scalar, not an array.
Nsections = int(indices_or_sections)
if Nsections <= 0:
raise ValueError('number sections must be larger than 0.')
Neach_section, extras = divmod(Ntotal, Nsections)
- section_sizes = [0] + \
- extras * [Neach_section+1] + \
- (Nsections-extras) * [Neach_section]
+ section_sizes = ([0] +
+ extras * [Neach_section+1] +
+ (Nsections-extras) * [Neach_section])
div_points = _nx.array(section_sizes).cumsum()
sub_arys = []
sary = _nx.swapaxes(ary, axis, 0)
for i in range(Nsections):
- st = div_points[i]; end = div_points[i+1]
+ st = div_points[i]
+ end = div_points[i + 1]
sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0))
# This "kludge" was introduced here to replace arrays shaped (0, 10)
@@ -488,12 +497,14 @@ def split(ary,indices_or_sections,axis=0):
array([], dtype=float64)]
"""
- try: len(indices_or_sections)
+ try:
+ len(indices_or_sections)
except TypeError:
sections = indices_or_sections
N = ary.shape[axis]
if N % sections:
- raise ValueError('array split does not result in an equal division')
+ raise ValueError(
+ 'array split does not result in an equal division')
res = array_split(ary, indices_or_sections, axis)
return res
@@ -845,7 +856,7 @@ def tile(A, reps):
if (d < c.ndim):
tup = (1,)*(c.ndim-d) + tup
for i, nrep in enumerate(tup):
- if nrep!=1:
+ if nrep != 1:
c = c.reshape(-1, n).repeat(nrep, 0)
dim_in = shape[i]
dim_out = dim_in*nrep