summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/__init__.py2
-rw-r--r--numpy/lib/stride_tricks.py109
-rw-r--r--numpy/lib/tests/test_stride_tricks.py206
3 files changed, 317 insertions, 0 deletions
diff --git a/numpy/lib/__init__.py b/numpy/lib/__init__.py
index aeb43fafb..f1b9cd2cf 100644
--- a/numpy/lib/__init__.py
+++ b/numpy/lib/__init__.py
@@ -5,6 +5,7 @@ from type_check import *
from index_tricks import *
from function_base import *
from shape_base import *
+from stride_tricks import *
from twodim_base import *
from ufunclike import *
@@ -24,6 +25,7 @@ __all__ += type_check.__all__
__all__ += index_tricks.__all__
__all__ += function_base.__all__
__all__ += shape_base.__all__
+__all__ += stride_tricks.__all__
__all__ += twodim_base.__all__
__all__ += ufunclike.__all__
__all__ += polynomial.__all__
diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py
new file mode 100644
index 000000000..25987362f
--- /dev/null
+++ b/numpy/lib/stride_tricks.py
@@ -0,0 +1,109 @@
+""" Utilities that manipulate strides to achieve desirable effects.
+"""
+import numpy as np
+
+__all__ = ['broadcast_arrays']
+
+class DummyArray(object):
+ """ Dummy object that just exists to hang __array_interface__ dictionaries
+ and possibly keep alive a reference to a base array.
+ """
+ def __init__(self, interface, base=None):
+ self.__array_interface__ = interface
+ self.base = base
+
+def as_strided(x, shape=None, strides=None):
+ """ Make an ndarray from the given array with the given shape and strides.
+ """
+ interface = dict(x.__array_interface__)
+ if shape is not None:
+ interface['shape'] = tuple(shape)
+ if strides is not None:
+ interface['strides'] = tuple(strides)
+ return np.asarray(DummyArray(interface, base=x))
+
+def broadcast_arrays(*args):
+ """ Broadcast any number of arrays against each other.
+
+ Parameters
+ ----------
+ *args : arrays
+
+ Returns
+ -------
+ broadcasted : list of arrays
+ These arrays are views on the original arrays. They are typically not
+ contiguous. Furthermore, more than one element of a broadcasted array
+ may refer to a single memory location. If you need to write to the
+ arrays, make copies first.
+
+ Examples
+ --------
+ >>> x = np.array([[1,2,3]])
+ >>> y = np.array([[1],[2],[3]])
+ >>> np.broadcast_arrays(x, y)
+ [array([[1, 2, 3],
+ [1, 2, 3],
+ [1, 2, 3]]), array([[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]])]
+
+ Here is a useful idiom for getting contiguous copies instead of
+ non-contiguous views.
+
+ >>> map(np.array, np.broadcast_arrays(x, y))
+ [array([[1, 2, 3],
+ [1, 2, 3],
+ [1, 2, 3]]), array([[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]])]
+
+ """
+ args = map(np.asarray, args)
+ shapes = [x.shape for x in args]
+ if len(set(shapes)) == 1:
+ # Common case where nothing needs to be broadcasted.
+ return args
+ shapes = [list(s) for s in shapes]
+ strides = [list(x.strides) for x in args]
+ nds = [len(s) for s in shapes]
+ biggest = max(nds)
+ # Go through each array and prepend dimensions of length 1 to each of the
+ # shapes in order to make the number of dimensions equal.
+ for i in range(len(args)):
+ diff = biggest - nds[i]
+ if diff > 0:
+ shapes[i] = [1] * diff + shapes[i]
+ strides[i] = [0] * diff + strides[i]
+ # Chech each dimension for compatibility. A dimension length of 1 is
+ # accepted as compatible with any other length.
+ common_shape = []
+ for axis in range(biggest):
+ lengths = [s[axis] for s in shapes]
+ unique = set(lengths + [1])
+ if len(unique) > 2:
+ # There must be at least two non-1 lengths for this axis.
+ raise ValueError("shape mismatch: two or more arrays have "
+ "incompatible dimensions on axis %r." % (axis,))
+ elif len(unique) == 2:
+ # There is exactly one non-1 length. The common shape will take this
+ # value.
+ unique.remove(1)
+ new_length = unique.pop()
+ common_shape.append(new_length)
+ # For each array, if this axis is being broadcasted from a length of
+ # 1, then set its stride to 0 so that it repeats its data.
+ for i in range(len(args)):
+ if shapes[i][axis] == 1:
+ shapes[i][axis] = new_length
+ strides[i][axis] = 0
+ else:
+ # Every array has a length of 1 on this axis. Strides can be left
+ # alone as nothing is broadcasted.
+ common_shape.append(1)
+
+ # Construct the new arrays.
+ broadcasted = [as_strided(x, shape=sh, strides=st) for (x,sh,st) in
+ zip(args, shapes, strides)]
+ return broadcasted
+
diff --git a/numpy/lib/tests/test_stride_tricks.py b/numpy/lib/tests/test_stride_tricks.py
new file mode 100644
index 000000000..955a2cbc7
--- /dev/null
+++ b/numpy/lib/tests/test_stride_tricks.py
@@ -0,0 +1,206 @@
+from nose.tools import assert_raises
+import numpy as np
+from numpy.testing import assert_array_equal
+
+from numpy.lib.stride_tricks import broadcast_arrays
+
+
+def assert_shapes_correct(input_shapes, expected_shape):
+ """ Broadcast a list of arrays with the given input shapes and check the
+ common output shape.
+ """
+ inarrays = [np.zeros(s) for s in input_shapes]
+ outarrays = broadcast_arrays(*inarrays)
+ outshapes = [a.shape for a in outarrays]
+ expected = [expected_shape] * len(inarrays)
+ assert outshapes == expected
+
+def assert_incompatible_shapes_raise(input_shapes):
+ """ Broadcast a list of arrays with the given (incompatible) input shapes
+ and check that they raise a ValueError.
+ """
+ inarrays = [np.zeros(s) for s in input_shapes]
+ assert_raises(ValueError, broadcast_arrays, *inarrays)
+
+def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False):
+ """ Broadcast two shapes against each other and check that the data layout
+ is the same as if a ufunc did the broadcasting.
+ """
+ x0 = np.zeros(shape0, dtype=int)
+ # Note that multiply.reduce's identity element is 1.0, so when shape1==(),
+ # this gives the desired n==1.
+ n = int(np.multiply.reduce(shape1))
+ x1 = np.arange(n).reshape(shape1)
+ if transposed:
+ x0 = x0.T
+ x1 = x1.T
+ if flipped:
+ x0 = x0[::-1]
+ x1 = x1[::-1]
+ # Use the add ufunc to do the broadcasting. Since we're adding 0s to x1, the
+ # result should be exactly the same as the broadcasted view of x1.
+ y = x0 + x1
+ b0, b1 = broadcast_arrays(x0, x1)
+ assert_array_equal(y, b1)
+
+
+def test_same():
+ x = np.arange(10)
+ y = np.arange(10)
+ bx, by = broadcast_arrays(x, y)
+ assert_array_equal(x, bx)
+ assert_array_equal(y, by)
+
+def test_one_off():
+ x = np.array([[1,2,3]])
+ y = np.array([[1],[2],[3]])
+ bx, by = broadcast_arrays(x, y)
+ bx0 = np.array([[1,2,3],[1,2,3],[1,2,3]])
+ by0 = bx0.T
+ assert_array_equal(bx0, bx)
+ assert_array_equal(by0, by)
+
+def test_same_input_shapes():
+ """ Check that the final shape is just the input shape.
+ """
+ data = [
+ (),
+ (1,),
+ (3,),
+ (0,1),
+ (0,3),
+ (1,0),
+ (3,0),
+ (1,3),
+ (3,1),
+ (3,3),
+ ]
+ for shape in data:
+ input_shapes = [shape]
+ # Single input.
+ yield assert_shapes_correct, input_shapes, shape
+ # Double input.
+ input_shapes2 = [shape, shape]
+ yield assert_shapes_correct, input_shapes2, shape
+ # Triple input.
+ input_shapes3 = [shape, shape, shape]
+ yield assert_shapes_correct, input_shapes3, shape
+
+def test_two_compatible_by_ones_input_shapes():
+ """ Check that two different input shapes (of the same length but some have
+ 1s) broadcast to the correct shape.
+ """
+ data = [
+ [[(1,), (3,)], (3,)],
+ [[(1,3), (3,3)], (3,3)],
+ [[(3,1), (3,3)], (3,3)],
+ [[(1,3), (3,1)], (3,3)],
+ [[(1,1), (3,3)], (3,3)],
+ [[(1,1), (1,3)], (1,3)],
+ [[(1,1), (3,1)], (3,1)],
+ [[(1,0), (0,0)], (0,0)],
+ [[(0,1), (0,0)], (0,0)],
+ [[(1,0), (0,1)], (0,0)],
+ [[(1,1), (0,0)], (0,0)],
+ [[(1,1), (1,0)], (1,0)],
+ [[(1,1), (0,1)], (0,1)],
+ ]
+ for input_shapes, expected_shape in data:
+ yield assert_shapes_correct, input_shapes, expected_shape
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_shapes_correct, input_shapes[::-1], expected_shape
+
+def test_two_compatible_by_prepending_ones_input_shapes():
+ """ Check that two different input shapes (of different lengths) broadcast
+ to the correct shape.
+ """
+ data = [
+ [[(), (3,)], (3,)],
+ [[(3,), (3,3)], (3,3)],
+ [[(3,), (3,1)], (3,3)],
+ [[(1,), (3,3)], (3,3)],
+ [[(), (3,3)], (3,3)],
+ [[(1,1), (3,)], (1,3)],
+ [[(1,), (3,1)], (3,1)],
+ [[(1,), (1,3)], (1,3)],
+ [[(), (1,3)], (1,3)],
+ [[(), (3,1)], (3,1)],
+ [[(), (0,)], (0,)],
+ [[(0,), (0,0)], (0,0)],
+ [[(0,), (0,1)], (0,0)],
+ [[(1,), (0,0)], (0,0)],
+ [[(), (0,0)], (0,0)],
+ [[(1,1), (0,)], (1,0)],
+ [[(1,), (0,1)], (0,1)],
+ [[(1,), (1,0)], (1,0)],
+ [[(), (1,0)], (1,0)],
+ [[(), (0,1)], (0,1)],
+ ]
+ for input_shapes, expected_shape in data:
+ yield assert_shapes_correct, input_shapes, expected_shape
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_shapes_correct, input_shapes[::-1], expected_shape
+
+def test_incompatible_shapes_raise_valueerror():
+ """ Check that a ValueError is raised for incompatible shapes.
+ """
+ data = [
+ [(3,), (4,)],
+ [(2,3), (2,)],
+ [(3,), (3,), (4,)],
+ [(1,3,4), (2,3,3)],
+ ]
+ for input_shapes in data:
+ yield assert_incompatible_shapes_raise, input_shapes
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_incompatible_shapes_raise, input_shapes[::-1]
+
+def test_same_as_ufunc():
+ """ Check that the data layout is the same as if a ufunc did the operation.
+ """
+ data = [
+ [[(1,), (3,)], (3,)],
+ [[(1,3), (3,3)], (3,3)],
+ [[(3,1), (3,3)], (3,3)],
+ [[(1,3), (3,1)], (3,3)],
+ [[(1,1), (3,3)], (3,3)],
+ [[(1,1), (1,3)], (1,3)],
+ [[(1,1), (3,1)], (3,1)],
+ [[(1,0), (0,0)], (0,0)],
+ [[(0,1), (0,0)], (0,0)],
+ [[(1,0), (0,1)], (0,0)],
+ [[(1,1), (0,0)], (0,0)],
+ [[(1,1), (1,0)], (1,0)],
+ [[(1,1), (0,1)], (0,1)],
+ [[(), (3,)], (3,)],
+ [[(3,), (3,3)], (3,3)],
+ [[(3,), (3,1)], (3,3)],
+ [[(1,), (3,3)], (3,3)],
+ [[(), (3,3)], (3,3)],
+ [[(1,1), (3,)], (1,3)],
+ [[(1,), (3,1)], (3,1)],
+ [[(1,), (1,3)], (1,3)],
+ [[(), (1,3)], (1,3)],
+ [[(), (3,1)], (3,1)],
+ [[(), (0,)], (0,)],
+ [[(0,), (0,0)], (0,0)],
+ [[(0,), (0,1)], (0,0)],
+ [[(1,), (0,0)], (0,0)],
+ [[(), (0,0)], (0,0)],
+ [[(1,1), (0,)], (1,0)],
+ [[(1,), (0,1)], (0,1)],
+ [[(1,), (1,0)], (1,0)],
+ [[(), (1,0)], (1,0)],
+ [[(), (0,1)], (0,1)],
+ ]
+ for input_shapes, expected_shape in data:
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1]
+ # Reverse the input shapes since broadcasting should be symmetric.
+ yield assert_same_as_ufunc, input_shapes[1], input_shapes[0]
+ # Try them transposed, too.
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True
+ # ... and flipped for non-rank-0 inputs in order to test negative
+ # strides.
+ if () not in input_shapes:
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], False, True
+ yield assert_same_as_ufunc, input_shapes[0], input_shapes[1], True, True