summaryrefslogtreecommitdiff
path: root/numpy/lib/tests/test__iotools.py
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2010-09-13 12:34:37 +0000
committerpierregm <pierregm@localhost>2010-09-13 12:34:37 +0000
commit7213c5d804412b1ab6f23c6419ba836865af517a (patch)
treec8cfae516fffc5eeffbbbebdc4fb929a0483aec7 /numpy/lib/tests/test__iotools.py
parenta70f5718a6199919e5bddf4a7daa032ae80c6779 (diff)
downloadnumpy-7213c5d804412b1ab6f23c6419ba836865af517a.tar.gz
* fixed 'flatten_dtype' to support fields w/ titles (bug #1591). Thx to Stefan vdW for the fix.
* added a unittest for flatten_dtype
Diffstat (limited to 'numpy/lib/tests/test__iotools.py')
-rw-r--r--numpy/lib/tests/test__iotools.py44
1 files changed, 33 insertions, 11 deletions
diff --git a/numpy/lib/tests/test__iotools.py b/numpy/lib/tests/test__iotools.py
index 7c45b3527..544057e3a 100644
--- a/numpy/lib/tests/test__iotools.py
+++ b/numpy/lib/tests/test__iotools.py
@@ -10,8 +10,8 @@ from datetime import date
import time
import numpy as np
-from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter,\
- has_nested_fields, easy_dtype
+from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter, \
+ has_nested_fields, easy_dtype, flatten_dtype
from numpy.testing import *
from numpy.compat import asbytes, asbytes_nested
@@ -37,10 +37,10 @@ class TestLineSplitter(TestCase):
def test_tab_delimiter(self):
"Test tab delimiter"
- strg= asbytes(" 1\t 2\t 3\t 4\t 5 6")
+ strg = asbytes(" 1\t 2\t 3\t 4\t 5 6")
test = LineSplitter(asbytes('\t'))(strg)
assert_equal(test, asbytes_nested(['1', '2', '3', '4', '5 6']))
- strg= asbytes(" 1 2\t 3 4\t 5 6")
+ strg = asbytes(" 1 2\t 3 4\t 5 6")
test = LineSplitter(asbytes('\t'))(strg)
assert_equal(test, asbytes_nested(['1 2', '3 4', '5 6']))
@@ -70,11 +70,11 @@ class TestLineSplitter(TestCase):
def test_variable_fixed_width(self):
strg = asbytes(" 1 3 4 5 6# test")
- test = LineSplitter((3,6,6,3))(strg)
+ test = LineSplitter((3, 6, 6, 3))(strg)
assert_equal(test, asbytes_nested(['1', '3', '4 5', '6']))
#
strg = asbytes(" 1 3 4 5 6# test")
- test = LineSplitter((6,6,9))(strg)
+ test = LineSplitter((6, 6, 9))(strg)
assert_equal(test, asbytes_nested(['1', '3 4', '5 6']))
@@ -97,7 +97,7 @@ class TestNameValidator(TestCase):
def test_excludelist(self):
"Test excludelist"
names = ['dates', 'data', 'Other Data', 'mask']
- validator = NameValidator(excludelist = ['dates', 'data', 'mask'])
+ validator = NameValidator(excludelist=['dates', 'data', 'mask'])
test = validator.validate(names)
assert_equal(test, ['dates_', 'data_', 'Other_Data', 'mask_'])
#
@@ -117,7 +117,7 @@ class TestNameValidator(TestCase):
"Test validate nb names"
namelist = ('a', 'b', 'c')
validator = NameValidator()
- assert_equal(validator(namelist, nbfields=1), ('a', ))
+ assert_equal(validator(namelist, nbfields=1), ('a',))
assert_equal(validator(namelist, nbfields=5, defaultfmt="g%i"),
['a', 'b', 'c', 'g0', 'g1'])
#
@@ -159,7 +159,7 @@ class TestStringConverter(TestCase):
converter.upgrade(asbytes('0j'))
assert_equal(converter._status, 3)
converter.upgrade(asbytes('a'))
- assert_equal(converter._status, len(converter._mapper)-1)
+ assert_equal(converter._status, len(converter._mapper) - 1)
#
def test_missing(self):
"Tests the use of missing values."
@@ -178,7 +178,7 @@ class TestStringConverter(TestCase):
def test_upgrademapper(self):
"Tests updatemapper"
dateparser = _bytes_to_date
- StringConverter.upgrade_mapper(dateparser, date(2000,1,1))
+ StringConverter.upgrade_mapper(dateparser, date(2000, 1, 1))
convert = StringConverter(dateparser, date(2000, 1, 1))
test = convert(asbytes('2001-01-01'))
assert_equal(test, date(2001, 01, 01))
@@ -196,7 +196,7 @@ class TestStringConverter(TestCase):
def test_keep_default(self):
"Make sure we don't lose an explicit default"
converter = StringConverter(None, missing_values=asbytes(''),
- default=-999)
+ default= -999)
converter.upgrade(asbytes('3.14159265'))
assert_equal(converter.default, -999)
assert_equal(converter.type, np.dtype(float))
@@ -287,3 +287,25 @@ class TestMiscFunctions(TestCase):
assert_equal(easy_dtype(ndtype, names=['', '', ''], defaultfmt="f%02i"),
np.dtype([(_, float) for _ in ('f00', 'f01', 'f02')]))
+
+ def test_flatten_dtype(self):
+ "Testing flatten_dtype"
+ # Standard dtype
+ dt = np.dtype([("a", "f8"), ("b", "f8")])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [float, float])
+ # Recursive dtype
+ dt = np.dtype([("a", [("aa", '|S1'), ("ab", '|S2')]), ("b", int)])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [np.dtype('|S1'), np.dtype('|S2'), int])
+ # dtype with shaped fields
+ dt = np.dtype([("a", (float, 2)), ("b", (int, 3))])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [float, int])
+ dt_flat = flatten_dtype(dt, True)
+ assert_equal(dt_flat, [float] * 2 + [int] * 3)
+ # dtype w/ titles
+ dt = np.dtype([(("a", "A"), "f8"), (("b", "B"), "f8")])
+ dt_flat = flatten_dtype(dt)
+ assert_equal(dt_flat, [float, float])
+