diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-13 09:04:39 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-13 09:04:39 +0000 |
commit | eee00f8f7e15592a048c8b841aef9ea81faa0fda (patch) | |
tree | 74052bb733a01a84b83c3354593d2e7eb0243ef5 /numpy | |
parent | 8e24ef871ee8a58ae65d4d59d8ac916a48568c56 (diff) | |
download | numpy-eee00f8f7e15592a048c8b841aef9ea81faa0fda.tar.gz |
Remove _as_parameter_ attribute from arrays and add it to the ctypes object. Create an ndpointer class factory to return classes that check for specific array types. These can be used in argtypes list to ctypes functions.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_internal.py | 23 | ||||
-rw-r--r-- | numpy/core/src/arrayobject.c | 24 | ||||
-rw-r--r-- | numpy/lib/utils.py | 30 |
3 files changed, 54 insertions, 23 deletions
diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 6f64f40f4..d538743b6 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -208,21 +208,30 @@ def _getintp_ctype(): _getintp_ctype.cache = None # Used for .ctypes attribute of ndarray + +class _missing_ctypes(object): + def cast(self, num, obj): + return num + + def c_void_p(self, num): + return num + class _ctypes(object): - def __init__(self, array): + def __init__(self, array, ptr=None): try: import ctypes self._ctypes = ctypes except ImportError: - raise AttributeError, "ctypes not available" + self._ctypes = _missing_ctypes() self._arr = array + self._data = ptr if self._arr.ndim == 0: self._zerod = True else: self._zerod = False def data_as(self, obj): - return self._ctypes.cast(self._arr._as_parameter_, obj) + return self._ctypes.cast(self._data, obj) def shape_as(self, obj): if self._zerod: @@ -235,7 +244,7 @@ class _ctypes(object): return (obj*self._arr.ndim)(*self._arr.strides) def get_data(self): - return self._ctypes.c_void_p(self._arr._as_parameter_) + return self._ctypes.c_void_p(self._data) def get_shape(self): if self._zerod: @@ -246,7 +255,11 @@ class _ctypes(object): if self._zerod: return None return (_getintp_ctype()*self._arr.ndim)(*self._arr.strides) - + + def get_as_parameter(self): + return self._data + data = property(get_data, None, doc="c-types data") shape = property(get_shape, None, doc="c-types shape") strides = property(get_strides, None, doc="c-types strides") + _as_parameter_ = property(get_as_parameter, None, doc="_as parameter_") diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 1e155fd00..c4303bdd3 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -5829,13 +5829,8 @@ static PyObject * array_ctypes_get(PyArrayObject *self) { return PyObject_CallMethod(_numpy_internal, "_ctypes", - "O", self); -} - -static PyObject * -array_as_parameter_get(PyArrayObject *self) -{ - return PyLong_FromVoidPtr(self->data); + "ON", self, + PyLong_FromVoidPtr(self->data)); } static PyObject * @@ -6413,9 +6408,6 @@ static PyGetSetDef array_getsetlist[] = { {"ctypes", (getter)array_ctypes_get, NULL, NULL}, - {"_as_parameter_", - (getter)array_as_parameter_get, - NULL, NULL}, {"T", (getter)array_transpose_get, NULL, NULL}, @@ -11062,11 +11054,11 @@ static PyTypeObject PyArrayDescr_Type = { (reprfunc)arraydescr_repr, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ - &descr_as_mapping, /* tp_as_mapping */ - 0, /* tp_hash */ + &descr_as_mapping, /* tp_as_mapping */ + (hashfunc)_Py_HashPointer, /* tp_hash */ 0, /* tp_call */ - (reprfunc)arraydescr_str, /* tp_str */ - 0, /* tp_getattro */ + (reprfunc)arraydescr_str, /* tp_str */ + 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ @@ -11075,8 +11067,8 @@ static PyTypeObject PyArrayDescr_Type = { 0, /* tp_clear */ (richcmpfunc)arraydescr_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ + 0, /* tp_iter */ + 0, /* tp_iternext */ arraydescr_methods, /* tp_methods */ arraydescr_members, /* tp_members */ arraydescr_getsets, /* tp_getset */ diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index 8306b799c..14b0d8ea3 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -7,8 +7,8 @@ from numpy.core import product, ndarray __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype', 'issubdtype', 'deprecate', 'get_numarray_include', - 'get_include', 'ctypes_load_library', 'info', - 'source', 'who'] + 'get_include', 'ctypes_load_library', 'ndpointer', + 'info', 'source', 'who'] def issubclass_(arg1, arg2): try: @@ -76,6 +76,32 @@ def ctypes_load_library(libname, loader_path): libpath = os.path.join(libdir, libname) return ctypes.cdll[libpath] +class _ndptr(object): + def from_param(cls, obj): + if not isinstance(obj, ndarray): + raise TypeError("argument must be an ndarray") + if obj.dtype != cls._dtype_: + raise TypeError("array must have data type", cls._dtype_) + return obj.ctypes + from_param = classmethod(from_param) + +# Factory for a type-checking object with from_param defined +_pointer_type_cache = {} +def ndpointer(datatype): + datatype = dtype(datatype) + try: + return _pointer_type_cache[datatype] + except KeyError: + pass + if datatype.names: + name = str(id(datatype)) + else: + name = datatype.str + klass = type("ndpointer_%s"%name, (_ndptr,), + {"_dtype_": datatype}) + _pointer_type_cache[datatype] = klass + return klass + if sys.version_info < (2, 4): # Can't set __name__ in 2.3 |