diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-08-14 20:13:33 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-08-14 20:13:33 +0000 |
commit | a06ddf3e22ee65fe4009c8f0f304d4b26143e600 (patch) | |
tree | 5425d16165803a5f8fead59e5828730af73b49f6 | |
parent | c70b3c6fe0e073fc70eb8b424c30ca6c5c01ea04 (diff) | |
download | numpy-a06ddf3e22ee65fe4009c8f0f304d4b26143e600.tar.gz |
Fix ndpointer and add tests from ticket #245
-rw-r--r-- | numpy/lib/tests/test_utils.py | 62 | ||||
-rw-r--r-- | numpy/lib/utils.py | 14 |
2 files changed, 72 insertions, 4 deletions
diff --git a/numpy/lib/tests/test_utils.py b/numpy/lib/tests/test_utils.py new file mode 100644 index 000000000..fc98a92b6 --- /dev/null +++ b/numpy/lib/tests/test_utils.py @@ -0,0 +1,62 @@ +from numpy.testing import * +set_package_path() +import numpy as N +restore_path() + +class test_ndpointer(NumpyTestCase): + def check_dtype(self): + dt = N.intc + p = N.ndpointer(dtype=dt) + self.assert_(p.from_param(N.array([1], dt))) + dt = '<i4' + p = N.ndpointer(dtype=dt) + self.assert_(p.from_param(N.array([1], dt))) + dt = N.dtype('>i4') + p = N.ndpointer(dtype=dt) + p.from_param(N.array([1], dt)) + self.assertRaises(TypeError, p.from_param, + N.array([1], dt.newbyteorder('swap'))) + dtnames = ['x', 'y'] + dtformats = [N.intc, N.float64] + dtdescr = {'names' : dtnames, 'formats' : dtformats} + dt = N.dtype(dtdescr) + p = N.ndpointer(dtype=dt) + self.assert_(p.from_param(N.zeros((10,), dt))) + samedt = N.dtype(dtdescr) + p = N.ndpointer(dtype=samedt) + self.assert_(p.from_param(N.zeros((10,), dt))) + dt2 = N.dtype(dtdescr, align=True) + if dt.itemsize != dt2.itemsize: + self.assertRaises(TypeError, p.from_param, N.zeros((10,), dt2)) + else: + self.assert_(p.from_param(N.zeros((10,), dt2))) + + def check_ndim(self): + p = N.ndpointer(ndim=0) + self.assert_(p.from_param(N.array(1))) + self.assertRaises(TypeError, p.from_param, N.array([1])) + p = N.ndpointer(ndim=1) + self.assertRaises(TypeError, p.from_param, N.array(1)) + self.assert_(p.from_param(N.array([1]))) + p = N.ndpointer(ndim=2) + self.assert_(p.from_param(N.array([[1]]))) + + def check_shape(self): + p = N.ndpointer(shape=(1,2)) + self.assert_(p.from_param(N.array([[1,2]]))) + self.assertRaises(TypeError, p.from_param, N.array([[1],[2]])) + p = N.ndpointer(shape=()) + self.assert_(p.from_param(N.array(1))) + + def check_flags(self): + x = N.array([[1,2,3]], order='F') + p = N.ndpointer(flags='FORTRAN') + self.assert_(p.from_param(x)) + p = N.ndpointer(flags='CONTIGUOUS') + self.assertRaises(TypeError, p.from_param, x) + p = N.ndpointer(flags=x.flags.num) + self.assert_(p.from_param(x)) + self.assertRaises(TypeError, p.from_param, N.array([[1,2,3]])) + +if __name__ == "__main__": + NumpyTest().run() diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py index d988b2d14..eae0ab7b7 100644 --- a/numpy/lib/utils.py +++ b/numpy/lib/utils.py @@ -2,7 +2,7 @@ import sys, os import inspect import types from numpy.core.numerictypes import obj2sctype, integer -from numpy.core.multiarray import dtype as _dtype, _flagdict +from numpy.core.multiarray import dtype as _dtype, _flagdict, flagsobj from numpy.core import product, ndarray __all__ = ['issubclass_', 'get_numpy_include', 'issubsctype', @@ -101,7 +101,7 @@ class _ndptr(object): raise TypeError, "array must have %d dimension(s)" % cls._ndim_ if cls._shape_ is not None \ and obj.shape != cls._shape_: - raise TypeError, "array must have shape %s" % cls._shape_ + raise TypeError, "array must have shape %s" % str(cls._shape_) if cls._flags_ is not None \ and ((obj.flags.num & cls._flags_) != cls._flags_): raise TypeError, "array must have flags %s" % \ @@ -121,9 +121,15 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None): flags = flags.split(',') elif isinstance(flags, (int, integer)): num = flags - flags = _flags_fromnum(flags) + flags = _flags_fromnum(num) + elif isinstance(flags, flagsobj): + num = flags.num + flags = _flags_fromnum(num) if num is None: - flags = [x.strip().upper() for x in flags] + try: + flags = [x.strip().upper() for x in flags] + except: + raise TypeError, "invalid flags specification" num = _num_fromflags(flags) try: return _pointer_type_cache[(dtype, ndim, shape, num)] |