summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorpierregm <pierregm@localhost>2009-11-10 19:55:19 +0000
committerpierregm <pierregm@localhost>2009-11-10 19:55:19 +0000
commit5aa2bb4fe50adbdcfbd1777a02c2bdd653319b43 (patch)
tree72e97f26a1f9716d866009b2e5b1b9c092e724d5 /numpy/core
parent0d1344ef88835faf34647bb702148767b8e23e76 (diff)
downloadnumpy-5aa2bb4fe50adbdcfbd1777a02c2bdd653319b43.tar.gz
* fixed rec.fromrecords for an explicit dtype with an object field (bug #1283)
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/records.py39
-rw-r--r--numpy/core/tests/test_records.py96
2 files changed, 75 insertions, 60 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py
index 8904674b1..a41a712b8 100644
--- a/numpy/core/records.py
+++ b/numpy/core/records.py
@@ -73,7 +73,7 @@ def find_duplicate(list):
"""Find duplication in a list, return a list of duplicated elements"""
dup = []
for i in range(len(list)):
- if (list[i] in list[i+1:]):
+ if (list[i] in list[i + 1:]):
if (list[i] not in dup):
dup.append(list[i])
return dup
@@ -196,7 +196,7 @@ class format_parser:
titles = []
if (self._nfields > len(titles)):
- self._titles += [None]*(self._nfields-len(titles))
+ self._titles += [None] * (self._nfields - len(titles))
def _createdescr(self, byteorder):
descr = sb.dtype({'names':self._names,
@@ -254,7 +254,7 @@ class record(nt.void):
if res:
return self.setfield(val, *res[:2])
else:
- if getattr(self,attr,None):
+ if getattr(self, attr, None):
return nt.void.__setattr__(self, attr, val)
else:
raise AttributeError, "'record' object has no "\
@@ -265,9 +265,9 @@ class record(nt.void):
names = self.dtype.names
maxlen = max([len(name) for name in names])
rows = []
- fmt = '%% %ds: %%s' %maxlen
+ fmt = '%% %ds: %%s' % maxlen
for name in names:
- rows.append(fmt%(name, getattr(self, name)))
+ rows.append(fmt % (name, getattr(self, name)))
return "\n".join(rows)
# The recarray is almost identical to a standard array (which supports
@@ -404,7 +404,7 @@ class recarray(ndarray):
return object.__getattribute__(self, attr)
except AttributeError: # attr must be a fieldname
pass
- fielddict = ndarray.__getattribute__(self,'dtype').fields
+ fielddict = ndarray.__getattribute__(self, 'dtype').fields
try:
res = fielddict[attr][:2]
except (TypeError, KeyError):
@@ -428,12 +428,12 @@ class recarray(ndarray):
try:
ret = object.__setattr__(self, attr, val)
except:
- fielddict = ndarray.__getattribute__(self,'dtype').fields or {}
+ fielddict = ndarray.__getattribute__(self, 'dtype').fields or {}
if attr not in fielddict:
exctype, value = sys.exc_info()[:2]
raise exctype, value
else:
- fielddict = ndarray.__getattribute__(self,'dtype').fields or {}
+ fielddict = ndarray.__getattribute__(self, 'dtype').fields or {}
if attr not in fielddict:
return ret
if newattr: # We just added this one
@@ -444,7 +444,7 @@ class recarray(ndarray):
return ret
try:
res = fielddict[attr][:2]
- except (TypeError,KeyError):
+ except (TypeError, KeyError):
raise AttributeError, "record array has no attribute %s" % attr
return self.setfield(val, *res)
@@ -460,10 +460,10 @@ class recarray(ndarray):
def field(self, attr, val=None):
if isinstance(attr, int):
- names = ndarray.__getattribute__(self,'dtype').names
+ names = ndarray.__getattribute__(self, 'dtype').names
attr = names[attr]
- fielddict = ndarray.__getattribute__(self,'dtype').fields
+ fielddict = ndarray.__getattribute__(self, 'dtype').fields
res = fielddict[attr][:2]
@@ -550,7 +550,7 @@ def fromarrays(arrayList, dtype=None, shape=None, formats=None,
for k, obj in enumerate(arrayList):
nn = len(descr[k].shape)
- testshape = obj.shape[:len(obj.shape)-nn]
+ testshape = obj.shape[:len(obj.shape) - nn]
if testshape != shape:
raise ValueError, "array-shape mismatch in array %d" % k
@@ -596,17 +596,17 @@ def fromrecords(recList, dtype=None, shape=None, formats=None, names=None,
nfields = len(recList[0])
if formats is None and dtype is None: # slower
obj = sb.array(recList, dtype=object)
- arrlist = [sb.array(obj[...,i].tolist()) for i in xrange(nfields)]
+ arrlist = [sb.array(obj[..., i].tolist()) for i in xrange(nfields)]
return fromarrays(arrlist, formats=formats, shape=shape, names=names,
titles=titles, aligned=aligned, byteorder=byteorder)
if dtype is not None:
- descr = sb.dtype(dtype)
+ descr = sb.dtype((record, dtype))
else:
descr = format_parser(formats, names, titles, aligned, byteorder)._descr
try:
- retval = sb.array(recList, dtype = descr)
+ retval = sb.array(recList, dtype=descr)
except TypeError: # list of lists instead of list of tuples
if (shape is None or shape == 0):
shape = len(recList)
@@ -624,7 +624,6 @@ def fromrecords(recList, dtype=None, shape=None, formats=None, names=None,
res = retval.view(recarray)
- res.dtype = sb.dtype((record, res.dtype))
return res
@@ -644,7 +643,7 @@ def fromstring(datastring, dtype=None, shape=None, offset=0, formats=None,
itemsize = descr.itemsize
if (shape is None or shape == 0 or shape == -1):
- shape = (len(datastring)-offset) / itemsize
+ shape = (len(datastring) - offset) / itemsize
_array = recarray(shape, descr, buf=datastring, offset=offset)
return _array
@@ -703,14 +702,14 @@ def fromfile(fd, dtype=None, shape=None, offset=0, formats=None,
itemsize = descr.itemsize
shapeprod = sb.array(shape).prod()
- shapesize = shapeprod*itemsize
+ shapesize = shapeprod * itemsize
if shapesize < 0:
shape = list(shape)
shape[ shape.index(-1) ] = size / -shapesize
shape = tuple(shape)
shapeprod = sb.array(shape).prod()
- nbytes = shapeprod*itemsize
+ nbytes = shapeprod * itemsize
if nbytes > size:
raise ValueError(
@@ -794,7 +793,7 @@ def array(obj, dtype=None, shape=None, offset=0, strides=None, formats=None,
obj = sb.array(obj)
if dtype is not None and (obj.dtype != dtype):
obj = obj.view(dtype)
- res = obj.view(recarray)
+ res = obj.view(recarray)
if issubclass(res.dtype.type, nt.void):
res.dtype = sb.dtype((record, res.dtype))
return res
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index 85306b1af..6b8f31f8b 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -4,38 +4,38 @@ from numpy.testing import *
class TestFromrecords(TestCase):
def test_fromrecords(self):
- r = np.rec.fromrecords([[456,'dbe',1.2],[2,'de',1.3]],
+ r = np.rec.fromrecords([[456, 'dbe', 1.2], [2, 'de', 1.3]],
names='col1,col2,col3')
assert_equal(r[0].item(), (456, 'dbe', 1.2))
def test_method_array(self):
- r = np.rec.array('abcdefg'*100,formats='i2,a3,i4',shape=3,byteorder='big')
+ r = np.rec.array('abcdefg' * 100, formats='i2,a3,i4', shape=3, byteorder='big')
assert_equal(r[1].item(), (25444, 'efg', 1633837924))
def test_method_array2(self):
- r = np.rec.array([(1,11,'a'),(2,22,'b'),(3,33,'c'),(4,44,'d'),(5,55,'ex'),
- (6,66,'f'),(7,77,'g')],formats='u1,f4,a1')
+ r = np.rec.array([(1, 11, 'a'), (2, 22, 'b'), (3, 33, 'c'), (4, 44, 'd'), (5, 55, 'ex'),
+ (6, 66, 'f'), (7, 77, 'g')], formats='u1,f4,a1')
assert_equal(r[1].item(), (2, 22.0, 'b'))
def test_recarray_slices(self):
- r = np.rec.array([(1,11,'a'),(2,22,'b'),(3,33,'c'),(4,44,'d'),(5,55,'ex'),
- (6,66,'f'),(7,77,'g')],formats='u1,f4,a1')
+ r = np.rec.array([(1, 11, 'a'), (2, 22, 'b'), (3, 33, 'c'), (4, 44, 'd'), (5, 55, 'ex'),
+ (6, 66, 'f'), (7, 77, 'g')], formats='u1,f4,a1')
assert_equal(r[1::2][1].item(), (4, 44.0, 'd'))
def test_recarray_fromarrays(self):
- x1 = np.array([1,2,3,4])
- x2 = np.array(['a','dd','xyz','12'])
- x3 = np.array([1.1,2,3,4])
- r = np.rec.fromarrays([x1,x2,x3],names='a,b,c')
- assert_equal(r[1].item(), (2,'dd',2.0))
+ x1 = np.array([1, 2, 3, 4])
+ x2 = np.array(['a', 'dd', 'xyz', '12'])
+ x3 = np.array([1.1, 2, 3, 4])
+ r = np.rec.fromarrays([x1, x2, x3], names='a,b,c')
+ assert_equal(r[1].item(), (2, 'dd', 2.0))
x1[1] = 34
- assert_equal(r.a, np.array([1,2,3,4]))
+ assert_equal(r.a, np.array([1, 2, 3, 4]))
def test_recarray_fromfile(self):
- data_dir = path.join(path.dirname(__file__),'data')
- filename = path.join(data_dir,'recarray_from_file.fits')
+ data_dir = path.join(path.dirname(__file__), 'data')
+ filename = path.join(data_dir, 'recarray_from_file.fits')
fd = open(filename)
- fd.seek(2880*2)
+ fd.seek(2880 * 2)
r = np.rec.fromfile(fd, formats='f8,i4,a5', shape=3, byteorder='big')
def test_recarray_from_obj(self):
@@ -44,16 +44,16 @@ class TestFromrecords(TestCase):
b = np.zeros(count, dtype='f8')
c = np.zeros(count, dtype='f8')
for i in range(len(a)):
- a[i] = range(1,10)
+ a[i] = range(1, 10)
- mine = np.rec.fromarrays([a,b,c], names='date,data1,data2')
+ mine = np.rec.fromarrays([a, b, c], names='date,data1,data2')
for i in range(len(a)):
- assert (mine.date[i] == range(1,10))
+ assert (mine.date[i] == range(1, 10))
assert (mine.data1[i] == 0.0)
assert (mine.data2[i] == 0.0)
def test_recarray_from_repr(self):
- x = np.rec.array([ (1, 2)],dtype=[('a', np.int8), ('b', np.int8)])
+ x = np.rec.array([ (1, 2)], dtype=[('a', np.int8), ('b', np.int8)])
y = eval("np." + repr(x))
assert isinstance(y, np.recarray)
assert_equal(y, x)
@@ -75,58 +75,74 @@ class TestFromrecords(TestCase):
assert ra[k].item() == pa[k].item()
def test_recarray_conflict_fields(self):
- ra = np.rec.array([(1,'abc',2.3),(2,'xyz',4.2),
- (3,'wrs',1.3)],
+ ra = np.rec.array([(1, 'abc', 2.3), (2, 'xyz', 4.2),
+ (3, 'wrs', 1.3)],
names='field, shape, mean')
- ra.mean = [1.1,2.2,3.3]
- assert_array_almost_equal(ra['mean'], [1.1,2.2,3.3])
+ ra.mean = [1.1, 2.2, 3.3]
+ assert_array_almost_equal(ra['mean'], [1.1, 2.2, 3.3])
assert type(ra.mean) is type(ra.var)
- ra.shape = (1,3)
- assert ra.shape == (1,3)
- ra.shape = ['A','B','C']
- assert_array_equal(ra['shape'], [['A','B','C']])
+ ra.shape = (1, 3)
+ assert ra.shape == (1, 3)
+ ra.shape = ['A', 'B', 'C']
+ assert_array_equal(ra['shape'], [['A', 'B', 'C']])
ra.field = 5
- assert_array_equal(ra['field'], [[5,5,5]])
+ assert_array_equal(ra['field'], [[5, 5, 5]])
assert callable(ra.field)
+ def test_fromrecords_with_explicit_dtype(self):
+ a = np.rec.fromrecords([(1, 'a'), (2, 'bbb')],
+ dtype=[('a', int), ('b', np.object)])
+ assert_equal(a.a, [1, 2])
+ assert_equal(a[0].a, 1)
+ assert_equal(a.b, ['a', 'bbb'])
+ assert_equal(a[-1].b, 'bbb')
+ #
+ ndtype = np.dtype([('a', int), ('b', np.object)])
+ a = np.rec.fromrecords([(1, 'a'), (2, 'bbb')], dtype=ndtype)
+ assert_equal(a.a, [1, 2])
+ assert_equal(a[0].a, 1)
+ assert_equal(a.b, ['a', 'bbb'])
+ assert_equal(a[-1].b, 'bbb')
+
+
class TestRecord(TestCase):
def setUp(self):
- self.data = np.rec.fromrecords([(1,2,3),(4,5,6)],
+ self.data = np.rec.fromrecords([(1, 2, 3), (4, 5, 6)],
dtype=[("col1", "<i4"),
("col2", "<i4"),
("col3", "<i4")])
def test_assignment1(self):
a = self.data
- assert_equal(a.col1[0],1)
+ assert_equal(a.col1[0], 1)
a[0].col1 = 0
- assert_equal(a.col1[0],0)
+ assert_equal(a.col1[0], 0)
def test_assignment2(self):
a = self.data
- assert_equal(a.col1[0],1)
+ assert_equal(a.col1[0], 1)
a.col1[0] = 0
- assert_equal(a.col1[0],0)
+ assert_equal(a.col1[0], 0)
def test_invalid_assignment(self):
a = self.data
def assign_invalid_column(x):
x[0].col5 = 1
- self.failUnlessRaises(AttributeError,assign_invalid_column,a)
+ self.failUnlessRaises(AttributeError, assign_invalid_column, a)
def test_find_duplicate():
- l1 = [1,2,3,4,5,6]
+ l1 = [1, 2, 3, 4, 5, 6]
assert np.rec.find_duplicate(l1) == []
- l2 = [1,2,1,4,5,6]
+ l2 = [1, 2, 1, 4, 5, 6]
assert np.rec.find_duplicate(l2) == [1]
- l3 = [1,2,1,4,1,6,2,3]
- assert np.rec.find_duplicate(l3) == [1,2]
+ l3 = [1, 2, 1, 4, 1, 6, 2, 3]
+ assert np.rec.find_duplicate(l3) == [1, 2]
- l3 = [2,2,1,4,1,6,2,3]
- assert np.rec.find_duplicate(l3) == [2,1]
+ l3 = [2, 2, 1, 4, 1, 6, 2, 3]
+ assert np.rec.find_duplicate(l3) == [2, 1]
if __name__ == "__main__":
run_module_suite()