diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-02-26 12:48:14 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-26 12:48:14 -0500 |
commit | ad8afe82e7b7643607a348c0e02b45c9131c6a06 (patch) | |
tree | e75ca21f4a8bdda9de4d7465ace4b0f0bc9770de /numpy/core | |
parent | c47198b4102cf3974e7fee2aa78d822877e391c2 (diff) | |
parent | 3461ea868bc7520d4569f7d61397044cf4e593d0 (diff) | |
download | numpy-ad8afe82e7b7643607a348c0e02b45c9131c6a06.tar.gz |
Merge pull request #8663 from eric-wieser/subclass-repr
ENH: Fix alignment of repr for array subclasses
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/numeric.py | 20 | ||||
-rw-r--r-- | numpy/core/tests/test_arrayprint.py | 22 |
2 files changed, 32 insertions, 10 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 896ad7f6a..01dd46c3c 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1890,35 +1890,35 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None): 'array([ 0.000001, 0. , 2. , 3. ])' """ + if type(arr) is not ndarray: + class_name = type(arr).__name__ + else: + class_name = "array" + if arr.size > 0 or arr.shape == (0,): lst = array2string(arr, max_line_width, precision, suppress_small, - ', ', "array(") + ', ', class_name + "(") else: # show zero-length shape unless it is (0,) lst = "[], shape=%s" % (repr(arr.shape),) - if arr.__class__ is not ndarray: - cName = arr.__class__.__name__ - else: - cName = "array" - skipdtype = (arr.dtype.type in _typelessdata) and arr.size > 0 if skipdtype: - return "%s(%s)" % (cName, lst) + return "%s(%s)" % (class_name, lst) else: typename = arr.dtype.name # Quote typename in the output if it is "complex". if typename and not (typename[0].isalpha() and typename.isalnum()): typename = "'%s'" % typename - lf = '' + lf = ' ' if issubclass(arr.dtype.type, flexible): if arr.dtype.names: typename = "%s" % str(arr.dtype) else: typename = "'%s'" % str(arr.dtype) - lf = '\n'+' '*len("array(") - return cName + "(%s, %sdtype=%s)" % (lst, lf, typename) + lf = '\n'+' '*len(class_name + "(") + return "%s(%s,%sdtype=%s)" % (class_name, lst, lf, typename) def array_str(a, max_line_width=None, precision=None, suppress_small=None): diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py index 9aa7b2609..607fa7010 100644 --- a/numpy/core/tests/test_arrayprint.py +++ b/numpy/core/tests/test_arrayprint.py @@ -14,6 +14,28 @@ class TestArrayRepr(object): x = np.array([np.nan, np.inf]) assert_equal(repr(x), 'array([ nan, inf])') + def test_subclass(self): + class sub(np.ndarray): pass + + # one dimensional + x1d = np.array([1, 2]).view(sub) + assert_equal(repr(x1d), 'sub([1, 2])') + + # two dimensional + x2d = np.array([[1, 2], [3, 4]]).view(sub) + assert_equal(repr(x2d), + 'sub([[1, 2],\n' + ' [3, 4]])') + + # two dimensional with flexible dtype + xstruct = np.ones((2,2), dtype=[('a', 'i4')]).view(sub) + assert_equal(repr(xstruct), + "sub([[(1,), (1,)],\n" + " [(1,), (1,)]],\n" + " dtype=[('a', '<i4')])" + ) + + class TestComplexArray(TestCase): def test_str(self): rvals = [0, 1, -1, np.inf, -np.inf, np.nan] |