diff options
-rw-r--r-- | benchmarks/benchmarks/bench_shape_base.py | 18 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 33 |
2 files changed, 33 insertions, 18 deletions
diff --git a/benchmarks/benchmarks/bench_shape_base.py b/benchmarks/benchmarks/bench_shape_base.py index 6edad2ea3..e48ea0adb 100644 --- a/benchmarks/benchmarks/bench_shape_base.py +++ b/benchmarks/benchmarks/bench_shape_base.py @@ -69,6 +69,24 @@ class Block(Benchmark): np.block(np.eye(3 * n)) +class Block2D(Benchmark): + params = [[(16, 16), (32, 32), (64, 64), (128, 128), (256, 256), (512, 512), (1024, 1024)], + ['uint8', 'uint16', 'uint32', 'uint64'], + [(2, 2), (4, 4)]] + param_names = ['shape', 'dtype', 'n_chunks'] + + def setup(self, shape, dtype, n_chunks): + + self.block_list = [ + [np.full(shape=[s//n_chunk for s, n_chunk in zip(shape, n_chunks)], + fill_value=1, dtype=dtype) for _ in range(n_chunks[1])] + for _ in range(n_chunks[0]) + ] + + def time_block2d(self, shape, dtype, n_chunks): + np.block(self.block_list) + + class Block3D(Benchmark): params = [1, 10, 100] param_names = ['size'] diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 30919ed7e..52717abda 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -7,7 +7,6 @@ __all__ = ['atleast_1d', 'atleast_2d', 'atleast_3d', 'block', 'hstack', from . import numeric as _nx from .numeric import array, asanyarray, newaxis from .multiarray import normalize_axis_index -from ._internal import recursive def atleast_1d(*arys): """ @@ -438,7 +437,13 @@ def _block_check_depths_match(arrays, parent_index=[]): return parent_index, _nx.ndim(arrays) -def _block(arrays, max_depth, result_ndim): +def _atleast_nd(a, ndim): + # Ensures `a` has at least `ndim` dimensions by prepending + # ones to `a.shape` as necessary + return array(a, ndmin=ndim, copy=False, subok=True) + + +def _block(arrays, max_depth, result_ndim, depth=0): """ Internal implementation of block. `arrays` is the argument passed to block. `max_depth` is the depth of nested lists within `arrays` and @@ -446,22 +451,14 @@ def _block(arrays, max_depth, result_ndim): `arrays` and the depth of the lists in `arrays` (see block docstring for details). """ - def atleast_nd(a, ndim): - # Ensures `a` has at least `ndim` dimensions by prepending - # ones to `a.shape` as necessary - return array(a, ndmin=ndim, copy=False, subok=True) - - @recursive - def block_recursion(self, arrays, depth=0): - if depth < max_depth: - arrs = [self(arr, depth+1) for arr in arrays] - return _nx.concatenate(arrs, axis=-(max_depth-depth)) - else: - # We've 'bottomed out' - arrays is either a scalar or an array - # type(arrays) is not list - return atleast_nd(arrays, result_ndim) - - return block_recursion(arrays) + if depth < max_depth: + arrs = [_block(arr, max_depth, result_ndim, depth+1) + for arr in arrays] + return _nx.concatenate(arrs, axis=-(max_depth-depth)) + else: + # We've 'bottomed out' - arrays is either a scalar or an array + # type(arrays) is not list + return _atleast_nd(arrays, result_ndim) def block(arrays): |