summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorStefan van der Walt <stefan@sun.ac.za>2008-05-01 19:24:51 +0000
committerStefan van der Walt <stefan@sun.ac.za>2008-05-01 19:24:51 +0000
commit3056ee6f48e54b93e074ab62d8a6fca148b2508d (patch)
tree09ef68a3008b080bef90e9bd5d34a179fe6d2663 /numpy/core
parentc402895302b40a13d200dab577d1b2bc108c4879 (diff)
downloadnumpy-3056ee6f48e54b93e074ab62d8a6fca148b2508d.tar.gz
Support for Python types in x.view.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/arraymethods.c59
-rw-r--r--numpy/core/tests/test_multiarray.py12
2 files changed, 60 insertions, 11 deletions
diff --git a/numpy/core/src/arraymethods.c b/numpy/core/src/arraymethods.c
index cfd912644..1af6a02b7 100644
--- a/numpy/core/src/arraymethods.c
+++ b/numpy/core/src/arraymethods.c
@@ -103,26 +103,63 @@ array_squeeze(PyArrayObject *self, PyObject *args)
static PyObject *
array_view(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *otype=NULL;
- PyArray_Descr *type=NULL;
+ PyObject *out_dtype_or_type=NULL;
+ PyObject *out_dtype=NULL;
+ PyObject *out_type=NULL;
+ PyArray_Descr *dtype=NULL;
- static char *kwlist[] = {"dtype", NULL};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, &otype))
+ static char *kwlist[] = {"dtype_or_type", "dtype", "type", NULL};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO", kwlist,
+ &out_dtype_or_type,
+ &out_dtype,
+ &out_type))
return NULL;
- if (otype) {
- if (PyType_Check(otype) && \
- PyType_IsSubtype((PyTypeObject *)otype,
+ /* If user specified a positional argument, guess whether it
+ represents a type or a dtype for backward compatibility. */
+ if (out_dtype_or_type) {
+
+ /* type specified? */
+ if (PyType_Check(out_dtype_or_type) &&
+ PyType_IsSubtype((PyTypeObject *)out_dtype_or_type,
&PyArray_Type)) {
- return PyArray_View(self, NULL,
- (PyTypeObject *)otype);
+ if (out_type) {
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot specify output type twice.");
+ return NULL;
+ }
+
+ out_type = out_dtype_or_type;
}
+
+ /* dtype specified */
else {
- if (PyArray_DescrConverter(otype, &type) == PY_FAIL)
+ if (out_dtype) {
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot specify output dtype twice.");
return NULL;
+ }
+
+ out_dtype = out_dtype_or_type;
}
}
- return PyArray_View(self, type, NULL);
+
+ if ((out_type) && (!PyType_Check(out_type) ||
+ !PyType_IsSubtype((PyTypeObject *)out_type,
+ &PyArray_Type))) {
+ PyErr_SetString(PyExc_ValueError,
+ "Type must be a Python type object");
+ return NULL;
+ }
+
+ if ((out_dtype) &&
+ (PyArray_DescrConverter(out_dtype, &dtype) == PY_FAIL)) {
+ PyErr_SetString(PyExc_ValueError,
+ "Dtype must be a numpy data-type");
+ return NULL;
+ }
+
+ return PyArray_View(self, dtype, (PyTypeObject*)out_type);
}
static PyObject *
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 9a7f8c9ff..ae6c43f10 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -842,6 +842,18 @@ class TestView(NumpyTestCase):
assert_array_equal(y,z)
assert_array_equal(y, [67305985, 134678021])
+ def test_type(self):
+ x = np.array([1,2,3])
+ assert(isinstance(x.view(np.matrix),np.matrix))
+
+ def test_keywords(self):
+ x = np.array([(1,2)],dtype=[('a',np.int8),('b',np.int8)])
+ y = x.view(dtype=np.int16, type=np.matrix)
+ assert_array_equal(y,[[513]])
+
+ assert(isinstance(y,np.matrix))
+ assert_equal(y.dtype,np.int16)
+
# Import tests without matching module names
set_local_path()