summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/arrayprint.py99
-rw-r--r--numpy/core/tests/test_arrayprint.py16
2 files changed, 66 insertions, 49 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py
index 7a84eb7c2..a9fcfcdaa 100644
--- a/numpy/core/arrayprint.py
+++ b/numpy/core/arrayprint.py
@@ -234,24 +234,7 @@ def _boolFormatter(x):
def repr_format(x):
return repr(x)
-def _get_format_function(data, precision, suppress_small, formatter):
- """
- find the right formatting function for the dtype_
- """
- dtype_ = data.dtype
- if dtype_.fields is not None:
- format_functions = []
- for descr in dtype_.descr:
- field_name = descr[0]
- field_values = data[field_name]
- if len(field_values.shape) <= 1:
- format_function = _get_format_function(
- field_values, precision, suppress_small, formatter)
- else:
- format_function = repr_format
- format_functions.append(format_function)
- return StructureFormat(format_functions)
-
+def _get_formatdict(data, precision, suppress_small, formatter):
formatdict = {'bool': _boolFormatter,
'int': IntegerFormat(data),
'float': FloatFormat(data, precision, suppress_small),
@@ -285,7 +268,26 @@ def _get_format_function(data, precision, suppress_small, formatter):
if key in fkeys:
formatdict[key] = formatter[key]
+ return formatdict
+
+def _get_format_function(data, precision, suppress_small, formatter):
+ """
+ find the right formatting function for the dtype_
+ """
+ dtype_ = data.dtype
+ if dtype_.fields is not None:
+ format_functions = []
+ for field_name in dtype_.names:
+ field_values = data[field_name]
+ format_function = _get_format_function(
+ ravel(field_values), precision, suppress_small, formatter)
+ if dtype_[field_name].shape != ():
+ format_function = SubArrayFormat(format_function)
+ format_functions.append(format_function)
+ return StructureFormat(format_functions)
+
dtypeobj = dtype_.type
+ formatdict = _get_formatdict(data, precision, suppress_small, formatter)
if issubclass(dtypeobj, _nt.bool_):
return formatdict['bool']
elif issubclass(dtypeobj, _nt.integer):
@@ -313,18 +315,6 @@ def _get_format_function(data, precision, suppress_small, formatter):
def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
prefix="", formatter=None):
- if max_line_width is None:
- max_line_width = _line_width
-
- if precision is None:
- precision = _float_output_precision
-
- if suppress_small is None:
- suppress_small = _float_output_suppress_small
-
- if formatter is None:
- formatter = _formatter
-
if a.size > _summaryThreshold:
summary_insert = "..., "
data = _leading_trailing(a)
@@ -346,17 +336,6 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
_summaryEdgeItems, summary_insert)[:-1]
return lst
-def _convert_arrays(obj):
- from . import numeric as _nc
- newtup = []
- for k in obj:
- if isinstance(k, _nc.ndarray):
- k = k.tolist()
- elif isinstance(k, tuple):
- k = _convert_arrays(k)
- newtup.append(k)
- return tuple(newtup)
-
def array2string(a, max_line_width=None, precision=None,
suppress_small=None, separator=' ', prefix="",
@@ -455,11 +434,27 @@ def array2string(a, max_line_width=None, precision=None,
"""
+ if max_line_width is None:
+ max_line_width = _line_width
+
+ if precision is None:
+ precision = _float_output_precision
+
+ if suppress_small is None:
+ suppress_small = _float_output_suppress_small
+
+ if formatter is None:
+ formatter = _formatter
+
if a.shape == ():
x = a.item()
- if isinstance(x, tuple):
- x = _convert_arrays(x)
- lst = style(x)
+ if a.dtype.fields is not None:
+ arr = array([x], dtype=a.dtype)
+ format_function = _get_format_function(
+ arr, precision, suppress_small, formatter)
+ lst = format_function(arr[0])
+ else:
+ lst = style(x)
elif reduce(product, a.shape) == 0:
# treat as a null array if any of shape elements == 0
lst = "[]"
@@ -468,6 +463,7 @@ def array2string(a, max_line_width=None, precision=None,
separator, prefix, formatter=formatter)
return lst
+
def _extendLine(s, line, word, max_line_len, next_line_prefix):
if len(line.rstrip()) + len(word.rstrip()) >= max_line_len:
s += line.rstrip() + "\n"
@@ -486,10 +482,7 @@ def _formatArray(a, format_function, rank, max_line_len,
"""
if rank == 0:
- obj = a.item()
- if isinstance(obj, tuple):
- obj = _convert_arrays(obj)
- return str(obj)
+ raise ValueError("rank shouldn't be zero.")
if summary_insert and 2*edge_items < len(a):
leading_items = edge_items
@@ -781,6 +774,16 @@ class TimedeltaFormat(object):
return self.format % x.astype('i8')
+class SubArrayFormat(object):
+ def __init__(self, format_function):
+ self.format_function = format_function
+
+ def __call__(self, arr):
+ if arr.ndim <= 1:
+ return "[" + ", ".join(self.format_function(a) for a in arr) + "]"
+ return "[" + ", ".join(self.__call__(a) for a in arr) + "]"
+
+
class StructureFormat(object):
def __init__(self, format_functions):
self.format_functions = format_functions
diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py
index 97b5420ca..6c804a3b7 100644
--- a/numpy/core/tests/test_arrayprint.py
+++ b/numpy/core/tests/test_arrayprint.py
@@ -117,7 +117,7 @@ class TestArray2String(TestCase):
dt = np.dtype([('name', np.str_, 16), ('grades', np.float64, (2,))])
x = np.array([('Sarah', (8.0, 7.0)), ('John', (6.0, 7.0))], dtype=dt)
assert_equal(np.array2string(x),
- "[('Sarah', array([ 8., 7.])) ('John', array([ 6., 7.]))]")
+ "[('Sarah', [ 8., 7.]) ('John', [ 6., 7.])]")
# for issue #5692
A = np.zeros(shape=10, dtype=[("A", "M8[s]")])
@@ -128,6 +128,20 @@ class TestArray2String(TestCase):
"('1970-01-01T00:00:00',) ('NaT',) ('NaT',)\n " +
"('NaT',) ('NaT',) ('NaT',)]")
+ # See #8160
+ struct_int = np.array([([1, -1],), ([123, 1],)], dtype=[('B', 'i4', 2)])
+ assert_equal(np.array2string(struct_int),
+ "[([ 1, -1],) ([123, 1],)]")
+ struct_2dint = np.array([([[0, 1], [2, 3]],), ([[12, 0], [0, 0]],)],
+ dtype=[('B', 'i4', (2, 2))])
+ assert_equal(np.array2string(struct_2dint),
+ "[([[ 0, 1], [ 2, 3]],) ([[12, 0], [ 0, 0]],)]")
+
+ # See #8172
+ array_scalar = np.array(
+ (1., 2.1234567890123456789, 3.), dtype=('f8,f8,f8'))
+ assert_equal(np.array2string(array_scalar), "( 1., 2.12345679, 3.)")
+
class TestPrintOptions:
"""Test getting and setting global print options."""