diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/numerictypes.py | 70 | ||||
-rw-r--r-- | numpy/core/tests/test_numerictypes.py | 22 | ||||
-rw-r--r-- | numpy/lib/index_tricks.py | 22 | ||||
-rw-r--r-- | numpy/lib/tests/test_index_tricks.py | 4 |
4 files changed, 107 insertions, 11 deletions
diff --git a/numpy/core/numerictypes.py b/numpy/core/numerictypes.py index c9c0bc82e..63d30013b 100644 --- a/numpy/core/numerictypes.py +++ b/numpy/core/numerictypes.py @@ -78,7 +78,7 @@ $Id: numerictypes.py,v 1.17 2005/09/09 22:20:06 teoliphant Exp $ # we add more at the bottom __all__ = ['sctypeDict', 'sctypeNA', 'typeDict', 'typeNA', 'sctypes', 'ScalarType', 'obj2sctype', 'cast', 'nbytes', 'sctype2char', - 'maximum_sctype', 'issctype', 'typecodes'] + 'maximum_sctype', 'issctype', 'typecodes', 'find_common_type'] from numpy.core.multiarray import typeinfo, ndarray, array, empty, dtype import types as _types @@ -566,7 +566,7 @@ for key in allTypes: del key -typecodes = {'Character':'S1', +typecodes = {'Character':'c', 'Integer':'bhilqp', 'UnsignedInteger':'BHILQP', 'Float':'fdg', @@ -578,3 +578,69 @@ typecodes = {'Character':'S1', # backwards compatibility --- deprecated name typeDict = sctypeDict typeNA = sctypeNA + +_kind_list = ['b', 'u', 'i', 'f', 'c', 'S', 'U', 'V', 'O'] + +__test_types = typecodes['AllInteger'][:-2]+typecodes['AllFloat']+'O' +__len_test_types = len(__test_types) + +# Keep incrementing until a common type both can be coerced to +# is found. Otherwise, return None +def _find_common_coerce(a, b): + if a > b: + return a + try: + thisind = __test_types.index(a.char) + except ValueError: + return None + while thisind < __len_test_types: + newdtype = dtype(__test_types[thisind]) + if newdtype >= b and newdtype >= a: + return newdtype + thisind += 1 + return None + + +def find_common_type(array_types, scalar_types): + """Determine common type following standard coercion rules + + Parameters + ---------- + array_types : sequence + A list of dtype convertible objects representing arrays + scalar_types : sequence + A list of dtype convertible objects representing scalars + + Returns + ------- + datatype : dtype + The common data-type which is the maximum of the array_types + ignoring the scalar_types unless the maximum of the scalar_types + is of a different kind. + + If the kinds is not understood, then None is returned. + """ + array_types = [dtype(x) for x in array_types] + scalar_types = [dtype(x) for x in scalar_types] + + if len(scalar_types) == 0: + if len(array_types) == 0: + return None + else: + return max(array_types) + if len(array_types) == 0: + return max(scalar_types) + + maxa = max(array_types) + maxsc = max(scalar_types) + + try: + index_a = _kind_list.index(maxa.kind) + index_sc = _kind_list.index(maxsc.kind) + except ValueError: + return None + + if index_sc > index_a: + return _find_common_coerce(maxsc,maxa) + else: + return maxa diff --git a/numpy/core/tests/test_numerictypes.py b/numpy/core/tests/test_numerictypes.py index 527b89b53..f0533e062 100644 --- a/numpy/core/tests/test_numerictypes.py +++ b/numpy/core/tests/test_numerictypes.py @@ -338,5 +338,27 @@ class TestEmptyField(NumpyTestCase): assert(a['int'].shape == (5,0)) assert(a['float'].shape == (5,2)) +class TestCommonType(NumpyTestCase): + def check_scalar_loses1(self): + res = numpy.find_common_type(['f4','f4','i4'],['f8']) + assert(res == 'f4') + def check_scalar_loses2(self): + res = numpy.find_common_type(['f4','f4'],['i8']) + assert(res == 'f4') + def check_scalar_wins(self): + res = numpy.find_common_type(['f4','f4','i4'],['c8']) + assert(res == 'c8') + def check_scalar_wins2(self): + res = numpy.find_common_type(['u4','i4','i4'],['f4']) + assert(res == 'f8') + def check_scalar_wins3(self): # doesn't go up to 'f16' on purpose + res = numpy.find_common_type(['u8','i8','i8'],['f8']) + assert(res == 'f8') + + + + + + if __name__ == "__main__": NumpyTest().run() diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index c45148057..22b8ef388 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -7,7 +7,8 @@ __all__ = ['unravel_index', import sys import numpy.core.numeric as _nx -from numpy.core.numeric import asarray, ScalarType, array +from numpy.core.numeric import asarray, ScalarType, array, dtype +from numpy.core.numerictypes import find_common_type import math import function_base @@ -225,7 +226,8 @@ class AxisConcatenator(object): key = (key,) objs = [] scalars = [] - final_dtypedescr = None + arraytypes = [] + scalartypes = [] for k in range(len(key)): scalar = False if type(key[k]) is slice: @@ -272,6 +274,7 @@ class AxisConcatenator(object): newobj = array(key[k],ndmin=ndmin) scalars.append(k) scalar = True + scalartypes.append(newobj.dtype) else: newobj = key[k] if ndmin > 1: @@ -289,14 +292,15 @@ class AxisConcatenator(object): newobj = newobj.transpose(axes) del tempobj objs.append(newobj) - if isinstance(newobj, _nx.ndarray) and not scalar: - if final_dtypedescr is None: - final_dtypedescr = newobj.dtype - elif newobj.dtype > final_dtypedescr: - final_dtypedescr = newobj.dtype - if final_dtypedescr is not None: + if not scalar and isinstance(newobj, _nx.ndarray): + arraytypes.append(newobj.dtype) + + # Esure that scalars won't up-cast unless warranted + final_dtype = find_common_type(arraytypes, scalartypes) + if final_dtype is not None: for k in scalars: - objs[k] = objs[k].astype(final_dtypedescr) + objs[k] = objs[k].astype(final_dtype) + res = _nx.concatenate(tuple(objs),axis=self.axis) return self._retval(res) diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 8fc192202..38bbaae96 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -35,6 +35,10 @@ class TestConcatenator(NumpyTestCase): c = r_[b,0,0,b] assert_array_equal(c,[1,1,1,1,1,0,0,1,1,1,1,1]) + def check_mixed_type(self): + g = r_[10.1, 1:10] + assert(g.dtype == 'f8') + def check_2d(self): b = rand(5,5) c = rand(5,5) |