summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/ctypeslib.py168
-rw-r--r--numpy/tests/test_ctypeslib.py92
2 files changed, 242 insertions, 18 deletions
diff --git a/numpy/ctypeslib.py b/numpy/ctypeslib.py
index 2e9781286..02c3bd211 100644
--- a/numpy/ctypeslib.py
+++ b/numpy/ctypeslib.py
@@ -346,27 +346,157 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
return klass
-def _get_typecodes():
- """ Return a dictionary mapping __array_interface__ formats to ctypes types """
- ct = ctypes
- simple_types = [
- ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong,
- ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong,
- ct.c_float, ct.c_double,
- ]
+if ctypes is not None:
+ def _ctype_ndarray(element_type, shape):
+ """ Create an ndarray of the given element type and shape """
+ for dim in shape[::-1]:
+ element_type = dim * element_type
+ # prevent the type name include np.ctypeslib
+ element_type.__module__ = None
+ return element_type
- return {_dtype(ctype).str: ctype for ctype in simple_types}
+ def _get_scalar_type_map():
+ """
+ Return a dictionary mapping native endian scalar dtype to ctypes types
+ """
+ ct = ctypes
+ simple_types = [
+ ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong,
+ ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong,
+ ct.c_float, ct.c_double,
+ ct.c_bool,
+ ]
+ return {_dtype(ctype): ctype for ctype in simple_types}
-def _ctype_ndarray(element_type, shape):
- """ Create an ndarray of the given element type and shape """
- for dim in shape[::-1]:
- element_type = element_type * dim
- return element_type
+ _scalar_type_map = _get_scalar_type_map()
+
+
+ def _ctype_from_dtype_scalar(dtype):
+ # swapping twice ensure that `=` is promoted to <, >, or |
+ dtype_with_endian = dtype.newbyteorder('S').newbyteorder('S')
+ dtype_native = dtype.newbyteorder('=')
+ try:
+ ctype = _scalar_type_map[dtype_native]
+ except KeyError:
+ raise NotImplementedError(
+ "Converting {!r} to a ctypes type".format(dtype)
+ )
+
+ if dtype_with_endian.byteorder == '>':
+ ctype = ctype.__ctype_be__
+ elif dtype_with_endian.byteorder == '<':
+ ctype = ctype.__ctype_le__
+
+ return ctype
+
+
+ def _ctype_from_dtype_subarray(dtype):
+ element_dtype, shape = dtype.subdtype
+ ctype = _ctype_from_dtype(element_dtype)
+ return _ctype_ndarray(ctype, shape)
+
+
+ def _ctype_from_dtype_structured(dtype):
+ # extract offsets of each field
+ field_data = []
+ for name in dtype.names:
+ field_dtype, offset = dtype.fields[name][:2]
+ field_data.append((offset, name, _ctype_from_dtype(field_dtype)))
+
+ # ctypes doesn't care about field order
+ field_data = sorted(field_data, key=lambda f: f[0])
+
+ if len(field_data) > 1 and all(offset == 0 for offset, name, ctype in field_data):
+ # union, if multiple fields all at address 0
+ size = 0
+ _fields_ = []
+ for offset, name, ctype in field_data:
+ _fields_.append((name, ctype))
+ size = max(size, ctypes.sizeof(ctype))
+
+ # pad to the right size
+ if dtype.itemsize != size:
+ _fields_.append(('', ctypes.c_char * dtype.itemsize))
+
+ # we inserted manual padding, so always `_pack_`
+ return type('union', (ctypes.Union,), dict(
+ _fields_=_fields_,
+ _pack_=1,
+ __module__=None,
+ ))
+ else:
+ last_offset = 0
+ _fields_ = []
+ for offset, name, ctype in field_data:
+ padding = offset - last_offset
+ if padding < 0:
+ raise NotImplementedError("Overlapping fields")
+ if padding > 0:
+ _fields_.append(('', ctypes.c_char * padding))
+
+ _fields_.append((name, ctype))
+ last_offset = offset + ctypes.sizeof(ctype)
+
+
+ padding = dtype.itemsize - last_offset
+ if padding > 0:
+ _fields_.append(('', ctypes.c_char * padding))
+
+ # we inserted manual padding, so always `_pack_`
+ return type('struct', (ctypes.Structure,), dict(
+ _fields_=_fields_,
+ _pack_=1,
+ __module__=None,
+ ))
+
+
+ def _ctype_from_dtype(dtype):
+ if dtype.fields is not None:
+ return _ctype_from_dtype_structured(dtype)
+ elif dtype.subdtype is not None:
+ return _ctype_from_dtype_subarray(dtype)
+ else:
+ return _ctype_from_dtype_scalar(dtype)
+
+
+ def as_ctypes_type(dtype):
+ """
+ Convert a dtype into a ctypes type.
+
+ Parameters
+ ----------
+ dtype : dtype
+ The dtype to convert
+
+ Returns
+ -------
+ ctypes
+ A ctype scalar, union, array, or struct
+
+ Raises
+ ------
+ NotImplementedError
+ If the conversion is not possible
+
+ Notes
+ -----
+ This function does not losslessly round-trip in either direction.
+
+ ``np.dtype(as_ctypes_type(dt))`` will:
+ - insert padding fields
+ - reorder fields to be sorted by offset
+ - discard field titles
+
+ ``as_ctypes_type(np.dtype(ctype))`` will:
+ - discard the class names of ``Structure``s and ``Union``s
+ - convert single-element ``Union``s into single-element ``Structure``s
+ - insert padding fields
+
+ """
+ return _ctype_from_dtype(_dtype(dtype))
-if ctypes is not None:
- _typecodes = _get_typecodes()
def as_array(obj, shape=None):
"""
@@ -388,6 +518,7 @@ if ctypes is not None:
return array(obj, copy=False)
+
def as_ctypes(obj):
"""Create and return a ctypes object from a numpy array. Actually
anything that exposes the __array_interface__ is accepted."""
@@ -399,7 +530,8 @@ if ctypes is not None:
addr, readonly = ai["data"]
if readonly:
raise TypeError("readonly arrays unsupported")
- tp = _ctype_ndarray(_typecodes[ai["typestr"]], ai["shape"])
- result = tp.from_address(addr)
+
+ dtype = _dtype((ai["typestr"], ai["shape"]))
+ result = as_ctypes_type(dtype).from_address(addr)
result.__keep = obj
return result
diff --git a/numpy/tests/test_ctypeslib.py b/numpy/tests/test_ctypeslib.py
index d389b37a8..521208c36 100644
--- a/numpy/tests/test_ctypeslib.py
+++ b/numpy/tests/test_ctypeslib.py
@@ -273,3 +273,95 @@ class TestAsArray(object):
# check we avoid the segfault
c_arr[0][0][0]
+
+
+@pytest.mark.skipif(ctypes is None,
+ reason="ctypes not available on this python installation")
+class TestAsCtypesType(object):
+ """ Test conversion from dtypes to ctypes types """
+ def test_scalar(self):
+ dt = np.dtype('<u2')
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_equal(ct, ctypes.c_uint16.__ctype_le__)
+
+ dt = np.dtype('>u2')
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_equal(ct, ctypes.c_uint16.__ctype_be__)
+
+ dt = np.dtype('u2')
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_equal(ct, ctypes.c_uint16)
+
+ def test_subarray(self):
+ dt = np.dtype((np.int32, (2, 3)))
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_equal(ct, 2 * (3 * ctypes.c_int32))
+
+ def test_structure(self):
+ dt = np.dtype([
+ ('a', np.uint16),
+ ('b', np.uint32),
+ ])
+
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_(issubclass(ct, ctypes.Structure))
+ assert_equal(ctypes.sizeof(ct), dt.itemsize)
+ assert_equal(ct._fields_, [
+ ('a', ctypes.c_uint16),
+ ('b', ctypes.c_uint32),
+ ])
+
+ def test_structure_aligned(self):
+ dt = np.dtype([
+ ('a', np.uint16),
+ ('b', np.uint32),
+ ], align=True)
+
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_(issubclass(ct, ctypes.Structure))
+ assert_equal(ctypes.sizeof(ct), dt.itemsize)
+ assert_equal(ct._fields_, [
+ ('a', ctypes.c_uint16),
+ ('', ctypes.c_char * 2), # padding
+ ('b', ctypes.c_uint32),
+ ])
+
+ def test_union(self):
+ dt = np.dtype(dict(
+ names=['a', 'b'],
+ offsets=[0, 0],
+ formats=[np.uint16, np.uint32]
+ ))
+
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_(issubclass(ct, ctypes.Union))
+ assert_equal(ctypes.sizeof(ct), dt.itemsize)
+ assert_equal(ct._fields_, [
+ ('a', ctypes.c_uint16),
+ ('b', ctypes.c_uint32),
+ ])
+
+ def test_padded_union(self):
+ dt = np.dtype(dict(
+ names=['a', 'b'],
+ offsets=[0, 0],
+ formats=[np.uint16, np.uint32],
+ itemsize=5,
+ ))
+
+ ct = np.ctypeslib.as_ctypes_type(dt)
+ assert_(issubclass(ct, ctypes.Union))
+ assert_equal(ctypes.sizeof(ct), dt.itemsize)
+ assert_equal(ct._fields_, [
+ ('a', ctypes.c_uint16),
+ ('b', ctypes.c_uint32),
+ ('', ctypes.c_char * 5), # padding
+ ])
+
+ def test_overlapping(self):
+ dt = np.dtype(dict(
+ names=['a', 'b'],
+ offsets=[0, 2],
+ formats=[np.uint32, np.uint32]
+ ))
+ assert_raises(NotImplementedError, np.ctypeslib.as_ctypes_type, dt)