From 2a4dd999c82276d00ef96d0d5839ff8b1f8a8871 Mon Sep 17 00:00:00 2001 From: Shota Kawabuchi Date: Sat, 22 Oct 2016 12:26:46 +0900 Subject: BUG: Fix subarray format changed in #8160 Preserving structured array element format, this commit fixes subarray format changed in PR #8160. This commit also changes iterator for field name from dtype_.descr to dtype_.names (Related to #8174). --- numpy/core/arrayprint.py | 49 ++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) (limited to 'numpy/core/arrayprint.py') diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 7a84eb7c2..1d93a0c0b 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -234,24 +234,7 @@ def _boolFormatter(x): def repr_format(x): return repr(x) -def _get_format_function(data, precision, suppress_small, formatter): - """ - find the right formatting function for the dtype_ - """ - dtype_ = data.dtype - if dtype_.fields is not None: - format_functions = [] - for descr in dtype_.descr: - field_name = descr[0] - field_values = data[field_name] - if len(field_values.shape) <= 1: - format_function = _get_format_function( - field_values, precision, suppress_small, formatter) - else: - format_function = repr_format - format_functions.append(format_function) - return StructureFormat(format_functions) - +def _get_formatdict(data, precision, suppress_small, formatter): formatdict = {'bool': _boolFormatter, 'int': IntegerFormat(data), 'float': FloatFormat(data, precision, suppress_small), @@ -285,7 +268,27 @@ def _get_format_function(data, precision, suppress_small, formatter): if key in fkeys: formatdict[key] = formatter[key] + return formatdict + +def _get_format_function(data, precision, suppress_small, formatter): + """ + find the right formatting function for the dtype_ + """ + dtype_ = data.dtype + if dtype_.fields is not None: + format_functions = [] + for field_name in dtype_.names: + field_values = data[field_name] + is_array_field = 1 < field_values.ndim + format_function = _get_format_function( + ravel(field_values), precision, suppress_small, formatter) + if is_array_field: + format_function = SubArrayFormat(format_function) + format_functions.append(format_function) + return StructureFormat(format_functions) + dtypeobj = dtype_.type + formatdict = _get_formatdict(data, precision, suppress_small, formatter) if issubclass(dtypeobj, _nt.bool_): return formatdict['bool'] elif issubclass(dtypeobj, _nt.integer): @@ -781,6 +784,16 @@ class TimedeltaFormat(object): return self.format % x.astype('i8') +class SubArrayFormat(object): + def __init__(self, format_function): + self.format_function = format_function + + def __call__(self, arr): + if arr.ndim <= 1: + return "[" + ", ".join(self.format_function(a) for a in arr) + "]" + return "[" + ", ".join(self.__call__(a) for a in arr) + "]" + + class StructureFormat(object): def __init__(self, format_functions): self.format_functions = format_functions -- cgit v1.2.1 From a6c2184c0dbc70b9a57fce21e4f769d313021261 Mon Sep 17 00:00:00 2001 From: Shota Kawabuchi Date: Sat, 22 Oct 2016 17:08:28 +0900 Subject: BUG: Fix array2string for structured array scalars PR #8160 added format function for structured arrays. But it is not applied for structured array scalars. Closes #8172 --- numpy/core/arrayprint.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) (limited to 'numpy/core/arrayprint.py') diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index 1d93a0c0b..ce0c6244e 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -316,18 +316,6 @@ def _get_format_function(data, precision, suppress_small, formatter): def _array2string(a, max_line_width, precision, suppress_small, separator=' ', prefix="", formatter=None): - if max_line_width is None: - max_line_width = _line_width - - if precision is None: - precision = _float_output_precision - - if suppress_small is None: - suppress_small = _float_output_suppress_small - - if formatter is None: - formatter = _formatter - if a.size > _summaryThreshold: summary_insert = "..., " data = _leading_trailing(a) @@ -458,11 +446,27 @@ def array2string(a, max_line_width=None, precision=None, """ + if max_line_width is None: + max_line_width = _line_width + + if precision is None: + precision = _float_output_precision + + if suppress_small is None: + suppress_small = _float_output_suppress_small + + if formatter is None: + formatter = _formatter + if a.shape == (): x = a.item() - if isinstance(x, tuple): - x = _convert_arrays(x) - lst = style(x) + if a.dtype.fields is not None: + arr = asarray([x], dtype=a.dtype) + format_function = _get_format_function( + arr, precision, suppress_small, formatter) + lst = format_function(arr[0]) + else: + lst = style(x) elif reduce(product, a.shape) == 0: # treat as a null array if any of shape elements == 0 lst = "[]" @@ -471,6 +475,7 @@ def array2string(a, max_line_width=None, precision=None, separator, prefix, formatter=formatter) return lst + def _extendLine(s, line, word, max_line_len, next_line_prefix): if len(line.rstrip()) + len(word.rstrip()) >= max_line_len: s += line.rstrip() + "\n" -- cgit v1.2.1 From e1326c31526a607cc981b309a3a092b1cbbc9b9c Mon Sep 17 00:00:00 2001 From: Shota Kawabuchi Date: Tue, 1 Nov 2016 23:22:55 +0900 Subject: MAINT: Refactor numpy/core/arrayprint.py Related to PR #8200 --- numpy/core/arrayprint.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) (limited to 'numpy/core/arrayprint.py') diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index ce0c6244e..a9fcfcdaa 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -279,10 +279,9 @@ def _get_format_function(data, precision, suppress_small, formatter): format_functions = [] for field_name in dtype_.names: field_values = data[field_name] - is_array_field = 1 < field_values.ndim format_function = _get_format_function( ravel(field_values), precision, suppress_small, formatter) - if is_array_field: + if dtype_[field_name].shape != (): format_function = SubArrayFormat(format_function) format_functions.append(format_function) return StructureFormat(format_functions) @@ -337,17 +336,6 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', _summaryEdgeItems, summary_insert)[:-1] return lst -def _convert_arrays(obj): - from . import numeric as _nc - newtup = [] - for k in obj: - if isinstance(k, _nc.ndarray): - k = k.tolist() - elif isinstance(k, tuple): - k = _convert_arrays(k) - newtup.append(k) - return tuple(newtup) - def array2string(a, max_line_width=None, precision=None, suppress_small=None, separator=' ', prefix="", @@ -461,7 +449,7 @@ def array2string(a, max_line_width=None, precision=None, if a.shape == (): x = a.item() if a.dtype.fields is not None: - arr = asarray([x], dtype=a.dtype) + arr = array([x], dtype=a.dtype) format_function = _get_format_function( arr, precision, suppress_small, formatter) lst = format_function(arr[0]) @@ -494,10 +482,7 @@ def _formatArray(a, format_function, rank, max_line_len, """ if rank == 0: - obj = a.item() - if isinstance(obj, tuple): - obj = _convert_arrays(obj) - return str(obj) + raise ValueError("rank shouldn't be zero.") if summary_insert and 2*edge_items < len(a): leading_items = edge_items -- cgit v1.2.1