diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-04-03 12:35:44 +0100 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2017-04-20 22:23:13 +0100 |
commit | 00b55fc7366a21d19518ff009917f2cd47d41562 (patch) | |
tree | 13b2ab28ecbc4cdda87f1f28b7ec4fd61f1bc950 /numpy/core/shape_base.py | |
parent | 7b64ca6a1278195ed109532d42c9ccb334a4438a (diff) | |
download | numpy-00b55fc7366a21d19518ff009917f2cd47d41562.tar.gz |
ENH: add support for nd inputs to block
This changes the API to not have any special cases for 2d arrays, and accept a
single input for clarity
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r-- | numpy/core/shape_base.py | 141 |
1 files changed, 99 insertions, 42 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index c85617e28..5f7fb9d57 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -361,7 +361,7 @@ def stack(arrays, axis=0): return _nx.concatenate(expanded_arrays, axis=axis) -def block(*arrays): +def block(arrays): """ Assemble an array from nested lists of blocks. @@ -370,21 +370,34 @@ def block(*arrays): Parameters ---------- - arrays : sequence of sequence of ndarrays - 1-D arrays are treated as row vectors. + arrays : nested list/tuple of ndarrays or scalars + lists and tuples are treated as sequence, everything else is treated + as an element to concatenate. + + Inputs are normalized to have uniform depth by wrapping elements in + extra layers of lists - for instance: + * ``[[[a, b], c], d]`` is normalized to ``[[[a, b], [c]], [[d]]]`` + * ``[[[a]], b]`` is normalized to ``[[[a]], [[b]]]`` + + After the above normalization, the innermost lists are `concatenate`d + along the last dimension, the second-innermost along the second-last + dimensions, etc. Returns ------- blocked : ndarray - The 2-D array assembled from the given blocks. + The array assembled from the given blocks. + The dimensionality of the output is determined by the dimensionality of + all the inputs, and the degree to which the input list is nested - + whichever is greatest. See Also -------- + concatenate : Join a sequence of arrays together. stack : Stack arrays in sequence along a new dimension. hstack : Stack arrays in sequence horizontally (column wise). vstack : Stack arrays in sequence vertically (row wise). dstack : Stack arrays in sequence depth wise (along third dimension). - concatenate : Join a sequence of arrays together. vsplit : Split array into a list of multiple sub-arrays vertically. Notes @@ -393,31 +406,41 @@ def block(*arrays): Examples -------- - Stacking in a row: - >>> A = np.array([[1, 2, 3]]) - >>> B = np.array([[2, 3, 4]]) + Stacking scalars in a row: + >>> block([1, 2, 3]) + array([1, 2, 3]) + + Stacking scalars with 1d arrays: + >>> a = np.array([2, 3]) + >>> block([1, a]) + np.array([1, 2, 3]) + + Stacking 1d arrays in a row: + >>> A = np.array([1, 2, 3]) + >>> B = np.array([2, 3, 4]) >>> block([A, B]) - array([[1, 2, 3, 2, 3, 4]]) + array([1, 2, 3, 2, 3, 4]) - Stacking in a column: + Stacking 2d row-vectors in a row: >>> A = np.array([[1, 2, 3]]) >>> B = np.array([[2, 3, 4]]) - >>> block(A, B) - array([[1, 2, 3], - [2, 3, 4]]) - - 1-D vectors are treated as row arrays - >>> a = np.array([1, 1]) - >>> b = np.array([2, 2]) - >>> block([a, b]) - array([[1, 1, 2, 2]]) + >>> block([A, B]) + array([[1, 2, 3, 2, 3, 4]]) + Stacking 1d arrays in a column: >>> a = np.array([1, 1]) >>> b = np.array([2, 2]) - >>> block(a, b) + >>> block([[a], [b]]) array([[1, 1], [2, 2]]) + Stacking 2d row-vectors in a column: + >>> A = np.array([[1, 2, 3]]) + >>> B = np.array([[2, 3, 4]]) + >>> block([[A], [B]]) + array([[1, 2, 3], + [2, 3, 4]]) + The tuple notation also works: >>> A = np.ones((2, 2)) >>> B = 2 * A @@ -426,18 +449,17 @@ def block(*arrays): [1, 1, 2, 2]]) Block array with arbitrary shaped elements - >>> One = np.array([[1, 1, 1]]) - >>> Two = np.array([[2, 2, 2]]) - >>> Three = np.array([[3, 3, 3, 3, 3, 3]]) + >>> one = np.array([[1, 1, 1]]) + >>> two = np.array([[2, 2, 2]]) + >>> three = np.array([[3, 3, 3, 3, 3, 3]]) >>> four = np.array([4, 4, 4, 4, 4, 4]) - >>> five = np.array([5]) >>> six = np.array([6, 6, 6, 6, 6]) - >>> Zeros = np.zeros((2, 6), dtype=int) - >>> block([One, Two], - ... Three, - ... four, - ... [five, six], - ... Zeros) + >>> zeros = np.zeros((2, 6), dtype=int) + >>> block([[one, two], + ... [three ], + ... [four ], + ... [5, six], + ... [zeros ]) array([[1, 1, 1, 2, 2, 2], [3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4], @@ -445,19 +467,54 @@ def block(*arrays): [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]) - """ - if len(arrays) < 1: - raise TypeError("need at least one array to create a block array") - result = [] - for row in arrays: - if isinstance(row, (list, tuple)): - result.append(hstack(row)) + def is_element(x): + return not isinstance(x, (list, tuple)) + + def recursive_map(x, base, aggregate=list): + """ + Iterate over the nested list, applying `base` to items, and `aggregate` + to iterables of mapped items + """ + def f(x): + if is_element(x): + return base(x) + else: + return aggregate(f(xi) for xi in x) + return f(x) + + def exactly_nd(x, ndim): + x = asanyarray(x) + shape = [1] * ndim + shape[ndim-x.ndim:] = x.shape + return x.reshape(shape) + + def max_or_0(xs): + """ Like max, but returns 0 on an empty iterable """ + xs = list(xs) + return max(xs) if xs else 0 + + # convert all the arrays to ndarrays + arrays = recursive_map(arrays, base=asanyarray) + + # determine the final number of dimensions + list_ndim = recursive_map(arrays, base=lambda x: 0, + aggregate=lambda xs: max_or_0(xs) + 1) + elem_ndim = recursive_map(arrays, base=lambda x: x.ndim, aggregate=max_or_0) + ndim = max(list_ndim, elem_ndim) + + # Make all the elements the same dimension + arrays = recursive_map(arrays, base=lambda x: exactly_nd(x, ndim)) + + # concate + def _concatenate_recursive(x, axis): + if is_element(x): + return x else: - result.append(row) + return _nx.concatenate([ + _concatenate_recursive(xi, axis=axis+1) + for xi in x + ], axis=axis) - if len(result) > 1: - return vstack(result) - else: - return atleast_2d(result[0]) + return _concatenate_recursive(arrays, -list_ndim)
\ No newline at end of file |