summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/numerictypes.py70
-rw-r--r--numpy/core/tests/test_numerictypes.py22
-rw-r--r--numpy/lib/index_tricks.py22
-rw-r--r--numpy/lib/tests/test_index_tricks.py4
4 files changed, 107 insertions, 11 deletions
diff --git a/numpy/core/numerictypes.py b/numpy/core/numerictypes.py
index c9c0bc82e..63d30013b 100644
--- a/numpy/core/numerictypes.py
+++ b/numpy/core/numerictypes.py
@@ -78,7 +78,7 @@ $Id: numerictypes.py,v 1.17 2005/09/09 22:20:06 teoliphant Exp $
# we add more at the bottom
__all__ = ['sctypeDict', 'sctypeNA', 'typeDict', 'typeNA', 'sctypes',
'ScalarType', 'obj2sctype', 'cast', 'nbytes', 'sctype2char',
- 'maximum_sctype', 'issctype', 'typecodes']
+ 'maximum_sctype', 'issctype', 'typecodes', 'find_common_type']
from numpy.core.multiarray import typeinfo, ndarray, array, empty, dtype
import types as _types
@@ -566,7 +566,7 @@ for key in allTypes:
del key
-typecodes = {'Character':'S1',
+typecodes = {'Character':'c',
'Integer':'bhilqp',
'UnsignedInteger':'BHILQP',
'Float':'fdg',
@@ -578,3 +578,69 @@ typecodes = {'Character':'S1',
# backwards compatibility --- deprecated name
typeDict = sctypeDict
typeNA = sctypeNA
+
+_kind_list = ['b', 'u', 'i', 'f', 'c', 'S', 'U', 'V', 'O']
+
+__test_types = typecodes['AllInteger'][:-2]+typecodes['AllFloat']+'O'
+__len_test_types = len(__test_types)
+
+# Keep incrementing until a common type both can be coerced to
+# is found. Otherwise, return None
+def _find_common_coerce(a, b):
+ if a > b:
+ return a
+ try:
+ thisind = __test_types.index(a.char)
+ except ValueError:
+ return None
+ while thisind < __len_test_types:
+ newdtype = dtype(__test_types[thisind])
+ if newdtype >= b and newdtype >= a:
+ return newdtype
+ thisind += 1
+ return None
+
+
+def find_common_type(array_types, scalar_types):
+ """Determine common type following standard coercion rules
+
+ Parameters
+ ----------
+ array_types : sequence
+ A list of dtype convertible objects representing arrays
+ scalar_types : sequence
+ A list of dtype convertible objects representing scalars
+
+ Returns
+ -------
+ datatype : dtype
+ The common data-type which is the maximum of the array_types
+ ignoring the scalar_types unless the maximum of the scalar_types
+ is of a different kind.
+
+ If the kinds is not understood, then None is returned.
+ """
+ array_types = [dtype(x) for x in array_types]
+ scalar_types = [dtype(x) for x in scalar_types]
+
+ if len(scalar_types) == 0:
+ if len(array_types) == 0:
+ return None
+ else:
+ return max(array_types)
+ if len(array_types) == 0:
+ return max(scalar_types)
+
+ maxa = max(array_types)
+ maxsc = max(scalar_types)
+
+ try:
+ index_a = _kind_list.index(maxa.kind)
+ index_sc = _kind_list.index(maxsc.kind)
+ except ValueError:
+ return None
+
+ if index_sc > index_a:
+ return _find_common_coerce(maxsc,maxa)
+ else:
+ return maxa
diff --git a/numpy/core/tests/test_numerictypes.py b/numpy/core/tests/test_numerictypes.py
index 527b89b53..f0533e062 100644
--- a/numpy/core/tests/test_numerictypes.py
+++ b/numpy/core/tests/test_numerictypes.py
@@ -338,5 +338,27 @@ class TestEmptyField(NumpyTestCase):
assert(a['int'].shape == (5,0))
assert(a['float'].shape == (5,2))
+class TestCommonType(NumpyTestCase):
+ def check_scalar_loses1(self):
+ res = numpy.find_common_type(['f4','f4','i4'],['f8'])
+ assert(res == 'f4')
+ def check_scalar_loses2(self):
+ res = numpy.find_common_type(['f4','f4'],['i8'])
+ assert(res == 'f4')
+ def check_scalar_wins(self):
+ res = numpy.find_common_type(['f4','f4','i4'],['c8'])
+ assert(res == 'c8')
+ def check_scalar_wins2(self):
+ res = numpy.find_common_type(['u4','i4','i4'],['f4'])
+ assert(res == 'f8')
+ def check_scalar_wins3(self): # doesn't go up to 'f16' on purpose
+ res = numpy.find_common_type(['u8','i8','i8'],['f8'])
+ assert(res == 'f8')
+
+
+
+
+
+
if __name__ == "__main__":
NumpyTest().run()
diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py
index c45148057..22b8ef388 100644
--- a/numpy/lib/index_tricks.py
+++ b/numpy/lib/index_tricks.py
@@ -7,7 +7,8 @@ __all__ = ['unravel_index',
import sys
import numpy.core.numeric as _nx
-from numpy.core.numeric import asarray, ScalarType, array
+from numpy.core.numeric import asarray, ScalarType, array, dtype
+from numpy.core.numerictypes import find_common_type
import math
import function_base
@@ -225,7 +226,8 @@ class AxisConcatenator(object):
key = (key,)
objs = []
scalars = []
- final_dtypedescr = None
+ arraytypes = []
+ scalartypes = []
for k in range(len(key)):
scalar = False
if type(key[k]) is slice:
@@ -272,6 +274,7 @@ class AxisConcatenator(object):
newobj = array(key[k],ndmin=ndmin)
scalars.append(k)
scalar = True
+ scalartypes.append(newobj.dtype)
else:
newobj = key[k]
if ndmin > 1:
@@ -289,14 +292,15 @@ class AxisConcatenator(object):
newobj = newobj.transpose(axes)
del tempobj
objs.append(newobj)
- if isinstance(newobj, _nx.ndarray) and not scalar:
- if final_dtypedescr is None:
- final_dtypedescr = newobj.dtype
- elif newobj.dtype > final_dtypedescr:
- final_dtypedescr = newobj.dtype
- if final_dtypedescr is not None:
+ if not scalar and isinstance(newobj, _nx.ndarray):
+ arraytypes.append(newobj.dtype)
+
+ # Esure that scalars won't up-cast unless warranted
+ final_dtype = find_common_type(arraytypes, scalartypes)
+ if final_dtype is not None:
for k in scalars:
- objs[k] = objs[k].astype(final_dtypedescr)
+ objs[k] = objs[k].astype(final_dtype)
+
res = _nx.concatenate(tuple(objs),axis=self.axis)
return self._retval(res)
diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py
index 8fc192202..38bbaae96 100644
--- a/numpy/lib/tests/test_index_tricks.py
+++ b/numpy/lib/tests/test_index_tricks.py
@@ -35,6 +35,10 @@ class TestConcatenator(NumpyTestCase):
c = r_[b,0,0,b]
assert_array_equal(c,[1,1,1,1,1,0,0,1,1,1,1,1])
+ def check_mixed_type(self):
+ g = r_[10.1, 1:10]
+ assert(g.dtype == 'f8')
+
def check_2d(self):
b = rand(5,5)
c = rand(5,5)