diff options
author | Pauli Virtanen <pav@iki.fi> | 2010-02-20 18:17:14 +0000 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-02-20 18:17:14 +0000 |
commit | 0f2e7db0da927cc3007e37c88abd03c6be2dd255 (patch) | |
tree | b14967493c2ad2ae8dac72d8c7b26ccf46e18d7a /numpy/lib | |
parent | 0e9a08cd2f5faa2b17dc6d62d4d0014386842628 (diff) | |
download | numpy-0f2e7db0da927cc3007e37c88abd03c6be2dd255.tar.gz |
3K: lib: fix some bytes vs. str issues in _iotools.py and io.py -- mainly genfromtxt
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/_iotools.py | 48 | ||||
-rw-r--r-- | numpy/lib/function_base.py | 14 | ||||
-rw-r--r-- | numpy/lib/io.py | 28 | ||||
-rw-r--r-- | numpy/lib/tests/test__iotools.py | 115 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 26 |
5 files changed, 143 insertions, 88 deletions
diff --git a/numpy/lib/_iotools.py b/numpy/lib/_iotools.py index fc076e46c..a19852ac6 100644 --- a/numpy/lib/_iotools.py +++ b/numpy/lib/_iotools.py @@ -1,10 +1,21 @@ """A collection of functions designed to help I/O with ascii files.""" __docformat__ = "restructuredtext en" +import sys import numpy as np import numpy.core.numeric as nx from __builtin__ import bool, int, long, float, complex, object, unicode, str +from numpy.compat import asbytes, bytes + +if sys.version_info[0] >= 3: + def _bytes_to_complex(s): + return complex(s.decode('ascii')) + def _bytes_to_name(s): + return s.decode('ascii') +else: + _bytes_to_complex = complex + _bytes_to_name = str def _is_string_like(obj): """ @@ -16,6 +27,16 @@ def _is_string_like(obj): return False return True +def _is_bytes_like(obj): + """ + Check whether obj behaves like a bytes object. + """ + try: + obj + asbytes('') + except (TypeError, ValueError): + return False + return True + def _to_filehandle(fname, flag='r', return_opened=False): """ @@ -157,10 +178,12 @@ class LineSplitter: """ return lambda input: [_.strip() for _ in method(input)] # - def __init__(self, delimiter=None, comments='#', autostrip=True): + def __init__(self, delimiter=None, comments=asbytes('#'), autostrip=True): self.comments = comments # Delimiter is a character - if (delimiter is None) or _is_string_like(delimiter): + if isinstance(delimiter, unicode): + delimiter = delimiter.encode('ascii') + if (delimiter is None) or _is_bytes_like(delimiter): delimiter = delimiter or None _handyman = self._delimited_splitter # Delimiter is a list of field widths @@ -180,7 +203,7 @@ class LineSplitter: self._handyman = _handyman # def _delimited_splitter(self, line): - line = line.split(self.comments)[0].strip(" \r\n") + line = line.split(self.comments)[0].strip(asbytes(" \r\n")) if not line: return [] return line.split(self.delimiter) @@ -382,9 +405,9 @@ def str2bool(value): """ value = value.upper() - if value == 'TRUE': + if value == asbytes('TRUE'): return True - elif value == 'FALSE': + elif value == asbytes('FALSE'): return False else: raise ValueError("Invalid boolean") @@ -468,8 +491,8 @@ class StringConverter: _mapper = [(nx.bool_, str2bool, False), (nx.integer, int, -1), (nx.floating, float, nx.nan), - (complex, complex, nx.nan + 0j), - (nx.string_, str, '???')] + (complex, _bytes_to_complex, nx.nan + 0j), + (nx.string_, bytes, asbytes('???'))] (_defaulttype, _defaultfunc, _defaultfill) = zip(*_mapper) # @classmethod @@ -570,11 +593,11 @@ class StringConverter: self.func = lambda x : int(float(x)) # Store the list of strings corresponding to missing values. if missing_values is None: - self.missing_values = set(['']) + self.missing_values = set([asbytes('')]) else: if isinstance(missing_values, basestring): - missing_values = missing_values.split(",") - self.missing_values = set(list(missing_values) + ['']) + missing_values = missing_values.split(asbytes(",")) + self.missing_values = set(list(missing_values) + [asbytes('')]) # self._callingfunction = self._strict_call self.type = ttype @@ -672,7 +695,8 @@ class StringConverter: self._status = _status self.iterupgrade(value) - def update(self, func, default=None, missing_values='', locked=False): + def update(self, func, default=None, missing_values=asbytes(''), + locked=False): """ Set StringConverter attributes directly. @@ -711,7 +735,7 @@ class StringConverter: self.type = self._getsubdtype(tester) # Add the missing values to the existing set if missing_values is not None: - if _is_string_like(missing_values): + if _is_bytes_like(missing_values): self.missing_values.add(missing_values) elif hasattr(missing_values, '__iter__'): for val in missing_values: diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index a1f7102df..17b254d82 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -13,6 +13,7 @@ __all__ = ['select', 'piecewise', 'trim_zeros', import warnings import types +import sys import numpy.core.numeric as _nx from numpy.core import linspace from numpy.core.numeric import ones, zeros, arange, concatenate, array, \ @@ -1596,7 +1597,18 @@ import re def _get_nargs(obj): if not callable(obj): raise TypeError, "Object is not callable." - if hasattr(obj,'func_code'): + if sys.version_info[0] >= 3: + import inspect + spec = inspect.getargspec(obj) + nargs = len(spec.args) + if spec.defaults: + ndefaults = len(spec.defaults) + else: + ndefaults = 0 + if inspect.ismethod(obj): + nargs -= 1 + return nargs, ndefaults + elif hasattr(obj,'func_code'): fcode = obj.func_code nargs = fcode.co_argcount if obj.func_defaults is not None: diff --git a/numpy/lib/io.py b/numpy/lib/io.py index 83fd4d809..d8861907d 100644 --- a/numpy/lib/io.py +++ b/numpy/lib/io.py @@ -21,7 +21,7 @@ from _compiled_base import packbits, unpackbits from _iotools import LineSplitter, NameValidator, StringConverter, \ ConverterError, ConverterLockError, ConversionWarning, \ _is_string_like, has_nested_fields, flatten_dtype, \ - easy_dtype + easy_dtype, _bytes_to_name from numpy.compat import asbytes @@ -478,8 +478,8 @@ def _getconv(dtype): -def loadtxt(fname, dtype=float, comments='#', delimiter=None, converters=None, - skiprows=0, usecols=None, unpack=False): +def loadtxt(fname, dtype=float, comments=asbytes('#'), delimiter=None, + converters=None, skiprows=0, usecols=None, unpack=False): """ Load data from a text file. @@ -613,7 +613,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, converters=None, first_vals = None while not first_vals: first_line = fh.readline() - if first_line == '': # EOF reached + if not first_line: # EOF reached raise IOError('End-of-file reached before encountering data.') first_vals = split_line(first_line) N = len(usecols or first_vals) @@ -891,9 +891,9 @@ def fromregex(file, regexp, dtype): -def genfromtxt(fname, dtype=float, comments='#', delimiter=None, +def genfromtxt(fname, dtype=float, comments=asbytes('#'), delimiter=None, skiprows=0, skip_header=0, skip_footer=0, converters=None, - missing='', missing_values=None, filling_values=None, + missing=asbytes(''), missing_values=None, filling_values=None, usecols=None, names=None, excludelist=None, deletechars=None, autostrip=False, case_sensitive=True, defaultfmt="f%i", unpack=None, usemask=False, loose=True, invalid_raise=True): @@ -1065,11 +1065,11 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, first_values = None while not first_values: first_line = fhd.readline() - if first_line == '': + if not first_line: raise IOError('End-of-file reached before encountering data.') if names is True: if comments in first_line: - first_line = ''.join(first_line.split(comments)[1]) + first_line = asbytes('').join(first_line.split(comments)[1:]) first_values = split_line(first_line) # Should we take the first values as names ? if names is True: @@ -1090,8 +1090,9 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, # Check the names and overwrite the dtype.names if needed if names is True: - names = validate_names([_.strip() for _ in first_values]) - first_line = '' + names = validate_names([_bytes_to_name(_.strip()) + for _ in first_values]) + first_line = asbytes('') elif _is_string_like(names): names = validate_names([_.strip() for _ in names.split(',')]) elif names: @@ -1127,7 +1128,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, user_missing_values = missing_values or () # Define the list of missing_values (one column: one list) - missing_values = [list(['']) for _ in range(nbcols)] + missing_values = [list([asbytes('')]) for _ in range(nbcols)] # We have a dictionary: process it field by field if isinstance(user_missing_values, dict): @@ -1176,7 +1177,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, entry.extend([str(user_missing_values)]) # Process the deprecated `missing` - if missing != '': + if missing != asbytes(''): warnings.warn("The use of `missing` is deprecated.\n"\ "Please use `missing_values` instead.", DeprecationWarning) @@ -1451,7 +1452,8 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, names = output.dtype.names if usemask and names: for (name, conv) in zip(names or (), converters): - missing_values = [conv(_) for _ in conv.missing_values if _ != ''] + missing_values = [conv(_) for _ in conv.missing_values + if _ != asbytes('')] for mval in missing_values: outputmask[name] |= (output[name] == mval) # Construct the final array diff --git a/numpy/lib/tests/test__iotools.py b/numpy/lib/tests/test__iotools.py index aa067115e..d105cf835 100644 --- a/numpy/lib/tests/test__iotools.py +++ b/numpy/lib/tests/test__iotools.py @@ -1,71 +1,78 @@ - -import StringIO +import sys +if sys.version_info[0] >= 3: + from io import BytesIO + def StringIO(s=""): + return BytesIO(asbytes(s)) +else: + from StringIO import StringIO import numpy as np from numpy.lib._iotools import LineSplitter, NameValidator, StringConverter,\ has_nested_fields, easy_dtype from numpy.testing import * +from numpy.compat import asbytes, asbytes_nested + class TestLineSplitter(TestCase): "Tests the LineSplitter class." # def test_no_delimiter(self): "Test LineSplitter w/o delimiter" - strg = " 1 2 3 4 5 # test" + strg = asbytes(" 1 2 3 4 5 # test") test = LineSplitter()(strg) - assert_equal(test, ['1', '2', '3', '4', '5']) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '5'])) test = LineSplitter('')(strg) - assert_equal(test, ['1', '2', '3', '4', '5']) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '5'])) def test_space_delimiter(self): "Test space delimiter" - strg = " 1 2 3 4 5 # test" - test = LineSplitter(' ')(strg) - assert_equal(test, ['1', '2', '3', '4', '', '5']) - test = LineSplitter(' ')(strg) - assert_equal(test, ['1 2 3 4', '5']) + strg = asbytes(" 1 2 3 4 5 # test") + test = LineSplitter(asbytes(' '))(strg) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '', '5'])) + test = LineSplitter(asbytes(' '))(strg) + assert_equal(test, asbytes_nested(['1 2 3 4', '5'])) def test_tab_delimiter(self): "Test tab delimiter" - strg= " 1\t 2\t 3\t 4\t 5 6" - test = LineSplitter('\t')(strg) - assert_equal(test, ['1', '2', '3', '4', '5 6']) - strg= " 1 2\t 3 4\t 5 6" - test = LineSplitter('\t')(strg) - assert_equal(test, ['1 2', '3 4', '5 6']) + strg= asbytes(" 1\t 2\t 3\t 4\t 5 6") + test = LineSplitter(asbytes('\t'))(strg) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '5 6'])) + strg= asbytes(" 1 2\t 3 4\t 5 6") + test = LineSplitter(asbytes('\t'))(strg) + assert_equal(test, asbytes_nested(['1 2', '3 4', '5 6'])) def test_other_delimiter(self): "Test LineSplitter on delimiter" - strg = "1,2,3,4,,5" - test = LineSplitter(',')(strg) - assert_equal(test, ['1', '2', '3', '4', '', '5']) + strg = asbytes("1,2,3,4,,5") + test = LineSplitter(asbytes(','))(strg) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '', '5'])) # - strg = " 1,2,3,4,,5 # test" - test = LineSplitter(',')(strg) - assert_equal(test, ['1', '2', '3', '4', '', '5']) + strg = asbytes(" 1,2,3,4,,5 # test") + test = LineSplitter(asbytes(','))(strg) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '', '5'])) def test_constant_fixed_width(self): "Test LineSplitter w/ fixed-width fields" - strg = " 1 2 3 4 5 # test" + strg = asbytes(" 1 2 3 4 5 # test") test = LineSplitter(3)(strg) - assert_equal(test, ['1', '2', '3', '4', '', '5', '']) + assert_equal(test, asbytes_nested(['1', '2', '3', '4', '', '5', ''])) # - strg = " 1 3 4 5 6# test" + strg = asbytes(" 1 3 4 5 6# test") test = LineSplitter(20)(strg) - assert_equal(test, ['1 3 4 5 6']) + assert_equal(test, asbytes_nested(['1 3 4 5 6'])) # - strg = " 1 3 4 5 6# test" + strg = asbytes(" 1 3 4 5 6# test") test = LineSplitter(30)(strg) - assert_equal(test, ['1 3 4 5 6']) + assert_equal(test, asbytes_nested(['1 3 4 5 6'])) def test_variable_fixed_width(self): - strg = " 1 3 4 5 6# test" + strg = asbytes(" 1 3 4 5 6# test") test = LineSplitter((3,6,6,3))(strg) - assert_equal(test, ['1', '3', '4 5', '6']) + assert_equal(test, asbytes_nested(['1', '3', '4 5', '6'])) # - strg = " 1 3 4 5 6# test" + strg = asbytes(" 1 3 4 5 6# test") test = LineSplitter((6,6,9))(strg) - assert_equal(test, ['1', '3 4', '5 6']) + assert_equal(test, asbytes_nested(['1', '3 4', '5 6'])) #------------------------------------------------------------------------------- @@ -136,23 +143,24 @@ class TestStringConverter(TestCase): "Tests the upgrade method." converter = StringConverter() assert_equal(converter._status, 0) - converter.upgrade('0') + converter.upgrade(asbytes('0')) assert_equal(converter._status, 1) - converter.upgrade('0.') + converter.upgrade(asbytes('0.')) assert_equal(converter._status, 2) - converter.upgrade('0j') + converter.upgrade(asbytes('0j')) assert_equal(converter._status, 3) - converter.upgrade('a') + converter.upgrade(asbytes('a')) assert_equal(converter._status, len(converter._mapper)-1) # def test_missing(self): "Tests the use of missing values." - converter = StringConverter(missing_values=('missing','missed')) - converter.upgrade('0') - assert_equal(converter('0'), 0) - assert_equal(converter(''), converter.default) - assert_equal(converter('missing'), converter.default) - assert_equal(converter('missed'), converter.default) + converter = StringConverter(missing_values=(asbytes('missing'), + asbytes('missed'))) + converter.upgrade(asbytes('0')) + assert_equal(converter(asbytes('0')), 0) + assert_equal(converter(asbytes('')), converter.default) + assert_equal(converter(asbytes('missing')), converter.default) + assert_equal(converter(asbytes('missed')), converter.default) try: converter('miss') except ValueError: @@ -162,7 +170,11 @@ class TestStringConverter(TestCase): "Tests updatemapper" from datetime import date import time - dateparser = lambda s : date(*time.strptime(s, "%Y-%m-%d")[:3]) + if sys.version_info[0] >= 3: + dateparser = lambda s : date(*time.strptime(s.decode('latin1'), + "%Y-%m-%d")[:3]) + else: + dateparser = lambda s : date(*time.strptime(s, "%Y-%m-%d")[:3]) StringConverter.upgrade_mapper(dateparser, date(2000,1,1)) convert = StringConverter(dateparser, date(2000, 1, 1)) test = convert('2001-01-01') @@ -182,25 +194,28 @@ class TestStringConverter(TestCase): # def test_keep_default(self): "Make sure we don't lose an explicit default" - converter = StringConverter(None, missing_values='', default=-999) - converter.upgrade('3.14159265') + converter = StringConverter(None, missing_values=asbytes(''), + default=-999) + converter.upgrade(asbytes('3.14159265')) assert_equal(converter.default, -999) assert_equal(converter.type, np.dtype(float)) # - converter = StringConverter(None, missing_values='', default=0) - converter.upgrade('3.14159265') + converter = StringConverter(None, missing_values=asbytes(''), default=0) + converter.upgrade(asbytes('3.14159265')) assert_equal(converter.default, 0) assert_equal(converter.type, np.dtype(float)) # def test_keep_default_zero(self): "Check that we don't lose a default of 0" - converter = StringConverter(int, default=0, missing_values="N/A") + converter = StringConverter(int, default=0, + missing_values=asbytes("N/A")) assert_equal(converter.default, 0) # def test_keep_missing_values(self): "Check that we're not losing missing values" - converter = StringConverter(int, default=0, missing_values="N/A") - assert_equal(converter.missing_values, set(['', 'N/A'])) + converter = StringConverter(int, default=0, + missing_values=asbytes("N/A")) + assert_equal(converter.missing_values, set(asbytes_nested(['', 'N/A']))) #------------------------------------------------------------------------------- diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 61cedd603..dd1bfbad8 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -5,11 +5,6 @@ from numpy.testing import assert_warns import sys -if sys.version_info[0] >= 3: - from io import BytesIO as StringIO -else: - from StringIO import StringIO - import gzip import os import threading @@ -20,7 +15,14 @@ from datetime import datetime from numpy.lib._iotools import ConverterError, ConverterLockError, \ ConversionWarning +from numpy.compat import asbytes +if sys.version_info[0] >= 3: + from io import BytesIO + def StringIO(s=""): + return BytesIO(asbytes(s)) +else: + from StringIO import StringIO MAJVER, MINVER = sys.version_info[:2] @@ -193,7 +195,7 @@ class TestSaveTxt(TestCase): def test_delimiter(self): a = np.array([[1., 2.], [3., 4.]]) c = StringIO() - np.savetxt(c, a, delimiter=',', fmt='%d') + np.savetxt(c, a, delimiter=asbytes(','), fmt='%d') c.seek(0) assert_equal(c.readlines(), ['1,2\n', '3,4\n']) @@ -440,7 +442,7 @@ class TestFromTxt(TestCase): # def test_record(self): "Test w/ explicit dtype" - data = StringIO('1 2\n3 4') + data = StringIO(asbytes('1 2\n3 4')) # data.seek(0) test = np.ndfromtxt(data, dtype=[('x', np.int32), ('y', np.int32)]) control = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')]) @@ -476,7 +478,7 @@ class TestFromTxt(TestCase): assert_array_equal(test, control) # data = StringIO('1,2,3,4\n') - test = np.ndfromtxt(data, dtype=int, delimiter=',') + test = np.ndfromtxt(data, dtype=int, delimiter=asbytes(',')) assert_array_equal(test, control) def test_comments(self): @@ -484,17 +486,17 @@ class TestFromTxt(TestCase): control = np.array([1, 2, 3, 5], int) # Comment on its own line data = StringIO('# comment\n1,2,3,5\n') - test = np.ndfromtxt(data, dtype=int, delimiter=',', comments='#') + test = np.ndfromtxt(data, dtype=int, delimiter=asbytes(','), comments=asbytes('#')) assert_equal(test, control) # Comment at the end of a line data = StringIO('1,2,3,5# comment\n') - test = np.ndfromtxt(data, dtype=int, delimiter=',', comments='#') + test = np.ndfromtxt(data, dtype=int, delimiter=asbytes(','), comments=asbytes('#')) assert_equal(test, control) def test_skiprows(self): "Test row skipping" control = np.array([1, 2, 3, 5], int) - kwargs = dict(dtype=int, delimiter=',') + kwargs = dict(dtype=int, delimiter=asbytes(',')) # data = StringIO('comment\n1,2,3,5\n') test = np.ndfromtxt(data, skip_header=1, **kwargs) @@ -510,7 +512,7 @@ class TestFromTxt(TestCase): data.extend(["%i,%3.1f,%03s" % (i, i, i) for i in range(51)]) data[-1] = "99,99" kwargs = dict(delimiter=",", names=True, skip_header=5, skip_footer=10) - test = np.genfromtxt(StringIO("\n".join(data)), **kwargs) + test = np.genfromtxt(StringIO(asbytes("\n".join(data))), **kwargs) ctrl = np.array([("%f" % i, "%f" % i, "%f" % i) for i in range(40)], dtype=[(_, float) for _ in "ABC"]) assert_equal(test, ctrl) |