diff options
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/index_tricks.py | 22 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 4 |
2 files changed, 17 insertions, 9 deletions
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c45148057..22b8ef388 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,8 @@ __all__ = ['unravel_index', import sys import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType, array +from numpy.core.numeric import asarray, ScalarType, array, dtype +from numpy.core.numerictypes import find_common_type import math import function_base @@ -225,7 +226,8 @@ class AxisConcatenator(object): key = (key,) objs = [] scalars = [] - final_dtypedescr = None + arraytypes = [] + scalartypes = [] for k in range(len(key)): scalar = False if type(key[k]) is slice: @@ -272,6 +274,7 @@ class AxisConcatenator(object): newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True + scalartypes.append(newobj.dtype) else: newobj = key[k] if ndmin > 1: @@ -289,14 +292,15 @@ class AxisConcatenator(object): newobj = newobj.transpose(axes) del tempobj objs.append(newobj) - if isinstance(newobj, _nx.ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: + if not scalar and isinstance(newobj, _nx.ndarray): + arraytypes.append(newobj.dtype) + + # Esure that scalars won't up-cast unless warranted + final_dtype = find_common_type(arraytypes, scalartypes) + if final_dtype is not None: for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) + objs[k] = objs[k].astype(final_dtype) + res = _nx.concatenate(tuple(objs),axis=self.axis) return self._retval(res) diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 8fc192202..38bbaae96 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -35,6 +35,10 @@ class TestConcatenator(NumpyTestCase): c = r_[b,0,0,b] assert_array_equal(c,[1,1,1,1,1,0,0,1,1,1,1,1]) + def check_mixed_type(self): + g = r_[10.1, 1:10] + assert(g.dtype == 'f8') + def check_2d(self): b = rand(5,5) c = rand(5,5) |