summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py68
-rw-r--r--numpy/core/tests/test_numeric.py36
2 files changed, 82 insertions, 22 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 7f1206855..11a95fa7b 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function
-import sys
+import collections
+import itertools
import operator
+import sys
import warnings
-import collections
+
from numpy.core import multiarray
from . import umath
from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
@@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None):
----------
a : array_like
Input array.
- shift : int
- The number of places by which elements are shifted.
- axis : int, optional
- The axis along which elements are shifted. By default, the array
- is flattened before shifting, after which the original
+ shift : int or tuple of ints
+ The number of places by which elements are shifted. If a tuple,
+ then `axis` must be a tuple of the same size, and each of the
+ given axes is shifted by the corresponding number. If an int
+ while `axis` is a tuple of ints, then the same value is used for
+ all given axes.
+ axis : int or tuple of ints, optional
+ Axis or axes along which elements are shifted. By default, the
+ array is flattened before shifting, after which the original
shape is restored.
Returns
@@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None):
rollaxis : Roll the specified axis backwards, until it lies in a
given position.
+ Notes
+ -----
+ .. versionadded:: 1.12.0
+
+ Supports rolling over multiple dimensions simultaneously.
+
Examples
--------
>>> x = np.arange(10)
@@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None):
"""
a = asanyarray(a)
if axis is None:
- n = a.size
- reshape = True
+ return roll(a.ravel(), shift, 0).reshape(a.shape)
+
else:
- try:
- n = a.shape[axis]
- except IndexError:
- raise ValueError('axis must be >= 0 and < %d' % a.ndim)
- reshape = False
- if n == 0:
- return a
- shift %= n
- indexes = concatenate((arange(n - shift, n), arange(n - shift)))
- res = a.take(indexes, axis)
- if reshape:
- res = res.reshape(a.shape)
- return res
+ broadcasted = broadcast(shift, axis)
+ if len(broadcasted.shape) > 1:
+ raise ValueError(
+ "'shift' and 'axis' should be scalars or 1D sequences")
+ shifts = {ax: 0 for ax in range(a.ndim)}
+ for sh, ax in broadcasted:
+ if -a.ndim <= ax < a.ndim:
+ shifts[ax % a.ndim] += sh
+ else:
+ raise ValueError("'axis' entry is out of bounds")
+
+ rolls = [((slice(None), slice(None)),)] * a.ndim
+ for ax, offset in shifts.items():
+ offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters.
+ if offset:
+ # (original, result), (original, result)
+ rolls[ax] = ((slice(None, -offset), slice(offset, None)),
+ (slice(-offset, None), slice(None, offset)))
+
+ result = empty_like(a)
+ for indices in itertools.product(*rolls):
+ arr_index, res_index = zip(*indices)
+ result[res_index] = a[arr_index]
+
+ return result
def rollaxis(a, axis, start=0):
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 0040f3a25..dd9c83b25 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2145,6 +2145,42 @@ class TestRoll(TestCase):
x2r = np.roll(x2, 1, axis=1)
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
+ # Roll multiple axes at once.
+ x2r = np.roll(x2, 1, axis=(0, 1))
+ assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]]))
+
+ x2r = np.roll(x2, (1, 0), axis=(0, 1))
+ assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))
+
+ x2r = np.roll(x2, (-1, 0), axis=(0, 1))
+ assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))
+
+ x2r = np.roll(x2, (0, 1), axis=(0, 1))
+ assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
+
+ x2r = np.roll(x2, (0, -1), axis=(0, 1))
+ assert_equal(x2r, np.array([[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]))
+
+ x2r = np.roll(x2, (1, 1), axis=(0, 1))
+ assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]]))
+
+ x2r = np.roll(x2, (-1, -1), axis=(0, 1))
+ assert_equal(x2r, np.array([[6, 7, 8, 9, 5], [1, 2, 3, 4, 0]]))
+
+ # Roll the same axis multiple times.
+ x2r = np.roll(x2, 1, axis=(0, 0))
+ assert_equal(x2r, np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]))
+
+ x2r = np.roll(x2, 1, axis=(1, 1))
+ assert_equal(x2r, np.array([[3, 4, 0, 1, 2], [8, 9, 5, 6, 7]]))
+
+ # Roll more than one turn in either direction.
+ x2r = np.roll(x2, 6, axis=1)
+ assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
+
+ x2r = np.roll(x2, -4, axis=1)
+ assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
+
def test_roll_empty(self):
x = np.array([])
assert_equal(np.roll(x, 1), np.array([]))