diff options
Diffstat (limited to 'numpy/core/src/arrayobject.c')
-rw-r--r-- | numpy/core/src/arrayobject.c | 59 |
1 files changed, 35 insertions, 24 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index ba47ffd19..2e2f6021b 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -5265,6 +5265,32 @@ _array_small_type(PyArray_Descr *chktype, PyArray_Descr* mintype) return outtype; } +static PyArray_Descr * +_array_find_python_scalar_type(PyObject *op) +{ + if (PyFloat_Check(op)) { + return PyArray_DescrFromType(PyArray_DOUBLE); + } else if (PyComplex_Check(op)) { + return PyArray_DescrFromType(PyArray_CDOUBLE); + } else if (PyInt_Check(op)) { + /* bools are a subclass of int */ + if (PyBool_Check(op)) { + return PyArray_DescrFromType(PyArray_BOOL); + } else { + return PyArray_DescrFromType(PyArray_LONG); + } + } else if (PyLong_Check(op)) { + /* if integer can fit into a longlong then return that + */ + if ((PyLong_AsLongLong(op) == -1) && PyErr_Occurred()) { + PyErr_Clear(); + return PyArray_DescrFromType(PyArray_OBJECT); + } + return PyArray_DescrFromType(PyArray_LONGLONG); + } + return NULL; +} + /* op is an object to be converted to an ndarray. minitype is the minimum type-descriptor needed. @@ -5299,6 +5325,11 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) goto finish; } + chktype = _array_find_python_scalar_type(op); + if (chktype) { + goto finish; + } + if ((ip=PyObject_GetAttrString(op, "__array_typestr__"))!=NULL) { if (PyString_Check(ip)) { chktype =_array_typedescr_fromstr(PyString_AS_STRING(ip)); @@ -5390,29 +5421,6 @@ _array_find_type(PyObject *op, PyArray_Descr *minitype, int max) goto finish; } - if (PyBool_Check(op)) { - chktype = PyArray_DescrFromType(PyArray_BOOL); - goto finish; - } - else if (PyInt_Check(op)) { - chktype = PyArray_DescrFromType(PyArray_LONG); - goto finish; - } else if (PyLong_Check(op)) { - /* if integer can fit into a longlong then return that - */ - if ((PyLong_AsLongLong(op) == -1) && PyErr_Occurred()) { - PyErr_Clear(); - goto deflt; - } - chktype = PyArray_DescrFromType(PyArray_LONGLONG); - goto finish; - } else if (PyFloat_Check(op)) { - chktype = PyArray_DescrFromType(PyArray_DOUBLE); - goto finish; - } else if (PyComplex_Check(op)) { - chktype = PyArray_DescrFromType(PyArray_CDOUBLE); - goto finish; - } deflt: chktype = PyArray_DescrFromType(PyArray_OBJECT); @@ -6286,7 +6294,10 @@ PyArray_FromAny(PyObject *op, PyArray_Descr *newtype, int min_depth, r = PyArray_FromArray((PyArrayObject *)op, newtype, flags); else if (PyArray_IsScalar(op, Generic)) { r = PyArray_FromScalar(op, newtype); - } + } else if (newtype == NULL && + (newtype = _array_find_python_scalar_type(op))) { + r = Array_FromScalar(op, newtype); + } else if (((r = PyArray_FromStructInterface(op))!=Py_NotImplemented)|| \ ((r = PyArray_FromInterface(op)) != Py_NotImplemented) || \ ((r = PyArray_FromArrayAttr(op, newtype, context)) \ |