diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2017-05-17 10:35:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-17 10:35:46 +0100 |
commit | 692655e77b65a9186bda7a701062abd6b62d4ca9 (patch) | |
tree | 77a4c6e1fcb3eea3588c178eeb0fdf98b8ccfab1 /numpy | |
parent | d4eaa2c01801ca2ce46b0c8b345367a54c8dde4b (diff) | |
parent | d2b06fe879f5b2b14de3dad0f517561a0c815df0 (diff) | |
download | numpy-692655e77b65a9186bda7a701062abd6b62d4ca9.tar.gz |
Merge pull request #8983 from ahaldane/fix0d_array2string_s
ENH: str/repr fixed for 0d-arrays
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/arrayprint.py | 67 | ||||
-rw-r--r-- | numpy/core/numeric.py | 2 | ||||
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 28 | ||||
-rw-r--r-- | numpy/core/tests/test_arrayprint.py | 14 |
4 files changed, 72 insertions, 39 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index e54f4602a..795ceec6c 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -15,6 +15,13 @@ __docformat__ = 'restructuredtext' # and by Perry Greenfield 2000-4-1 for numarray # and by Travis Oliphant 2005-8-22 for numpy + +# Note: Both scalartypes.c.src and arrayprint.py implement strs for numpy +# scalars but for different purposes. scalartypes.c.src has str/reprs for when +# the scalar is printed on its own, while arrayprint.py has strs for when +# scalars are printed inside an ndarray. Only the latter strs are currently +# user-customizable. + import sys import functools if sys.version_info[0] >= 3: @@ -28,12 +35,14 @@ else: except ImportError: from dummy_thread import get_ident +import numpy as np from . import numerictypes as _nt from .umath import maximum, minimum, absolute, not_equal, isnan, isinf from .multiarray import (array, format_longfloat, datetime_as_string, datetime_data, dtype) from .fromnumeric import ravel from .numeric import asarray +import warnings if sys.version_info[0] >= 3: _MAXINT = sys.maxsize @@ -399,7 +408,7 @@ def _recursive_guard(fillvalue='...'): @_recursive_guard() def array2string(a, max_line_width=None, precision=None, suppress_small=None, separator=' ', prefix="", - style=repr, formatter=None): + style=np._NoValue, formatter=None): """ Return a string representation of an array. @@ -425,9 +434,10 @@ def array2string(a, max_line_width=None, precision=None, The length of the prefix string is used to align the output correctly. - style : function, optional - A function that accepts an ndarray and returns a string. Used only - when the shape of `a` is equal to ``()``, i.e. for 0-D arrays. + style : _NoValue, optional + Has no effect, do not use. + + .. deprecated:: 1.14.0 formatter : dict of callables, optional If not None, the keys should indicate the type(s) that the respective formatting function applies to. Callables should return a string. @@ -494,6 +504,11 @@ def array2string(a, max_line_width=None, precision=None, """ + # Deprecation 05-16-2017 v1.14 + if style is not np._NoValue: + warnings.warn("'style' argument is deprecated and no longer functional", + DeprecationWarning, stacklevel=3) + if max_line_width is None: max_line_width = _line_width @@ -506,16 +521,7 @@ def array2string(a, max_line_width=None, precision=None, if formatter is None: formatter = _formatter - if a.shape == (): - x = a.item() - if a.dtype.fields is not None: - arr = array([x], dtype=a.dtype) - format_function = _get_format_function( - arr, precision, suppress_small, formatter) - lst = format_function(arr[0]) - else: - lst = style(x) - elif functools.reduce(product, a.shape) == 0: + if a.size == 0: # treat as a null array if any of shape elements == 0 lst = "[]" else: @@ -542,7 +548,7 @@ def _formatArray(a, format_function, rank, max_line_len, """ if rank == 0: - raise ValueError("rank shouldn't be zero.") + return format_function(a[()]) + '\n' if summary_insert and 2*edge_items < len(a): leading_items = edge_items @@ -809,22 +815,21 @@ class DatetimeFormat(object): class TimedeltaFormat(object): def __init__(self, data): - if data.dtype.kind == 'm': - nat_value = array(['NaT'], dtype=data.dtype)[0] - int_dtype = dtype(data.dtype.byteorder + 'i8') - int_view = data.view(int_dtype) - v = int_view[not_equal(int_view, nat_value.view(int_dtype))] - if len(v) > 0: - # Max str length of non-NaT elements - max_str_len = max(len(str(maximum.reduce(v))), - len(str(minimum.reduce(v)))) - else: - max_str_len = 0 - if len(v) < len(data): - # data contains a NaT - max_str_len = max(max_str_len, 5) - self.format = '%' + str(max_str_len) + 'd' - self._nat = "'NaT'".rjust(max_str_len) + nat_value = array(['NaT'], dtype=data.dtype)[0] + int_dtype = dtype(data.dtype.byteorder + 'i8') + int_view = data.view(int_dtype) + v = int_view[not_equal(int_view, nat_value.view(int_dtype))] + if len(v) > 0: + # Max str length of non-NaT elements + max_str_len = max(len(str(maximum.reduce(v))), + len(str(minimum.reduce(v)))) + else: + max_str_len = 0 + if len(v) < len(data): + # data contains a NaT + max_str_len = max(max_str_len, 5) + self.format = '%' + str(max_str_len) + 'd' + self._nat = "'NaT'".rjust(max_str_len) def __call__(self, x): # TODO: After NAT == NAT deprecation should be simplified: diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 6b4a93ce0..1dde02400 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1936,7 +1936,7 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): '[0 1 2]' """ - return array2string(a, max_line_width, precision, suppress_small, ' ', "", str) + return array2string(a, max_line_width, precision, suppress_small, ' ', "") def set_string_function(f, repr=True): diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 02d9f5a31..a7ed4b49d 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -338,7 +338,6 @@ gentype_str(PyObject *self) return ret; } - static PyObject * gentype_repr(PyObject *self) { @@ -353,6 +352,20 @@ gentype_repr(PyObject *self) return ret; } +static PyObject * +genint_type_str(PyObject *self) +{ + PyObject *item, *item_str; + item = gentype_generic_method(self, NULL, NULL, "item"); + if (item == NULL) { + return NULL; + } + + item_str = PyObject_Str(item); + Py_DECREF(item); + return item_str; +} + /* * The __format__ method for PEP 3101. */ @@ -4185,6 +4198,19 @@ initialize_numeric_types(void) /**end repeat**/ + + /**begin repeat + * #Type = Bool, Byte, UByte, Short, UShort, Int, UInt, Long, + * ULong, LongLong, ULongLong# + */ + + /* both str/repr use genint_type_str to avoid trailing "L" of longs */ + Py@Type@ArrType_Type.tp_str = genint_type_str; + Py@Type@ArrType_Type.tp_repr = genint_type_str; + + /**end repeat**/ + + PyHalfArrType_Type.tp_print = halftype_print; PyFloatArrType_Type.tp_print = floattype_print; PyDoubleArrType_Type.tp_print = doubletype_print; diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py index e7ac0cdfd..b80c5d190 100644 --- a/numpy/core/tests/test_arrayprint.py +++ b/numpy/core/tests/test_arrayprint.py @@ -115,12 +115,6 @@ class TestArray2String(TestCase): assert_(np.array2string(a) == '[0 1 2]') assert_(np.array2string(a, max_line_width=4) == '[0 1\n 2]') - def test_style_keyword(self): - """This should only apply to 0-D arrays. See #1218.""" - stylestr = np.array2string(np.array(1.5), - style=lambda x: "Value in 0-D array: " + str(x)) - assert_(stylestr == 'Value in 0-D array: 1.5') - def test_format_function(self): """Test custom format function for each element in array.""" def _format_function(x): @@ -242,6 +236,14 @@ class TestPrintOptions: np.set_printoptions(formatter={'float_kind':None}) assert_equal(repr(x), "array([ 0., 1., 2.])") + def test_0d_arrays(self): + assert_equal(repr(np.datetime64('2005-02-25')[...]), + "array('2005-02-25', dtype='datetime64[D]')") + + x = np.array(1) + np.set_printoptions(formatter={'all':lambda x: "test"}) + assert_equal(repr(x), "array(test)") + def test_unicode_object_array(): import sys if sys.version_info[0] >= 3: |