summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c8
-rw-r--r--numpy/core/tests/test_shape_base.py65
2 files changed, 71 insertions, 2 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index 07b7df726..3241a02c7 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -337,6 +337,14 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
if (axis < 0) {
axis += ndim;
}
+
+ if (ndim == 1 & axis != 0) {
+ PyErr_WarnEx(PyExc_FutureWarning,
+ "axis not 0 for ndim == 0; this will raise an error "
+ "in future versions of numpy", 2);
+ axis = 0;
+ }
+
if (axis < 0 || axis >= ndim) {
PyErr_Format(PyExc_IndexError,
"axis %d out of bounds [0, %d)", orig_axis, ndim);
diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py
index 2017ca7a3..0325d4a2c 100644
--- a/numpy/core/tests/test_shape_base.py
+++ b/numpy/core/tests/test_shape_base.py
@@ -1,7 +1,7 @@
import warnings
import numpy as np
-from numpy.testing import (TestCase, assert_, assert_raises, assert_equal,
- assert_array_equal, run_module_suite)
+from numpy.testing import (TestCase, assert_, assert_raises, assert_array_equal,
+ assert_equal, run_module_suite)
from numpy.core import (array, arange, atleast_1d, atleast_2d, atleast_3d,
vstack, hstack, newaxis, concatenate)
@@ -40,6 +40,7 @@ class TestAtleast1d(TestCase):
assert_(atleast_1d(3.0).shape == (1,))
assert_(atleast_1d([[2,3],[4,5]]).shape == (2,2))
+
class TestAtleast2d(TestCase):
def test_0D_array(self):
a = array(1); b = array(2);
@@ -100,6 +101,7 @@ class TestAtleast3d(TestCase):
desired = [a,b]
assert_array_equal(res,desired)
+
class TestHstack(TestCase):
def test_0D_array(self):
a = array(1); b = array(2);
@@ -119,6 +121,7 @@ class TestHstack(TestCase):
desired = array([[1,1],[2,2]])
assert_array_equal(res,desired)
+
class TestVstack(TestCase):
def test_0D_array(self):
a = array(1); b = array(2);
@@ -159,5 +162,63 @@ def test_concatenate_axis_None():
'0', '1', '2', 'x'])
assert_array_equal(r,d)
+
+def test_concatenate():
+ # Test concatenate function
+ # No arrays raise ValueError
+ assert_raises(ValueError, concatenate, ())
+ # Scalars cannot be concatenated
+ assert_raises(ValueError, concatenate, (0,))
+ assert_raises(ValueError, concatenate, (array(0),))
+ # One sequence returns unmodified (but as array)
+ r4 = list(range(4))
+ assert_array_equal(concatenate((r4,)), r4)
+ # Any sequence
+ assert_array_equal(concatenate((tuple(r4),)), r4)
+ assert_array_equal(concatenate((array(r4),)), r4)
+ # 1D default concatenation
+ r3 = list(range(3))
+ assert_array_equal(concatenate((r4, r3)), r4 + r3)
+ # Mixed sequence types
+ assert_array_equal(concatenate((tuple(r4), r3)), r4 + r3)
+ assert_array_equal(concatenate((array(r4), r3)), r4 + r3)
+ # Explicit axis specification
+ assert_array_equal(concatenate((r4, r3), 0), r4 + r3)
+ # Including negative
+ assert_array_equal(concatenate((r4, r3), -1), r4 + r3)
+ # 2D
+ a23 = array([[10, 11, 12], [13, 14, 15]])
+ a13 = array([[0, 1, 2]])
+ res = array([[10, 11, 12], [13, 14, 15], [0, 1, 2]])
+ assert_array_equal(concatenate((a23, a13)), res)
+ assert_array_equal(concatenate((a23, a13), 0), res)
+ assert_array_equal(concatenate((a23.T, a13.T), 1), res.T)
+ assert_array_equal(concatenate((a23.T, a13.T), -1), res.T)
+ # Arrays much match shape
+ assert_raises(ValueError, concatenate, (a23.T, a13.T), 0)
+ # 3D
+ res = arange(2 * 3 * 7).reshape((2, 3, 7))
+ a0 = res[..., :4]
+ a1 = res[..., 4:6]
+ a2 = res[..., 6:]
+ assert_array_equal(concatenate((a0, a1, a2), 2), res)
+ assert_array_equal(concatenate((a0, a1, a2), -1), res)
+ assert_array_equal(concatenate((a0.T, a1.T, a2.T), 0), res.T)
+
+
+def test_concatenate_sloppy0():
+ # Versions of numpy < 1.7.0 ignored axis argument value for 1D arrays. We
+ # allow this for now, but in due course we will raise an error
+ r4 = list(range(4))
+ r3 = list(range(3))
+ assert_array_equal(concatenate((r4, r3), 0), r4 + r3)
+ warnings.simplefilter('ignore', FutureWarning)
+ try:
+ assert_array_equal(concatenate((r4, r3), -10), r4 + r3)
+ assert_array_equal(concatenate((r4, r3), 10), r4 + r3)
+ finally:
+ warnings.filters.pop(0)
+
+
if __name__ == "__main__":
run_module_suite()