diff options
author | pierregm <pierregm@localhost> | 2009-11-10 19:55:19 +0000 |
---|---|---|
committer | pierregm <pierregm@localhost> | 2009-11-10 19:55:19 +0000 |
commit | 5aa2bb4fe50adbdcfbd1777a02c2bdd653319b43 (patch) | |
tree | 72e97f26a1f9716d866009b2e5b1b9c092e724d5 /numpy/core | |
parent | 0d1344ef88835faf34647bb702148767b8e23e76 (diff) | |
download | numpy-5aa2bb4fe50adbdcfbd1777a02c2bdd653319b43.tar.gz |
* fixed rec.fromrecords for an explicit dtype with an object field (bug #1283)
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/records.py | 39 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 96 |
2 files changed, 75 insertions, 60 deletions
diff --git a/numpy/core/records.py b/numpy/core/records.py index 8904674b1..a41a712b8 100644 --- a/numpy/core/records.py +++ b/numpy/core/records.py @@ -73,7 +73,7 @@ def find_duplicate(list): """Find duplication in a list, return a list of duplicated elements""" dup = [] for i in range(len(list)): - if (list[i] in list[i+1:]): + if (list[i] in list[i + 1:]): if (list[i] not in dup): dup.append(list[i]) return dup @@ -196,7 +196,7 @@ class format_parser: titles = [] if (self._nfields > len(titles)): - self._titles += [None]*(self._nfields-len(titles)) + self._titles += [None] * (self._nfields - len(titles)) def _createdescr(self, byteorder): descr = sb.dtype({'names':self._names, @@ -254,7 +254,7 @@ class record(nt.void): if res: return self.setfield(val, *res[:2]) else: - if getattr(self,attr,None): + if getattr(self, attr, None): return nt.void.__setattr__(self, attr, val) else: raise AttributeError, "'record' object has no "\ @@ -265,9 +265,9 @@ class record(nt.void): names = self.dtype.names maxlen = max([len(name) for name in names]) rows = [] - fmt = '%% %ds: %%s' %maxlen + fmt = '%% %ds: %%s' % maxlen for name in names: - rows.append(fmt%(name, getattr(self, name))) + rows.append(fmt % (name, getattr(self, name))) return "\n".join(rows) # The recarray is almost identical to a standard array (which supports @@ -404,7 +404,7 @@ class recarray(ndarray): return object.__getattribute__(self, attr) except AttributeError: # attr must be a fieldname pass - fielddict = ndarray.__getattribute__(self,'dtype').fields + fielddict = ndarray.__getattribute__(self, 'dtype').fields try: res = fielddict[attr][:2] except (TypeError, KeyError): @@ -428,12 +428,12 @@ class recarray(ndarray): try: ret = object.__setattr__(self, attr, val) except: - fielddict = ndarray.__getattribute__(self,'dtype').fields or {} + fielddict = ndarray.__getattribute__(self, 'dtype').fields or {} if attr not in fielddict: exctype, value = sys.exc_info()[:2] raise exctype, value else: - fielddict = ndarray.__getattribute__(self,'dtype').fields or {} + fielddict = ndarray.__getattribute__(self, 'dtype').fields or {} if attr not in fielddict: return ret if newattr: # We just added this one @@ -444,7 +444,7 @@ class recarray(ndarray): return ret try: res = fielddict[attr][:2] - except (TypeError,KeyError): + except (TypeError, KeyError): raise AttributeError, "record array has no attribute %s" % attr return self.setfield(val, *res) @@ -460,10 +460,10 @@ class recarray(ndarray): def field(self, attr, val=None): if isinstance(attr, int): - names = ndarray.__getattribute__(self,'dtype').names + names = ndarray.__getattribute__(self, 'dtype').names attr = names[attr] - fielddict = ndarray.__getattribute__(self,'dtype').fields + fielddict = ndarray.__getattribute__(self, 'dtype').fields res = fielddict[attr][:2] @@ -550,7 +550,7 @@ def fromarrays(arrayList, dtype=None, shape=None, formats=None, for k, obj in enumerate(arrayList): nn = len(descr[k].shape) - testshape = obj.shape[:len(obj.shape)-nn] + testshape = obj.shape[:len(obj.shape) - nn] if testshape != shape: raise ValueError, "array-shape mismatch in array %d" % k @@ -596,17 +596,17 @@ def fromrecords(recList, dtype=None, shape=None, formats=None, names=None, nfields = len(recList[0]) if formats is None and dtype is None: # slower obj = sb.array(recList, dtype=object) - arrlist = [sb.array(obj[...,i].tolist()) for i in xrange(nfields)] + arrlist = [sb.array(obj[..., i].tolist()) for i in xrange(nfields)] return fromarrays(arrlist, formats=formats, shape=shape, names=names, titles=titles, aligned=aligned, byteorder=byteorder) if dtype is not None: - descr = sb.dtype(dtype) + descr = sb.dtype((record, dtype)) else: descr = format_parser(formats, names, titles, aligned, byteorder)._descr try: - retval = sb.array(recList, dtype = descr) + retval = sb.array(recList, dtype=descr) except TypeError: # list of lists instead of list of tuples if (shape is None or shape == 0): shape = len(recList) @@ -624,7 +624,6 @@ def fromrecords(recList, dtype=None, shape=None, formats=None, names=None, res = retval.view(recarray) - res.dtype = sb.dtype((record, res.dtype)) return res @@ -644,7 +643,7 @@ def fromstring(datastring, dtype=None, shape=None, offset=0, formats=None, itemsize = descr.itemsize if (shape is None or shape == 0 or shape == -1): - shape = (len(datastring)-offset) / itemsize + shape = (len(datastring) - offset) / itemsize _array = recarray(shape, descr, buf=datastring, offset=offset) return _array @@ -703,14 +702,14 @@ def fromfile(fd, dtype=None, shape=None, offset=0, formats=None, itemsize = descr.itemsize shapeprod = sb.array(shape).prod() - shapesize = shapeprod*itemsize + shapesize = shapeprod * itemsize if shapesize < 0: shape = list(shape) shape[ shape.index(-1) ] = size / -shapesize shape = tuple(shape) shapeprod = sb.array(shape).prod() - nbytes = shapeprod*itemsize + nbytes = shapeprod * itemsize if nbytes > size: raise ValueError( @@ -794,7 +793,7 @@ def array(obj, dtype=None, shape=None, offset=0, strides=None, formats=None, obj = sb.array(obj) if dtype is not None and (obj.dtype != dtype): obj = obj.view(dtype) - res = obj.view(recarray) + res = obj.view(recarray) if issubclass(res.dtype.type, nt.void): res.dtype = sb.dtype((record, res.dtype)) return res diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index 85306b1af..6b8f31f8b 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -4,38 +4,38 @@ from numpy.testing import * class TestFromrecords(TestCase): def test_fromrecords(self): - r = np.rec.fromrecords([[456,'dbe',1.2],[2,'de',1.3]], + r = np.rec.fromrecords([[456, 'dbe', 1.2], [2, 'de', 1.3]], names='col1,col2,col3') assert_equal(r[0].item(), (456, 'dbe', 1.2)) def test_method_array(self): - r = np.rec.array('abcdefg'*100,formats='i2,a3,i4',shape=3,byteorder='big') + r = np.rec.array('abcdefg' * 100, formats='i2,a3,i4', shape=3, byteorder='big') assert_equal(r[1].item(), (25444, 'efg', 1633837924)) def test_method_array2(self): - r = np.rec.array([(1,11,'a'),(2,22,'b'),(3,33,'c'),(4,44,'d'),(5,55,'ex'), - (6,66,'f'),(7,77,'g')],formats='u1,f4,a1') + r = np.rec.array([(1, 11, 'a'), (2, 22, 'b'), (3, 33, 'c'), (4, 44, 'd'), (5, 55, 'ex'), + (6, 66, 'f'), (7, 77, 'g')], formats='u1,f4,a1') assert_equal(r[1].item(), (2, 22.0, 'b')) def test_recarray_slices(self): - r = np.rec.array([(1,11,'a'),(2,22,'b'),(3,33,'c'),(4,44,'d'),(5,55,'ex'), - (6,66,'f'),(7,77,'g')],formats='u1,f4,a1') + r = np.rec.array([(1, 11, 'a'), (2, 22, 'b'), (3, 33, 'c'), (4, 44, 'd'), (5, 55, 'ex'), + (6, 66, 'f'), (7, 77, 'g')], formats='u1,f4,a1') assert_equal(r[1::2][1].item(), (4, 44.0, 'd')) def test_recarray_fromarrays(self): - x1 = np.array([1,2,3,4]) - x2 = np.array(['a','dd','xyz','12']) - x3 = np.array([1.1,2,3,4]) - r = np.rec.fromarrays([x1,x2,x3],names='a,b,c') - assert_equal(r[1].item(), (2,'dd',2.0)) + x1 = np.array([1, 2, 3, 4]) + x2 = np.array(['a', 'dd', 'xyz', '12']) + x3 = np.array([1.1, 2, 3, 4]) + r = np.rec.fromarrays([x1, x2, x3], names='a,b,c') + assert_equal(r[1].item(), (2, 'dd', 2.0)) x1[1] = 34 - assert_equal(r.a, np.array([1,2,3,4])) + assert_equal(r.a, np.array([1, 2, 3, 4])) def test_recarray_fromfile(self): - data_dir = path.join(path.dirname(__file__),'data') - filename = path.join(data_dir,'recarray_from_file.fits') + data_dir = path.join(path.dirname(__file__), 'data') + filename = path.join(data_dir, 'recarray_from_file.fits') fd = open(filename) - fd.seek(2880*2) + fd.seek(2880 * 2) r = np.rec.fromfile(fd, formats='f8,i4,a5', shape=3, byteorder='big') def test_recarray_from_obj(self): @@ -44,16 +44,16 @@ class TestFromrecords(TestCase): b = np.zeros(count, dtype='f8') c = np.zeros(count, dtype='f8') for i in range(len(a)): - a[i] = range(1,10) + a[i] = range(1, 10) - mine = np.rec.fromarrays([a,b,c], names='date,data1,data2') + mine = np.rec.fromarrays([a, b, c], names='date,data1,data2') for i in range(len(a)): - assert (mine.date[i] == range(1,10)) + assert (mine.date[i] == range(1, 10)) assert (mine.data1[i] == 0.0) assert (mine.data2[i] == 0.0) def test_recarray_from_repr(self): - x = np.rec.array([ (1, 2)],dtype=[('a', np.int8), ('b', np.int8)]) + x = np.rec.array([ (1, 2)], dtype=[('a', np.int8), ('b', np.int8)]) y = eval("np." + repr(x)) assert isinstance(y, np.recarray) assert_equal(y, x) @@ -75,58 +75,74 @@ class TestFromrecords(TestCase): assert ra[k].item() == pa[k].item() def test_recarray_conflict_fields(self): - ra = np.rec.array([(1,'abc',2.3),(2,'xyz',4.2), - (3,'wrs',1.3)], + ra = np.rec.array([(1, 'abc', 2.3), (2, 'xyz', 4.2), + (3, 'wrs', 1.3)], names='field, shape, mean') - ra.mean = [1.1,2.2,3.3] - assert_array_almost_equal(ra['mean'], [1.1,2.2,3.3]) + ra.mean = [1.1, 2.2, 3.3] + assert_array_almost_equal(ra['mean'], [1.1, 2.2, 3.3]) assert type(ra.mean) is type(ra.var) - ra.shape = (1,3) - assert ra.shape == (1,3) - ra.shape = ['A','B','C'] - assert_array_equal(ra['shape'], [['A','B','C']]) + ra.shape = (1, 3) + assert ra.shape == (1, 3) + ra.shape = ['A', 'B', 'C'] + assert_array_equal(ra['shape'], [['A', 'B', 'C']]) ra.field = 5 - assert_array_equal(ra['field'], [[5,5,5]]) + assert_array_equal(ra['field'], [[5, 5, 5]]) assert callable(ra.field) + def test_fromrecords_with_explicit_dtype(self): + a = np.rec.fromrecords([(1, 'a'), (2, 'bbb')], + dtype=[('a', int), ('b', np.object)]) + assert_equal(a.a, [1, 2]) + assert_equal(a[0].a, 1) + assert_equal(a.b, ['a', 'bbb']) + assert_equal(a[-1].b, 'bbb') + # + ndtype = np.dtype([('a', int), ('b', np.object)]) + a = np.rec.fromrecords([(1, 'a'), (2, 'bbb')], dtype=ndtype) + assert_equal(a.a, [1, 2]) + assert_equal(a[0].a, 1) + assert_equal(a.b, ['a', 'bbb']) + assert_equal(a[-1].b, 'bbb') + + class TestRecord(TestCase): def setUp(self): - self.data = np.rec.fromrecords([(1,2,3),(4,5,6)], + self.data = np.rec.fromrecords([(1, 2, 3), (4, 5, 6)], dtype=[("col1", "<i4"), ("col2", "<i4"), ("col3", "<i4")]) def test_assignment1(self): a = self.data - assert_equal(a.col1[0],1) + assert_equal(a.col1[0], 1) a[0].col1 = 0 - assert_equal(a.col1[0],0) + assert_equal(a.col1[0], 0) def test_assignment2(self): a = self.data - assert_equal(a.col1[0],1) + assert_equal(a.col1[0], 1) a.col1[0] = 0 - assert_equal(a.col1[0],0) + assert_equal(a.col1[0], 0) def test_invalid_assignment(self): a = self.data def assign_invalid_column(x): x[0].col5 = 1 - self.failUnlessRaises(AttributeError,assign_invalid_column,a) + self.failUnlessRaises(AttributeError, assign_invalid_column, a) def test_find_duplicate(): - l1 = [1,2,3,4,5,6] + l1 = [1, 2, 3, 4, 5, 6] assert np.rec.find_duplicate(l1) == [] - l2 = [1,2,1,4,5,6] + l2 = [1, 2, 1, 4, 5, 6] assert np.rec.find_duplicate(l2) == [1] - l3 = [1,2,1,4,1,6,2,3] - assert np.rec.find_duplicate(l3) == [1,2] + l3 = [1, 2, 1, 4, 1, 6, 2, 3] + assert np.rec.find_duplicate(l3) == [1, 2] - l3 = [2,2,1,4,1,6,2,3] - assert np.rec.find_duplicate(l3) == [2,1] + l3 = [2, 2, 1, 4, 1, 6, 2, 3] + assert np.rec.find_duplicate(l3) == [2, 1] if __name__ == "__main__": run_module_suite() |