summaryrefslogtreecommitdiff
path: root/numpy/core/shape_base.py
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-04-03 12:35:44 +0100
committerEric Wieser <wieser.eric@gmail.com>2017-04-20 22:23:13 +0100
commit00b55fc7366a21d19518ff009917f2cd47d41562 (patch)
tree13b2ab28ecbc4cdda87f1f28b7ec4fd61f1bc950 /numpy/core/shape_base.py
parent7b64ca6a1278195ed109532d42c9ccb334a4438a (diff)
downloadnumpy-00b55fc7366a21d19518ff009917f2cd47d41562.tar.gz
ENH: add support for nd inputs to block
This changes the API to not have any special cases for 2d arrays, and accept a single input for clarity
Diffstat (limited to 'numpy/core/shape_base.py')
-rw-r--r--numpy/core/shape_base.py141
1 files changed, 99 insertions, 42 deletions
diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py
index c85617e28..5f7fb9d57 100644
--- a/numpy/core/shape_base.py
+++ b/numpy/core/shape_base.py
@@ -361,7 +361,7 @@ def stack(arrays, axis=0):
return _nx.concatenate(expanded_arrays, axis=axis)
-def block(*arrays):
+def block(arrays):
"""
Assemble an array from nested lists of blocks.
@@ -370,21 +370,34 @@ def block(*arrays):
Parameters
----------
- arrays : sequence of sequence of ndarrays
- 1-D arrays are treated as row vectors.
+ arrays : nested list/tuple of ndarrays or scalars
+ lists and tuples are treated as sequence, everything else is treated
+ as an element to concatenate.
+
+ Inputs are normalized to have uniform depth by wrapping elements in
+ extra layers of lists - for instance:
+ * ``[[[a, b], c], d]`` is normalized to ``[[[a, b], [c]], [[d]]]``
+ * ``[[[a]], b]`` is normalized to ``[[[a]], [[b]]]``
+
+ After the above normalization, the innermost lists are `concatenate`d
+ along the last dimension, the second-innermost along the second-last
+ dimensions, etc.
Returns
-------
blocked : ndarray
- The 2-D array assembled from the given blocks.
+ The array assembled from the given blocks.
+ The dimensionality of the output is determined by the dimensionality of
+ all the inputs, and the degree to which the input list is nested -
+ whichever is greatest.
See Also
--------
+ concatenate : Join a sequence of arrays together.
stack : Stack arrays in sequence along a new dimension.
hstack : Stack arrays in sequence horizontally (column wise).
vstack : Stack arrays in sequence vertically (row wise).
dstack : Stack arrays in sequence depth wise (along third dimension).
- concatenate : Join a sequence of arrays together.
vsplit : Split array into a list of multiple sub-arrays vertically.
Notes
@@ -393,31 +406,41 @@ def block(*arrays):
Examples
--------
- Stacking in a row:
- >>> A = np.array([[1, 2, 3]])
- >>> B = np.array([[2, 3, 4]])
+ Stacking scalars in a row:
+ >>> block([1, 2, 3])
+ array([1, 2, 3])
+
+ Stacking scalars with 1d arrays:
+ >>> a = np.array([2, 3])
+ >>> block([1, a])
+ np.array([1, 2, 3])
+
+ Stacking 1d arrays in a row:
+ >>> A = np.array([1, 2, 3])
+ >>> B = np.array([2, 3, 4])
>>> block([A, B])
- array([[1, 2, 3, 2, 3, 4]])
+ array([1, 2, 3, 2, 3, 4])
- Stacking in a column:
+ Stacking 2d row-vectors in a row:
>>> A = np.array([[1, 2, 3]])
>>> B = np.array([[2, 3, 4]])
- >>> block(A, B)
- array([[1, 2, 3],
- [2, 3, 4]])
-
- 1-D vectors are treated as row arrays
- >>> a = np.array([1, 1])
- >>> b = np.array([2, 2])
- >>> block([a, b])
- array([[1, 1, 2, 2]])
+ >>> block([A, B])
+ array([[1, 2, 3, 2, 3, 4]])
+ Stacking 1d arrays in a column:
>>> a = np.array([1, 1])
>>> b = np.array([2, 2])
- >>> block(a, b)
+ >>> block([[a], [b]])
array([[1, 1],
[2, 2]])
+ Stacking 2d row-vectors in a column:
+ >>> A = np.array([[1, 2, 3]])
+ >>> B = np.array([[2, 3, 4]])
+ >>> block([[A], [B]])
+ array([[1, 2, 3],
+ [2, 3, 4]])
+
The tuple notation also works:
>>> A = np.ones((2, 2))
>>> B = 2 * A
@@ -426,18 +449,17 @@ def block(*arrays):
[1, 1, 2, 2]])
Block array with arbitrary shaped elements
- >>> One = np.array([[1, 1, 1]])
- >>> Two = np.array([[2, 2, 2]])
- >>> Three = np.array([[3, 3, 3, 3, 3, 3]])
+ >>> one = np.array([[1, 1, 1]])
+ >>> two = np.array([[2, 2, 2]])
+ >>> three = np.array([[3, 3, 3, 3, 3, 3]])
>>> four = np.array([4, 4, 4, 4, 4, 4])
- >>> five = np.array([5])
>>> six = np.array([6, 6, 6, 6, 6])
- >>> Zeros = np.zeros((2, 6), dtype=int)
- >>> block([One, Two],
- ... Three,
- ... four,
- ... [five, six],
- ... Zeros)
+ >>> zeros = np.zeros((2, 6), dtype=int)
+ >>> block([[one, two],
+ ... [three ],
+ ... [four ],
+ ... [5, six],
+ ... [zeros ])
array([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4],
@@ -445,19 +467,54 @@ def block(*arrays):
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]])
-
"""
- if len(arrays) < 1:
- raise TypeError("need at least one array to create a block array")
- result = []
- for row in arrays:
- if isinstance(row, (list, tuple)):
- result.append(hstack(row))
+ def is_element(x):
+ return not isinstance(x, (list, tuple))
+
+ def recursive_map(x, base, aggregate=list):
+ """
+ Iterate over the nested list, applying `base` to items, and `aggregate`
+ to iterables of mapped items
+ """
+ def f(x):
+ if is_element(x):
+ return base(x)
+ else:
+ return aggregate(f(xi) for xi in x)
+ return f(x)
+
+ def exactly_nd(x, ndim):
+ x = asanyarray(x)
+ shape = [1] * ndim
+ shape[ndim-x.ndim:] = x.shape
+ return x.reshape(shape)
+
+ def max_or_0(xs):
+ """ Like max, but returns 0 on an empty iterable """
+ xs = list(xs)
+ return max(xs) if xs else 0
+
+ # convert all the arrays to ndarrays
+ arrays = recursive_map(arrays, base=asanyarray)
+
+ # determine the final number of dimensions
+ list_ndim = recursive_map(arrays, base=lambda x: 0,
+ aggregate=lambda xs: max_or_0(xs) + 1)
+ elem_ndim = recursive_map(arrays, base=lambda x: x.ndim, aggregate=max_or_0)
+ ndim = max(list_ndim, elem_ndim)
+
+ # Make all the elements the same dimension
+ arrays = recursive_map(arrays, base=lambda x: exactly_nd(x, ndim))
+
+ # concate
+ def _concatenate_recursive(x, axis):
+ if is_element(x):
+ return x
else:
- result.append(row)
+ return _nx.concatenate([
+ _concatenate_recursive(xi, axis=axis+1)
+ for xi in x
+ ], axis=axis)
- if len(result) > 1:
- return vstack(result)
- else:
- return atleast_2d(result[0])
+ return _concatenate_recursive(arrays, -list_ndim) \ No newline at end of file