summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorNikita Titov <nekit94-08@mail.ru>2018-08-26 19:15:27 +0300
committerCharles Harris <charlesr.harris@gmail.com>2018-08-26 11:15:27 -0500
commit63a57c3f6a58c61763daf89a2d5495f6a855bf9b (patch)
treec46fc415a0bb045874399fb4d6f49af8a3fee489 /numpy
parentab30683fe1277a7fd471be63975a02ed8d766f94 (diff)
downloadnumpy-63a57c3f6a58c61763daf89a2d5495f6a855bf9b.tar.gz
BUG: fix array_split incorrect behavior with array size bigger MAX_INT32 (#11813)
Fixes #11809. * BUG: fix array_split incorrect behavior with array size bigger MAX_INT32 * TST: added test for array_split with array size greater MAX_INT32 * addressed review comments
Diffstat (limited to 'numpy')
-rw-r--r--numpy/lib/shape_base.py4
-rw-r--r--numpy/lib/tests/test_shape_base.py14
2 files changed, 16 insertions, 2 deletions
diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py
index d31d8a939..66f534734 100644
--- a/numpy/lib/shape_base.py
+++ b/numpy/lib/shape_base.py
@@ -688,7 +688,7 @@ def array_split(ary, indices_or_sections, axis=0):
except AttributeError:
Ntotal = len(ary)
try:
- # handle scalar case.
+ # handle array case.
Nsections = len(indices_or_sections) + 1
div_points = [0] + list(indices_or_sections) + [Ntotal]
except TypeError:
@@ -700,7 +700,7 @@ def array_split(ary, indices_or_sections, axis=0):
section_sizes = ([0] +
extras * [Neach_section+1] +
(Nsections-extras) * [Neach_section])
- div_points = _nx.array(section_sizes).cumsum()
+ div_points = _nx.array(section_sizes, dtype=_nx.intp).cumsum()
sub_arys = []
sary = _nx.swapaxes(ary, axis, 0)
diff --git a/numpy/lib/tests/test_shape_base.py b/numpy/lib/tests/test_shape_base.py
index 6d24dd624..6e4cd225d 100644
--- a/numpy/lib/tests/test_shape_base.py
+++ b/numpy/lib/tests/test_shape_base.py
@@ -3,6 +3,8 @@ from __future__ import division, absolute_import, print_function
import numpy as np
import warnings
import functools
+import sys
+import pytest
from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
@@ -14,6 +16,9 @@ from numpy.testing import (
)
+IS_64BIT = sys.maxsize > 2**32
+
+
def _add_keepdims(func):
""" hack in keepdims behavior into a function taking an axis """
@functools.wraps(func)
@@ -403,6 +408,15 @@ class TestArraySplit(object):
assert_(a.dtype.type is res[-1].dtype.type)
# perhaps should check higher dimensions
+ @pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform")
+ def test_integer_split_2D_rows_greater_max_int32(self):
+ a = np.broadcast_to([0], (1 << 32, 2))
+ res = array_split(a, 4)
+ chunk = np.broadcast_to([0], (1 << 30, 2))
+ tgt = [chunk] * 4
+ for i in range(len(tgt)):
+ assert_equal(res[i].shape, tgt[i].shape)
+
def test_index_split_simple(self):
a = np.arange(10)
indices = [1, 5, 7]