diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2019-04-26 09:21:32 -0700 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2019-04-26 09:35:05 -0700 |
commit | 59a521ee07693cc5c58d68987691df4bcc9e48ff (patch) | |
tree | 6ba8883728a85a9615a862dd510c7e730d895d07 /numpy/lib | |
parent | dea1239b6dcaf072fc9b70e6af0c0a100cead69e (diff) | |
download | numpy-59a521ee07693cc5c58d68987691df4bcc9e48ff.tar.gz |
BUG: (py2 only) fix unicode support for savetxt fmt string
By now, all that is needed is to also allow unicode strings to
pass through. Adds a test for the support which already succeeds
on python3.
Closes gh-4053 (replaces the old PR)
Diffstat (limited to 'numpy/lib')
-rw-r--r-- | numpy/lib/npyio.py | 2 | ||||
-rw-r--r-- | numpy/lib/tests/test_io.py | 13 |
2 files changed, 14 insertions, 1 deletions
diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index beeba1334..3414ecd81 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -1390,7 +1390,7 @@ def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', if len(fmt) != ncol: raise AttributeError('fmt has wrong shape. %s' % str(fmt)) format = asstr(delimiter).join(map(asstr, fmt)) - elif isinstance(fmt, str): + elif isinstance(fmt, basestring): n_fmt_chars = fmt.count('%') error = ValueError('fmt has wrong number of %% formats: %s' % fmt) if n_fmt_chars == 1: diff --git a/numpy/lib/tests/test_io.py b/numpy/lib/tests/test_io.py index 835344429..f26f89f40 100644 --- a/numpy/lib/tests/test_io.py +++ b/numpy/lib/tests/test_io.py @@ -561,6 +561,19 @@ class TestSaveTxt(object): s.seek(0) assert_equal(s.read(), utf8 + '\n') + @pytest.mark.parametrize("fmt", [u"%f", b"%f"]) + @pytest.mark.parametrize("iotype", [StringIO, BytesIO]) + def test_unicode_and_bytes_fmt(self, fmt, iotype): + # string type of fmt should not matter, see also gh-4053 + a = np.array([1.]) + s = iotype() + np.savetxt(s, a, fmt=fmt) + s.seek(0) + if iotype is StringIO: + assert_equal(s.read(), u"%f\n" % 1.) + else: + assert_equal(s.read(), b"%f\n" % 1.) + class LoadTxtBase(object): def check_compressed(self, fopen, suffixes): |