summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/lib/tests/test_utils.py62
-rw-r--r--numpy/lib/utils.py14
2 files changed, 72 insertions, 4 deletions
diff --git a/numpy/lib/tests/test_utils.py b/numpy/lib/tests/test_utils.py
new file mode 100644
index 000000000..fc98a92b6
--- /dev/null
+++ b/numpy/lib/tests/test_utils.py
@@ -0,0 +1,62 @@
+from numpy.testing import *
+set_package_path()
+import numpy as N
+restore_path()
+
+class test_ndpointer(NumpyTestCase):
+ def check_dtype(self):
+ dt = N.intc
+ p = N.ndpointer(dtype=dt)
+ self.assert_(p.from_param(N.array([1], dt)))
+ dt = '<i4'
+ p = N.ndpointer(dtype=dt)
+ self.assert_(p.from_param(N.array([1], dt)))
+ dt = N.dtype('>i4')
+ p = N.ndpointer(dtype=dt)
+ p.from_param(N.array([1], dt))
+ self.assertRaises(TypeError, p.from_param,
+ N.array([1], dt.newbyteorder('swap')))
+ dtnames = ['x', 'y']
+ dtformats = [N.intc, N.float64]
+ dtdescr = {'names' : dtnames, 'formats' : dtformats}
+ dt = N.dtype(dtdescr)
+ p = N.ndpointer(dtype=dt)
+ self.assert_(p.from_param(N.zeros((10,), dt)))
+ samedt = N.dtype(dtdescr)
+ p = N.ndpointer(dtype=samedt)
+ self.assert_(p.from_param(N.zeros((10,), dt)))
+ dt2 = N.dtype(dtdescr, align=True)
+ if dt.itemsize != dt2.itemsize:
+ self.assertRaises(TypeError, p.from_param, N.zeros((10,), dt2))
+ else:
+ self.assert_(p.from_param(N.zeros((10,), dt2)))
+
+ def check_ndim(self):
+ p = N.ndpointer(ndim=0)
+ self.assert_(p.from_param(N.array(1)))
+ self.assertRaises(TypeError, p.from_param, N.array([1]))
+ p = N.ndpointer(ndim=1)
+ self.assertRaises(TypeError, p.from_param, N.array(1))
+ self.assert_(p.from_param(N.array([1])))
+ p = N.ndpointer(ndim=2)
+ self.assert_(p.from_param(N.array([[1]])))
+
+ def check_shape(self):
+ p = N.ndpointer(shape=(1,2))
+ self.assert_(p.from_param(N.array([[1,2]])))
+ self.assertRaises(TypeError, p.from_param, N.array([[1],[2]]))
+ p = N.ndpointer(shape=())
+ self.assert_(p.from_param(N.array(1)))
+
+ def check_flags(self):
+ x = N.array([[1,2,3]], order='F')
+ p = N.ndpointer(flags='FORTRAN')
+ self.assert_(p.from_param(x))
+ p = N.ndpointer(flags='CONTIGUOUS')
+ self.assertRaises(TypeError, p.from_param, x)
+ p = N.ndpointer(flags=x.flags.num)
+ self.assert_(p.from_param(x))
+ self.assertRaises(TypeError, p.from_param, N.array([[1,2,3]]))
+
+if __name__ == "__main__":
+ NumpyTest().run()
diff --git a/numpy/lib/utils.py b/numpy/lib/utils.py
index d988b2d14..eae0ab7b7 100644
--- a/numpy/lib/utils.py
+++ b/numpy/lib/utils.py
@@ -2,7 +2,7 @@ import sys, os
import inspect
import types
from numpy.core.numerictypes import obj2sctype, integer
-from numpy.core.multiarray import dtype as _dtype, _flagdict
+from numpy.core.multiarray import dtype as _dtype, _flagdict, flagsobj
from numpy.core import product, ndarray
__all__ = ['issubclass_', 'get_numpy_include', 'issubsctype',
@@ -101,7 +101,7 @@ class _ndptr(object):
raise TypeError, "array must have %d dimension(s)" % cls._ndim_
if cls._shape_ is not None \
and obj.shape != cls._shape_:
- raise TypeError, "array must have shape %s" % cls._shape_
+ raise TypeError, "array must have shape %s" % str(cls._shape_)
if cls._flags_ is not None \
and ((obj.flags.num & cls._flags_) != cls._flags_):
raise TypeError, "array must have flags %s" % \
@@ -121,9 +121,15 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
flags = flags.split(',')
elif isinstance(flags, (int, integer)):
num = flags
- flags = _flags_fromnum(flags)
+ flags = _flags_fromnum(num)
+ elif isinstance(flags, flagsobj):
+ num = flags.num
+ flags = _flags_fromnum(num)
if num is None:
- flags = [x.strip().upper() for x in flags]
+ try:
+ flags = [x.strip().upper() for x in flags]
+ except:
+ raise TypeError, "invalid flags specification"
num = _num_fromflags(flags)
try:
return _pointer_type_cache[(dtype, ndim, shape, num)]