summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/fromnumeric.py2
-rw-r--r--numpy/core/numeric.py107
-rw-r--r--numpy/core/tests/test_numeric.py77
3 files changed, 175 insertions, 11 deletions
diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py
index a2937c5c5..67d2c5b48 100644
--- a/numpy/core/fromnumeric.py
+++ b/numpy/core/fromnumeric.py
@@ -518,7 +518,7 @@ def transpose(a, axes=None):
See Also
--------
- rollaxis
+ moveaxis
argsort
Notes
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 4f3d418e6..a18b38072 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function
import sys
+import operator
import warnings
import collections
from numpy.core import multiarray
@@ -15,8 +16,10 @@ from ._internal import TooHardError
if sys.version_info[0] >= 3:
import pickle
basestring = str
+ import builtins
else:
import cPickle as pickle
+ import __builtin__ as builtins
loads = pickle.loads
@@ -31,15 +34,15 @@ __all__ = [
'ascontiguousarray', 'asfortranarray', 'isfortran', 'empty_like',
'zeros_like', 'ones_like', 'correlate', 'convolve', 'inner', 'dot',
'einsum', 'outer', 'vdot', 'alterdot', 'restoredot', 'roll',
- 'rollaxis', 'cross', 'tensordot', 'array2string', 'get_printoptions',
- 'set_printoptions', 'array_repr', 'array_str', 'set_string_function',
- 'little_endian', 'require', 'fromiter', 'array_equal', 'array_equiv',
- 'indices', 'fromfunction', 'isclose', 'load', 'loads', 'isscalar',
- 'binary_repr', 'base_repr', 'ones', 'identity', 'allclose',
- 'compare_chararrays', 'putmask', 'seterr', 'geterr', 'setbufsize',
- 'getbufsize', 'seterrcall', 'geterrcall', 'errstate', 'flatnonzero',
- 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_', 'True_',
- 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
+ 'rollaxis', 'moveaxis', 'cross', 'tensordot', 'array2string',
+ 'get_printoptions', 'set_printoptions', 'array_repr', 'array_str',
+ 'set_string_function', 'little_endian', 'require', 'fromiter',
+ 'array_equal', 'array_equiv', 'indices', 'fromfunction', 'isclose', 'load',
+ 'loads', 'isscalar', 'binary_repr', 'base_repr', 'ones', 'identity',
+ 'allclose', 'compare_chararrays', 'putmask', 'seterr', 'geterr',
+ 'setbufsize', 'getbufsize', 'seterrcall', 'geterrcall', 'errstate',
+ 'flatnonzero', 'Inf', 'inf', 'infty', 'Infinity', 'nan', 'NaN', 'False_',
+ 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
'TooHardError',
@@ -1422,6 +1425,7 @@ def rollaxis(a, axis, start=0):
See Also
--------
+ moveaxis : Move array axes to new positions.
roll : Roll the elements of an array by a number of positions along a
given axis.
@@ -1457,6 +1461,91 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)
+def _validate_axis(axis, ndim, argname):
+ try:
+ axis = [operator.index(axis)]
+ except TypeError:
+ axis = list(axis)
+ axis = [a + ndim if a < 0 else a for a in axis]
+ if not builtins.all(0 <= a < ndim for a in axis):
+ raise ValueError('invalid axis for this array in `%s` argument' %
+ argname)
+ if len(set(axis)) != len(axis):
+ raise ValueError('repeated axis in `%s` argument' % argname)
+ return axis
+
+
+def moveaxis(a, source, destination):
+ """
+ Move axes of an array to new positions.
+
+ Other axes remain in their original order.
+
+ .. versionadded::1.11.0
+
+ Parameters
+ ----------
+ a : np.ndarray
+ The array whose axes should be reordered.
+ source : int or sequence of int
+ Original positions of the axes to move. These must be unique.
+ destination : int or sequence of int
+ Destination positions for each of the original axes. These must also be
+ unique.
+
+ Returns
+ -------
+ result : np.ndarray
+ Array with moved axes. This array is a view of the input array.
+
+ See Also
+ --------
+ transpose: Permute the dimensions of an array.
+ swapaxes: Interchange two axes of an array.
+
+ Examples
+ --------
+
+ >>> x = np.zeros((3, 4, 5))
+ >>> np.moveaxis(x, 0, -1).shape
+ (4, 5, 3)
+ >>> np.moveaxis(x, -1, 0).shape
+ (5, 3, 4)
+
+ These all achieve the same result:
+
+ >>> np.transpose(x).shape
+ (5, 4, 3)
+ >>> np.swapaxis(x, 0, -1).shape
+ (5, 4, 3)
+ >>> np.moveaxis(x, [0, 1], [-1, -2]).shape
+ (5, 4, 3)
+ >>> np.moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape
+ (5, 4, 3)
+
+ """
+ try:
+ # allow duck-array types if they define transpose
+ transpose = a.transpose
+ except AttributeError:
+ a = asarray(a)
+ transpose = a.transpose
+
+ source = _validate_axis(source, a.ndim, 'source')
+ destination = _validate_axis(destination, a.ndim, 'destination')
+ if len(source) != len(destination):
+ raise ValueError('`source` and `destination` arguments must have '
+ 'the same number of elements')
+
+ order = [n for n in range(a.ndim) if n not in source]
+
+ for dest, src in sorted(zip(destination, source)):
+ order.insert(dest, src)
+
+ result = transpose(order)
+ return result
+
+
# fix hack in scipy which imports this function
def _move_axis_to_0(a, axis):
return rollaxis(a, axis, 0)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index d63118080..a8ad4c763 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -11,7 +11,8 @@ from numpy.core import umath
from numpy.random import rand, randint, randn
from numpy.testing import (
TestCase, run_module_suite, assert_, assert_equal, assert_raises,
- assert_array_equal, assert_almost_equal, assert_array_almost_equal, dec
+ assert_raises_regex, assert_array_equal, assert_almost_equal,
+ assert_array_almost_equal, dec
)
@@ -2029,6 +2030,80 @@ class TestRollaxis(TestCase):
assert_(not res.flags['OWNDATA'])
+class TestMoveaxis(TestCase):
+ def test_move_to_end(self):
+ x = np.random.randn(5, 6, 7)
+ for source, expected in [(0, (6, 7, 5)),
+ (1, (5, 7, 6)),
+ (2, (5, 6, 7)),
+ (-1, (5, 6, 7))]:
+ actual = np.moveaxis(x, source, -1).shape
+ assert_(actual, expected)
+
+ def test_move_new_position(self):
+ x = np.random.randn(1, 2, 3, 4)
+ for source, destination, expected in [
+ (0, 1, (2, 1, 3, 4)),
+ (1, 2, (1, 3, 2, 4)),
+ (1, -1, (1, 3, 4, 2)),
+ ]:
+ actual = np.moveaxis(x, source, destination).shape
+ assert_(actual, expected)
+
+ def test_preserve_order(self):
+ x = np.zeros((1, 2, 3, 4))
+ for source, destination in [
+ (0, 0),
+ (3, -1),
+ (-1, 3),
+ ([0, -1], [0, -1]),
+ ([2, 0], [2, 0]),
+ (range(4), range(4)),
+ ]:
+ actual = np.moveaxis(x, source, destination).shape
+ assert_(actual, (1, 2, 3, 4))
+
+ def test_move_multiples(self):
+ x = np.zeros((0, 1, 2, 3))
+ for source, destination, expected in [
+ ([0, 1], [2, 3], (2, 3, 0, 1)),
+ ([2, 3], [0, 1], (2, 3, 0, 1)),
+ ([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)),
+ ([3, 0], [1, 0], (0, 3, 1, 2)),
+ ([0, 3], [0, 1], (0, 3, 1, 2)),
+ ]:
+ actual = np.moveaxis(x, source, destination).shape
+ assert_(actual, expected)
+
+ def test_errors(self):
+ x = np.random.randn(1, 2, 3)
+ assert_raises_regex(ValueError, 'invalid axis .* `source`',
+ np.moveaxis, x, 3, 0)
+ assert_raises_regex(ValueError, 'invalid axis .* `source`',
+ np.moveaxis, x, -4, 0)
+ assert_raises_regex(ValueError, 'invalid axis .* `destination`',
+ np.moveaxis, x, 0, 5)
+ assert_raises_regex(ValueError, 'repeated axis in `source`',
+ np.moveaxis, x, [0, 0], [0, 1])
+ assert_raises_regex(ValueError, 'repeated axis in `destination`',
+ np.moveaxis, x, [0, 1], [1, 1])
+ assert_raises_regex(ValueError, 'must have the same number',
+ np.moveaxis, x, 0, [0, 1])
+ assert_raises_regex(ValueError, 'must have the same number',
+ np.moveaxis, x, [0, 1], [0])
+
+ def test_array_likes(self):
+ x = np.ma.zeros((1, 2, 3))
+ result = np.moveaxis(x, 0, 0)
+ assert_(x.shape, result.shape)
+ assert_(isinstance(result, np.ma.MaskedArray))
+
+ x = [1, 2, 3]
+ result = np.moveaxis(x, 0, 0)
+ assert_(x, list(result))
+ assert_(isinstance(result, np.ndarray))
+
+
class TestCross(TestCase):
def test_2x2(self):
u = [1, 2]